ac/nir: Add remappability to tess and ESGS I/O lowering passes.

This will be used for radeonsi to map common I/O location to fixed
slots agreed by different shader stages.

Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Acked-by: Pierre-Eric Pelloux-Prayer <pierre-eric.pelloux-prayer@amd.com>
Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16418>
This commit is contained in:
Timur Kristóf 2022-05-12 15:48:24 +02:00 committed by Marge Bot
parent 666dbbf1a3
commit f7f2770e72
6 changed files with 81 additions and 40 deletions

View File

@ -35,6 +35,37 @@ ac_nir_load_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac_
return nir_load_vector_arg_amd(b, num_components, .base = arg.arg_index);
}
/**
* This function takes an I/O intrinsic like load/store_input,
* and emits a sequence that calculates the full offset of that instruction,
* including a stride to the base and component offsets.
*/
nir_ssa_def *
ac_nir_calc_io_offset(nir_builder *b,
nir_intrinsic_instr *intrin,
nir_ssa_def *base_stride,
unsigned component_stride,
ac_nir_map_io_driver_location map_io)
{
unsigned base = nir_intrinsic_base(intrin);
unsigned semantic = nir_intrinsic_io_semantics(intrin).location;
unsigned mapped_driver_location = map_io ? map_io(semantic) : base;
/* base is the driver_location, which is in slots (1 slot = 4x4 bytes) */
nir_ssa_def *base_op = nir_imul_imm(b, base_stride, mapped_driver_location);
/* offset should be interpreted in relation to the base,
* so the instruction effectively reads/writes another input/output
* when it has an offset
*/
nir_ssa_def *offset_op = nir_imul(b, base_stride, nir_ssa_for_src(b, *nir_get_io_offset_src(intrin), 1));
/* component is in bytes */
unsigned const_op = nir_intrinsic_component(intrin) * component_stride;
return nir_iadd_imm_nuw(b, nir_iadd_nuw(b, base_op, offset_op), const_op);
}
bool
ac_nir_lower_indirect_derefs(nir_shader *shader,
enum amd_gfx_level gfx_level)

View File

@ -48,6 +48,9 @@ enum
AC_EXP_PARAM_UNDEFINED = 255, /* deprecated, use AC_EXP_PARAM_DEFAULT_VAL_0000 instead */
};
/* Maps I/O semantics to the actual location used by the lowering pass. */
typedef unsigned (*ac_nir_map_io_driver_location)(unsigned semantic);
/* Forward declaration of nir_builder so we don't have to include nir_builder.h here */
struct nir_builder;
typedef struct nir_builder nir_builder;
@ -55,21 +58,31 @@ typedef struct nir_builder nir_builder;
nir_ssa_def *
ac_nir_load_arg(nir_builder *b, const struct ac_shader_args *ac_args, struct ac_arg arg);
nir_ssa_def *
ac_nir_calc_io_offset(nir_builder *b,
nir_intrinsic_instr *intrin,
nir_ssa_def *base_stride,
unsigned component_stride,
ac_nir_map_io_driver_location map_io);
bool ac_nir_optimize_outputs(nir_shader *nir, bool sprite_tex_disallowed,
int8_t slot_remap[NUM_TOTAL_VARYING_SLOTS],
uint8_t param_export_index[NUM_TOTAL_VARYING_SLOTS]);
void
ac_nir_lower_ls_outputs_to_mem(nir_shader *ls,
ac_nir_map_io_driver_location map,
bool tcs_in_out_eq,
uint64_t tcs_temp_only_inputs);
void
ac_nir_lower_hs_inputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map,
bool tcs_in_out_eq);
void
ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map,
enum amd_gfx_level gfx_level,
bool tes_reads_tessfactors,
uint64_t tes_inputs_read,
@ -80,16 +93,19 @@ ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
void
ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map,
unsigned num_reserved_tcs_outputs,
unsigned num_reserved_tcs_patch_outputs);
void
ac_nir_lower_es_outputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map,
enum amd_gfx_level gfx_level,
unsigned num_reserved_es_outputs);
void
ac_nir_lower_gs_inputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map,
enum amd_gfx_level gfx_level,
unsigned num_reserved_es_outputs);

