spirv/alu: Add support for the NoContraction decoration
This commit is contained in:
parent
00fa795cd3
commit
fbb9e1f008
|
@ -305,6 +305,17 @@ vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap)
|
|||
}
|
||||
}
|
||||
|
||||
static void
|
||||
handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
|
||||
const struct vtn_decoration *dec, void *_void)
|
||||
{
|
||||
assert(dec->scope == VTN_DEC_DECORATION);
|
||||
if (dec->decoration != SpvDecorationNoContraction)
|
||||
return;
|
||||
|
||||
b->nb.exact = true;
|
||||
}
|
||||
|
||||
void
|
||||
vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
||||
const uint32_t *w, unsigned count)
|
||||
|
@ -313,15 +324,39 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
|||
const struct glsl_type *type =
|
||||
vtn_value(b, w[1], vtn_value_type_type)->type->type;
|
||||
|
||||
vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
|
||||
|
||||
/* Collect the various SSA sources */
|
||||
const unsigned num_inputs = count - 3;
|
||||
struct vtn_ssa_value *vtn_src[4] = { NULL, };
|
||||
for (unsigned i = 0; i < num_inputs; i++)
|
||||
for (unsigned i = 0; i < num_inputs; i++) {
|
||||
vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
|
||||
|
||||
/* The way SPIR-V defines the NoContraction decoration is rediculous.
|
||||
* It expressly says in the SPIR-V spec:
|
||||
*
|
||||
* "For example, if applied to an OpFMul, that multiply can’t be
|
||||
* combined with an addition to yield a fused multiply-add
|
||||
* operation."
|
||||
*
|
||||
* Technically, this means we would have to either rewrite NIR with
|
||||
* another silly "don't fuse me" flag or we would have to propagate
|
||||
* the NoContraction decoration to all consumers of a value which
|
||||
* would make it far more infectious than anyone intended.
|
||||
*
|
||||
* Instead, we take a short-cut by simply looking at the sources and
|
||||
* see if any of them have it. That should be good enough.
|
||||
*
|
||||
* See also issue #17 on the SPIR-V gitlab
|
||||
*/
|
||||
vtn_foreach_decoration(b, vtn_untyped_value(b, w[i + 3]),
|
||||
handle_no_contraction, NULL);
|
||||
}
|
||||
|
||||
if (glsl_type_is_matrix(vtn_src[0]->type) ||
|
||||
(num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
|
||||
vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
|
||||
b->nb.exact = false;
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -347,7 +382,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
|||
nir_imm_int(&b->nb, NIR_FALSE),
|
||||
NULL, NULL);
|
||||
}
|
||||
return;
|
||||
break;
|
||||
|
||||
case SpvOpAll:
|
||||
if (src[0]->num_components == 1) {
|
||||
|
@ -363,73 +398,73 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
|||
nir_imm_int(&b->nb, NIR_TRUE),
|
||||
NULL, NULL);
|
||||
}
|
||||
return;
|
||||
break;
|
||||
|
||||
case SpvOpOuterProduct: {
|
||||
for (unsigned i = 0; i < src[1]->num_components; i++) {
|
||||
val->ssa->elems[i]->def =
|
||||
nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
|
||||
}
|
||||
return;
|
||||
break;
|
||||
}
|
||||
|
||||
case SpvOpDot:
|
||||
val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
|
||||
return;
|
||||
break;
|
||||
|
||||
case SpvOpIAddCarry:
|
||||
assert(glsl_type_is_struct(val->ssa->type));
|
||||
val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
|
||||
val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
|
||||
return;
|
||||
break;
|
||||
|
||||
case SpvOpISubBorrow:
|
||||
assert(glsl_type_is_struct(val->ssa->type));
|
||||
val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
|
||||
val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
|
||||
return;
|
||||
break;
|
||||
|
||||
case SpvOpUMulExtended:
|
||||
assert(glsl_type_is_struct(val->ssa->type));
|
||||
val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
|
||||
val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
|
||||
return;
|
||||
break;
|
||||
|
||||
case SpvOpSMulExtended:
|
||||
assert(glsl_type_is_struct(val->ssa->type));
|
||||
val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
|
||||
val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
|
||||
return;
|
||||
break;
|
||||
|
||||
case SpvOpFwidth:
|
||||
val->ssa->def = nir_fadd(&b->nb,
|
||||
nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
|
||||
nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
|
||||
return;
|
||||
break;
|
||||
case SpvOpFwidthFine:
|
||||
val->ssa->def = nir_fadd(&b->nb,
|
||||
nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
|
||||
nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
|
||||
return;
|
||||
break;
|
||||
case SpvOpFwidthCoarse:
|
||||
val->ssa->def = nir_fadd(&b->nb,
|
||||
nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
|
||||
nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
|
||||
return;
|
||||
break;
|
||||
|
||||
case SpvOpVectorTimesScalar:
|
||||
/* The builder will take care of splatting for us. */
|
||||
val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
|
||||
return;
|
||||
break;
|
||||
|
||||
case SpvOpIsNan:
|
||||
val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
|
||||
return;
|
||||
break;
|
||||
|
||||
case SpvOpIsInf:
|
||||
val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]),
|
||||
nir_imm_float(&b->nb, INFINITY));
|
||||
return;
|
||||
break;
|
||||
|
||||
default: {
|
||||
bool swap;
|
||||
|
@ -442,7 +477,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
|
|||
}
|
||||
|
||||
val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
|
||||
return;
|
||||
break;
|
||||
} /* default */
|
||||
}
|
||||
|
||||
b->nb.exact = false;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue