radv: Lower mesh shader 3D workgroup ID to 1D index.

This allows future mesh shaders to use a 3D workgroup ID.
Also changes how the NV_mesh_shader first_task is emulated.
The new code moves the responsibility from ac_nir into radv.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17023>
This commit is contained in:
Timur Kristóf 2022-02-28 14:22:09 +01:00 committed by Marge Bot
parent e05f63f56c
commit b243e94f07
2 changed files with 48 additions and 22 deletions

View File

@ -2393,15 +2393,6 @@ ms_load_arrayed_output_intrin(nir_builder *b,
return regroup_load_val(b, load, bit_size);
}
static nir_ssa_def *
lower_ms_load_workgroup_id(nir_builder *b,
UNUSED nir_intrinsic_instr *intrin,
lower_ngg_ms_state *s)
{
/* NV_mesh_shader: workgroup ID is 1 dimensional */
return nir_vec3(b, s->workgroup_index, nir_imm_int(b, 0), nir_imm_int(b, 0));
}
static nir_ssa_def *
lower_ms_load_workgroup_index(nir_builder *b,
UNUSED nir_intrinsic_instr *intrin,
@ -2451,8 +2442,6 @@ lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
case nir_intrinsic_load_per_vertex_output:
case nir_intrinsic_load_per_primitive_output:
return ms_load_arrayed_output_intrin(b, intrin, s);
case nir_intrinsic_load_workgroup_id:
return lower_ms_load_workgroup_id(b, intrin, s);
case nir_intrinsic_scoped_barrier:
return update_ms_scoped_barrier(b, intrin, s);
case nir_intrinsic_load_workgroup_index:
@ -2477,7 +2466,6 @@ filter_ms_intrinsic(const nir_instr *instr,
intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
intrin->intrinsic == nir_intrinsic_load_per_primitive_output ||
intrin->intrinsic == nir_intrinsic_scoped_barrier ||
intrin->intrinsic == nir_intrinsic_load_workgroup_id ||
intrin->intrinsic == nir_intrinsic_load_workgroup_index;
}
@ -2545,17 +2533,12 @@ emit_ms_prelude(nir_builder *b, lower_ngg_ms_state *s)
*
* Due to the register programming of mesh shaders, this value is only filled for
* the first invocation of the first wave. To let other waves know, we use LDS.
*
* NV_mesh_shader: firstTask is emulated using first_vertex here.
*/
nir_ssa_def *workgroup_index = nir_load_vertex_id_zero_base(b);
if (s->api_workgroup_size <= s->wave_size) {
/* API workgroup is small, so we don't need to use LDS. */
workgroup_index = nir_read_first_invocation(b, workgroup_index);
workgroup_index = nir_iadd(b, workgroup_index, nir_load_first_vertex(b));
s->workgroup_index = workgroup_index;
s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
return;
}
@ -2592,10 +2575,7 @@ emit_ms_prelude(nir_builder *b, lower_ngg_ms_state *s)
nir_pop_if(b, if_elected);
workgroup_index = nir_if_phi(b, workgroup_index, dont_care);
workgroup_index = nir_read_first_invocation(b, workgroup_index);
workgroup_index = nir_iadd(b, workgroup_index, nir_load_first_vertex(b));
s->workgroup_index = workgroup_index;
s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
}
static void

View File

@ -586,6 +586,39 @@ radv_lower_fs_intrinsics(nir_shader *nir, const struct radv_pipeline_stage *fs_s
return progress;
}
/* Emulates NV_mesh_shader first_task using first_vertex. */
static bool
radv_lower_ms_workgroup_id(nir_shader *nir)
{
nir_function_impl *impl = nir_shader_get_entrypoint(nir);
bool progress = false;
nir_builder b;
nir_builder_init(&b, impl);
nir_foreach_block(block, impl) {
nir_foreach_instr_safe(instr, block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
if (intrin->intrinsic != nir_intrinsic_load_workgroup_id)
continue;
progress = true;
b.cursor = nir_after_instr(instr);
nir_ssa_def *x = nir_channel(&b, &intrin->dest.ssa, 0);
nir_ssa_def *x_full = nir_iadd(&b, x, nir_load_first_vertex(&b));
nir_ssa_def *v = nir_vector_insert_imm(&b, &intrin->dest.ssa, x_full, 0);
nir_ssa_def_rewrite_uses_after(&intrin->dest.ssa, v, v->parent_instr);
}
}
nir_metadata preserved =
progress ? (nir_metadata_block_index | nir_metadata_dominance) : nir_metadata_all;
nir_metadata_preserve(impl, preserved);
return progress;
}
nir_shader *
radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_pipeline_stage *stage,
const struct radv_pipeline_key *key)
@ -809,6 +842,19 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_pipeline_
};
NIR_PASS(_, nir, nir_lower_compute_system_values, &csv_options);
if (nir->info.stage == MESA_SHADER_MESH) {
/* NV_mesh_shader: include first_task (aka. first_vertex) in workgroup ID. */
NIR_PASS(_, nir, radv_lower_ms_workgroup_id);
/* Mesh shaders only have a 1D "vertex index" which we use
* as "workgroup index" to emulate the 3D workgroup ID.
*/
nir_lower_compute_system_values_options o = {
.lower_workgroup_id_to_index = true,
};
NIR_PASS(_, nir, nir_lower_compute_system_values, &o);
}
/* Vulkan uses the separate-shader linking model */
nir->info.separate_shader = true;