nir/opt_shrink_vectors: Round to supported vec size

The set of supported vector sizes in NIR has holes in it. For example, we
support vec5 and vec8, but not vec6 or vec7. However, this pass did not take
that into account, and would happily shrink a vec8 down to a vec7, causing NIR
validation to fail. Instead, the pass should round up to the next supported
vector size.

Fixes NIR validation fail in OpenCL's test_basic hiloeo subtest.

v2: Clamp -> round rename.

Signed-off-by: Alyssa Rosenzweig <alyssa@collabora.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17194>
This commit is contained in:
Alyssa Rosenzweig 2022-06-22 15:58:16 -04:00 committed by Marge Bot
parent 21b3a23404
commit befc68ec33
1 changed files with 24 additions and 0 deletions

View File

@ -45,6 +45,18 @@
#include "nir.h"
#include "nir_builder.h"
#include "util/u_math.h"
/*
* Round up a vector size to a vector size that's valid in NIR. At present, NIR
* supports only vec2-5, vec8, and vec16. Attempting to generate other sizes
* will fail validation.
*/
static unsigned
round_up_components(unsigned n)
{
return (n > 5) ? util_next_power_of_two(n) : n;
}
static bool
shrink_dest_to_read_mask(nir_ssa_def *def)
@ -66,6 +78,10 @@ shrink_dest_to_read_mask(nir_ssa_def *def)
if (!mask)
return false;
unsigned rounded = round_up_components(last_bit);
assert(rounded <= def->num_components);
last_bit = rounded;
if (def->num_components > last_bit) {
def->num_components = last_bit;
return true;
@ -179,6 +195,10 @@ opt_shrink_vectors_alu(nir_builder *b, nir_alu_instr *instr)
unsigned last_bit = util_last_bit(mask);
unsigned num_components = util_bitcount(mask);
unsigned rounded = round_up_components(num_components);
assert(rounded <= def->num_components);
num_components = rounded;
/* return, if there is nothing to do */
if (mask == 0 || num_components == def->num_components)
return false;
@ -293,6 +313,10 @@ opt_shrink_vectors_load_const(nir_load_const_instr *instr)
}
}
unsigned rounded = round_up_components(num_components);
assert(rounded <= def->num_components);
num_components = rounded;
if (num_components == def->num_components)
return false;