aco: handle SDWA in the optimizer
Apply SGPRs/modifiers when possible and try not to break when SDWA instructions are encountered. No shader-db changes. Signed-off-by: Rhys Perry <pendingchaos02@gmail.com> Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7349>
This commit is contained in:
parent
ecc5b59a70
commit
1761379481
|
@ -619,8 +619,10 @@ bool can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
|
||||||
instr->opcode != aco_opcode::v_readfirstlane_b32;
|
instr->opcode != aco_opcode::v_readfirstlane_b32;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool can_apply_sgprs(aco_ptr<Instruction>& instr)
|
bool can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
|
||||||
{
|
{
|
||||||
|
if (instr->isSDWA() && ctx.program->chip_class < GFX9)
|
||||||
|
return false;
|
||||||
return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
|
return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
|
||||||
instr->opcode != aco_opcode::v_readlane_b32 &&
|
instr->opcode != aco_opcode::v_readlane_b32 &&
|
||||||
instr->opcode != aco_opcode::v_readlane_b32_e64 &&
|
instr->opcode != aco_opcode::v_readlane_b32_e64 &&
|
||||||
|
@ -891,7 +893,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
|
||||||
info = ctx.info[info.temp.id()];
|
info = ctx.info[info.temp.id()];
|
||||||
}
|
}
|
||||||
/* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
|
/* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
|
||||||
if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(instr) && instr->operands.size() == 1) {
|
if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) && instr->operands.size() == 1) {
|
||||||
instr->operands[i].setTemp(info.temp);
|
instr->operands[i].setTemp(info.temp);
|
||||||
info = ctx.info[info.temp.id()];
|
info = ctx.info[info.temp.id()];
|
||||||
}
|
}
|
||||||
|
@ -900,12 +902,19 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
|
||||||
unsigned can_use_mod = instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
|
unsigned can_use_mod = instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
|
||||||
can_use_mod = can_use_mod && instr_info.can_use_input_modifiers[(int)instr->opcode];
|
can_use_mod = can_use_mod && instr_info.can_use_input_modifiers[(int)instr->opcode];
|
||||||
|
|
||||||
if (info.is_abs() && (can_use_VOP3(ctx, instr) || instr->isDPP()) && can_use_mod) {
|
if (instr->isSDWA())
|
||||||
if (!instr->isDPP())
|
can_use_mod = can_use_mod && (static_cast<SDWA_instruction*>(instr.get())->sel[i] & sdwa_asuint) == sdwa_udword;
|
||||||
|
else
|
||||||
|
can_use_mod = can_use_mod && (instr->isDPP() || can_use_VOP3(ctx, instr));
|
||||||
|
|
||||||
|
if (info.is_abs() && can_use_mod) {
|
||||||
|
if (!instr->isDPP() && !instr->isSDWA())
|
||||||
to_VOP3(ctx, instr);
|
to_VOP3(ctx, instr);
|
||||||
instr->operands[i] = Operand(info.temp);
|
instr->operands[i] = Operand(info.temp);
|
||||||
if (instr->isDPP())
|
if (instr->isDPP())
|
||||||
static_cast<DPP_instruction*>(instr.get())->abs[i] = true;
|
static_cast<DPP_instruction*>(instr.get())->abs[i] = true;
|
||||||
|
else if (instr->isSDWA())
|
||||||
|
static_cast<SDWA_instruction*>(instr.get())->abs[i] = true;
|
||||||
else
|
else
|
||||||
static_cast<VOP3A_instruction*>(instr.get())->abs[i] = true;
|
static_cast<VOP3A_instruction*>(instr.get())->abs[i] = true;
|
||||||
}
|
}
|
||||||
|
@ -917,12 +926,14 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
|
||||||
instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
|
instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
|
||||||
instr->operands[i].setTemp(info.temp);
|
instr->operands[i].setTemp(info.temp);
|
||||||
continue;
|
continue;
|
||||||
} else if (info.is_neg() && (can_use_VOP3(ctx, instr) || instr->isDPP()) && can_use_mod) {
|
} else if (info.is_neg() && can_use_mod) {
|
||||||
if (!instr->isDPP())
|
if (!instr->isDPP() && !instr->isSDWA())
|
||||||
to_VOP3(ctx, instr);
|
to_VOP3(ctx, instr);
|
||||||
instr->operands[i].setTemp(info.temp);
|
instr->operands[i].setTemp(info.temp);
|
||||||
if (instr->isDPP())
|
if (instr->isDPP())
|
||||||
static_cast<DPP_instruction*>(instr.get())->neg[i] = true;
|
static_cast<DPP_instruction*>(instr.get())->neg[i] = true;
|
||||||
|
else if (instr->isSDWA())
|
||||||
|
static_cast<SDWA_instruction*>(instr.get())->neg[i] = true;
|
||||||
else
|
else
|
||||||
static_cast<VOP3A_instruction*>(instr.get())->neg[i] = true;
|
static_cast<VOP3A_instruction*>(instr.get())->neg[i] = true;
|
||||||
continue;
|
continue;
|
||||||
|
@ -932,7 +943,8 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
|
||||||
(!instr->isSDWA() || ctx.program->chip_class >= GFX9)) {
|
(!instr->isSDWA() || ctx.program->chip_class >= GFX9)) {
|
||||||
Operand op = get_constant_op(ctx, info, bits);
|
Operand op = get_constant_op(ctx, info, bits);
|
||||||
perfwarn(ctx.program, instr->opcode == aco_opcode::v_cndmask_b32 && i == 2, "v_cndmask_b32 with a constant selector", instr.get());
|
perfwarn(ctx.program, instr->opcode == aco_opcode::v_cndmask_b32 && i == 2, "v_cndmask_b32 with a constant selector", instr.get());
|
||||||
if (i == 0 || instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_writelane_b32) {
|
if (i == 0 || instr->isSDWA() || instr->opcode == aco_opcode::v_readlane_b32 ||
|
||||||
|
instr->opcode == aco_opcode::v_writelane_b32) {
|
||||||
instr->operands[i] = op;
|
instr->operands[i] = op;
|
||||||
continue;
|
continue;
|
||||||
} else if (!instr->isVOP3() && can_swap_operands(instr)) {
|
} else if (!instr->isVOP3() && can_swap_operands(instr)) {
|
||||||
|
@ -1641,6 +1653,8 @@ bool combine_ordering_test(opt_ctx &ctx, aco_ptr<Instruction>& instr)
|
||||||
neg[i] = vop3->neg[0];
|
neg[i] = vop3->neg[0];
|
||||||
abs[i] = vop3->abs[0];
|
abs[i] = vop3->abs[0];
|
||||||
opsel |= (vop3->opsel & 1) << i;
|
opsel |= (vop3->opsel & 1) << i;
|
||||||
|
} else if (op_instr[i]->isSDWA()) {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
Temp op0 = op_instr[i]->operands[0].getTemp();
|
Temp op0 = op_instr[i]->operands[0].getTemp();
|
||||||
|
@ -1715,6 +1729,8 @@ bool combine_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& instr)
|
||||||
Instruction *cmp = follow_operand(ctx, instr->operands[1], true);
|
Instruction *cmp = follow_operand(ctx, instr->operands[1], true);
|
||||||
if (!nan_test || !cmp)
|
if (!nan_test || !cmp)
|
||||||
return false;
|
return false;
|
||||||
|
if (nan_test->isSDWA() || cmp->isSDWA())
|
||||||
|
return false;
|
||||||
|
|
||||||
if (get_f32_cmp(cmp->opcode) == expected_nan_test)
|
if (get_f32_cmp(cmp->opcode) == expected_nan_test)
|
||||||
std::swap(nan_test, cmp);
|
std::swap(nan_test, cmp);
|
||||||
|
@ -1785,6 +1801,8 @@ bool combine_constant_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& in
|
||||||
|
|
||||||
if (!nan_test || !cmp)
|
if (!nan_test || !cmp)
|
||||||
return false;
|
return false;
|
||||||
|
if (nan_test->isSDWA() || cmp->isSDWA())
|
||||||
|
return false;
|
||||||
|
|
||||||
aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
|
aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
|
||||||
if (get_f32_cmp(cmp->opcode) == expected_nan_test)
|
if (get_f32_cmp(cmp->opcode) == expected_nan_test)
|
||||||
|
@ -1906,6 +1924,18 @@ bool combine_inverse_comparison(opt_ctx &ctx, aco_ptr<Instruction>& instr)
|
||||||
new_vop3->omod = cmp_vop3->omod;
|
new_vop3->omod = cmp_vop3->omod;
|
||||||
new_vop3->opsel = cmp_vop3->opsel;
|
new_vop3->opsel = cmp_vop3->opsel;
|
||||||
new_instr = new_vop3;
|
new_instr = new_vop3;
|
||||||
|
} else if (cmp->isSDWA()) {
|
||||||
|
SDWA_instruction *new_sdwa = create_instruction<SDWA_instruction>(
|
||||||
|
new_opcode, (Format)((uint16_t)Format::SDWA | (uint16_t)Format::VOPC), 2, 1);
|
||||||
|
SDWA_instruction *cmp_sdwa = static_cast<SDWA_instruction*>(cmp);
|
||||||
|
memcpy(new_sdwa->abs, cmp_sdwa->abs, sizeof(new_sdwa->abs));
|
||||||
|
memcpy(new_sdwa->sel, cmp_sdwa->sel, sizeof(new_sdwa->sel));
|
||||||
|
memcpy(new_sdwa->neg, cmp_sdwa->neg, sizeof(new_sdwa->neg));
|
||||||
|
new_sdwa->dst_sel = cmp_sdwa->dst_sel;
|
||||||
|
new_sdwa->dst_preserve = cmp_sdwa->dst_preserve;
|
||||||
|
new_sdwa->clamp = cmp_sdwa->clamp;
|
||||||
|
new_sdwa->omod = cmp_sdwa->omod;
|
||||||
|
new_instr = new_sdwa;
|
||||||
} else {
|
} else {
|
||||||
new_instr = create_instruction<VOPC_instruction>(new_opcode, Format::VOPC, 2, 1);
|
new_instr = create_instruction<VOPC_instruction>(new_opcode, Format::VOPC, 2, 1);
|
||||||
}
|
}
|
||||||
|
@ -1942,6 +1972,9 @@ bool match_op3_for_vop3(opt_ctx &ctx, aco_opcode op1, aco_opcode op2,
|
||||||
VOP3A_instruction *op1_vop3 = op1_instr->isVOP3() ? static_cast<VOP3A_instruction *>(op1_instr) : NULL;
|
VOP3A_instruction *op1_vop3 = op1_instr->isVOP3() ? static_cast<VOP3A_instruction *>(op1_instr) : NULL;
|
||||||
VOP3A_instruction *op2_vop3 = op2_instr->isVOP3() ? static_cast<VOP3A_instruction *>(op2_instr) : NULL;
|
VOP3A_instruction *op2_vop3 = op2_instr->isVOP3() ? static_cast<VOP3A_instruction *>(op2_instr) : NULL;
|
||||||
|
|
||||||
|
if (op1_instr->isSDWA() || op2_instr->isSDWA())
|
||||||
|
return false;
|
||||||
|
|
||||||
/* don't support inbetween clamp/omod */
|
/* don't support inbetween clamp/omod */
|
||||||
if (op2_vop3 && (op2_vop3->clamp || op2_vop3->omod))
|
if (op2_vop3 && (op2_vop3->clamp || op2_vop3->omod))
|
||||||
return false;
|
return false;
|
||||||
|
@ -2431,7 +2464,7 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
|
||||||
/* Applying two sgprs require making it VOP3, so don't do it unless it's
|
/* Applying two sgprs require making it VOP3, so don't do it unless it's
|
||||||
* definitively beneficial.
|
* definitively beneficial.
|
||||||
* TODO: this is too conservative because later the use count could be reduced to 1 */
|
* TODO: this is too conservative because later the use count could be reduced to 1 */
|
||||||
if (num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3())
|
if (num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() && !instr->isSDWA())
|
||||||
break;
|
break;
|
||||||
|
|
||||||
Temp sgpr = ctx.info[sgpr_info_id].temp;
|
Temp sgpr = ctx.info[sgpr_info_id].temp;
|
||||||
|
@ -2439,7 +2472,7 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
|
||||||
if (new_sgpr && num_sgprs >= max_sgprs)
|
if (new_sgpr && num_sgprs >= max_sgprs)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (sgpr_idx == 0 || instr->isVOP3()) {
|
if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA()) {
|
||||||
instr->operands[sgpr_idx] = Operand(sgpr);
|
instr->operands[sgpr_idx] = Operand(sgpr);
|
||||||
} else if (can_swap_operands(instr)) {
|
} else if (can_swap_operands(instr)) {
|
||||||
instr->operands[sgpr_idx] = instr->operands[0];
|
instr->operands[sgpr_idx] = instr->operands[0];
|
||||||
|
@ -2461,22 +2494,20 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool apply_omod_clamp_helper(opt_ctx &ctx, aco_ptr<Instruction>& instr, ssa_info& def_info)
|
template <typename T>
|
||||||
|
bool apply_omod_clamp_helper(opt_ctx &ctx, T *instr, ssa_info& def_info)
|
||||||
{
|
{
|
||||||
to_VOP3(ctx, instr);
|
if (!def_info.is_clamp() && (instr->clamp || instr->omod))
|
||||||
|
|
||||||
if (!def_info.is_clamp() && (static_cast<VOP3A_instruction*>(instr.get())->clamp ||
|
|
||||||
static_cast<VOP3A_instruction*>(instr.get())->omod))
|
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (def_info.is_omod2())
|
if (def_info.is_omod2())
|
||||||
static_cast<VOP3A_instruction*>(instr.get())->omod = 1;
|
instr->omod = 1;
|
||||||
else if (def_info.is_omod4())
|
else if (def_info.is_omod4())
|
||||||
static_cast<VOP3A_instruction*>(instr.get())->omod = 2;
|
instr->omod = 2;
|
||||||
else if (def_info.is_omod5())
|
else if (def_info.is_omod5())
|
||||||
static_cast<VOP3A_instruction*>(instr.get())->omod = 3;
|
instr->omod = 3;
|
||||||
else if (def_info.is_clamp())
|
else if (def_info.is_clamp())
|
||||||
static_cast<VOP3A_instruction*>(instr.get())->clamp = true;
|
instr->clamp = true;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -2488,11 +2519,14 @@ bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
|
||||||
!instr_info.can_use_output_modifiers[(int)instr->opcode])
|
!instr_info.can_use_output_modifiers[(int)instr->opcode])
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (!can_use_VOP3(ctx, instr))
|
bool can_vop3 = can_use_VOP3(ctx, instr);
|
||||||
|
if (!instr->isSDWA() && !can_vop3)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
/* omod has no effect if denormals are enabled */
|
/* omod has no effect if denormals are enabled */
|
||||||
bool can_use_omod = (instr->definitions[0].bytes() == 4 ? block.fp_mode.denorm32 : block.fp_mode.denorm16_64) == 0;
|
bool can_use_omod = (instr->definitions[0].bytes() == 4 ? block.fp_mode.denorm32 : block.fp_mode.denorm16_64) == 0;
|
||||||
|
can_use_omod = can_use_omod && (can_vop3 || ctx.program->chip_class >= GFX9); /* SDWA omod is GFX9+ */
|
||||||
|
|
||||||
ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
|
ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
|
||||||
|
|
||||||
uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
|
uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
|
||||||
|
@ -2506,8 +2540,14 @@ bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
|
||||||
/* MADs/FMAs are created later, so we don't have to update the original add */
|
/* MADs/FMAs are created later, so we don't have to update the original add */
|
||||||
assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
|
assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
|
||||||
|
|
||||||
if (!apply_omod_clamp_helper(ctx, instr, def_info))
|
if (instr->isSDWA()) {
|
||||||
return false;
|
if (!apply_omod_clamp_helper(ctx, static_cast<SDWA_instruction *>(instr.get()), def_info))
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
to_VOP3(ctx, instr);
|
||||||
|
if (!apply_omod_clamp_helper(ctx, static_cast<VOP3A_instruction *>(instr.get()), def_info))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
std::swap(instr->definitions[0], def_info.instr->definitions[0]);
|
std::swap(instr->definitions[0], def_info.instr->definitions[0]);
|
||||||
ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
|
ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
|
||||||
|
@ -2525,7 +2565,7 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (instr->isVALU()) {
|
if (instr->isVALU()) {
|
||||||
if (can_apply_sgprs(instr))
|
if (can_apply_sgprs(ctx, instr))
|
||||||
apply_sgprs(ctx, instr);
|
apply_sgprs(ctx, instr);
|
||||||
while (apply_omod_clamp(ctx, block, instr)) ;
|
while (apply_omod_clamp(ctx, block, instr)) ;
|
||||||
}
|
}
|
||||||
|
@ -2534,6 +2574,9 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
|
||||||
instr->definitions[0].setHint(vcc);
|
instr->definitions[0].setHint(vcc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (instr->isSDWA())
|
||||||
|
return;
|
||||||
|
|
||||||
/* TODO: There are still some peephole optimizations that could be done:
|
/* TODO: There are still some peephole optimizations that could be done:
|
||||||
* - abs(a - b) -> s_absdiff_i32
|
* - abs(a - b) -> s_absdiff_i32
|
||||||
* - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
|
* - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
|
||||||
|
@ -2557,6 +2600,8 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
|
||||||
return;
|
return;
|
||||||
if (mul_instr->isVOP3() && static_cast<VOP3A_instruction*>(mul_instr)->clamp)
|
if (mul_instr->isVOP3() && static_cast<VOP3A_instruction*>(mul_instr)->clamp)
|
||||||
return;
|
return;
|
||||||
|
if (mul_instr->isSDWA())
|
||||||
|
return;
|
||||||
|
|
||||||
/* convert to mul(neg(a), b) */
|
/* convert to mul(neg(a), b) */
|
||||||
ctx.uses[mul_instr->definitions[0].tempId()]--;
|
ctx.uses[mul_instr->definitions[0].tempId()]--;
|
||||||
|
@ -2639,6 +2684,8 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
|
||||||
// TODO: would be better to check this before selecting a mul instr?
|
// TODO: would be better to check this before selecting a mul instr?
|
||||||
if (!check_vop3_operands(ctx, 3, op))
|
if (!check_vop3_operands(ctx, 3, op))
|
||||||
return;
|
return;
|
||||||
|
if (mul_instr->isSDWA())
|
||||||
|
return;
|
||||||
|
|
||||||
if (mul_instr->isVOP3()) {
|
if (mul_instr->isVOP3()) {
|
||||||
VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*> (mul_instr);
|
VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*> (mul_instr);
|
||||||
|
|
Loading…
Reference in New Issue