diff --git a/src/compiler/nir/nir_deref.c b/src/compiler/nir/nir_deref.c index 96ac786c0e0..35a624ee378 100644 --- a/src/compiler/nir/nir_deref.c +++ b/src/compiler/nir/nir_deref.c @@ -1106,6 +1106,145 @@ opt_deref_ptr_as_array(nir_builder *b, nir_deref_instr *deref) return true; } +static bool +is_vector_bitcast_deref(nir_deref_instr *cast, + nir_component_mask_t mask, + bool is_write) +{ + if (cast->deref_type != nir_deref_type_cast) + return false; + + /* Don't throw away useful alignment information */ + if (cast->cast.align_mul > 0) + return false; + + /* It has to be a cast of another deref */ + nir_deref_instr *parent = nir_src_as_deref(cast->parent); + if (parent == NULL) + return false; + + /* Don't bother with 1-bit types */ + unsigned cast_bit_size = glsl_get_bit_size(cast->type); + unsigned parent_bit_size = glsl_get_bit_size(parent->type); + if (cast_bit_size == 1 || parent_bit_size == 1) + return false; + + /* A strided vector type means it's not tightly packed */ + if (glsl_get_explicit_stride(cast->type) || + glsl_get_explicit_stride(parent->type)) + return false; + + assert(cast_bit_size > 0 && cast_bit_size % 8 == 0); + assert(parent_bit_size > 0 && parent_bit_size % 8 == 0); + unsigned bytes_used = util_last_bit(mask) * (cast_bit_size / 8); + unsigned parent_bytes = glsl_get_vector_elements(parent->type) * + (parent_bit_size / 8); + if (bytes_used > parent_bytes) + return false; + + if (is_write && !nir_component_mask_can_reinterpret(mask, cast_bit_size, + parent_bit_size)) + return false; + + return true; +} + +static nir_ssa_def * +resize_vector(nir_builder *b, nir_ssa_def *data, unsigned num_components) +{ + if (num_components == data->num_components) + return data; + + unsigned swiz[NIR_MAX_VEC_COMPONENTS] = { 0, }; + for (unsigned i = 0; i < MIN2(num_components, data->num_components); i++) + swiz[i] = i; + + return nir_swizzle(b, data, swiz, num_components); +} + +static bool +opt_load_vec_deref(nir_builder *b, nir_intrinsic_instr *load) +{ + nir_deref_instr *deref = nir_src_as_deref(load->src[0]); + nir_component_mask_t read_mask = + nir_ssa_def_components_read(&load->dest.ssa); + + /* LLVM loves take advantage of the fact that vec3s in OpenCL are + * vec4-aligned and so it can just read/write them as vec4s. This + * results in a LOT of vec4->vec3 casts on loads and stores. + */ + if (is_vector_bitcast_deref(deref, read_mask, false)) { + const unsigned old_num_comps = load->dest.ssa.num_components; + const unsigned old_bit_size = load->dest.ssa.bit_size; + + nir_deref_instr *parent = nir_src_as_deref(deref->parent); + const unsigned new_num_comps = glsl_get_vector_elements(parent->type); + const unsigned new_bit_size = glsl_get_bit_size(parent->type); + + /* Stomp it to reference the parent */ + nir_instr_rewrite_src(&load->instr, &load->src[0], + nir_src_for_ssa(&parent->dest.ssa)); + assert(load->dest.is_ssa); + load->dest.ssa.bit_size = new_bit_size; + load->dest.ssa.num_components = new_num_comps; + load->num_components = new_num_comps; + + b->cursor = nir_after_instr(&load->instr); + nir_ssa_def *data = &load->dest.ssa; + if (old_bit_size != new_bit_size) + data = nir_bitcast_vector(b, &load->dest.ssa, old_bit_size); + data = resize_vector(b, data, old_num_comps); + + nir_ssa_def_rewrite_uses_after(&load->dest.ssa, nir_src_for_ssa(data), + data->parent_instr); + return true; + } + + return false; +} + +static bool +opt_store_vec_deref(nir_builder *b, nir_intrinsic_instr *store) +{ + nir_deref_instr *deref = nir_src_as_deref(store->src[0]); + nir_component_mask_t write_mask = nir_intrinsic_write_mask(store); + + /* LLVM loves take advantage of the fact that vec3s in OpenCL are + * vec4-aligned and so it can just read/write them as vec4s. This + * results in a LOT of vec4->vec3 casts on loads and stores. + */ + if (is_vector_bitcast_deref(deref, write_mask, true)) { + assert(store->src[1].is_ssa); + nir_ssa_def *data = store->src[1].ssa; + + const unsigned old_bit_size = data->bit_size; + + nir_deref_instr *parent = nir_src_as_deref(deref->parent); + const unsigned new_num_comps = glsl_get_vector_elements(parent->type); + const unsigned new_bit_size = glsl_get_bit_size(parent->type); + + nir_instr_rewrite_src(&store->instr, &store->src[0], + nir_src_for_ssa(&parent->dest.ssa)); + + /* Restrict things down as needed so the bitcast doesn't fail */ + data = nir_channels(b, data, (1 << util_last_bit(write_mask)) - 1); + if (old_bit_size != new_bit_size) + data = nir_bitcast_vector(b, data, new_bit_size); + data = resize_vector(b, data, new_num_comps); + nir_instr_rewrite_src(&store->instr, &store->src[1], + nir_src_for_ssa(data)); + store->num_components = new_num_comps; + + /* Adjust the write mask */ + write_mask = nir_component_mask_reinterpret(write_mask, old_bit_size, + new_bit_size); + nir_intrinsic_set_write_mask(store, write_mask); + return true; + } + + return false; +} + bool nir_opt_deref_impl(nir_function_impl *impl) { @@ -1139,6 +1278,26 @@ nir_opt_deref_impl(nir_function_impl *impl) break; } + case nir_instr_type_intrinsic: { + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + switch (intrin->intrinsic) { + case nir_intrinsic_load_deref: + if (opt_load_vec_deref(&b, intrin)) + progress = true; + break; + + case nir_intrinsic_store_deref: + if (opt_store_vec_deref(&b, intrin)) + progress = true; + break; + + default: + /* Do nothing */ + break; + } + break; + } + default: /* Do nothing */ break;