nir/opt_deref: Add an optimization for bitcasts

LLVM loves take advantage of the fact that vec3s in OpenCL are 16B
aligned so it can just read/write them as vec4s.  This is questionably
legal except that it uses a xyz write-mask when it does it.  The result
is a LOT of vec4->vec3 casts on loads and stores.  This optimization
detects this case as well as other bit-cast cases and rewrites them to
get rid of the cast.

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6871>
This commit is contained in:
Jason Ekstrand 2020-09-25 16:03:36 -05:00 committed by Marge Bot
parent 80e6ac3341
commit 9190f82d57
1 changed files with 159 additions and 0 deletions

View File

@ -1106,6 +1106,145 @@ opt_deref_ptr_as_array(nir_builder *b, nir_deref_instr *deref)
return true;
}
static bool
is_vector_bitcast_deref(nir_deref_instr *cast,
nir_component_mask_t mask,
bool is_write)
{
if (cast->deref_type != nir_deref_type_cast)
return false;
/* Don't throw away useful alignment information */
if (cast->cast.align_mul > 0)
return false;
/* It has to be a cast of another deref */
nir_deref_instr *parent = nir_src_as_deref(cast->parent);
if (parent == NULL)
return false;
/* Don't bother with 1-bit types */
unsigned cast_bit_size = glsl_get_bit_size(cast->type);
unsigned parent_bit_size = glsl_get_bit_size(parent->type);
if (cast_bit_size == 1 || parent_bit_size == 1)
return false;
/* A strided vector type means it's not tightly packed */
if (glsl_get_explicit_stride(cast->type) ||
glsl_get_explicit_stride(parent->type))
return false;
assert(cast_bit_size > 0 && cast_bit_size % 8 == 0);
assert(parent_bit_size > 0 && parent_bit_size % 8 == 0);
unsigned bytes_used = util_last_bit(mask) * (cast_bit_size / 8);
unsigned parent_bytes = glsl_get_vector_elements(parent->type) *
(parent_bit_size / 8);
if (bytes_used > parent_bytes)
return false;
if (is_write && !nir_component_mask_can_reinterpret(mask, cast_bit_size,
parent_bit_size))
return false;
return true;
}
static nir_ssa_def *
resize_vector(nir_builder *b, nir_ssa_def *data, unsigned num_components)
{
if (num_components == data->num_components)
return data;
unsigned swiz[NIR_MAX_VEC_COMPONENTS] = { 0, };
for (unsigned i = 0; i < MIN2(num_components, data->num_components); i++)
swiz[i] = i;
return nir_swizzle(b, data, swiz, num_components);
}
static bool
opt_load_vec_deref(nir_builder *b, nir_intrinsic_instr *load)
{
nir_deref_instr *deref = nir_src_as_deref(load->src[0]);
nir_component_mask_t read_mask =
nir_ssa_def_components_read(&load->dest.ssa);
/* LLVM loves take advantage of the fact that vec3s in OpenCL are
* vec4-aligned and so it can just read/write them as vec4s. This
* results in a LOT of vec4->vec3 casts on loads and stores.
*/
if (is_vector_bitcast_deref(deref, read_mask, false)) {
const unsigned old_num_comps = load->dest.ssa.num_components;
const unsigned old_bit_size = load->dest.ssa.bit_size;
nir_deref_instr *parent = nir_src_as_deref(deref->parent);
const unsigned new_num_comps = glsl_get_vector_elements(parent->type);
const unsigned new_bit_size = glsl_get_bit_size(parent->type);
/* Stomp it to reference the parent */
nir_instr_rewrite_src(&load->instr, &load->src[0],
nir_src_for_ssa(&parent->dest.ssa));
assert(load->dest.is_ssa);
load->dest.ssa.bit_size = new_bit_size;
load->dest.ssa.num_components = new_num_comps;
load->num_components = new_num_comps;
b->cursor = nir_after_instr(&load->instr);
nir_ssa_def *data = &load->dest.ssa;
if (old_bit_size != new_bit_size)
data = nir_bitcast_vector(b, &load->dest.ssa, old_bit_size);
data = resize_vector(b, data, old_num_comps);
nir_ssa_def_rewrite_uses_after(&load->dest.ssa, nir_src_for_ssa(data),
data->parent_instr);
return true;
}
return false;
}
static bool
opt_store_vec_deref(nir_builder *b, nir_intrinsic_instr *store)
{
nir_deref_instr *deref = nir_src_as_deref(store->src[0]);
nir_component_mask_t write_mask = nir_intrinsic_write_mask(store);
/* LLVM loves take advantage of the fact that vec3s in OpenCL are
* vec4-aligned and so it can just read/write them as vec4s. This
* results in a LOT of vec4->vec3 casts on loads and stores.
*/
if (is_vector_bitcast_deref(deref, write_mask, true)) {
assert(store->src[1].is_ssa);
nir_ssa_def *data = store->src[1].ssa;
const unsigned old_bit_size = data->bit_size;
nir_deref_instr *parent = nir_src_as_deref(deref->parent);
const unsigned new_num_comps = glsl_get_vector_elements(parent->type);
const unsigned new_bit_size = glsl_get_bit_size(parent->type);
nir_instr_rewrite_src(&store->instr, &store->src[0],
nir_src_for_ssa(&parent->dest.ssa));
/* Restrict things down as needed so the bitcast doesn't fail */
data = nir_channels(b, data, (1 << util_last_bit(write_mask)) - 1);
if (old_bit_size != new_bit_size)
data = nir_bitcast_vector(b, data, new_bit_size);
data = resize_vector(b, data, new_num_comps);
nir_instr_rewrite_src(&store->instr, &store->src[1],
nir_src_for_ssa(data));
store->num_components = new_num_comps;
/* Adjust the write mask */
write_mask = nir_component_mask_reinterpret(write_mask, old_bit_size,
new_bit_size);
nir_intrinsic_set_write_mask(store, write_mask);
return true;
}
return false;
}
bool
nir_opt_deref_impl(nir_function_impl *impl)
{
@ -1139,6 +1278,26 @@ nir_opt_deref_impl(nir_function_impl *impl)
break;
}
case nir_instr_type_intrinsic: {
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
switch (intrin->intrinsic) {
case nir_intrinsic_load_deref:
if (opt_load_vec_deref(&b, intrin))
progress = true;
break;
case nir_intrinsic_store_deref:
if (opt_store_vec_deref(&b, intrin))
progress = true;
break;
default:
/* Do nothing */
break;
}
break;
}
default:
/* Do nothing */
break;