From dd73719856c0e571d2d0863609e2175a1f0f8de6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timur=20Krist=C3=B3f?= Date: Thu, 1 Oct 2020 13:50:43 +0200 Subject: [PATCH] aco/ngg: Add shader query support to NGG GS. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In each GS thread, we calculate the number of "real" primitives that were emitted (points, lines, triangles, not strips). Then we accumulate the number of "real" primitives emitted by the entire threadgroup in GDS. Signed-off-by: Timur Kristóf Reviewed-by: Rhys Perry Part-of: --- .../compiler/aco_instruction_selection.cpp | 63 ++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index e2624b268c9..7af0e5a8fc3 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -6895,6 +6895,7 @@ void ngg_visit_emit_vertex_with_counter(isel_context *ctx, nir_intrinsic_instr * } void ngg_gs_clear_primflags(isel_context *ctx, Temp vtx_cnt, unsigned stream); +void ngg_gs_write_shader_query(isel_context *ctx, nir_intrinsic_instr *instr); void ngg_visit_set_vertex_and_primitive_count(isel_context *ctx, nir_intrinsic_instr *instr) { @@ -6908,7 +6909,7 @@ void ngg_visit_set_vertex_and_primitive_count(isel_context *ctx, nir_intrinsic_i ngg_gs_clear_primflags(ctx, vtx_cnt, stream); } - /* TODO: also take the primitive count into use */ + ngg_gs_write_shader_query(ctx, instr); } void visit_emit_vertex_with_counter(isel_context *ctx, nir_intrinsic_instr *instr) @@ -11165,6 +11166,66 @@ void ngg_gs_clear_primflags(isel_context *ctx, Temp vtx_cnt, unsigned stream) end_loop(ctx, &lc); } +void ngg_gs_write_shader_query(isel_context *ctx, nir_intrinsic_instr *instr) +{ + /* Each subgroup uses a single GDS atomic to collect the total number of primitives. + * TODO: Consider using primitive compaction at the end instead. + */ + + unsigned total_vtx_per_prim = gs_outprim_vertices(ctx->shader->info.gs.output_primitive); + if_context ic_shader_query; + Builder bld(ctx->program, ctx->block); + + Temp shader_query = bld.sopc(aco_opcode::s_bitcmp1_b32, bld.def(s1, scc), get_arg(ctx, ctx->args->ngg_gs_state), Operand(0u)); + begin_uniform_if_then(ctx, &ic_shader_query, shader_query); + bld.reset(ctx->block); + + Temp gs_vtx_cnt = get_ssa_temp(ctx, instr->src[0].ssa); + Temp gs_prm_cnt = get_ssa_temp(ctx, instr->src[1].ssa); + Temp sg_prm_cnt; + + /* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives. + * GS emits points, line strips or triangle strips. + * Real primitives are points, lines or triangles. + */ + if (nir_src_is_const(instr->src[0]) && nir_src_is_const(instr->src[1])) { + unsigned gs_vtx_cnt = nir_src_as_uint(instr->src[0]); + unsigned gs_prm_cnt = nir_src_as_uint(instr->src[1]); + Temp prm_cnt = bld.copy(bld.def(s1), Operand(gs_vtx_cnt - gs_prm_cnt * (total_vtx_per_prim - 1u))); + Temp thread_cnt = bld.sop1(Builder::s_bcnt1_i32, bld.def(s1), bld.def(s1, scc), Operand(exec, bld.lm)); + sg_prm_cnt = bld.sop2(aco_opcode::s_mul_i32, bld.def(s1), prm_cnt, thread_cnt); + } else { + Temp prm_cnt = gs_vtx_cnt; + if (total_vtx_per_prim > 1) + prm_cnt = bld.vop3(aco_opcode::v_mad_i32_i24, bld.def(v1), gs_prm_cnt, Operand(-1u * (total_vtx_per_prim - 1)), gs_vtx_cnt); + + /* Reduction calculates the primitive count for the entire subgroup. */ + sg_prm_cnt = bld.tmp(s1); + aco_ptr red_instr + {create_reduction_instr(ctx, aco_opcode::p_reduce, ReduceOp::iadd32, Definition(sg_prm_cnt), prm_cnt)}; + red_instr->cluster_size = ctx->program->wave_size; + bld.insert(std::move(red_instr)); + } + + Temp first_lane = bld.sop1(Builder::s_ff1_i32, bld.def(s1), Operand(exec, bld.lm)); + Temp is_first_lane = bld.sop2(Builder::s_lshl, bld.def(bld.lm), bld.def(s1, scc), + Operand(1u, ctx->program->wave_size == 64), first_lane); + + if_context ic_last_lane; + begin_divergent_if_then(ctx, &ic_last_lane, is_first_lane); + bld.reset(ctx->block); + + Temp gds_addr = bld.copy(bld.def(v1), Operand(0u)); + Operand m = bld.m0((Temp)bld.sopk(aco_opcode::s_movk_i32, bld.def(s1, m0), 0x100)); + bld.ds(aco_opcode::ds_add_u32, gds_addr, as_vgpr(ctx, sg_prm_cnt), m, 0u, 0u, true); + + begin_divergent_if_else(ctx, &ic_last_lane); + end_divergent_if(ctx, &ic_last_lane); + + begin_uniform_if_else(ctx, &ic_shader_query); + end_uniform_if(ctx, &ic_shader_query); +} + Temp ngg_gs_load_prim_flag_0(isel_context *ctx, Temp tid_in_tg, Temp max_vtxcnt, Temp vertex_lds_addr) { if_context ic;