diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index c6c8931a426..7a16fc176c9 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -53,11 +53,10 @@ struct mad_info { aco_ptr add_instr; uint32_t mul_temp_id; uint32_t literal_idx; - bool needs_vop3; bool check_literal; - mad_info(aco_ptr instr, uint32_t id, bool vop3) - : add_instr(std::move(instr)), mul_temp_id(id), needs_vop3(vop3), check_literal(false) {} + mad_info(aco_ptr instr, uint32_t id) + : add_instr(std::move(instr)), mul_temp_id(id), check_literal(false) {} }; enum Label { @@ -2194,7 +2193,6 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr bool abs[3] = {false, false, false}; unsigned omod = 0; bool clamp = false; - bool need_vop3 = false; op[0] = mul_instr->operands[0]; op[1] = mul_instr->operands[1]; op[2] = instr->operands[add_op_idx]; @@ -2202,18 +2200,12 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr if (!check_vop3_operands(ctx, 3, op)) return; - for (unsigned i = 0; i < 3; i++) { - if (!(i == 0 || (op[i].isTemp() && op[i].getTemp().type() == RegType::vgpr))) - need_vop3 = true; - } - if (mul_instr->isVOP3()) { VOP3A_instruction* vop3 = static_cast (mul_instr); neg[0] = vop3->neg[0]; neg[1] = vop3->neg[1]; abs[0] = vop3->abs[0]; abs[1] = vop3->abs[1]; - need_vop3 = true; /* we cannot use these modifiers between mul and add */ if (vop3->clamp || vop3->omod) return; @@ -2243,15 +2235,11 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr } /* neg of the multiplication result */ neg[1] = neg[1] ^ vop3->neg[1 - add_op_idx]; - need_vop3 = true; } - if (instr->opcode == aco_opcode::v_sub_f32) { + if (instr->opcode == aco_opcode::v_sub_f32) neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true; - need_vop3 = true; - } else if (instr->opcode == aco_opcode::v_subrev_f32) { + else if (instr->opcode == aco_opcode::v_subrev_f32) neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true; - need_vop3 = true; - } aco_ptr mad{create_instruction(aco_opcode::v_mad_f32, Format::VOP3A, 3, 1)}; for (unsigned i = 0; i < 3; i++) @@ -2265,7 +2253,7 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr mad->definitions[0] = instr->definitions[0]; /* mark this ssa_def to be re-checked for profitability and literals */ - ctx.mad_infos.emplace_back(std::move(instr), mul_instr->definitions[0].tempId(), need_vop3); + ctx.mad_infos.emplace_back(std::move(instr), mul_instr->definitions[0].tempId()); ctx.info[mad->definitions[0].tempId()].set_mad(mad.get(), ctx.mad_infos.size() - 1); instr.reset(mad.release()); return; @@ -2353,48 +2341,55 @@ void select_instruction(opt_ctx &ctx, aco_ptr& instr) } } - /* re-check mad instructions */ + mad_info* mad_info = NULL; if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) { - mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val]; - /* first, check profitability */ - if (ctx.uses[info->mul_temp_id]) { - ctx.uses[info->mul_temp_id]++; + mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val]; + /* re-check mad instructions */ + if (ctx.uses[mad_info->mul_temp_id]) { + ctx.uses[mad_info->mul_temp_id]++; if (instr->operands[0].isTemp()) ctx.uses[instr->operands[0].tempId()]--; if (instr->operands[1].isTemp()) ctx.uses[instr->operands[1].tempId()]--; - instr.swap(info->add_instr); - - /* second, check possible literals */ - } else if (!info->needs_vop3) { + instr.swap(mad_info->add_instr); + mad_info = NULL; + } + /* check literals */ + else if (!instr->usesModifiers()) { + bool sgpr_used = false; uint32_t literal_idx = 0; uint32_t literal_uses = UINT32_MAX; for (unsigned i = 0; i < instr->operands.size(); i++) { + if (instr->operands[i].isConstant() && i > 0) { + literal_uses = UINT32_MAX; + break; + } if (!instr->operands[i].isTemp()) continue; - /* if one of the operands is sgpr, we cannot add a literal somewhere else */ - if (instr->operands[i].getTemp().type() == RegType::sgpr) { + /* if one of the operands is sgpr, we cannot add a literal somewhere else on pre-GFX10 or operands other than the 1st */ + if (instr->operands[i].getTemp().type() == RegType::sgpr && (i > 0 || ctx.program->chip_class < GFX10)) { if (ctx.info[instr->operands[i].tempId()].is_literal()) { literal_uses = ctx.uses[instr->operands[i].tempId()]; literal_idx = i; } else { literal_uses = UINT32_MAX; } - break; - } - else if (ctx.info[instr->operands[i].tempId()].is_literal() && - ctx.uses[instr->operands[i].tempId()] < literal_uses) { + sgpr_used = true; + /* don't break because we still need to check constants */ + } else if (!sgpr_used && + ctx.info[instr->operands[i].tempId()].is_literal() && + ctx.uses[instr->operands[i].tempId()] < literal_uses) { literal_uses = ctx.uses[instr->operands[i].tempId()]; literal_idx = i; } } if (literal_uses < threshold) { ctx.uses[instr->operands[literal_idx].tempId()]--; - info->check_literal = true; - info->literal_idx = literal_idx; + mad_info->check_literal = true; + mad_info->literal_idx = literal_idx; + return; } - return; } } @@ -2480,31 +2475,28 @@ void apply_literals(opt_ctx &ctx, aco_ptr& instr) return; /* apply literals on MAD */ - bool literals_applied = false; if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) { mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val]; - if (!info->needs_vop3) { + if (info->check_literal && ctx.uses[instr->operands[info->literal_idx].tempId()] == 0) { aco_ptr new_mad; - if (info->check_literal && ctx.uses[instr->operands[info->literal_idx].tempId()] == 0) { - if (info->literal_idx == 2) { /* add literal -> madak */ - new_mad.reset(create_instruction(aco_opcode::v_madak_f32, Format::VOP2, 3, 1)); - new_mad->operands[0] = instr->operands[0]; - new_mad->operands[1] = instr->operands[1]; - } else { /* mul literal -> madmk */ - new_mad.reset(create_instruction(aco_opcode::v_madmk_f32, Format::VOP2, 3, 1)); - new_mad->operands[0] = instr->operands[1 - info->literal_idx]; - new_mad->operands[1] = instr->operands[2]; - } - new_mad->operands[2] = Operand(ctx.info[instr->operands[info->literal_idx].tempId()].val); - new_mad->definitions[0] = instr->definitions[0]; - instr.swap(new_mad); + if (info->literal_idx == 2) { /* add literal -> madak */ + new_mad.reset(create_instruction(aco_opcode::v_madak_f32, Format::VOP2, 3, 1)); + new_mad->operands[0] = instr->operands[0]; + new_mad->operands[1] = instr->operands[1]; + } else { /* mul literal -> madmk */ + new_mad.reset(create_instruction(aco_opcode::v_madmk_f32, Format::VOP2, 3, 1)); + new_mad->operands[0] = instr->operands[1 - info->literal_idx]; + new_mad->operands[1] = instr->operands[2]; } - literals_applied = true; + new_mad->operands[2] = Operand(ctx.info[instr->operands[info->literal_idx].tempId()].val); + new_mad->definitions[0] = instr->definitions[0]; + ctx.instructions.emplace_back(std::move(new_mad)); + return; } } - /* apply literals on SALU/VALU */ - if (!literals_applied && (instr->isSALU() || instr->isVALU())) { + /* apply literals on other SALU/VALU */ + if (instr->isSALU() || instr->isVALU()) { for (unsigned i = 0; i < instr->operands.size(); i++) { Operand op = instr->operands[i]; if (op.isTemp() && ctx.info[op.tempId()].is_literal() && ctx.uses[op.tempId()] == 0) {