diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index e2624b268c9..7af0e5a8fc3 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -6895,6 +6895,7 @@ void ngg_visit_emit_vertex_with_counter(isel_context *ctx, nir_intrinsic_instr * } void ngg_gs_clear_primflags(isel_context *ctx, Temp vtx_cnt, unsigned stream); +void ngg_gs_write_shader_query(isel_context *ctx, nir_intrinsic_instr *instr); void ngg_visit_set_vertex_and_primitive_count(isel_context *ctx, nir_intrinsic_instr *instr) { @@ -6908,7 +6909,7 @@ void ngg_visit_set_vertex_and_primitive_count(isel_context *ctx, nir_intrinsic_i ngg_gs_clear_primflags(ctx, vtx_cnt, stream); } - /* TODO: also take the primitive count into use */ + ngg_gs_write_shader_query(ctx, instr); } void visit_emit_vertex_with_counter(isel_context *ctx, nir_intrinsic_instr *instr) @@ -11165,6 +11166,66 @@ void ngg_gs_clear_primflags(isel_context *ctx, Temp vtx_cnt, unsigned stream) end_loop(ctx, &lc); } +void ngg_gs_write_shader_query(isel_context *ctx, nir_intrinsic_instr *instr) +{ + /* Each subgroup uses a single GDS atomic to collect the total number of primitives. + * TODO: Consider using primitive compaction at the end instead. + */ + + unsigned total_vtx_per_prim = gs_outprim_vertices(ctx->shader->info.gs.output_primitive); + if_context ic_shader_query; + Builder bld(ctx->program, ctx->block); + + Temp shader_query = bld.sopc(aco_opcode::s_bitcmp1_b32, bld.def(s1, scc), get_arg(ctx, ctx->args->ngg_gs_state), Operand(0u)); + begin_uniform_if_then(ctx, &ic_shader_query, shader_query); + bld.reset(ctx->block); + + Temp gs_vtx_cnt = get_ssa_temp(ctx, instr->src[0].ssa); + Temp gs_prm_cnt = get_ssa_temp(ctx, instr->src[1].ssa); + Temp sg_prm_cnt; + + /* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives. + * GS emits points, line strips or triangle strips. + * Real primitives are points, lines or triangles. + */ + if (nir_src_is_const(instr->src[0]) && nir_src_is_const(instr->src[1])) { + unsigned gs_vtx_cnt = nir_src_as_uint(instr->src[0]); + unsigned gs_prm_cnt = nir_src_as_uint(instr->src[1]); + Temp prm_cnt = bld.copy(bld.def(s1), Operand(gs_vtx_cnt - gs_prm_cnt * (total_vtx_per_prim - 1u))); + Temp thread_cnt = bld.sop1(Builder::s_bcnt1_i32, bld.def(s1), bld.def(s1, scc), Operand(exec, bld.lm)); + sg_prm_cnt = bld.sop2(aco_opcode::s_mul_i32, bld.def(s1), prm_cnt, thread_cnt); + } else { + Temp prm_cnt = gs_vtx_cnt; + if (total_vtx_per_prim > 1) + prm_cnt = bld.vop3(aco_opcode::v_mad_i32_i24, bld.def(v1), gs_prm_cnt, Operand(-1u * (total_vtx_per_prim - 1)), gs_vtx_cnt); + + /* Reduction calculates the primitive count for the entire subgroup. */ + sg_prm_cnt = bld.tmp(s1); + aco_ptr red_instr + {create_reduction_instr(ctx, aco_opcode::p_reduce, ReduceOp::iadd32, Definition(sg_prm_cnt), prm_cnt)}; + red_instr->cluster_size = ctx->program->wave_size; + bld.insert(std::move(red_instr)); + } + + Temp first_lane = bld.sop1(Builder::s_ff1_i32, bld.def(s1), Operand(exec, bld.lm)); + Temp is_first_lane = bld.sop2(Builder::s_lshl, bld.def(bld.lm), bld.def(s1, scc), + Operand(1u, ctx->program->wave_size == 64), first_lane); + + if_context ic_last_lane; + begin_divergent_if_then(ctx, &ic_last_lane, is_first_lane); + bld.reset(ctx->block); + + Temp gds_addr = bld.copy(bld.def(v1), Operand(0u)); + Operand m = bld.m0((Temp)bld.sopk(aco_opcode::s_movk_i32, bld.def(s1, m0), 0x100)); + bld.ds(aco_opcode::ds_add_u32, gds_addr, as_vgpr(ctx, sg_prm_cnt), m, 0u, 0u, true); + + begin_divergent_if_else(ctx, &ic_last_lane); + end_divergent_if(ctx, &ic_last_lane); + + begin_uniform_if_else(ctx, &ic_shader_query); + end_uniform_if(ctx, &ic_shader_query); +} + Temp ngg_gs_load_prim_flag_0(isel_context *ctx, Temp tid_in_tg, Temp max_vtxcnt, Temp vertex_lds_addr) { if_context ic;