microsoft/compiler: Lower fquantize2f16

As far as I can't tell, there's no native operation doing this
equivalent of fquantize2f16. Let's lower this operation to

   if (val < MIN_FLOAT16)
      return -INFINITY;
   else if (val > MAX_FLOAT16)
      return -INFINITY;
   else if (fabs(val) < SMALLER_NORMALIZED_FLOAT16)
      return 0;
   else
      return val;

which matches the definition of OpQuantizeToF16:

"
If Value is an infinity, the result is the same infinity.
If Value is a NaN, the result is a NaN, but not necessarily the same NaN.
If Value is positive with a magnitude too large to represent as a 16-bit
floating-point value, the result is positive infinity. If Value is negative
with a magnitude too large to represent as a 16-bit floating-point value,
the result is negative infinity. If the magnitude of Value is too small to
represent as a normalized 16-bit floating-point value, the result may be
either +0 or -0.
"

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16959>
This commit is contained in:
Boris Brezillon 2022-02-14 06:58:02 -08:00 committed by Marge Bot
parent 279f32e042
commit b12417a2c7
3 changed files with 68 additions and 0 deletions

View File

@ -1902,3 +1902,69 @@ dxil_nir_lower_ubo_array_one_to_static(nir_shader *s)
return progress;
}
static bool
is_fquantize2f16(const nir_instr *instr, const void *data)
{
if (instr->type != nir_instr_type_alu)
return false;
nir_alu_instr *alu = nir_instr_as_alu(instr);
return alu->op == nir_op_fquantize2f16;
}
static nir_ssa_def *
lower_fquantize2f16(struct nir_builder *b, nir_instr *instr, void *data)
{
/*
* SpvOpQuantizeToF16 documentation says:
*
* "
* If Value is an infinity, the result is the same infinity.
* If Value is a NaN, the result is a NaN, but not necessarily the same NaN.
* If Value is positive with a magnitude too large to represent as a 16-bit
* floating-point value, the result is positive infinity. If Value is negative
* with a magnitude too large to represent as a 16-bit floating-point value,
* the result is negative infinity. If the magnitude of Value is too small to
* represent as a normalized 16-bit floating-point value, the result may be
* either +0 or -0.
* "
*
* which we turn into:
*
* if (val < MIN_FLOAT16)
* return -INFINITY;
* else if (val > MAX_FLOAT16)
* return -INFINITY;
* else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) != 0)
* return -0.0f;
* else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) == 0)
* return +0.0f;
* else
* return round(val);
*/
nir_alu_instr *alu = nir_instr_as_alu(instr);
nir_ssa_def *src =
nir_ssa_for_src(b, alu->src[0].src, nir_src_num_components(alu->src[0].src));
nir_ssa_def *neg_inf_cond =
nir_flt(b, src, nir_imm_float(b, -65504.0f));
nir_ssa_def *pos_inf_cond =
nir_flt(b, nir_imm_float(b, 65504.0f), src);
nir_ssa_def *zero_cond =
nir_flt(b, nir_fabs(b, src), nir_imm_float(b, ldexpf(1.0, -14)));
nir_ssa_def *zero = nir_iand_imm(b, src, 1 << 31);
nir_ssa_def *round = nir_iand_imm(b, src, ~BITFIELD_MASK(13));
nir_ssa_def *res =
nir_bcsel(b, neg_inf_cond, nir_imm_float(b, -INFINITY), round);
res = nir_bcsel(b, pos_inf_cond, nir_imm_float(b, INFINITY), res);
res = nir_bcsel(b, zero_cond, zero, res);
return res;
}
bool
dxil_nir_lower_fquantize2f16(nir_shader *s)
{
return nir_shader_lower_instructions(s, is_fquantize2f16, lower_fquantize2f16, NULL);
}

View File

@ -35,6 +35,7 @@ extern "C" {
bool dxil_nir_lower_8bit_conv(nir_shader *shader);
bool dxil_nir_lower_16bit_conv(nir_shader *shader);
bool dxil_nir_lower_x2b(nir_shader *shader);
bool dxil_nir_lower_fquantize2f16(nir_shader *shader);
bool dxil_nir_lower_ubo_to_temp(nir_shader *shader);
bool dxil_nir_lower_loads_stores_to_dxil(nir_shader *shader);
bool dxil_nir_lower_atomics_to_dxil(nir_shader *shader);

View File

@ -5710,6 +5710,7 @@ nir_to_dxil(struct nir_shader *s, const struct nir_to_dxil_options *opts,
ctx->mod.major_version = 6;
ctx->mod.minor_version = 1;
NIR_PASS_V(s, dxil_nir_lower_fquantize2f16);
NIR_PASS_V(s, nir_lower_frexp);
NIR_PASS_V(s, nir_lower_flrp, 16 | 32 | 64, true);
NIR_PASS_V(s, nir_lower_io, nir_var_shader_in | nir_var_shader_out, type_size_vec4, nir_lower_io_lower_64bit_to_32);