agx: Emit splits for intrinsics

This allows optimizing the extracts.

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16268>
This commit is contained in:
Alyssa Rosenzweig 2022-04-12 19:46:13 -04:00
parent d06394095b
commit 4f78141c77
2 changed files with 72 additions and 59 deletions

View File

@ -221,8 +221,8 @@ agx_udiv_const(agx_builder *b, agx_index P, uint32_t Q)
}
/* AGX appears to lack support for vertex attributes. Lower to global loads. */
static agx_instr *
agx_emit_load_attr(agx_builder *b, nir_intrinsic_instr *instr)
static void
agx_emit_load_attr(agx_builder *b, agx_index *dests, nir_intrinsic_instr *instr)
{
nir_src *offset_src = nir_get_io_offset_src(instr);
assert(nir_src_is_const(*offset_src) && "no attribute indirects");
@ -259,31 +259,24 @@ agx_emit_load_attr(agx_builder *b, nir_intrinsic_instr *instr)
/* Load the data */
assert(instr->num_components <= 4);
bool pad = ((attrib.nr_comps_minus_1 + 1) < instr->num_components);
agx_index real_dest = agx_dest_index(&instr->dest);
agx_index dest = pad ? agx_temp(b->shader, AGX_SIZE_32) : real_dest;
agx_device_load_to(b, dest, base, offset, attrib.format,
unsigned actual_comps = (attrib.nr_comps_minus_1 + 1);
agx_index vec = agx_vec_for_dest(b->shader, &instr->dest);
agx_device_load_to(b, vec, base, offset, attrib.format,
BITFIELD_MASK(attrib.nr_comps_minus_1 + 1), 0);
agx_wait(b, 0);
if (pad) {
agx_index one = agx_mov_imm(b, 32, fui(1.0));
agx_index zero = agx_mov_imm(b, 32, 0);
agx_index channels[4] = { zero, zero, zero, one };
for (unsigned i = 0; i < (attrib.nr_comps_minus_1 + 1); ++i)
channels[i] = agx_p_extract(b, dest, i);
for (unsigned i = instr->num_components; i < 4; ++i)
channels[i] = agx_null();
agx_p_combine_to(b, real_dest, channels[0], channels[1], channels[2], channels[3]);
}
agx_emit_split(b, dests, vec, actual_comps);
return NULL;
agx_index one = agx_mov_imm(b, 32, fui(1.0));
agx_index zero = agx_mov_imm(b, 32, 0);
agx_index default_value[4] = { zero, zero, zero, one };
for (unsigned i = actual_comps; i < instr->num_components; ++i)
dests[i] = default_value[i];
}
static agx_instr *
agx_emit_load_vary_flat(agx_builder *b, nir_intrinsic_instr *instr)
static void
agx_emit_load_vary_flat(agx_builder *b, agx_index *dests, nir_intrinsic_instr *instr)
{
unsigned components = instr->num_components;
assert(components >= 1 && components <= 4);
@ -293,20 +286,15 @@ agx_emit_load_vary_flat(agx_builder *b, nir_intrinsic_instr *instr)
unsigned imm_index = b->shader->varyings[nir_intrinsic_base(instr)];
imm_index += nir_src_as_uint(*offset);
agx_index chan[4] = { agx_null() };
for (unsigned i = 0; i < components; ++i) {
/* vec3 for each vertex, unknown what first 2 channels are for */
agx_index values = agx_ld_vary_flat(b, agx_immediate(imm_index + i), 1);
chan[i] = agx_p_extract(b, values, 2);
dests[i] = agx_p_extract(b, values, 2);
}
return agx_p_combine_to(b, agx_dest_index(&instr->dest),
chan[0], chan[1], chan[2], chan[3]);
}
static agx_instr *
agx_emit_load_vary(agx_builder *b, nir_intrinsic_instr *instr)
static void
agx_emit_load_vary(agx_builder *b, agx_index *dests, nir_intrinsic_instr *instr)
{
ASSERTED unsigned components = instr->num_components;
ASSERTED nir_intrinsic_instr *parent = nir_src_as_intrinsic(instr->src[0]);
@ -322,8 +310,9 @@ agx_emit_load_vary(agx_builder *b, nir_intrinsic_instr *instr)
unsigned imm_index = b->shader->varyings[nir_intrinsic_base(instr)];
imm_index += nir_src_as_uint(*offset) * 4;
return agx_ld_vary_to(b, agx_dest_index(&instr->dest),
agx_immediate(imm_index), components, true);
agx_index vec = agx_vec_for_intr(b->shader, instr);
agx_ld_vary_to(b, vec, agx_immediate(imm_index), components, true);
agx_emit_split(b, dests, vec, components);
}
static agx_instr *
@ -380,8 +369,8 @@ agx_emit_fragment_out(agx_builder *b, nir_intrinsic_instr *instr)
b->shader->key->fs.tib_formats[rt]);
}
static agx_instr *
agx_emit_load_tile(agx_builder *b, nir_intrinsic_instr *instr)
static void
agx_emit_load_tile(agx_builder *b, agx_index *dests, nir_intrinsic_instr *instr)
{
const nir_variable *var =
nir_find_variable_with_driver_location(b->shader->nir,
@ -399,8 +388,9 @@ agx_emit_load_tile(agx_builder *b, nir_intrinsic_instr *instr)
b->shader->did_writeout = true;
b->shader->out->reads_tib = true;
return agx_ld_tile_to(b, agx_dest_index(&instr->dest),
b->shader->key->fs.tib_formats[rt]);
agx_index vec = agx_vec_for_dest(b->shader, &instr->dest);
agx_ld_tile_to(b, vec, b->shader->key->fs.tib_formats[rt]);
agx_emit_split(b, dests, vec, 4);
}
static enum agx_format
@ -415,7 +405,7 @@ agx_format_for_bits(unsigned bits)
}
static agx_instr *
agx_emit_load_ubo(agx_builder *b, nir_intrinsic_instr *instr)
agx_emit_load_ubo(agx_builder *b, agx_index dst, nir_intrinsic_instr *instr)
{
bool kernel_input = (instr->intrinsic == nir_intrinsic_load_kernel_input);
nir_src *offset = nir_get_io_offset_src(instr);
@ -439,31 +429,27 @@ agx_emit_load_ubo(agx_builder *b, nir_intrinsic_instr *instr)
/* Load the data */
assert(instr->num_components <= 4);
agx_device_load_to(b, agx_dest_index(&instr->dest),
base, agx_src_index(offset),
agx_device_load_to(b, dst, base, agx_src_index(offset),
agx_format_for_bits(nir_dest_bit_size(instr->dest)),
BITFIELD_MASK(instr->num_components), 0);
agx_wait(b, 0);
agx_emit_cached_split(b, dst, instr->num_components);
return agx_wait(b, 0);
return NULL;
}
static agx_instr *
agx_emit_load_frag_coord(agx_builder *b, nir_intrinsic_instr *instr)
static void
agx_emit_load_frag_coord(agx_builder *b, agx_index *dests, nir_intrinsic_instr *instr)
{
agx_index xy[2];
/* xy */
for (unsigned i = 0; i < 2; ++i) {
xy[i] = agx_fadd(b, agx_convert(b, agx_immediate(AGX_CONVERT_U32_TO_F),
dests[i] = agx_fadd(b, agx_convert(b, agx_immediate(AGX_CONVERT_U32_TO_F),
agx_get_sr(b, 32, AGX_SR_THREAD_POSITION_IN_GRID_X + i),
AGX_ROUND_RTE), agx_immediate_f(0.5f));
}
/* Ordering by the ABI */
agx_index z = agx_ld_vary(b, agx_immediate(1), 1, false);
agx_index w = agx_ld_vary(b, agx_immediate(0), 1, false);
return agx_p_combine_to(b, agx_dest_index(&instr->dest),
xy[0], xy[1], z, w);
dests[2] = agx_ld_vary(b, agx_immediate(1), 1, false); /* z */
dests[3] = agx_ld_vary(b, agx_immediate(0), 1, false); /* w */
}
static agx_instr *
@ -500,6 +486,7 @@ agx_emit_intrinsic(agx_builder *b, nir_intrinsic_instr *instr)
agx_index dst = nir_intrinsic_infos[instr->intrinsic].has_dest ?
agx_dest_index(&instr->dest) : agx_null();
gl_shader_stage stage = b->shader->stage;
agx_index dests[4] = { agx_null() };
switch (instr->intrinsic) {
case nir_intrinsic_load_barycentric_pixel:
@ -511,16 +498,19 @@ agx_emit_intrinsic(agx_builder *b, nir_intrinsic_instr *instr)
return NULL;
case nir_intrinsic_load_interpolated_input:
assert(stage == MESA_SHADER_FRAGMENT);
return agx_emit_load_vary(b, instr);
agx_emit_load_vary(b, dests, instr);
break;
case nir_intrinsic_load_input:
if (stage == MESA_SHADER_FRAGMENT)
return agx_emit_load_vary_flat(b, instr);
agx_emit_load_vary_flat(b, dests, instr);
else if (stage == MESA_SHADER_VERTEX)
return agx_emit_load_attr(b, instr);
agx_emit_load_attr(b, dests, instr);
else
unreachable("Unsupported shader stage");
break;
case nir_intrinsic_store_output:
if (stage == MESA_SHADER_FRAGMENT)
return agx_emit_fragment_out(b, instr);
@ -531,14 +521,16 @@ agx_emit_intrinsic(agx_builder *b, nir_intrinsic_instr *instr)
case nir_intrinsic_load_output:
assert(stage == MESA_SHADER_FRAGMENT);
return agx_emit_load_tile(b, instr);
agx_emit_load_tile(b, dests, instr);
break;
case nir_intrinsic_load_ubo:
case nir_intrinsic_load_kernel_input:
return agx_emit_load_ubo(b, instr);
return agx_emit_load_ubo(b, dst, instr);
case nir_intrinsic_load_frag_coord:
return agx_emit_load_frag_coord(b, instr);
agx_emit_load_frag_coord(b, dests, instr);
break;
case nir_intrinsic_discard:
return agx_emit_discard(b, instr);
@ -561,6 +553,14 @@ agx_emit_intrinsic(agx_builder *b, nir_intrinsic_instr *instr)
fprintf(stderr, "Unhandled intrinsic %s\n", nir_intrinsic_infos[instr->intrinsic].name);
unreachable("Unhandled intrinsic");
}
/* If we got here, there is a vector destination for the intrinsic composed
* of separate scalars. Its components are specified separately in the dests
* array. We need to combine them so the vector destination itself is valid.
* If only individual components are accessed, this combine will be dead code
* eliminated.
*/
return agx_emit_combine_to(b, dst, dests[0], dests[1], dests[2], dests[3]);
}
static agx_index
@ -831,7 +831,7 @@ agx_emit_alu(agx_builder *b, nir_alu_instr *instr)
case nir_op_vec2:
case nir_op_vec3:
case nir_op_vec4:
return agx_p_combine_to(b, dst, s0, s1, s2, s3);
return agx_emit_combine_to(b, dst, s0, s1, s2, s3);
case nir_op_vec8:
case nir_op_vec16:
@ -966,14 +966,15 @@ agx_emit_tex(agx_builder *b, nir_tex_instr *instr)
}
}
agx_texture_sample_to(b, agx_dest_index(&instr->dest),
coords, lod, texture, sampler, offset,
agx_index dst = agx_dest_index(&instr->dest);
agx_texture_sample_to(b, dst, coords, lod, texture, sampler, offset,
agx_tex_dim(instr->sampler_dim, instr->is_array),
agx_lod_mode_for_nir(instr->op),
0xF, /* TODO: wrmask */
0);
agx_wait(b, 0);
agx_emit_cached_split(b, dst, 4);
}
/* NIR loops are treated as a pair of AGX loops:

View File

@ -469,6 +469,18 @@ agx_dest_index(nir_dest *dst)
agx_size_for_bits(nir_dest_bit_size(*dst)));
}
static inline agx_index
agx_vec_for_dest(agx_context *ctx, nir_dest *dest)
{
return agx_temp(ctx, agx_size_for_bits(nir_dest_bit_size(*dest)));
}
static inline agx_index
agx_vec_for_intr(agx_context *ctx, nir_intrinsic_instr *instr)
{
return agx_vec_for_dest(ctx, &instr->dest);
}
/* Iterators for AGX IR */
#define agx_foreach_block(ctx, v) \