aco: refactor selection of mad/fma

In the future, whether we need to use fma will depend on which
multiplication is chosen.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14769>
This commit is contained in:
Rhys Perry 2022-01-17 17:33:25 +00:00 committed by Marge Bot
parent e12bee3cb7
commit eeef1bbe65
1 changed files with 16 additions and 16 deletions

View File

@ -3549,25 +3549,15 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
instr->opcode == aco_opcode::v_subrev_f16;
bool mad64 = instr->opcode == aco_opcode::v_add_f64;
if (mad16 || mad32 || mad64) {
bool need_fma =
mad32 ? (ctx.fp_mode.denorm32 != 0 || ctx.program->chip_class >= GFX10_3)
: (ctx.fp_mode.denorm16_64 != 0 || ctx.program->chip_class >= GFX10 || mad64);
if (need_fma && instr->definitions[0].isPrecise())
return;
if (need_fma && mad32 && !ctx.program->dev.has_fast_fma32)
return;
Instruction* mul_instr = nullptr;
unsigned add_op_idx = 0;
uint32_t uses = UINT32_MAX;
bool emit_fma = false;
/* find the 'best' mul instruction to combine with the add */
for (unsigned i = 0; i < 2; i++) {
if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
continue;
/* check precision requirements */
ssa_info& info = ctx.info[instr->operands[i].tempId()];
if (need_fma && info.instr->definitions[0].isPrecise())
continue;
/* no clamp/omod allowed between mul and add */
if (info.instr->isVOP3() && (info.instr->vop3().clamp || info.instr->vop3().omod))
@ -3577,7 +3567,16 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
continue;
bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
if (legacy && need_fma && ctx.program->chip_class < GFX10_3)
bool has_fma = mad16 || mad64 || (legacy && ctx.program->chip_class >= GFX10_3) ||
(mad32 && !legacy && ctx.program->dev.has_fast_fma32);
bool has_mad = (mad32 && ctx.program->chip_class < GFX10_3) ||
(mad16 && ctx.program->chip_class <= GFX9);
bool can_use_fma = has_fma && !info.instr->definitions[0].isPrecise() &&
!instr->definitions[0].isPrecise();
bool can_use_mad =
has_mad && (mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
if (!can_use_fma && !can_use_mad)
continue;
Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
@ -3595,6 +3594,7 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
mul_instr = info.instr;
add_op_idx = 1 - i;
uses = ctx.uses[instr->operands[i].tempId()];
emit_fma = !can_use_mad;
}
if (mul_instr) {
@ -3644,12 +3644,12 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
instr->opcode == aco_opcode::v_subrev_f16)
neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
aco_opcode mad_op = need_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
assert(need_fma == (ctx.program->chip_class >= GFX10_3));
mad_op = need_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
assert(emit_fma == (ctx.program->chip_class >= GFX10_3));
mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
} else if (mad16) {
mad_op = need_fma ? (ctx.program->chip_class == GFX8 ? aco_opcode::v_fma_legacy_f16
mad_op = emit_fma ? (ctx.program->chip_class == GFX8 ? aco_opcode::v_fma_legacy_f16
: aco_opcode::v_fma_f16)
: (ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_f16
: aco_opcode::v_mad_f16);