nir/lower_subgroups: Properly lower masks when subgroup_size == 0

Instead of building a constant mask (which depends on knowing the
subgroup size), we build an expression.  Because the pass uses the
nir_shader_lower_instructions helper, subgroup lowering will be run on
any newly emitted instructions as well as the previously existing
instructions.  In particular, if the subgroup size is known, the newly
emitted subgroup_size intrinsic will get turned into a constant and a
later constant folding pass will clean it up.

Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
This commit is contained in:
Jason Ekstrand 2019-07-10 22:20:00 -05:00
parent 256e6c2d94
commit 799f0f7b28
1 changed files with 11 additions and 5 deletions

View File

@ -292,6 +292,15 @@ lower_subgroups_filter(const nir_instr *instr, const void *_options)
return instr->type == nir_instr_type_intrinsic;
}
static nir_ssa_def *
build_subgroup_mask(nir_builder *b, unsigned bit_size,
const nir_lower_subgroups_options *options)
{
return nir_ushr(b, nir_imm_intN_t(b, ~0ull, bit_size),
nir_isub(b, nir_imm_int(b, bit_size),
nir_load_subgroup_size(b)));
}
static nir_ssa_def *
lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
{
@ -343,9 +352,6 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
const unsigned bit_size = MAX2(options->ballot_bit_size,
intrin->dest.ssa.bit_size);
assert(options->subgroup_size <= 64);
uint64_t group_mask = ~0ull >> (64 - options->subgroup_size);
nir_ssa_def *count = nir_load_subgroup_invocation(b);
nir_ssa_def *val;
switch (intrin->intrinsic) {
@ -354,11 +360,11 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options)
break;
case nir_intrinsic_load_subgroup_ge_mask:
val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~0ull, bit_size), count),
nir_imm_intN_t(b, group_mask, bit_size));
build_subgroup_mask(b, bit_size, options));
break;
case nir_intrinsic_load_subgroup_gt_mask:
val = nir_iand(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count),
nir_imm_intN_t(b, group_mask, bit_size));
build_subgroup_mask(b, bit_size, options));
break;
case nir_intrinsic_load_subgroup_le_mask:
val = nir_inot(b, nir_ishl(b, nir_imm_intN_t(b, ~1ull, bit_size), count));