spirv_to_nir: Cast RelaxedPrecision ALU op dests to mediump.

This is controlled by spirv_to_nir_options.relaxed_precision_alu, because
some drivers don't want it.

This gets us mostly 16-bit math on turnip in vk-5-normal.

Reviewed-by: Matt Turner <mattst88@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16465>
This commit is contained in:
Emma Anholt 2022-04-26 16:29:04 -07:00 committed by Marge Bot
parent 87d7431198
commit 260559050a
4 changed files with 193 additions and 3 deletions

View File

@ -75,6 +75,16 @@ struct spirv_to_nir_options {
*/
uint16_t float_controls_execution_mode;
/* True if RelaxedPrecision-decorated ALU result values should be performed
* with 16-bit math.
*/
bool mediump_16bit_alu;
/* When mediump_16bit_alu is set, determines whether nir_op_fddx/fddy can be
* performed in 16-bit math.
*/
bool mediump_16bit_derivatives;
struct spirv_supported_capabilities caps;
/* Address format for various kinds of pointers. */

View File

@ -153,6 +153,48 @@ mat_times_scalar(struct vtn_builder *b,
return dest;
}
nir_ssa_def *
vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def)
{
if (def->bit_size == 16)
return def;
switch (base_type) {
case GLSL_TYPE_FLOAT:
return nir_f2fmp(&b->nb, def);
case GLSL_TYPE_INT:
case GLSL_TYPE_UINT:
return nir_i2imp(&b->nb, def);
default:
unreachable("bad relaxed precision input type");
}
}
struct vtn_ssa_value *
vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src)
{
if (!src)
return src;
struct vtn_ssa_value *srcmp = vtn_create_ssa_value(b, src->type);
if (src->transposed) {
srcmp->transposed = vtn_mediump_downconvert_value(b, src->transposed);
} else {
enum glsl_base_type base_type = glsl_get_base_type(src->type);
if (glsl_type_is_vector_or_scalar(src->type)) {
srcmp->def = vtn_mediump_downconvert(b, base_type, src->def);
} else {
assert(glsl_get_base_type(src->type) == GLSL_TYPE_FLOAT);
for (int i = 0; i < glsl_get_matrix_columns(src->type); i++)
srcmp->elems[i]->def = vtn_mediump_downconvert(b, base_type, src->elems[i]->def);
}
}
return srcmp;
}
static struct vtn_ssa_value *
vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
@ -465,6 +507,84 @@ handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
}
}
static void
vtn_value_is_relaxed_precision_cb(struct vtn_builder *b,
struct vtn_value *val, int member,
const struct vtn_decoration *dec, void *void_ctx)
{
bool *relaxed_precision = void_ctx;
switch (dec->decoration) {
case SpvDecorationRelaxedPrecision:
*relaxed_precision = true;
break;
default:
break;
}
}
bool
vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val)
{
bool result = false;
vtn_foreach_decoration(b, val,
vtn_value_is_relaxed_precision_cb, &result);
return result;
}
static bool
vtn_alu_op_mediump_16bit(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val)
{
if (!b->options->mediump_16bit_alu || !vtn_value_is_relaxed_precision(b, dest_val))
return false;
switch (opcode) {
case SpvOpDPdx:
case SpvOpDPdy:
case SpvOpDPdxFine:
case SpvOpDPdyFine:
case SpvOpDPdxCoarse:
case SpvOpDPdyCoarse:
case SpvOpFwidth:
case SpvOpFwidthFine:
case SpvOpFwidthCoarse:
return b->options->mediump_16bit_derivatives;
default:
return true;
}
}
static nir_ssa_def *
vtn_mediump_upconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def)
{
if (def->bit_size != 16)
return def;
switch (base_type) {
case GLSL_TYPE_FLOAT:
return nir_f2f32(&b->nb, def);
case GLSL_TYPE_INT:
return nir_i2i32(&b->nb, def);
case GLSL_TYPE_UINT:
return nir_u2u32(&b->nb, def);
default:
unreachable("bad relaxed precision output type");
}
}
void
vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value)
{
enum glsl_base_type base_type = glsl_get_base_type(value->type);
if (glsl_type_is_vector_or_scalar(value->type)) {
value->def = vtn_mediump_upconvert(b, base_type, value->def);
} else {
for (int i = 0; i < glsl_get_matrix_columns(value->type); i++)
value->elems[i]->def = vtn_mediump_upconvert(b, base_type, value->elems[i]->def);
}
}
void
vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
@ -473,17 +593,25 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
vtn_handle_no_contraction(b, dest_val);
bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
/* Collect the various SSA sources */
const unsigned num_inputs = count - 3;
struct vtn_ssa_value *vtn_src[4] = { NULL, };
for (unsigned i = 0; i < num_inputs; i++)
for (unsigned i = 0; i < num_inputs; i++) {
vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
if (mediump_16bit)
vtn_src[i] = vtn_mediump_downconvert_value(b, vtn_src[i]);
}
if (glsl_type_is_matrix(vtn_src[0]->type) ||
(num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
vtn_push_ssa_value(b, w[2],
vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]));
struct vtn_ssa_value *dest = vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]);
if (mediump_16bit)
vtn_mediump_upconvert_value(b, dest);
vtn_push_ssa_value(b, w[2], dest);
b->nb.exact = b->exact;
return;
}
@ -861,6 +989,8 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
break;
}
if (mediump_16bit)
vtn_mediump_upconvert_value(b, dest);
vtn_push_ssa_value(b, w[2], dest);
b->nb.exact = b->exact;

