nir: introduce new nir_alu_alu_width() with nir_vectorize_cb callback
This function allows to only scalarize instructions down to a desired vectorization width. nir_lower_alu_to_scalar() was changed to use the new function with a width of 1. Swizzles outside vectorization width are considered and reduce the target width. This prevents ending up with code like vec2 16 ssa_2 = iadd ssa_0.xz, ssa_1.xz which requires to emit shuffle code in backends and usually is not beneficial. Reviewed-by: Alyssa Rosenzweig <alyssa@collabora.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/13080>
This commit is contained in:
parent
bd151a256e
commit
be01e8711b
|
@ -128,7 +128,7 @@ files_libnir = files(
|
|||
'nir_loop_analyze.c',
|
||||
'nir_loop_analyze.h',
|
||||
'nir_lower_alu.c',
|
||||
'nir_lower_alu_to_scalar.c',
|
||||
'nir_lower_alu_width.c',
|
||||
'nir_lower_alpha_test.c',
|
||||
'nir_lower_amul.c',
|
||||
'nir_lower_array_deref_of_vec.c',
|
||||
|
|
|
@ -4801,6 +4801,7 @@ bool nir_lower_flrp(nir_shader *shader, unsigned lowering_mask,
|
|||
bool nir_scale_fdiv(nir_shader *shader);
|
||||
|
||||
bool nir_lower_alu_to_scalar(nir_shader *shader, nir_instr_filter_cb cb, const void *data);
|
||||
bool nir_lower_alu_width(nir_shader *shader, nir_vectorize_cb cb, const void *data);
|
||||
bool nir_lower_bool_to_bitsize(nir_shader *shader);
|
||||
bool nir_lower_bool_to_float(nir_shader *shader);
|
||||
bool nir_lower_bool_to_int32(nir_shader *shader);
|
||||
|
|
|
@ -24,15 +24,20 @@
|
|||
#include "nir.h"
|
||||
#include "nir_builder.h"
|
||||
|
||||
struct alu_to_scalar_data {
|
||||
nir_instr_filter_cb cb;
|
||||
struct alu_width_data {
|
||||
nir_vectorize_cb cb;
|
||||
const void *data;
|
||||
};
|
||||
|
||||
/** @file nir_lower_alu_to_scalar.c
|
||||
/** @file nir_lower_alu_width.c
|
||||
*
|
||||
* Replaces nir_alu_instr operations with more than one channel used in the
|
||||
* arguments with individual per-channel operations.
|
||||
*
|
||||
* Optionally, a callback function which returns the max vectorization width
|
||||
* per instruction can be provided.
|
||||
*
|
||||
* The max vectorization width must be a power of 2.
|
||||
*/
|
||||
|
||||
static bool
|
||||
|
@ -52,6 +57,36 @@ inst_is_vector_alu(const nir_instr *instr, const void *_state)
|
|||
nir_op_infos[alu->op].input_sizes[0] > 1;
|
||||
}
|
||||
|
||||
/* Checks whether all operands of an ALU instruction are swizzled
|
||||
* within the targeted vectorization width.
|
||||
*
|
||||
* The assumption here is that a vecN instruction can only swizzle
|
||||
* within the first N channels of the values it consumes, irrespective
|
||||
* of the capabilities of the instruction which produced those values.
|
||||
* If we assume values are packed consistently (i.e., they always start
|
||||
* at the beginning of a hardware register), we can actually access any
|
||||
* aligned group of N channels so long as we stay within the group.
|
||||
* This means for a vectorization width of 4 that only swizzles from
|
||||
* either [xyzw] or [abcd] etc are allowed. For a width of 2 these are
|
||||
* swizzles from either [xy] or [zw] etc.
|
||||
*/
|
||||
static bool
|
||||
alu_is_swizzled_in_bounds(const nir_alu_instr *alu, unsigned width)
|
||||
{
|
||||
for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
|
||||
if (nir_op_infos[alu->op].input_sizes[i] == 1)
|
||||
continue;
|
||||
|
||||
unsigned mask = ~(width - 1);
|
||||
for (unsigned j = 1; j < alu->dest.dest.ssa.num_components; j++) {
|
||||
if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void
|
||||
nir_alu_ssa_dest_init(nir_alu_instr *alu, unsigned num_components,
|
||||
unsigned bit_size)
|
||||
|
@ -140,9 +175,9 @@ lower_fdot(nir_alu_instr *alu, nir_builder *builder)
|
|||
}
|
||||
|
||||
static nir_ssa_def *
|
||||
lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
|
||||
lower_alu_instr_width(nir_builder *b, nir_instr *instr, void *_data)
|
||||
{
|
||||
struct alu_to_scalar_data *data = _data;
|
||||
struct alu_width_data *data = _data;
|
||||
nir_alu_instr *alu = nir_instr_as_alu(instr);
|
||||
unsigned num_src = nir_op_infos[alu->op].num_inputs;
|
||||
unsigned i, chan;
|
||||
|
@ -152,8 +187,15 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
|
|||
|
||||
b->exact = alu->exact;
|
||||
|
||||
if (data->cb && !data->cb(instr, data->data))
|
||||
return NULL;
|
||||
unsigned num_components = alu->dest.dest.ssa.num_components;
|
||||
unsigned target_width = 1;
|
||||
|
||||
if (data->cb) {
|
||||
target_width = data->cb(instr, data->data);
|
||||
assert(util_is_power_of_two_or_zero(target_width));
|
||||
if (target_width == 0)
|
||||
return NULL;
|
||||
}
|
||||
|
||||
#define LOWER_REDUCTION(name, chan, merge) \
|
||||
case name##2: \
|
||||
|
@ -319,37 +361,78 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
|
|||
break;
|
||||
}
|
||||
|
||||
if (alu->dest.dest.ssa.num_components == 1)
|
||||
if (num_components == 1)
|
||||
return NULL;
|
||||
|
||||
unsigned num_components = alu->dest.dest.ssa.num_components;
|
||||
nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS] = { NULL };
|
||||
if (num_components <= target_width) {
|
||||
/* If the ALU instr is swizzled outside the target width,
|
||||
* reduce the target width.
|
||||
*/
|
||||
if (alu_is_swizzled_in_bounds(alu, target_width))
|
||||
return NULL;
|
||||
else
|
||||
target_width = DIV_ROUND_UP(num_components, 2);
|
||||
}
|
||||
|
||||
for (chan = 0; chan < num_components; chan++) {
|
||||
nir_alu_instr *vec = nir_alu_instr_create(b->shader, nir_op_vec(num_components));
|
||||
|
||||
for (chan = 0; chan < num_components; chan += target_width) {
|
||||
unsigned components = MIN2(target_width, num_components - chan);
|
||||
nir_alu_instr *lower = nir_alu_instr_create(b->shader, alu->op);
|
||||
|
||||
for (i = 0; i < num_src; i++) {
|
||||
nir_alu_src_copy(&lower->src[i], &alu->src[i]);
|
||||
|
||||
/* We only handle same-size-as-dest (input_sizes[] == 0) or scalar
|
||||
* args (input_sizes[] == 1).
|
||||
*/
|
||||
assert(nir_op_infos[alu->op].input_sizes[i] < 2);
|
||||
unsigned src_chan = (nir_op_infos[alu->op].input_sizes[i] == 1 ?
|
||||
0 : chan);
|
||||
|
||||
nir_alu_src_copy(&lower->src[i], &alu->src[i]);
|
||||
for (int j = 0; j < NIR_MAX_VEC_COMPONENTS; j++)
|
||||
lower->src[i].swizzle[j] = alu->dest.write_mask & (1 << chan) ?
|
||||
alu->src[i].swizzle[src_chan] : 0;
|
||||
for (int j = 0; j < components; j++) {
|
||||
unsigned src_chan = nir_op_infos[alu->op].input_sizes[i] == 1 ? 0 : chan + j;
|
||||
lower->src[i].swizzle[j] = alu->src[i].swizzle[src_chan];
|
||||
}
|
||||
}
|
||||
|
||||
nir_alu_ssa_dest_init(lower, 1, alu->dest.dest.ssa.bit_size);
|
||||
nir_alu_ssa_dest_init(lower, components, alu->dest.dest.ssa.bit_size);
|
||||
lower->dest.saturate = alu->dest.saturate;
|
||||
comps[chan] = &lower->dest.dest.ssa;
|
||||
lower->exact = alu->exact;
|
||||
|
||||
for (i = 0; i < components; i++) {
|
||||
vec->src[chan + i].src = nir_src_for_ssa(&lower->dest.dest.ssa);
|
||||
vec->src[chan + i].swizzle[0] = i;
|
||||
}
|
||||
|
||||
nir_builder_instr_insert(b, &lower->instr);
|
||||
}
|
||||
|
||||
return nir_vec(b, comps, num_components);
|
||||
return nir_builder_alu_instr_finish_and_insert(b, vec);
|
||||
}
|
||||
|
||||
bool
|
||||
nir_lower_alu_width(nir_shader *shader, nir_vectorize_cb cb, const void *_data)
|
||||
{
|
||||
struct alu_width_data data = {
|
||||
.cb = cb,
|
||||
.data = _data,
|
||||
};
|
||||
|
||||
return nir_shader_lower_instructions(shader,
|
||||
inst_is_vector_alu,
|
||||
lower_alu_instr_width,
|
||||
&data);
|
||||
}
|
||||
|
||||
struct alu_to_scalar_data {
|
||||
nir_instr_filter_cb cb;
|
||||
const void *data;
|
||||
};
|
||||
|
||||
static uint8_t
|
||||
scalar_cb(const nir_instr *instr, const void *data)
|
||||
{
|
||||
/* return vectorization-width = 1 for filtered instructions */
|
||||
const struct alu_to_scalar_data *filter = data;
|
||||
return filter->cb(instr, filter->data) ? 1 : 0;
|
||||
}
|
||||
|
||||
bool
|
||||
|
@ -360,8 +443,6 @@ nir_lower_alu_to_scalar(nir_shader *shader, nir_instr_filter_cb cb, const void *
|
|||
.data = _data,
|
||||
};
|
||||
|
||||
return nir_shader_lower_instructions(shader,
|
||||
inst_is_vector_alu,
|
||||
lower_alu_instr_scalar,
|
||||
&data);
|
||||
return nir_lower_alu_width(shader, cb ? scalar_cb : NULL, &data);
|
||||
}
|
||||
|
Loading…
Reference in New Issue