aco/ngg: Add shader query support to NGG GS.

In each GS thread, we calculate the number of "real" primitives that
were emitted (points, lines, triangles, not strips). Then we
accumulate the number of "real" primitives emitted by the
entire threadgroup in GDS.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6964>
This commit is contained in:
Timur Kristóf 2020-10-01 13:50:43 +02:00
parent df62c8fbea
commit dd73719856
1 changed files with 62 additions and 1 deletions

View File

@ -6895,6 +6895,7 @@ void ngg_visit_emit_vertex_with_counter(isel_context *ctx, nir_intrinsic_instr *
}
void ngg_gs_clear_primflags(isel_context *ctx, Temp vtx_cnt, unsigned stream);
void ngg_gs_write_shader_query(isel_context *ctx, nir_intrinsic_instr *instr);
void ngg_visit_set_vertex_and_primitive_count(isel_context *ctx, nir_intrinsic_instr *instr)
{
@ -6908,7 +6909,7 @@ void ngg_visit_set_vertex_and_primitive_count(isel_context *ctx, nir_intrinsic_i
ngg_gs_clear_primflags(ctx, vtx_cnt, stream);
}
/* TODO: also take the primitive count into use */
ngg_gs_write_shader_query(ctx, instr);
}
void visit_emit_vertex_with_counter(isel_context *ctx, nir_intrinsic_instr *instr)
@ -11165,6 +11166,66 @@ void ngg_gs_clear_primflags(isel_context *ctx, Temp vtx_cnt, unsigned stream)
end_loop(ctx, &lc);
}
void ngg_gs_write_shader_query(isel_context *ctx, nir_intrinsic_instr *instr)
{
/* Each subgroup uses a single GDS atomic to collect the total number of primitives.
* TODO: Consider using primitive compaction at the end instead.
*/
unsigned total_vtx_per_prim = gs_outprim_vertices(ctx->shader->info.gs.output_primitive);
if_context ic_shader_query;
Builder bld(ctx->program, ctx->block);
Temp shader_query = bld.sopc(aco_opcode::s_bitcmp1_b32, bld.def(s1, scc), get_arg(ctx, ctx->args->ngg_gs_state), Operand(0u));
begin_uniform_if_then(ctx, &ic_shader_query, shader_query);
bld.reset(ctx->block);
Temp gs_vtx_cnt = get_ssa_temp(ctx, instr->src[0].ssa);
Temp gs_prm_cnt = get_ssa_temp(ctx, instr->src[1].ssa);
Temp sg_prm_cnt;
/* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives.
* GS emits points, line strips or triangle strips.
* Real primitives are points, lines or triangles.
*/
if (nir_src_is_const(instr->src[0]) && nir_src_is_const(instr->src[1])) {
unsigned gs_vtx_cnt = nir_src_as_uint(instr->src[0]);
unsigned gs_prm_cnt = nir_src_as_uint(instr->src[1]);
Temp prm_cnt = bld.copy(bld.def(s1), Operand(gs_vtx_cnt - gs_prm_cnt * (total_vtx_per_prim - 1u)));
Temp thread_cnt = bld.sop1(Builder::s_bcnt1_i32, bld.def(s1), bld.def(s1, scc), Operand(exec, bld.lm));
sg_prm_cnt = bld.sop2(aco_opcode::s_mul_i32, bld.def(s1), prm_cnt, thread_cnt);
} else {
Temp prm_cnt = gs_vtx_cnt;
if (total_vtx_per_prim > 1)
prm_cnt = bld.vop3(aco_opcode::v_mad_i32_i24, bld.def(v1), gs_prm_cnt, Operand(-1u * (total_vtx_per_prim - 1)), gs_vtx_cnt);
/* Reduction calculates the primitive count for the entire subgroup. */
sg_prm_cnt = bld.tmp(s1);
aco_ptr<Pseudo_reduction_instruction> red_instr
{create_reduction_instr(ctx, aco_opcode::p_reduce, ReduceOp::iadd32, Definition(sg_prm_cnt), prm_cnt)};
red_instr->cluster_size = ctx->program->wave_size;
bld.insert(std::move(red_instr));
}
Temp first_lane = bld.sop1(Builder::s_ff1_i32, bld.def(s1), Operand(exec, bld.lm));
Temp is_first_lane = bld.sop2(Builder::s_lshl, bld.def(bld.lm), bld.def(s1, scc),
Operand(1u, ctx->program->wave_size == 64), first_lane);
if_context ic_last_lane;
begin_divergent_if_then(ctx, &ic_last_lane, is_first_lane);
bld.reset(ctx->block);
Temp gds_addr = bld.copy(bld.def(v1), Operand(0u));
Operand m = bld.m0((Temp)bld.sopk(aco_opcode::s_movk_i32, bld.def(s1, m0), 0x100));
bld.ds(aco_opcode::ds_add_u32, gds_addr, as_vgpr(ctx, sg_prm_cnt), m, 0u, 0u, true);
begin_divergent_if_else(ctx, &ic_last_lane);
end_divergent_if(ctx, &ic_last_lane);
begin_uniform_if_else(ctx, &ic_shader_query);
end_uniform_if(ctx, &ic_shader_query);
}
Temp ngg_gs_load_prim_flag_0(isel_context *ctx, Temp tid_in_tg, Temp max_vtxcnt, Temp vertex_lds_addr)
{
if_context ic;