spirv/alu: Add support for the NoContraction decoration

This commit is contained in:
Jason Ekstrand 2016-03-25 15:30:46 -07:00
parent 00fa795cd3
commit fbb9e1f008
1 changed files with 53 additions and 16 deletions

View File

@ -305,6 +305,17 @@ vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap)
}
}
static void
handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
const struct vtn_decoration *dec, void *_void)
{
assert(dec->scope == VTN_DEC_DECORATION);
if (dec->decoration != SpvDecorationNoContraction)
return;
b->nb.exact = true;
}
void
vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
@ -313,15 +324,39 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
const struct glsl_type *type =
vtn_value(b, w[1], vtn_value_type_type)->type->type;
vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
/* Collect the various SSA sources */
const unsigned num_inputs = count - 3;
struct vtn_ssa_value *vtn_src[4] = { NULL, };
for (unsigned i = 0; i < num_inputs; i++)
for (unsigned i = 0; i < num_inputs; i++) {
vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
/* The way SPIR-V defines the NoContraction decoration is rediculous.
* It expressly says in the SPIR-V spec:
*
* "For example, if applied to an OpFMul, that multiply cant be
* combined with an addition to yield a fused multiply-add
* operation."
*
* Technically, this means we would have to either rewrite NIR with
* another silly "don't fuse me" flag or we would have to propagate
* the NoContraction decoration to all consumers of a value which
* would make it far more infectious than anyone intended.
*
* Instead, we take a short-cut by simply looking at the sources and
* see if any of them have it. That should be good enough.
*
* See also issue #17 on the SPIR-V gitlab
*/
vtn_foreach_decoration(b, vtn_untyped_value(b, w[i + 3]),
handle_no_contraction, NULL);
}
if (glsl_type_is_matrix(vtn_src[0]->type) ||
(num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
b->nb.exact = false;
return;
}
@ -347,7 +382,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
nir_imm_int(&b->nb, NIR_FALSE),
NULL, NULL);
}
return;
break;
case SpvOpAll:
if (src[0]->num_components == 1) {
@ -363,73 +398,73 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
nir_imm_int(&b->nb, NIR_TRUE),
NULL, NULL);
}
return;
break;
case SpvOpOuterProduct: {
for (unsigned i = 0; i < src[1]->num_components; i++) {
val->ssa->elems[i]->def =
nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
}
return;
break;
}
case SpvOpDot:
val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
return;
break;
case SpvOpIAddCarry:
assert(glsl_type_is_struct(val->ssa->type));
val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
return;
break;
case SpvOpISubBorrow:
assert(glsl_type_is_struct(val->ssa->type));
val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
return;
break;
case SpvOpUMulExtended:
assert(glsl_type_is_struct(val->ssa->type));
val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
return;
break;
case SpvOpSMulExtended:
assert(glsl_type_is_struct(val->ssa->type));
val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
return;
break;
case SpvOpFwidth:
val->ssa->def = nir_fadd(&b->nb,
nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
return;
break;
case SpvOpFwidthFine:
val->ssa->def = nir_fadd(&b->nb,
nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
return;
break;
case SpvOpFwidthCoarse:
val->ssa->def = nir_fadd(&b->nb,
nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
return;
break;
case SpvOpVectorTimesScalar:
/* The builder will take care of splatting for us. */
val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
return;
break;
case SpvOpIsNan:
val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
return;
break;
case SpvOpIsInf:
val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]),
nir_imm_float(&b->nb, INFINITY));
return;
break;
default: {
bool swap;
@ -442,7 +477,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
}
val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
return;
break;
} /* default */
}
b->nb.exact = false;
}