diff --git a/src/microsoft/compiler/dxil_nir.c b/src/microsoft/compiler/dxil_nir.c index 350576aee98..97a28515565 100644 --- a/src/microsoft/compiler/dxil_nir.c +++ b/src/microsoft/compiler/dxil_nir.c @@ -1902,3 +1902,69 @@ dxil_nir_lower_ubo_array_one_to_static(nir_shader *s) return progress; } + +static bool +is_fquantize2f16(const nir_instr *instr, const void *data) +{ + if (instr->type != nir_instr_type_alu) + return false; + + nir_alu_instr *alu = nir_instr_as_alu(instr); + return alu->op == nir_op_fquantize2f16; +} + +static nir_ssa_def * +lower_fquantize2f16(struct nir_builder *b, nir_instr *instr, void *data) +{ + /* + * SpvOpQuantizeToF16 documentation says: + * + * " + * If Value is an infinity, the result is the same infinity. + * If Value is a NaN, the result is a NaN, but not necessarily the same NaN. + * If Value is positive with a magnitude too large to represent as a 16-bit + * floating-point value, the result is positive infinity. If Value is negative + * with a magnitude too large to represent as a 16-bit floating-point value, + * the result is negative infinity. If the magnitude of Value is too small to + * represent as a normalized 16-bit floating-point value, the result may be + * either +0 or -0. + * " + * + * which we turn into: + * + * if (val < MIN_FLOAT16) + * return -INFINITY; + * else if (val > MAX_FLOAT16) + * return -INFINITY; + * else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) != 0) + * return -0.0f; + * else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) == 0) + * return +0.0f; + * else + * return round(val); + */ + nir_alu_instr *alu = nir_instr_as_alu(instr); + nir_ssa_def *src = + nir_ssa_for_src(b, alu->src[0].src, nir_src_num_components(alu->src[0].src)); + + nir_ssa_def *neg_inf_cond = + nir_flt(b, src, nir_imm_float(b, -65504.0f)); + nir_ssa_def *pos_inf_cond = + nir_flt(b, nir_imm_float(b, 65504.0f), src); + nir_ssa_def *zero_cond = + nir_flt(b, nir_fabs(b, src), nir_imm_float(b, ldexpf(1.0, -14))); + nir_ssa_def *zero = nir_iand_imm(b, src, 1 << 31); + nir_ssa_def *round = nir_iand_imm(b, src, ~BITFIELD_MASK(13)); + + nir_ssa_def *res = + nir_bcsel(b, neg_inf_cond, nir_imm_float(b, -INFINITY), round); + res = nir_bcsel(b, pos_inf_cond, nir_imm_float(b, INFINITY), res); + res = nir_bcsel(b, zero_cond, zero, res); + return res; +} + +bool +dxil_nir_lower_fquantize2f16(nir_shader *s) +{ + return nir_shader_lower_instructions(s, is_fquantize2f16, lower_fquantize2f16, NULL); +} diff --git a/src/microsoft/compiler/dxil_nir.h b/src/microsoft/compiler/dxil_nir.h index 1d04fe73c95..1d271347966 100644 --- a/src/microsoft/compiler/dxil_nir.h +++ b/src/microsoft/compiler/dxil_nir.h @@ -35,6 +35,7 @@ extern "C" { bool dxil_nir_lower_8bit_conv(nir_shader *shader); bool dxil_nir_lower_16bit_conv(nir_shader *shader); bool dxil_nir_lower_x2b(nir_shader *shader); +bool dxil_nir_lower_fquantize2f16(nir_shader *shader); bool dxil_nir_lower_ubo_to_temp(nir_shader *shader); bool dxil_nir_lower_loads_stores_to_dxil(nir_shader *shader); bool dxil_nir_lower_atomics_to_dxil(nir_shader *shader); diff --git a/src/microsoft/compiler/nir_to_dxil.c b/src/microsoft/compiler/nir_to_dxil.c index b8641233fa9..c8a17c4ef9e 100644 --- a/src/microsoft/compiler/nir_to_dxil.c +++ b/src/microsoft/compiler/nir_to_dxil.c @@ -5710,6 +5710,7 @@ nir_to_dxil(struct nir_shader *s, const struct nir_to_dxil_options *opts, ctx->mod.major_version = 6; ctx->mod.minor_version = 1; + NIR_PASS_V(s, dxil_nir_lower_fquantize2f16); NIR_PASS_V(s, nir_lower_frexp); NIR_PASS_V(s, nir_lower_flrp, 16 | 32 | 64, true); NIR_PASS_V(s, nir_lower_io, nir_var_shader_in | nir_var_shader_out, type_size_vec4, nir_lower_io_lower_64bit_to_32);