diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 281e290f447..9fa4c443c62 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -4636,6 +4636,26 @@ nir_variable_is_in_block(const nir_variable *var) return nir_variable_is_in_ubo(var) || nir_variable_is_in_ssbo(var); } +typedef struct nir_unsigned_upper_bound_config { + unsigned min_subgroup_size; + unsigned max_subgroup_size; + unsigned max_work_group_invocations; + unsigned max_work_group_count[3]; + unsigned max_work_group_size[3]; + + uint32_t vertex_attrib_max[32]; +} nir_unsigned_upper_bound_config; + +uint32_t +nir_unsigned_upper_bound(nir_shader *shader, struct hash_table *range_ht, + nir_ssa_scalar scalar, + const nir_unsigned_upper_bound_config *config); + +bool +nir_addition_might_overflow(nir_shader *shader, struct hash_table *range_ht, + nir_ssa_scalar ssa, unsigned const_val, + const nir_unsigned_upper_bound_config *config); + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/src/compiler/nir/nir_range_analysis.c b/src/compiler/nir/nir_range_analysis.c index 548123be4ce..87ee919710d 100644 --- a/src/compiler/nir/nir_range_analysis.c +++ b/src/compiler/nir/nir_range_analysis.c @@ -1086,3 +1086,420 @@ nir_analyze_range(struct hash_table *range_ht, return analyze_expression(instr, src, range_ht, nir_alu_src_type(instr, src)); } + +static uint32_t bitmask(uint32_t size) { + return size >= 32 ? 0xffffffffu : ((uint32_t)1 << size) - 1u; +} + +static uint64_t mul_clamp(uint32_t a, uint32_t b) +{ + if (a != 0 && (a * b) / a != b) + return (uint64_t)UINT32_MAX + 1; + else + return a * b; +} + +static unsigned +search_phi_bcsel(nir_ssa_scalar scalar, nir_ssa_scalar *buf, unsigned buf_size, struct set *visited) +{ + if (_mesa_set_search(visited, scalar.def)) + return 0; + _mesa_set_add(visited, scalar.def); + + if (scalar.def->parent_instr->type == nir_instr_type_phi) { + nir_phi_instr *phi = nir_instr_as_phi(scalar.def->parent_instr); + unsigned num_sources_left = exec_list_length(&phi->srcs); + unsigned total_added = 0; + nir_foreach_phi_src(src, phi) { + unsigned added = search_phi_bcsel( + (nir_ssa_scalar){src->src.ssa, 0}, buf + total_added, buf_size - num_sources_left, visited); + buf_size -= added; + total_added += added; + num_sources_left--; + } + return total_added; + } + + if (nir_ssa_scalar_is_alu(scalar)) { + nir_op op = nir_ssa_scalar_alu_op(scalar); + + if ((op == nir_op_bcsel || op == nir_op_b32csel) && buf_size >= 2) { + nir_ssa_scalar src0 = nir_ssa_scalar_chase_alu_src(scalar, 0); + nir_ssa_scalar src1 = nir_ssa_scalar_chase_alu_src(scalar, 1); + + unsigned added = search_phi_bcsel(src0, buf, buf_size - 1, visited); + buf_size -= added; + added += search_phi_bcsel(src1, buf + added, buf_size, visited); + return added; + } + } + + buf[0] = scalar; + return 1; +} + +static nir_variable * +lookup_input(nir_shader *shader, unsigned driver_location) +{ + nir_foreach_variable(var, &shader->inputs) { + if (driver_location == var->data.driver_location) + return var; + } + return NULL; +} + +uint32_t +nir_unsigned_upper_bound(nir_shader *shader, struct hash_table *range_ht, + nir_ssa_scalar scalar, + const nir_unsigned_upper_bound_config *config) +{ + assert(scalar.def->bit_size <= 32); + + if (nir_ssa_scalar_is_const(scalar)) + return nir_ssa_scalar_as_uint(scalar); + + /* keys can't be 0, so we have to add 1 to the index */ + void *key = (void*)(((uintptr_t)(scalar.def->index + 1) << 4) | scalar.comp); + struct hash_entry *he = _mesa_hash_table_search(range_ht, key); + if (he != NULL) + return (uintptr_t)he->data; + + uint32_t max = bitmask(scalar.def->bit_size); + + if (scalar.def->parent_instr->type == nir_instr_type_intrinsic) { + uint32_t res = max; + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(scalar.def->parent_instr); + switch (intrin->intrinsic) { + case nir_intrinsic_load_local_invocation_index: + if (shader->info.cs.local_size_variable) { + res = config->max_work_group_invocations - 1; + } else { + res = (shader->info.cs.local_size[0] * + shader->info.cs.local_size[1] * + shader->info.cs.local_size[2]) - 1u; + } + break; + case nir_intrinsic_load_local_invocation_id: + if (shader->info.cs.local_size_variable) + res = config->max_work_group_size[scalar.comp] - 1u; + else + res = shader->info.cs.local_size[scalar.comp] - 1u; + break; + case nir_intrinsic_load_work_group_id: + res = config->max_work_group_count[scalar.comp] - 1u; + break; + case nir_intrinsic_load_num_work_groups: + res = config->max_work_group_count[scalar.comp]; + break; + case nir_intrinsic_load_global_invocation_id: + if (shader->info.cs.local_size_variable) { + res = mul_clamp(config->max_work_group_size[scalar.comp], + config->max_work_group_count[scalar.comp]) - 1u; + } else { + res = (shader->info.cs.local_size[scalar.comp] * + config->max_work_group_count[scalar.comp]) - 1u; + } + break; + case nir_intrinsic_load_subgroup_invocation: + case nir_intrinsic_first_invocation: + case nir_intrinsic_mbcnt_amd: + res = config->max_subgroup_size - 1; + break; + case nir_intrinsic_load_subgroup_size: + res = config->max_subgroup_size; + break; + case nir_intrinsic_load_subgroup_id: + case nir_intrinsic_load_num_subgroups: { + uint32_t work_group_size = config->max_work_group_invocations; + if (!shader->info.cs.local_size_variable) { + work_group_size = shader->info.cs.local_size[0] * + shader->info.cs.local_size[1] * + shader->info.cs.local_size[2]; + } + res = (work_group_size + config->min_subgroup_size - 1) / config->min_subgroup_size; + if (intrin->intrinsic == nir_intrinsic_load_subgroup_id) + res--; + break; + } + case nir_intrinsic_load_input: { + if (shader->info.stage == MESA_SHADER_VERTEX && nir_src_is_const(intrin->src[0])) { + nir_variable *var = lookup_input(shader, nir_intrinsic_base(intrin)); + if (var) { + int loc = var->data.location - VERT_ATTRIB_GENERIC0; + if (loc >= 0) + res = config->vertex_attrib_max[loc]; + } + } + break; + } + case nir_intrinsic_reduce: + case nir_intrinsic_inclusive_scan: + case nir_intrinsic_exclusive_scan: { + nir_op op = nir_intrinsic_reduction_op(intrin); + if (op == nir_op_umin || op == nir_op_umax || op == nir_op_imin || op == nir_op_imax) + res = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[0].ssa, 0}, config); + break; + } + case nir_intrinsic_read_first_invocation: + case nir_intrinsic_read_invocation: + case nir_intrinsic_shuffle: + case nir_intrinsic_shuffle_xor: + case nir_intrinsic_shuffle_up: + case nir_intrinsic_shuffle_down: + case nir_intrinsic_quad_broadcast: + case nir_intrinsic_quad_swap_horizontal: + case nir_intrinsic_quad_swap_vertical: + case nir_intrinsic_quad_swap_diagonal: + case nir_intrinsic_quad_swizzle_amd: + case nir_intrinsic_masked_swizzle_amd: + res = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[0].ssa, 0}, config); + break; + case nir_intrinsic_write_invocation_amd: { + uint32_t src0 = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[0].ssa, 0}, config); + uint32_t src1 = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[1].ssa, 0}, config); + res = MAX2(src0, src1); + break; + } + default: + break; + } + if (res != max) + _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)res); + return res; + } + + if (scalar.def->parent_instr->type == nir_instr_type_phi) { + bool cyclic = false; + nir_foreach_phi_src(src, nir_instr_as_phi(scalar.def->parent_instr)) { + if (nir_block_dominates(scalar.def->parent_instr->block, src->pred)) { + cyclic = true; + break; + } + } + + uint32_t res = 0; + if (cyclic) { + _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)max); + + struct set *visited = _mesa_pointer_set_create(NULL); + nir_ssa_scalar defs[64]; + unsigned def_count = search_phi_bcsel(scalar, defs, 64, visited); + _mesa_set_destroy(visited, NULL); + + for (unsigned i = 0; i < def_count; i++) + res = MAX2(res, nir_unsigned_upper_bound(shader, range_ht, defs[i], config)); + } else { + nir_foreach_phi_src(src, nir_instr_as_phi(scalar.def->parent_instr)) { + res = MAX2(res, nir_unsigned_upper_bound( + shader, range_ht, (nir_ssa_scalar){src->src.ssa, 0}, config)); + } + } + + _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)res); + return res; + } + + if (nir_ssa_scalar_is_alu(scalar)) { + nir_op op = nir_ssa_scalar_alu_op(scalar); + + switch (op) { + case nir_op_umin: + case nir_op_imin: + case nir_op_imax: + case nir_op_umax: + case nir_op_iand: + case nir_op_ior: + case nir_op_ixor: + case nir_op_ishl: + case nir_op_imul: + case nir_op_ushr: + case nir_op_ishr: + case nir_op_iadd: + case nir_op_umod: + case nir_op_udiv: + case nir_op_bcsel: + case nir_op_b32csel: + case nir_op_imax3: + case nir_op_imin3: + case nir_op_umax3: + case nir_op_umin3: + case nir_op_ubfe: + case nir_op_bfm: + case nir_op_f2u32: + case nir_op_fmul: + break; + default: + return max; + } + + uint32_t src0 = nir_unsigned_upper_bound(shader, range_ht, nir_ssa_scalar_chase_alu_src(scalar, 0), config); + uint32_t src1 = max, src2 = max; + if (nir_op_infos[op].num_inputs > 1) + src1 = nir_unsigned_upper_bound(shader, range_ht, nir_ssa_scalar_chase_alu_src(scalar, 1), config); + if (nir_op_infos[op].num_inputs > 2) + src2 = nir_unsigned_upper_bound(shader, range_ht, nir_ssa_scalar_chase_alu_src(scalar, 2), config); + + uint32_t res = max; + switch (op) { + case nir_op_umin: + res = src0 < src1 ? src0 : src1; + break; + case nir_op_imin: + case nir_op_imax: + case nir_op_umax: + res = src0 > src1 ? src0 : src1; + break; + case nir_op_iand: + res = bitmask(util_last_bit64(src0)) & bitmask(util_last_bit64(src1)); + break; + case nir_op_ior: + case nir_op_ixor: + res = bitmask(util_last_bit64(src0)) | bitmask(util_last_bit64(src1)); + break; + case nir_op_ishl: + if (util_last_bit64(src0) + src1 > scalar.def->bit_size) + res = max; /* overflow */ + else + res = src0 << MIN2(src1, scalar.def->bit_size - 1u); + break; + case nir_op_imul: + if (src0 != 0 && (src0 * src1) / src0 != src1) + res = max; + else + res = src0 * src1; + break; + case nir_op_ushr: { + nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1); + if (nir_ssa_scalar_is_const(src1_scalar)) + res = src0 >> nir_ssa_scalar_as_uint(src1_scalar); + else + res = src0; + break; + } + case nir_op_ishr: { + nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1); + if (src0 <= 2147483647 && nir_ssa_scalar_is_const(src1_scalar)) + res = src0 >> nir_ssa_scalar_as_uint(src1_scalar); + else + res = src0; + break; + } + case nir_op_iadd: + if (src0 + src1 < src0) + res = max; /* overflow */ + else + res = src0 + src1; + break; + case nir_op_umod: + res = src1 ? src1 - 1 : 0; + break; + case nir_op_udiv: { + nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1); + if (nir_ssa_scalar_is_const(src1_scalar)) + res = nir_ssa_scalar_as_uint(src1_scalar) ? src0 / nir_ssa_scalar_as_uint(src1_scalar) : 0; + else + res = src0; + break; + } + case nir_op_bcsel: + case nir_op_b32csel: + res = src1 > src2 ? src1 : src2; + break; + case nir_op_imax3: + case nir_op_imin3: + case nir_op_umax3: + src0 = src0 > src1 ? src0 : src1; + res = src0 > src2 ? src0 : src2; + break; + case nir_op_umin3: + src0 = src0 < src1 ? src0 : src1; + res = src0 < src2 ? src0 : src2; + break; + case nir_op_ubfe: + res = bitmask(MIN2(src2, scalar.def->bit_size)); + break; + case nir_op_bfm: { + nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1); + if (nir_ssa_scalar_is_const(src1_scalar)) { + src0 = MIN2(src0, 31); + src1 = nir_ssa_scalar_as_uint(src1_scalar) & 0x1fu; + res = bitmask(src0) << src1; + } else { + src0 = MIN2(src0, 31); + src1 = MIN2(src1, 31); + res = bitmask(MIN2(src0 + src1, 32)); + } + break; + } + /* limited floating-point support for f2u32(fmul(load_input(), )) */ + case nir_op_f2u32: + /* infinity/NaN starts at 0x7f800000u, negative numbers at 0x80000000 */ + if (src0 < 0x7f800000u) { + float val; + memcpy(&val, &src0, 4); + res = (uint32_t)val; + } + break; + case nir_op_fmul: + /* infinity/NaN starts at 0x7f800000u, negative numbers at 0x80000000 */ + if (src0 < 0x7f800000u && src1 < 0x7f800000u) { + float src0_f, src1_f; + memcpy(&src0_f, &src0, 4); + memcpy(&src1_f, &src1, 4); + /* not a proper rounding-up multiplication, but should be good enough */ + float max_f = ceilf(src0_f) * ceilf(src1_f); + memcpy(&res, &max_f, 4); + } + break; + default: + res = max; + break; + } + _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)res); + return res; + } + + return max; +} + +bool +nir_addition_might_overflow(nir_shader *shader, struct hash_table *range_ht, + nir_ssa_scalar ssa, unsigned const_val, + const nir_unsigned_upper_bound_config *config) +{ + nir_op alu_op = nir_ssa_scalar_alu_op(ssa); + + /* iadd(imul(a, #b), #c) */ + if (alu_op == nir_op_imul || alu_op == nir_op_ishl) { + nir_ssa_scalar mul_src0 = nir_ssa_scalar_chase_alu_src(ssa, 0); + nir_ssa_scalar mul_src1 = nir_ssa_scalar_chase_alu_src(ssa, 1); + uint32_t stride = 1; + if (nir_ssa_scalar_is_const(mul_src0)) + stride = nir_ssa_scalar_as_uint(mul_src0); + else if (nir_ssa_scalar_is_const(mul_src1)) + stride = nir_ssa_scalar_as_uint(mul_src1); + + if (alu_op == nir_op_ishl) + stride = 1u << (stride % 32u); + + if (!stride || const_val <= UINT32_MAX - (UINT32_MAX / stride * stride)) + return false; + } + + /* iadd(iand(a, #b), #c) */ + if (alu_op == nir_op_iand) { + nir_ssa_scalar and_src0 = nir_ssa_scalar_chase_alu_src(ssa, 0); + nir_ssa_scalar and_src1 = nir_ssa_scalar_chase_alu_src(ssa, 1); + uint32_t mask = 0xffffffff; + if (nir_ssa_scalar_is_const(and_src0)) + mask = nir_ssa_scalar_as_uint(and_src0); + else if (nir_ssa_scalar_is_const(and_src1)) + mask = nir_ssa_scalar_as_uint(and_src1); + if (mask == 0 || const_val < (1u << (ffs(mask) - 1))) + return false; + } + + uint32_t ub = nir_unsigned_upper_bound(shader, range_ht, ssa, config); + return const_val + ub < const_val; +} +