View File

@ -277,6 +277,41 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
{
struct nir_builder *nb = &b->nb;
const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
bool mediump_16bit;
switch (entrypoint) {
case GLSLstd450PackSnorm4x8:
case GLSLstd450PackUnorm4x8:
case GLSLstd450PackSnorm2x16:
case GLSLstd450PackUnorm2x16:
case GLSLstd450PackHalf2x16:
case GLSLstd450PackDouble2x32:
case GLSLstd450UnpackSnorm4x8:
case GLSLstd450UnpackUnorm4x8:
case GLSLstd450UnpackSnorm2x16:
case GLSLstd450UnpackUnorm2x16:
case GLSLstd450UnpackHalf2x16:
case GLSLstd450UnpackDouble2x32:
/* Asking for relaxed precision snorm 4x8 pack results (for example)
* doesn't even make sense. The NIR opcodes have a fixed output size, so
* no trying to reduce precision.
*/
mediump_16bit = false;
break;
case GLSLstd450Frexp:
case GLSLstd450FrexpStruct:
case GLSLstd450Modf:
case GLSLstd450ModfStruct:
/* Not sure how to detect the ->elems[i] destinations on these in vtn_upconvert_value(). */
mediump_16bit = false;
break;
default:
mediump_16bit = b->options->mediump_16bit_alu && vtn_value_is_relaxed_precision(b, dest_val);
break;
}
/* Collect the various SSA sources */
unsigned num_inputs = count - 5;
@ -287,9 +322,14 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
continue;
src[i] = vtn_get_nir_ssa(b, w[i + 5]);
if (mediump_16bit) {
struct vtn_ssa_value *vtn_src = vtn_ssa_value(b, w[i + 5]);
src[i] = vtn_mediump_downconvert(b, glsl_get_base_type(vtn_src->type), src[i]);
}
}
struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
vtn_handle_no_contraction(b, vtn_untyped_value(b, w[2]));
switch (entrypoint) {
case GLSLstd450Radians:
@ -589,6 +629,9 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
}
b->nb.exact = false;
if (mediump_16bit)
vtn_mediump_upconvert_value(b, dest);
vtn_push_ssa_value(b, w[2], dest);
}

View File

@ -1048,6 +1048,13 @@ SpvMemorySemanticsMask vtn_mode_to_memory_semantics(enum vtn_variable_mode mode)
void vtn_emit_memory_barrier(struct vtn_builder *b, SpvScope scope,
SpvMemorySemanticsMask semantics);
bool vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val);
nir_ssa_def *
vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def);
struct vtn_ssa_value *
vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src);
void vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value);
static inline int
cmp_uint32_t(const void *pa, const void *pb)
{