View File

@ -44,6 +44,9 @@ typedef struct {
/* Which hardware generation we're dealing with */
enum amd_gfx_level gfx_level;
/* I/O semantic -> real location used by lowering. */
ac_nir_map_io_driver_location map_io;
/* Number of ES outputs for which memory should be reserved.
* When compacted, this should be the number of linked ES outputs.
*/
@ -125,7 +128,7 @@ lower_es_output_store(nir_builder *b,
unsigned write_mask = nir_intrinsic_write_mask(intrin);
b->cursor = nir_before_instr(instr);
nir_ssa_def *io_off = nir_build_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u);
nir_ssa_def *io_off = ac_nir_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u, st->map_io);
if (st->gfx_level <= GFX8) {
/* GFX6-8: ES is a separate HW stage, data is passed from ES to GS in VRAM. */
@ -198,7 +201,7 @@ gs_per_vertex_input_offset(nir_builder *b,
: gs_per_vertex_input_vertex_offset_gfx6(b, vertex_src);
unsigned base_stride = st->gfx_level >= GFX9 ? 1 : 64 /* Wave size on GFX6-8 */;
nir_ssa_def *io_off = nir_build_calc_io_offset(b, instr, nir_imm_int(b, base_stride * 4u), base_stride);
nir_ssa_def *io_off = ac_nir_calc_io_offset(b, instr, nir_imm_int(b, base_stride * 4u), base_stride, st->map_io);
nir_ssa_def *off = nir_iadd(b, io_off, vertex_offset);
return nir_imul_imm(b, off, 4u);
}
@ -230,12 +233,14 @@ filter_load_per_vertex_input(const nir_instr *instr, UNUSED const void *state)
void
ac_nir_lower_es_outputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map,
enum amd_gfx_level gfx_level,
unsigned num_reserved_es_outputs)
{
lower_esgs_io_state state = {
.gfx_level = gfx_level,
.num_reserved_es_outputs = num_reserved_es_outputs,
.map_io = map,
};
nir_shader_instructions_pass(shader,
@ -246,12 +251,14 @@ ac_nir_lower_es_outputs_to_mem(nir_shader *shader,
void
ac_nir_lower_gs_inputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map,
enum amd_gfx_level gfx_level,
unsigned num_reserved_es_outputs)
{
lower_esgs_io_state state = {
.gfx_level = gfx_level,
.num_reserved_es_outputs = num_reserved_es_outputs,
.map_io = map,
};
nir_shader_lower_instructions(shader,

View File

@ -123,6 +123,9 @@ typedef struct {
/* Which hardware generation we're dealing with */
enum amd_gfx_level gfx_level;
/* I/O semantic -> real location used by lowering. */
ac_nir_map_io_driver_location map_io;
/* True if merged VS+TCS (on GFX9+) has the same number
* of input and output patch size.
*/
@ -239,7 +242,7 @@ lower_ls_output_store(nir_builder *b,
nir_ssa_def *vertex_idx = nir_load_local_invocation_index(b);
nir_ssa_def *base_off_var = nir_imul(b, vertex_idx, nir_load_lshs_vertex_stride_amd(b));
nir_ssa_def *io_off = nir_build_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u);
nir_ssa_def *io_off = ac_nir_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u, st->map_io);
unsigned write_mask = nir_intrinsic_write_mask(intrin);
nir_ssa_def *off = nir_iadd_nuw(b, base_off_var, io_off);
@ -299,7 +302,7 @@ hs_per_vertex_input_lds_offset(nir_builder *b,
nir_ssa_def *tcs_in_current_patch_offset = nir_imul(b, rel_patch_id, tcs_in_patch_stride);
nir_ssa_def *io_offset = nir_build_calc_io_offset(b, instr, nir_imm_int(b, 16u), 4u);
nir_ssa_def *io_offset = ac_nir_calc_io_offset(b, instr, nir_imm_int(b, 16u), 4u, st->map_io);
return nir_iadd_nuw(b, nir_iadd_nuw(b, tcs_in_current_patch_offset, vertex_index_off), io_offset);
}
@ -323,7 +326,7 @@ hs_output_lds_offset(nir_builder *b,
nir_ssa_def *output_patch0_offset = nir_imul(b, input_patch_size, tcs_num_patches);
nir_ssa_def *off = intrin
? nir_build_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u)
? ac_nir_calc_io_offset(b, intrin, nir_imm_int(b, 16u), 4u, st->map_io)
: nir_imm_int(b, 0);
nir_ssa_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
@ -353,7 +356,7 @@ hs_per_vertex_output_vmem_offset(nir_builder *b,
nir_ssa_def *tcs_num_patches = nir_load_tcs_num_patches_amd(b);
nir_ssa_def *attr_stride = nir_imul(b, tcs_num_patches, nir_imul_imm(b, out_vertices_per_patch, 16u));
nir_ssa_def *io_offset = nir_build_calc_io_offset(b, intrin, attr_stride, 4u);
nir_ssa_def *io_offset = ac_nir_calc_io_offset(b, intrin, attr_stride, 4u, st->map_io);
nir_ssa_def *rel_patch_id = nir_load_tess_rel_patch_id_amd(b);
nir_ssa_def *patch_offset = nir_imul(b, rel_patch_id, nir_imul_imm(b, out_vertices_per_patch, 16u));
@ -379,7 +382,7 @@ hs_per_patch_output_vmem_offset(nir_builder *b,
nir_ssa_def *per_patch_data_offset = nir_imul(b, tcs_num_patches, per_vertex_output_patch_size);
nir_ssa_def * off = intrin
? nir_build_calc_io_offset(b, intrin, nir_imul_imm(b, tcs_num_patches, 16u), 4u)
? ac_nir_calc_io_offset(b, intrin, nir_imul_imm(b, tcs_num_patches, 16u), 4u, st->map_io)
: nir_imm_int(b, 0);
if (const_base_offset)
@ -650,6 +653,7 @@ filter_any_input_access(const nir_instr *instr,
void
ac_nir_lower_ls_outputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map,
bool tcs_in_out_eq,
uint64_t tcs_temp_only_inputs)
{
@ -658,6 +662,7 @@ ac_nir_lower_ls_outputs_to_mem(nir_shader *shader,
lower_tess_io_state state = {
.tcs_in_out_eq = tcs_in_out_eq,
.tcs_temp_only_inputs = tcs_in_out_eq ? tcs_temp_only_inputs : 0,
.map_io = map,
};
nir_shader_instructions_pass(shader,
@ -668,12 +673,14 @@ ac_nir_lower_ls_outputs_to_mem(nir_shader *shader,
void
ac_nir_lower_hs_inputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map,
bool tcs_in_out_eq)
{
assert(shader->info.stage == MESA_SHADER_TESS_CTRL);
lower_tess_io_state state = {
.tcs_in_out_eq = tcs_in_out_eq,
.map_io = map,
};
nir_shader_lower_instructions(shader,
@ -684,6 +691,7 @@ ac_nir_lower_hs_inputs_to_mem(nir_shader *shader,
void
ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map,
enum amd_gfx_level gfx_level,
bool tes_reads_tessfactors,
uint64_t tes_inputs_read,
@ -702,6 +710,7 @@ ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
.tcs_num_reserved_outputs = num_reserved_tcs_outputs,
.tcs_num_reserved_patch_outputs = num_reserved_tcs_patch_outputs,
.tcs_out_patch_fits_subgroup = 32 % shader->info.tess.tcs_vertices_out == 0,
.map_io = map,
};
nir_shader_lower_instructions(shader,
@ -715,6 +724,7 @@ ac_nir_lower_hs_outputs_to_mem(nir_shader *shader,
void
ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
ac_nir_map_io_driver_location map,
unsigned num_reserved_tcs_outputs,
unsigned num_reserved_tcs_patch_outputs)
{
@ -723,6 +733,7 @@ ac_nir_lower_tes_inputs_to_mem(nir_shader *shader,
lower_tess_io_state state = {
.tcs_num_reserved_outputs = num_reserved_tcs_outputs,
.tcs_num_reserved_patch_outputs = num_reserved_tcs_patch_outputs,
.map_io = map,
};
nir_shader_lower_instructions(shader,

View File

@ -1059,34 +1059,36 @@ radv_lower_io_to_mem(struct radv_device *device, struct radv_pipeline_stage *sta
if (nir->info.stage == MESA_SHADER_VERTEX) {
if (info->vs.as_ls) {
NIR_PASS_V(nir, ac_nir_lower_ls_outputs_to_mem, info->vs.tcs_in_out_eq,
NIR_PASS_V(nir, ac_nir_lower_ls_outputs_to_mem, NULL, info->vs.tcs_in_out_eq,
info->vs.tcs_temp_only_input_mask);
return true;
} else if (info->vs.as_es) {
NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem,
NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem, NULL,
device->physical_device->rad_info.gfx_level, info->vs.num_linked_outputs);
return true;
}
} else if (nir->info.stage == MESA_SHADER_TESS_CTRL) {
NIR_PASS_V(nir, ac_nir_lower_hs_inputs_to_mem, info->vs.tcs_in_out_eq);
NIR_PASS_V(nir, ac_nir_lower_hs_outputs_to_mem, device->physical_device->rad_info.gfx_level,
NIR_PASS_V(nir, ac_nir_lower_hs_inputs_to_mem, NULL, info->vs.tcs_in_out_eq);
NIR_PASS_V(nir, ac_nir_lower_hs_outputs_to_mem, NULL,
device->physical_device->rad_info.gfx_level,
info->tcs.tes_reads_tess_factors, info->tcs.tes_inputs_read,
info->tcs.tes_patch_inputs_read, info->tcs.num_linked_outputs,
info->tcs.num_linked_patch_outputs, true);
return true;
} else if (nir->info.stage == MESA_SHADER_TESS_EVAL) {
NIR_PASS_V(nir, ac_nir_lower_tes_inputs_to_mem, info->tes.num_linked_inputs,
NIR_PASS_V(nir, ac_nir_lower_tes_inputs_to_mem, NULL, info->tes.num_linked_inputs,
info->tes.num_linked_patch_inputs);
if (info->tes.as_es) {
NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem,
NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem, NULL,
device->physical_device->rad_info.gfx_level, info->tes.num_linked_outputs);
}
return true;
} else if (nir->info.stage == MESA_SHADER_GEOMETRY) {
NIR_PASS_V(nir, ac_nir_lower_gs_inputs_to_mem, device->physical_device->rad_info.gfx_level,
NIR_PASS_V(nir, ac_nir_lower_gs_inputs_to_mem, NULL,
device->physical_device->rad_info.gfx_level,
info->gs.num_linked_inputs);
return true;
} else if (nir->info.stage == MESA_SHADER_TASK) {

View File

@ -1543,32 +1543,6 @@ nir_load_param(nir_builder *build, uint32_t param_idx)
return nir_build_load_param(build, param->num_components, param->bit_size, param_idx);
}
/**
* This function takes an I/O intrinsic like load/store_input,
* and emits a sequence that calculates the full offset of that instruction,
* including a stride to the base and component offsets.
*/
static inline nir_ssa_def *
nir_build_calc_io_offset(nir_builder *b,
nir_intrinsic_instr *intrin,
nir_ssa_def *base_stride,
unsigned component_stride)
{
/* base is the driver_location, which is in slots (1 slot = 4x4 bytes) */
nir_ssa_def *base_op = nir_imul_imm(b, base_stride, nir_intrinsic_base(intrin));
/* offset should be interpreted in relation to the base,
* so the instruction effectively reads/writes another input/output
* when it has an offset
*/
nir_ssa_def *offset_op = nir_imul(b, base_stride, nir_ssa_for_src(b, *nir_get_io_offset_src(intrin), 1));
/* component is in bytes */
unsigned const_op = nir_intrinsic_component(intrin) * component_stride;
return nir_iadd_imm_nuw(b, nir_iadd_nuw(b, base_op, offset_op), const_op);
}
/* calculate a `(1 << value) - 1` in ssa without overflows */
static inline nir_ssa_def *
nir_mask(nir_builder *b, nir_ssa_def *bits, unsigned dst_bit_size)