ac/nir: Fix workgroup ID in mesh shader waves other than the first.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15199>
This commit is contained in:
Timur Kristóf 2022-02-27 18:40:36 +01:00 committed by Marge Bot
parent 57775dd76a
commit 3a3bd9cff1
1 changed files with 87 additions and 2 deletions

View File

@ -110,6 +110,8 @@ typedef struct
unsigned api_workgroup_size;
unsigned hw_workgroup_size;
nir_ssa_def *workgroup_index;
struct {
/* Bitmask of components used: 4 bits per slot, 1 bit per component. */
uint32_t components_mask;
@ -2266,6 +2268,15 @@ lower_ms_load_per_primitive_output(nir_builder *b,
return ms_load_arrayed_output_intrin(b, intrin, s->num_per_primitive_outputs, s->prim_attr_lds_addr);
}
static nir_ssa_def *
lower_ms_load_workgroup_id(nir_builder *b,
UNUSED nir_intrinsic_instr *intrin,
lower_ngg_ms_state *s)
{
/* NV_mesh_shader: workgroup ID is 1 dimensional */
return nir_vec3(b, s->workgroup_index, nir_imm_int(b, 0), nir_imm_int(b, 0));
}
static nir_ssa_def *
update_ms_scoped_barrier(nir_builder *b,
nir_intrinsic_instr *intrin,
@ -2307,6 +2318,8 @@ lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
return lower_ms_store_per_primitive_output(b, intrin, s);
else if (intrin->intrinsic == nir_intrinsic_load_per_primitive_output)
return lower_ms_load_per_primitive_output(b, intrin, s);
else if (intrin->intrinsic == nir_intrinsic_load_workgroup_id)
return lower_ms_load_workgroup_id(b, intrin, s);
else if (intrin->intrinsic == nir_intrinsic_scoped_barrier)
return update_ms_scoped_barrier(b, intrin, s);
else
@ -2327,7 +2340,8 @@ filter_ms_intrinsic(const nir_instr *instr,
intrin->intrinsic == nir_intrinsic_load_per_vertex_output ||
intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
intrin->intrinsic == nir_intrinsic_load_per_primitive_output ||
intrin->intrinsic == nir_intrinsic_scoped_barrier;
intrin->intrinsic == nir_intrinsic_scoped_barrier ||
intrin->intrinsic == nir_intrinsic_load_workgroup_id;
}
static void
@ -2368,6 +2382,76 @@ ms_emit_arrayed_outputs(nir_builder *b,
}
}
static void
emit_ms_prelude(nir_builder *b, lower_ngg_ms_state *s)
{
bool uses_workgroup_id =
BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID);
if (!uses_workgroup_id)
return;
b->cursor = nir_before_cf_list(&b->impl->body);
/* The HW doesn't support a proper workgroup index for vertex processing stages,
* so we use the vertex ID which is equivalent to the index of the current workgroup
* within the current dispatch.
*
* Due to the register programming of mesh shaders, this value is only filled for
* the first invocation of the first wave. To let other waves know, we use LDS.
*
* NV_mesh_shader: firstTask is emulated using first_vertex here.
*/
nir_ssa_def *workgroup_index = nir_load_vertex_id_zero_base(b);
if (s->api_workgroup_size <= s->wave_size) {
/* API workgroup is small, so we don't need to use LDS. */
workgroup_index = nir_read_first_invocation(b, workgroup_index);
workgroup_index = nir_iadd(b, workgroup_index, nir_load_first_vertex(b));
s->workgroup_index = workgroup_index;
return;
}
unsigned workgroup_index_lds_addr = s->numprims_lds_addr + 8;
nir_ssa_def *zero = nir_imm_int(b, 0);
nir_ssa_def *dont_care = nir_ssa_undef(b, 1, 32);
nir_ssa_def *loaded_workgroup_index = NULL;
/* Use elect to make sure only 1 invocation uses LDS. */
nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
{
nir_ssa_def *wave_id = nir_load_subgroup_id(b);
nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0));
{
nir_store_shared(b, workgroup_index, zero, .base = workgroup_index_lds_addr);
nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
.memory_scope = NIR_SCOPE_WORKGROUP,
.memory_semantics = NIR_MEMORY_ACQ_REL,
.memory_modes = nir_var_mem_shared);
}
nir_push_else(b, if_wave_0);
{
nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
.memory_scope = NIR_SCOPE_WORKGROUP,
.memory_semantics = NIR_MEMORY_ACQ_REL,
.memory_modes = nir_var_mem_shared);
loaded_workgroup_index = nir_load_shared(b, 1, 32, zero, .base = workgroup_index_lds_addr);
}
nir_pop_if(b, if_wave_0);
workgroup_index = nir_if_phi(b, workgroup_index, loaded_workgroup_index);
}
nir_pop_if(b, if_elected);
workgroup_index = nir_if_phi(b, workgroup_index, dont_care);
workgroup_index = nir_read_first_invocation(b, workgroup_index);
workgroup_index = nir_iadd(b, workgroup_index, nir_load_first_vertex(b));
s->workgroup_index = workgroup_index;
}
static void
emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
{
@ -2623,7 +2707,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
/* LDS area for total number of output primitives and other info.
* DW0: number of primitives
* DW1: reserved for later use
* DW2: reserved for later use
* DW2: workgroup index within the current dispatch
* DW3: number of API workgroups in flight
*/
unsigned numprims_lds_addr = ALIGN(shader->info.shared_size, 16);
@ -2678,6 +2762,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
b->cursor = nir_before_cf_list(&impl->body);
handle_smaller_ms_api_workgroup(b, &state);
emit_ms_prelude(b, &state);
lower_ms_intrinsics(shader, &state);
emit_ms_finale(b, &state);