zink: vectorize io loads/stores when possible

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/28723>
This commit is contained in:
Mike Blumenkrantz 2024-03-21 14:52:25 -04:00 committed by Marge Bot
parent 3c673919c3
commit 6fe0cfdc09
1 changed files with 271 additions and 16 deletions

View File

@ -5675,18 +5675,6 @@ rework_io_vars(nir_shader *nir, nir_variable_mode mode, struct zink_shader *zs)
loop_io_var_mask(nir, mode, false, false, mask);
}
/* can't scalarize these */
static bool
skip_scalarize(const nir_instr *instr, const void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
return !sem.fb_fetch_output && sem.num_slots == 1;
}
static int
zink_type_size(const struct glsl_type *type, bool bindless)
{
@ -5800,6 +5788,274 @@ fix_vertex_input_locations(nir_shader *nir)
return nir_shader_intrinsics_pass(nir, fix_vertex_input_locations_instr, nir_metadata_all, NULL);
}
struct trivial_revectorize_state {
bool has_xfb;
uint32_t component_mask;
nir_intrinsic_instr *base;
nir_intrinsic_instr *next_emit_vertex;
nir_intrinsic_instr *merge[NIR_MAX_VEC_COMPONENTS];
struct set *deletions;
};
/* always skip xfb; scalarized xfb is preferred */
static bool
intr_has_xfb(nir_intrinsic_instr *intr)
{
if (!nir_intrinsic_has_io_xfb(intr))
return false;
for (unsigned i = 0; i < 2; i++) {
if (nir_intrinsic_io_xfb(intr).out[i].num_components || nir_intrinsic_io_xfb2(intr).out[i].num_components) {
return true;
}
}
return false;
}
/* helper to avoid vectorizing i/o for different vertices */
static nir_intrinsic_instr *
find_next_emit_vertex(nir_intrinsic_instr *intr)
{
bool found = false;
nir_foreach_instr_safe(instr, intr->instr.block) {
if (instr->type == nir_instr_type_intrinsic) {
nir_intrinsic_instr *test_intr = nir_instr_as_intrinsic(instr);
if (!found && test_intr != intr)
continue;
if (!found) {
assert(intr == test_intr);
found = true;
continue;
}
if (test_intr->intrinsic == nir_intrinsic_emit_vertex)
return test_intr;
}
}
return NULL;
}
/* scan for vectorizable instrs on a given location */
static bool
trivial_revectorize_intr_scan(nir_shader *nir, nir_intrinsic_instr *intr, struct trivial_revectorize_state *state)
{
nir_intrinsic_instr *base = state->base;
if (intr == base)
return false;
if (intr->intrinsic != base->intrinsic)
return false;
if (_mesa_set_search(state->deletions, intr))
return false;
bool is_load = false;
bool is_input = false;
bool is_interp = false;
filter_io_instr(intr, &is_load, &is_input, &is_interp);
nir_io_semantics base_sem = nir_intrinsic_io_semantics(base);
nir_io_semantics test_sem = nir_intrinsic_io_semantics(intr);
nir_alu_type base_type = is_load ? nir_intrinsic_dest_type(base) : nir_intrinsic_src_type(base);
nir_alu_type test_type = is_load ? nir_intrinsic_dest_type(intr) : nir_intrinsic_src_type(intr);
int c = nir_intrinsic_component(intr);
/* already detected */
if (state->component_mask & BITFIELD_BIT(c))
return false;
/* not a match */
if (base_sem.location != test_sem.location || base_sem.num_slots != test_sem.num_slots || base_type != test_type)
return false;
/* only vectorize when all srcs match */
for (unsigned i = !is_input; i < nir_intrinsic_infos[intr->intrinsic].num_srcs; i++) {
if (!nir_srcs_equal(intr->src[i], base->src[i]))
return false;
}
/* never match xfb */
state->has_xfb |= intr_has_xfb(intr);
if (state->has_xfb)
return false;
if (nir->info.stage == MESA_SHADER_GEOMETRY) {
/* only match same vertex */
if (state->next_emit_vertex != find_next_emit_vertex(intr))
return false;
}
uint32_t mask = is_load ? BITFIELD_RANGE(c, intr->num_components) : (nir_intrinsic_write_mask(intr) << c);
state->component_mask |= mask;
u_foreach_bit(component, mask)
state->merge[component] = intr;
return true;
}
static bool
trivial_revectorize_scan(struct nir_builder *b, nir_intrinsic_instr *intr, void *data)
{
bool is_load = false;
bool is_input = false;
bool is_interp = false;
if (!filter_io_instr(intr, &is_load, &is_input, &is_interp))
return false;
if (intr->num_components != 1)
return false;
nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
if (!is_input || b->shader->info.stage != MESA_SHADER_VERTEX) {
/* always ignore compact arrays */
switch (sem.location) {
case VARYING_SLOT_CLIP_DIST0:
case VARYING_SLOT_CLIP_DIST1:
case VARYING_SLOT_CULL_DIST0:
case VARYING_SLOT_CULL_DIST1:
case VARYING_SLOT_TESS_LEVEL_INNER:
case VARYING_SLOT_TESS_LEVEL_OUTER:
return false;
default: break;
}
}
/* always ignore to-be-deleted instrs */
if (_mesa_set_search(data, intr))
return false;
/* never vectorize xfb */
if (intr_has_xfb(intr))
return false;
int ic = nir_intrinsic_component(intr);
uint32_t mask = is_load ? BITFIELD_RANGE(ic, intr->num_components) : (nir_intrinsic_write_mask(intr) << ic);
/* already vectorized */
if (util_bitcount(mask) == 4)
return false;
struct trivial_revectorize_state state = {
.component_mask = mask,
.base = intr,
/* avoid clobbering i/o for different vertices */
.next_emit_vertex = b->shader->info.stage == MESA_SHADER_GEOMETRY ? find_next_emit_vertex(intr) : NULL,
.deletions = data,
};
u_foreach_bit(bit, mask)
state.merge[bit] = intr;
bool progress = false;
nir_foreach_instr(instr, intr->instr.block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *test_intr = nir_instr_as_intrinsic(instr);
/* no matching across vertex emission */
if (test_intr->intrinsic == nir_intrinsic_emit_vertex)
break;
progress |= trivial_revectorize_intr_scan(b->shader, test_intr, &state);
}
if (!progress || state.has_xfb)
return false;
/* verify nothing crazy happened */
assert(state.component_mask);
for (unsigned i = 0; i < 4; i++) {
assert(!state.merge[i] || !intr_has_xfb(state.merge[i]));
}
unsigned first_component = ffs(state.component_mask) - 1;
unsigned num_components = util_bitcount(state.component_mask);
unsigned num_contiguous = 0;
uint32_t contiguous_mask = 0;
for (unsigned i = 0; i < num_components; i++) {
unsigned c = i + first_component;
/* calc mask of contiguous components to vectorize */
if (state.component_mask & BITFIELD_BIT(c)) {
num_contiguous++;
contiguous_mask |= BITFIELD_BIT(c);
}
/* on the first gap or the the last component, vectorize */
if (!(state.component_mask & BITFIELD_BIT(c)) || i == num_components - 1) {
if (num_contiguous > 1) {
/* reindex to enable easy src/dest index comparison */
nir_index_ssa_defs(nir_shader_get_entrypoint(b->shader));
/* determine the first/last instr to use for the base (vectorized) load/store */
unsigned first_c = ffs(contiguous_mask) - 1;
nir_intrinsic_instr *base = NULL;
unsigned test_idx = is_load ? UINT32_MAX : 0;
for (unsigned j = 0; j < num_contiguous; j++) {
unsigned merge_c = j + first_c;
nir_intrinsic_instr *merge_intr = state.merge[merge_c];
/* avoid breaking ssa ordering by using:
* - first instr for vectorized load
* - last instr for vectorized store
* this guarantees all srcs have been seen
*/
if ((is_load && merge_intr->def.index < test_idx) ||
(!is_load && merge_intr->src[0].ssa->index >= test_idx)) {
test_idx = is_load ? merge_intr->def.index : merge_intr->src[0].ssa->index;
base = merge_intr;
}
}
assert(base);
/* update instr components */
nir_intrinsic_set_component(base, nir_intrinsic_component(state.merge[first_c]));
unsigned orig_components = base->num_components;
base->num_components = num_contiguous;
/* do rewrites after loads and before stores */
b->cursor = is_load ? nir_after_instr(&base->instr) : nir_before_instr(&base->instr);
if (is_load) {
base->def.num_components = num_contiguous;
/* iterate the contiguous loaded components and rewrite merged dests */
for (unsigned j = 0; j < num_contiguous; j++) {
unsigned merge_c = j + first_c;
nir_intrinsic_instr *merge_intr = state.merge[merge_c];
/* detect if the merged instr loaded multiple components and use swizzle mask for rewrite */
unsigned use_components = merge_intr == base ? orig_components : merge_intr->def.num_components;
nir_def *swiz = nir_channels(b, &base->def, BITFIELD_RANGE(j, use_components));
nir_def_rewrite_uses_after(&merge_intr->def, swiz, merge_intr == base ? swiz->parent_instr : &merge_intr->instr);
j += use_components - 1;
}
} else {
nir_def *comp[NIR_MAX_VEC_COMPONENTS];
/* generate swizzled vec of store components and rewrite store src */
for (unsigned j = 0; j < num_contiguous; j++) {
unsigned merge_c = j + first_c;
nir_intrinsic_instr *merge_intr = state.merge[merge_c];
/* detect if the merged instr stored multiple components and extract them for rewrite */
unsigned use_components = merge_intr == base ? orig_components : merge_intr->num_components;
for (unsigned k = 0; k < use_components; k++)
comp[j + k] = nir_channel(b, merge_intr->src[0].ssa, k);
j += use_components - 1;
}
nir_def *val = nir_vec(b, comp, num_contiguous);
nir_src_rewrite(&base->src[0], val);
nir_intrinsic_set_write_mask(base, BITFIELD_MASK(num_contiguous));
}
/* deleting instructions during a foreach explodes the compiler, so delete later */
for (unsigned j = 0; j < num_contiguous; j++) {
unsigned merge_c = j + first_c;
nir_intrinsic_instr *merge_intr = state.merge[merge_c];
if (merge_intr != base)
_mesa_set_add(data, &merge_intr->instr);
}
}
contiguous_mask = 0;
num_contiguous = 0;
}
}
return true;
}
/* attempt to revectorize scalar i/o, ignoring xfb and "hard stuff" */
static bool
trivial_revectorize(nir_shader *nir)
{
struct set deletions;
if (nir->info.stage > MESA_SHADER_FRAGMENT)
return false;
_mesa_set_init(&deletions, NULL, _mesa_hash_pointer, _mesa_key_pointer_equal);
bool progress = nir_shader_intrinsics_pass(nir, trivial_revectorize_scan, nir_metadata_dominance, &deletions);
/* now it's safe to delete */
set_foreach_remove(&deletions, entry) {
nir_instr *instr = (void*)entry->key;
nir_instr_remove(instr);
}
ralloc_free(deletions.table);
return progress;
}
struct zink_shader *
zink_shader_create(struct zink_screen *screen, struct nir_shader *nir)
{
@ -5855,10 +6111,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir)
NIR_PASS_V(nir, nir_lower_alu_vec8_16_srcs);
}
nir_variable_mode scalarize = nir_var_shader_in;
if (nir->info.stage != MESA_SHADER_FRAGMENT)
scalarize |= nir_var_shader_out;
NIR_PASS_V(nir, nir_lower_io_to_scalar, scalarize, skip_scalarize, NULL);
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_in | nir_var_shader_out, NULL, NULL);
optimize_nir(nir, NULL, true);
nir_foreach_variable_with_modes(var, nir, nir_var_shader_in | nir_var_shader_out) {
if (glsl_type_is_image(var->type) || glsl_type_is_sampler(var->type)) {
@ -5871,6 +6124,8 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir)
NIR_PASS_V(nir, fix_vertex_input_locations);
nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
scan_nir(screen, nir, ret);
NIR_PASS_V(nir, nir_opt_vectorize, NULL, NULL);
NIR_PASS_V(nir, trivial_revectorize);
if (nir->info.io_lowered) {
rework_io_vars(nir, nir_var_shader_in, ret);
rework_io_vars(nir, nir_var_shader_out, ret);