From 4d83306a9aabb5f9ea7e6a54d0e25c0f82805965 Mon Sep 17 00:00:00 2001 From: Jesse Natalie Date: Mon, 28 Dec 2020 15:45:58 -0800 Subject: [PATCH] nir: Update saturated float->int/uint conversion algorithm The mantissa for a float doesn't contain enough data to accurately represent the min/max values for some destination types. Instead of clamping before converting, clamp after converting when coming from floats. This improves conformance of CL conversions, specifically for float -> long/ulong with int64 emulation enabled. Refactors the limit determination from the clamp, so we can determine limits for the dest type (int/uint) in both the source (float) and dest type. The limit as a float is used for comparison, while the limit as a dest type is used for bcsel. Important note is that the comparison is inverted to fge instead of flt, so the bcsel chooses the direct int/uint over the converted float in the case where the comparison comes up equal, but the conversion can't produce the exact min/max value. Reviewed-by: Jason Ekstrand Part-of: --- src/compiler/nir/nir_conversion_builder.h | 133 ++++++++++++++-------- 1 file changed, 87 insertions(+), 46 deletions(-) diff --git a/src/compiler/nir/nir_conversion_builder.h b/src/compiler/nir/nir_conversion_builder.h index 78e41bfb690..c124e2650f0 100644 --- a/src/compiler/nir/nir_conversion_builder.h +++ b/src/compiler/nir/nir_conversion_builder.h @@ -222,28 +222,26 @@ nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b) } /** - * Clamp the source value into the widest representatble range of the - * destination type with cmp + bcsel. + * Retrieves limits used for clamping a value of the src type into + * the widest representable range of the dst type via cmp + bcsel */ -static inline nir_ssa_def * -nir_clamp_to_type_range(nir_builder *b, - nir_ssa_def *src, nir_alu_type src_type, - nir_alu_type dest_type) +static inline void +nir_get_clamp_limits(nir_builder *b, + nir_alu_type src_type, + nir_alu_type dest_type, + nir_ssa_def **low, nir_ssa_def **high) { - assert(nir_alu_type_get_type_size(src_type) == 0 || - nir_alu_type_get_type_size(src_type) == src->bit_size); - src_type |= src->bit_size; - if (nir_alu_type_range_contains_type_range(dest_type, src_type)) - return src; - /* Split types from bit sizes */ nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type); nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type); + unsigned src_bit_size = nir_alu_type_get_type_size(src_type); unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type); - assert(dest_bit_size != 0); + assert(dest_bit_size != 0 && src_bit_size != 0); + + *low = NULL; + *high = NULL; /* limits of the destination type, expressed in the source type */ - nir_ssa_def *low = NULL, *high = NULL; switch (dest_base_type) { case nir_type_int: { int64_t ilow, ihigh; @@ -256,14 +254,14 @@ nir_clamp_to_type_range(nir_builder *b, } if (src_base_type == nir_type_int) { - low = nir_imm_intN_t(b, ilow, src->bit_size); - high = nir_imm_intN_t(b, ihigh, src->bit_size); + *low = nir_imm_intN_t(b, ilow, src_bit_size); + *high = nir_imm_intN_t(b, ihigh, src_bit_size); } else if (src_base_type == nir_type_uint) { - assert(src->bit_size >= dest_bit_size); - high = nir_imm_intN_t(b, ihigh, src->bit_size); + assert(src_bit_size >= dest_bit_size); + *high = nir_imm_intN_t(b, ihigh, src_bit_size); } else { - low = nir_imm_floatN_t(b, ilow, src->bit_size); - high = nir_imm_floatN_t(b, ihigh, src->bit_size); + *low = nir_imm_floatN_t(b, ilow, src_bit_size); + *high = nir_imm_floatN_t(b, ihigh, src_bit_size); } break; } @@ -271,12 +269,12 @@ nir_clamp_to_type_range(nir_builder *b, uint64_t uhigh = dest_bit_size == 64 ? ~0ull : (1ull << dest_bit_size) - 1; if (src_base_type != nir_type_float) { - low = nir_imm_intN_t(b, 0, src->bit_size); - if (src_base_type == nir_type_uint || src->bit_size > dest_bit_size) - high = nir_imm_intN_t(b, uhigh, src->bit_size); + *low = nir_imm_intN_t(b, 0, src_bit_size); + if (src_base_type == nir_type_uint || src_bit_size > dest_bit_size) + *high = nir_imm_intN_t(b, uhigh, src_bit_size); } else { - low = nir_imm_floatN_t(b, 0.0f, src->bit_size); - high = nir_imm_floatN_t(b, uhigh, src->bit_size); + *low = nir_imm_floatN_t(b, 0.0f, src_bit_size); + *high = nir_imm_floatN_t(b, uhigh, src_bit_size); } break; } @@ -302,29 +300,29 @@ nir_clamp_to_type_range(nir_builder *b, switch (src_base_type) { case nir_type_int: { int64_t src_ilow, src_ihigh; - if (src->bit_size == 64) { + if (src_bit_size == 64) { src_ilow = INT64_MIN; src_ihigh = INT64_MAX; } else { - src_ilow = -(1ll << (src->bit_size - 1)); - src_ihigh = (1ll << (src->bit_size - 1)) - 1; + src_ilow = -(1ll << (src_bit_size - 1)); + src_ihigh = (1ll << (src_bit_size - 1)) - 1; } if (src_ilow < flow) - low = nir_imm_intN_t(b, flow, src->bit_size); + *low = nir_imm_intN_t(b, flow, src_bit_size); if (src_ihigh > fhigh) - high = nir_imm_intN_t(b, fhigh, src->bit_size); + *high = nir_imm_intN_t(b, fhigh, src_bit_size); break; } case nir_type_uint: { - uint64_t src_uhigh = src->bit_size == 64 ? - ~0ull : (1ull << src->bit_size) - 1; + uint64_t src_uhigh = src_bit_size == 64 ? + ~0ull : (1ull << src_bit_size) - 1; if (src_uhigh > fhigh) - high = nir_imm_intN_t(b, fhigh, src->bit_size); + *high = nir_imm_intN_t(b, fhigh, src_bit_size); break; } case nir_type_float: - low = nir_imm_floatN_t(b, flow, src->bit_size); - high = nir_imm_floatN_t(b, fhigh, src->bit_size); + *low = nir_imm_floatN_t(b, flow, src_bit_size); + *high = nir_imm_floatN_t(b, fhigh, src_bit_size); break; default: unreachable("Clamping from unknown type"); @@ -335,9 +333,34 @@ nir_clamp_to_type_range(nir_builder *b, unreachable("clamping to unknown type"); break; } +} + +/** + * Clamp the value into the widest representatble range of the + * destination type with cmp + bcsel. + * + * val/val_type: The variables used for bcsel + * src/src_type: The variables used for comparison + * dest_type: The type which determines the range used for comparison + */ +static inline nir_ssa_def * +nir_clamp_to_type_range(nir_builder *b, + nir_ssa_def *val, nir_alu_type val_type, + nir_ssa_def *src, nir_alu_type src_type, + nir_alu_type dest_type) +{ + assert(nir_alu_type_get_type_size(src_type) == 0 || + nir_alu_type_get_type_size(src_type) == src->bit_size); + src_type |= src->bit_size; + if (nir_alu_type_range_contains_type_range(dest_type, src_type)) + return val; + + /* limits of the destination type, expressed in the source type */ + nir_ssa_def *low = NULL, *high = NULL; + nir_get_clamp_limits(b, src_type, dest_type, &low, &high); nir_ssa_def *low_cond = NULL, *high_cond = NULL; - switch (src_base_type) { + switch (nir_alu_type_get_base_type(src_type)) { case nir_type_int: low_cond = low ? nir_ilt(b, src, low) : NULL; high_cond = high ? nir_ilt(b, high, src) : NULL; @@ -347,18 +370,23 @@ nir_clamp_to_type_range(nir_builder *b, high_cond = high ? nir_ult(b, high, src) : NULL; break; case nir_type_float: - low_cond = low ? nir_flt(b, src, low) : NULL; - high_cond = high ? nir_flt(b, high, src) : NULL; + low_cond = low ? nir_fge(b, low, src) : NULL; + high_cond = high ? nir_fge(b, src, high) : NULL; break; default: unreachable("clamping from unknown type"); } - nir_ssa_def *res = src; - if (low_cond) - res = nir_bcsel(b, low_cond, low, res); - if (high_cond) - res = nir_bcsel(b, high_cond, high, res); + nir_ssa_def *val_low = low, *val_high = high; + if (val_type != src_type) { + nir_get_clamp_limits(b, val_type, dest_type, &val_low, &val_high); + } + + nir_ssa_def *res = val; + if (low_cond && val_low) + res = nir_bcsel(b, low_cond, val_low, res); + if (high_cond && val_high) + res = nir_bcsel(b, high_cond, val_high, res); return res; } @@ -425,6 +453,14 @@ nir_convert_with_rounding(nir_builder *b, !nir_alu_type_range_contains_type_range(dest_type, src_type); round = nir_simplify_conversion_rounding(src_type, dest_type, round); + /* For float -> int/uint conversions, we might not be able to represent + * the destination range in the source float accurately. For these cases, + * do the comparison in float range, but the bcsel in the destination range. + */ + bool clamp_after_conversion = clamp && + src_base_type == nir_type_float && + dest_base_type != nir_type_float; + /* * If we don't care about rounding and clamping, we can just use NIR's * built-in ops. There is also a special case for SPIR-V in shaders, where @@ -452,8 +488,8 @@ nir_convert_with_rounding(nir_builder *b, nir_ssa_def *dest = src; /* clamp the result into range */ - if (clamp) - dest = nir_clamp_to_type_range(b, dest, src_type, dest_type); + if (clamp && !clamp_after_conversion) + dest = nir_clamp_to_type_range(b, src, src_type, src, src_type, dest_type); /* round with selected rounding mode */ if (!trivial_convert && round != nir_rounding_mode_undef) { @@ -472,7 +508,12 @@ nir_convert_with_rounding(nir_builder *b, /* now we can convert the value */ nir_op op = nir_type_conversion_op(src_type, dest_type, round); - return nir_build_alu(b, op, dest, NULL, NULL, NULL); + dest = nir_build_alu(b, op, dest, NULL, NULL, NULL); + + if (clamp_after_conversion) + dest = nir_clamp_to_type_range(b, dest, dest_type, src, src_type, dest_type); + + return dest; } #ifdef __cplusplus