aco: try to use fma instead of mad when denormals are enabled

v_mad_f32 doesn't support denormals but v_fma_f32 does.

No fossil-db changes.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5245>
This commit is contained in:
Rhys Perry 2020-05-15 14:03:15 +01:00 committed by Marge Bot
parent 6cb42cdd8f
commit 1b10764e50
4 changed files with 54 additions and 20 deletions

View File

@ -1256,6 +1256,10 @@ setup_isel_context(Program* program,
setup_xnack(program);
program->sram_ecc_enabled = args->options->family == CHIP_ARCTURUS;
/* apparently gfx702 also has fast v_fma_f32 but I can't find a family for that */
program->has_fast_fma32 = program->chip_class >= GFX9;
if (args->options->family == CHIP_TAHITI || args->options->family == CHIP_CARRIZO || args->options->family == CHIP_HAWAII)
program->has_fast_fma32 = true;
return ctx;
}

View File

@ -1451,6 +1451,7 @@ public:
bool xnack_enabled = false;
bool sram_ecc_enabled = false;
bool has_fast_fma32 = false;
bool needs_vcc = false;
bool needs_flat_scr = false;

View File

@ -2410,37 +2410,44 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
return;
}
/* combine mul+add -> mad */
else if ((instr->opcode == aco_opcode::v_add_f32 ||
instr->opcode == aco_opcode::v_sub_f32 ||
instr->opcode == aco_opcode::v_subrev_f32) &&
block.fp_mode.denorm32 == 0) {
//TODO: we could use fma instead when denormals are enabled if the NIR isn't marked as precise
bool mad32 = instr->opcode == aco_opcode::v_add_f32 ||
instr->opcode == aco_opcode::v_sub_f32 ||
instr->opcode == aco_opcode::v_subrev_f32;
if (mad32) {
bool need_fma = block.fp_mode.denorm32 != 0;
if (need_fma && instr->definitions[0].isPrecise())
return;
if (need_fma && !ctx.program->has_fast_fma32)
return;
uint32_t uses_src0 = UINT32_MAX;
uint32_t uses_src1 = UINT32_MAX;
Instruction* mul_instr = nullptr;
unsigned add_op_idx;
/* check if any of the operands is a multiplication */
if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_mul())
ssa_info *op0_info = instr->operands[0].isTemp() ? &ctx.info[instr->operands[0].tempId()] : NULL;
ssa_info *op1_info = instr->operands[1].isTemp() ? &ctx.info[instr->operands[1].tempId()] : NULL;
if (op0_info && op0_info->is_mul() && (!need_fma || !op0_info->instr->definitions[0].isPrecise()))
uses_src0 = ctx.uses[instr->operands[0].tempId()];
if (instr->operands[1].isTemp() && ctx.info[instr->operands[1].tempId()].is_mul())
if (op1_info && op1_info->is_mul() && (!need_fma || !op1_info->instr->definitions[0].isPrecise()))
uses_src1 = ctx.uses[instr->operands[1].tempId()];
/* find the 'best' mul instruction to combine with the add */
if (uses_src0 < uses_src1) {
mul_instr = ctx.info[instr->operands[0].tempId()].instr;
mul_instr = op0_info->instr;
add_op_idx = 1;
} else if (uses_src1 < uses_src0) {
mul_instr = ctx.info[instr->operands[1].tempId()].instr;
mul_instr = op1_info->instr;
add_op_idx = 0;
} else if (uses_src0 != UINT32_MAX) {
/* tiebreaker: quite random what to pick */
if (ctx.info[instr->operands[0].tempId()].instr->operands[0].isLiteral()) {
mul_instr = ctx.info[instr->operands[1].tempId()].instr;
if (op0_info->instr->operands[0].isLiteral()) {
mul_instr = op1_info->instr;
add_op_idx = 0;
} else {
mul_instr = ctx.info[instr->operands[0].tempId()].instr;
mul_instr = op0_info->instr;
add_op_idx = 1;
}
}
@ -2498,7 +2505,9 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
else if (instr->opcode == aco_opcode::v_subrev_f32)
neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
aco_ptr<VOP3A_instruction> mad{create_instruction<VOP3A_instruction>(aco_opcode::v_mad_f32, Format::VOP3A, 3, 1)};
aco_opcode mad_op = need_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
aco_ptr<VOP3A_instruction> mad{create_instruction<VOP3A_instruction>(mad_op, Format::VOP3A, 3, 1)};
for (unsigned i = 0; i < 3; i++)
{
mad->operands[i] = op[i];
@ -2706,7 +2715,7 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
}
mad_info* mad_info = NULL;
if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
/* re-check mad instructions */
if (ctx.uses[mad_info->mul_temp_id]) {
@ -2720,6 +2729,10 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
}
/* check literals */
else if (!instr->usesModifiers()) {
/* FMA can only take literals on GFX10+ */
if (instr->opcode == aco_opcode::v_fma_f32 && ctx.program->chip_class < GFX10)
return;
bool sgpr_used = false;
uint32_t literal_idx = 0;
uint32_t literal_uses = UINT32_MAX;
@ -2881,17 +2894,21 @@ void apply_literals(opt_ctx &ctx, aco_ptr<Instruction>& instr)
return;
/* apply literals on MAD */
if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
if (info->check_literal &&
(ctx.uses[instr->operands[info->literal_idx].tempId()] == 0 || info->literal_idx == 2)) {
aco_ptr<Instruction> new_mad;
aco_opcode new_op = info->literal_idx == 2 ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32;
if (instr->opcode == aco_opcode::v_fma_f32)
new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32;
new_mad.reset(create_instruction<VOP2_instruction>(new_op, Format::VOP2, 3, 1));
if (info->literal_idx == 2) { /* add literal -> madak */
new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madak_f32, Format::VOP2, 3, 1));
new_mad->operands[0] = instr->operands[0];
new_mad->operands[1] = instr->operands[1];
} else { /* mul literal -> madmk */
new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madmk_f32, Format::VOP2, 3, 1));
new_mad->operands[0] = instr->operands[1 - info->literal_idx];
new_mad->operands[1] = instr->operands[2];
}

