nir: introduce new nir_alu_alu_width() with nir_vectorize_cb callback

This function allows to only scalarize instructions down to a desired
vectorization width. nir_lower_alu_to_scalar() was changed to use the
new function with a width of 1.

Swizzles outside vectorization width are considered and reduce
the target width.

This prevents ending up with code like
  vec2 16 ssa_2 = iadd ssa_0.xz, ssa_1.xz

which requires to emit shuffle code in backends
and usually is not beneficial.

Reviewed-by: Alyssa Rosenzweig <alyssa@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13080>
This commit is contained in:
Daniel Schürmann 2021-07-06 19:08:04 +02:00 committed by Marge Bot
parent bd151a256e
commit be01e8711b
3 changed files with 108 additions and 26 deletions

View File

@ -128,7 +128,7 @@ files_libnir = files(
'nir_loop_analyze.c',
'nir_loop_analyze.h',
'nir_lower_alu.c',
'nir_lower_alu_to_scalar.c',
'nir_lower_alu_width.c',
'nir_lower_alpha_test.c',
'nir_lower_amul.c',
'nir_lower_array_deref_of_vec.c',

View File

@ -4801,6 +4801,7 @@ bool nir_lower_flrp(nir_shader *shader, unsigned lowering_mask,
bool nir_scale_fdiv(nir_shader *shader);
bool nir_lower_alu_to_scalar(nir_shader *shader, nir_instr_filter_cb cb, const void *data);
bool nir_lower_alu_width(nir_shader *shader, nir_vectorize_cb cb, const void *data);
bool nir_lower_bool_to_bitsize(nir_shader *shader);
bool nir_lower_bool_to_float(nir_shader *shader);
bool nir_lower_bool_to_int32(nir_shader *shader);

View File

@ -24,15 +24,20 @@
#include "nir.h"
#include "nir_builder.h"
struct alu_to_scalar_data {
nir_instr_filter_cb cb;
struct alu_width_data {
nir_vectorize_cb cb;
const void *data;
};
/** @file nir_lower_alu_to_scalar.c
/** @file nir_lower_alu_width.c
*
* Replaces nir_alu_instr operations with more than one channel used in the
* arguments with individual per-channel operations.
*
* Optionally, a callback function which returns the max vectorization width
* per instruction can be provided.
*
* The max vectorization width must be a power of 2.
*/
static bool
@ -52,6 +57,36 @@ inst_is_vector_alu(const nir_instr *instr, const void *_state)
nir_op_infos[alu->op].input_sizes[0] > 1;
}
/* Checks whether all operands of an ALU instruction are swizzled
* within the targeted vectorization width.
*
* The assumption here is that a vecN instruction can only swizzle
* within the first N channels of the values it consumes, irrespective
* of the capabilities of the instruction which produced those values.
* If we assume values are packed consistently (i.e., they always start
* at the beginning of a hardware register), we can actually access any
* aligned group of N channels so long as we stay within the group.
* This means for a vectorization width of 4 that only swizzles from
* either [xyzw] or [abcd] etc are allowed. For a width of 2 these are
* swizzles from either [xy] or [zw] etc.
*/
static bool
alu_is_swizzled_in_bounds(const nir_alu_instr *alu, unsigned width)
{
for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
if (nir_op_infos[alu->op].input_sizes[i] == 1)
continue;
unsigned mask = ~(width - 1);
for (unsigned j = 1; j < alu->dest.dest.ssa.num_components; j++) {
if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask))
return false;
}
}
return true;
}
static void
nir_alu_ssa_dest_init(nir_alu_instr *alu, unsigned num_components,
unsigned bit_size)
@ -140,9 +175,9 @@ lower_fdot(nir_alu_instr *alu, nir_builder *builder)
}
static nir_ssa_def *
lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
lower_alu_instr_width(nir_builder *b, nir_instr *instr, void *_data)
{
struct alu_to_scalar_data *data = _data;
struct alu_width_data *data = _data;
nir_alu_instr *alu = nir_instr_as_alu(instr);
unsigned num_src = nir_op_infos[alu->op].num_inputs;
unsigned i, chan;
@ -152,8 +187,15 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
b->exact = alu->exact;
if (data->cb && !data->cb(instr, data->data))
return NULL;
unsigned num_components = alu->dest.dest.ssa.num_components;
unsigned target_width = 1;
if (data->cb) {
target_width = data->cb(instr, data->data);
assert(util_is_power_of_two_or_zero(target_width));
if (target_width == 0)
return NULL;
}
#define LOWER_REDUCTION(name, chan, merge) \
case name##2: \
@ -319,37 +361,78 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
break;
}
if (alu->dest.dest.ssa.num_components == 1)
if (num_components == 1)
return NULL;
unsigned num_components = alu->dest.dest.ssa.num_components;
nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS] = { NULL };
if (num_components <= target_width) {
/* If the ALU instr is swizzled outside the target width,
* reduce the target width.
*/
if (alu_is_swizzled_in_bounds(alu, target_width))
return NULL;
else
target_width = DIV_ROUND_UP(num_components, 2);
}
for (chan = 0; chan < num_components; chan++) {
nir_alu_instr *vec = nir_alu_instr_create(b->shader, nir_op_vec(num_components));
for (chan = 0; chan < num_components; chan += target_width) {
unsigned components = MIN2(target_width, num_components - chan);
nir_alu_instr *lower = nir_alu_instr_create(b->shader, alu->op);
for (i = 0; i < num_src; i++) {
nir_alu_src_copy(&lower->src[i], &alu->src[i]);
/* We only handle same-size-as-dest (input_sizes[] == 0) or scalar
* args (input_sizes[] == 1).
*/
assert(nir_op_infos[alu->op].input_sizes[i] < 2);
unsigned src_chan = (nir_op_infos[alu->op].input_sizes[i] == 1 ?
0 : chan);
nir_alu_src_copy(&lower->src[i], &alu->src[i]);
for (int j = 0; j < NIR_MAX_VEC_COMPONENTS; j++)
lower->src[i].swizzle[j] = alu->dest.write_mask & (1 << chan) ?
alu->src[i].swizzle[src_chan] : 0;
for (int j = 0; j < components; j++) {
unsigned src_chan = nir_op_infos[alu->op].input_sizes[i] == 1 ? 0 : chan + j;
lower->src[i].swizzle[j] = alu->src[i].swizzle[src_chan];
}
}
nir_alu_ssa_dest_init(lower, 1, alu->dest.dest.ssa.bit_size);
nir_alu_ssa_dest_init(lower, components, alu->dest.dest.ssa.bit_size);
lower->dest.saturate = alu->dest.saturate;
comps[chan] = &lower->dest.dest.ssa;
lower->exact = alu->exact;
for (i = 0; i < components; i++) {
vec->src[chan + i].src = nir_src_for_ssa(&lower->dest.dest.ssa);
vec->src[chan + i].swizzle[0] = i;
}
nir_builder_instr_insert(b, &lower->instr);
}
return nir_vec(b, comps, num_components);
return nir_builder_alu_instr_finish_and_insert(b, vec);
}
bool
nir_lower_alu_width(nir_shader *shader, nir_vectorize_cb cb, const void *_data)
{
struct alu_width_data data = {
.cb = cb,
.data = _data,
};
return nir_shader_lower_instructions(shader,
inst_is_vector_alu,
lower_alu_instr_width,
&data);
}
struct alu_to_scalar_data {
nir_instr_filter_cb cb;
const void *data;
};
static uint8_t
scalar_cb(const nir_instr *instr, const void *data)
{
/* return vectorization-width = 1 for filtered instructions */
const struct alu_to_scalar_data *filter = data;
return filter->cb(instr, filter->data) ? 1 : 0;
}
bool
@ -360,8 +443,6 @@ nir_lower_alu_to_scalar(nir_shader *shader, nir_instr_filter_cb cb, const void *
.data = _data,
};
return nir_shader_lower_instructions(shader,
inst_is_vector_alu,
lower_alu_instr_scalar,
&data);
return nir_lower_alu_width(shader, cb ? scalar_cb : NULL, &data);
}