View File

@ -1734,7 +1734,8 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
Operand op = Operand();
if (!def.isFixed() && instr->opcode == aco_opcode::p_parallelcopy)
op = instr->operands[i];
else if (instr->opcode == aco_opcode::v_mad_f32 && !instr->usesModifiers())
else if ((instr->opcode == aco_opcode::v_mad_f32 ||
(instr->opcode == aco_opcode::v_fma_f32 && program->chip_class >= GFX10)) && !instr->usesModifiers())
op = instr->operands[2];
if (op.isTemp() && op.isFirstKillBeforeDef() && def.regClass() == op.regClass()) {
@ -2009,7 +2010,8 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
}
/* try to optimize v_mad_f32 -> v_mac_f32 */
if (instr->opcode == aco_opcode::v_mad_f32 &&
if ((instr->opcode == aco_opcode::v_mad_f32 ||
(instr->opcode == aco_opcode::v_fma_f32 && program->chip_class >= GFX10)) &&
instr->operands[2].isTemp() &&
instr->operands[2].isKillBeforeDef() &&
instr->operands[2].getTemp().type() == RegType::vgpr &&
@ -2022,13 +2024,23 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
instr->operands[2].physReg() == ctx.assignments[it->second].reg ||
register_file.test(ctx.assignments[it->second].reg, instr->operands[2].bytes())) {
instr->format = Format::VOP2;
instr->opcode = aco_opcode::v_mac_f32;
switch (instr->opcode) {
case aco_opcode::v_mad_f32:
instr->opcode = aco_opcode::v_mac_f32;
break;
case aco_opcode::v_fma_f32:
instr->opcode = aco_opcode::v_fmac_f32;
break;
default:
break;
}
}
}
/* handle definitions which must have the same register as an operand */
if (instr->opcode == aco_opcode::v_interp_p2_f32 ||
instr->opcode == aco_opcode::v_mac_f32 ||
instr->opcode == aco_opcode::v_fmac_f32 ||
instr->opcode == aco_opcode::v_writelane_b32 ||
instr->opcode == aco_opcode::v_writelane_b32_e64) {
instr->definitions[0].setFixed(instr->operands[2].physReg());