diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 48cf5e345b5..9a8d6d9fe2f 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -619,8 +619,10 @@ bool can_use_VOP3(opt_ctx& ctx, const aco_ptr& instr) instr->opcode != aco_opcode::v_readfirstlane_b32; } -bool can_apply_sgprs(aco_ptr& instr) +bool can_apply_sgprs(opt_ctx& ctx, aco_ptr& instr) { + if (instr->isSDWA() && ctx.program->chip_class < GFX9) + return false; return instr->opcode != aco_opcode::v_readfirstlane_b32 && instr->opcode != aco_opcode::v_readlane_b32 && instr->opcode != aco_opcode::v_readlane_b32_e64 && @@ -891,7 +893,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr) info = ctx.info[info.temp.id()]; } /* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */ - if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(instr) && instr->operands.size() == 1) { + if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) && instr->operands.size() == 1) { instr->operands[i].setTemp(info.temp); info = ctx.info[info.temp.id()]; } @@ -900,12 +902,19 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr) unsigned can_use_mod = instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4; can_use_mod = can_use_mod && instr_info.can_use_input_modifiers[(int)instr->opcode]; - if (info.is_abs() && (can_use_VOP3(ctx, instr) || instr->isDPP()) && can_use_mod) { - if (!instr->isDPP()) + if (instr->isSDWA()) + can_use_mod = can_use_mod && (static_cast(instr.get())->sel[i] & sdwa_asuint) == sdwa_udword; + else + can_use_mod = can_use_mod && (instr->isDPP() || can_use_VOP3(ctx, instr)); + + if (info.is_abs() && can_use_mod) { + if (!instr->isDPP() && !instr->isSDWA()) to_VOP3(ctx, instr); instr->operands[i] = Operand(info.temp); if (instr->isDPP()) static_cast(instr.get())->abs[i] = true; + else if (instr->isSDWA()) + static_cast(instr.get())->abs[i] = true; else static_cast(instr.get())->abs[i] = true; } @@ -917,12 +926,14 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr) instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16; instr->operands[i].setTemp(info.temp); continue; - } else if (info.is_neg() && (can_use_VOP3(ctx, instr) || instr->isDPP()) && can_use_mod) { - if (!instr->isDPP()) + } else if (info.is_neg() && can_use_mod) { + if (!instr->isDPP() && !instr->isSDWA()) to_VOP3(ctx, instr); instr->operands[i].setTemp(info.temp); if (instr->isDPP()) static_cast(instr.get())->neg[i] = true; + else if (instr->isSDWA()) + static_cast(instr.get())->neg[i] = true; else static_cast(instr.get())->neg[i] = true; continue; @@ -932,7 +943,8 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr) (!instr->isSDWA() || ctx.program->chip_class >= GFX9)) { Operand op = get_constant_op(ctx, info, bits); perfwarn(ctx.program, instr->opcode == aco_opcode::v_cndmask_b32 && i == 2, "v_cndmask_b32 with a constant selector", instr.get()); - if (i == 0 || instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_writelane_b32) { + if (i == 0 || instr->isSDWA() || instr->opcode == aco_opcode::v_readlane_b32 || + instr->opcode == aco_opcode::v_writelane_b32) { instr->operands[i] = op; continue; } else if (!instr->isVOP3() && can_swap_operands(instr)) { @@ -1641,6 +1653,8 @@ bool combine_ordering_test(opt_ctx &ctx, aco_ptr& instr) neg[i] = vop3->neg[0]; abs[i] = vop3->abs[0]; opsel |= (vop3->opsel & 1) << i; + } else if (op_instr[i]->isSDWA()) { + return false; } Temp op0 = op_instr[i]->operands[0].getTemp(); @@ -1715,6 +1729,8 @@ bool combine_comparison_ordering(opt_ctx &ctx, aco_ptr& instr) Instruction *cmp = follow_operand(ctx, instr->operands[1], true); if (!nan_test || !cmp) return false; + if (nan_test->isSDWA() || cmp->isSDWA()) + return false; if (get_f32_cmp(cmp->opcode) == expected_nan_test) std::swap(nan_test, cmp); @@ -1785,6 +1801,8 @@ bool combine_constant_comparison_ordering(opt_ctx &ctx, aco_ptr& in if (!nan_test || !cmp) return false; + if (nan_test->isSDWA() || cmp->isSDWA()) + return false; aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32; if (get_f32_cmp(cmp->opcode) == expected_nan_test) @@ -1906,6 +1924,18 @@ bool combine_inverse_comparison(opt_ctx &ctx, aco_ptr& instr) new_vop3->omod = cmp_vop3->omod; new_vop3->opsel = cmp_vop3->opsel; new_instr = new_vop3; + } else if (cmp->isSDWA()) { + SDWA_instruction *new_sdwa = create_instruction( + new_opcode, (Format)((uint16_t)Format::SDWA | (uint16_t)Format::VOPC), 2, 1); + SDWA_instruction *cmp_sdwa = static_cast(cmp); + memcpy(new_sdwa->abs, cmp_sdwa->abs, sizeof(new_sdwa->abs)); + memcpy(new_sdwa->sel, cmp_sdwa->sel, sizeof(new_sdwa->sel)); + memcpy(new_sdwa->neg, cmp_sdwa->neg, sizeof(new_sdwa->neg)); + new_sdwa->dst_sel = cmp_sdwa->dst_sel; + new_sdwa->dst_preserve = cmp_sdwa->dst_preserve; + new_sdwa->clamp = cmp_sdwa->clamp; + new_sdwa->omod = cmp_sdwa->omod; + new_instr = new_sdwa; } else { new_instr = create_instruction(new_opcode, Format::VOPC, 2, 1); } @@ -1942,6 +1972,9 @@ bool match_op3_for_vop3(opt_ctx &ctx, aco_opcode op1, aco_opcode op2, VOP3A_instruction *op1_vop3 = op1_instr->isVOP3() ? static_cast(op1_instr) : NULL; VOP3A_instruction *op2_vop3 = op2_instr->isVOP3() ? static_cast(op2_instr) : NULL; + if (op1_instr->isSDWA() || op2_instr->isSDWA()) + return false; + /* don't support inbetween clamp/omod */ if (op2_vop3 && (op2_vop3->clamp || op2_vop3->omod)) return false; @@ -2431,7 +2464,7 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr& instr) /* Applying two sgprs require making it VOP3, so don't do it unless it's * definitively beneficial. * TODO: this is too conservative because later the use count could be reduced to 1 */ - if (num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3()) + if (num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() && !instr->isSDWA()) break; Temp sgpr = ctx.info[sgpr_info_id].temp; @@ -2439,7 +2472,7 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr& instr) if (new_sgpr && num_sgprs >= max_sgprs) continue; - if (sgpr_idx == 0 || instr->isVOP3()) { + if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA()) { instr->operands[sgpr_idx] = Operand(sgpr); } else if (can_swap_operands(instr)) { instr->operands[sgpr_idx] = instr->operands[0]; @@ -2461,22 +2494,20 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr& instr) } } -bool apply_omod_clamp_helper(opt_ctx &ctx, aco_ptr& instr, ssa_info& def_info) +template +bool apply_omod_clamp_helper(opt_ctx &ctx, T *instr, ssa_info& def_info) { - to_VOP3(ctx, instr); - - if (!def_info.is_clamp() && (static_cast(instr.get())->clamp || - static_cast(instr.get())->omod)) + if (!def_info.is_clamp() && (instr->clamp || instr->omod)) return false; if (def_info.is_omod2()) - static_cast(instr.get())->omod = 1; + instr->omod = 1; else if (def_info.is_omod4()) - static_cast(instr.get())->omod = 2; + instr->omod = 2; else if (def_info.is_omod5()) - static_cast(instr.get())->omod = 3; + instr->omod = 3; else if (def_info.is_clamp()) - static_cast(instr.get())->clamp = true; + instr->clamp = true; return true; } @@ -2488,11 +2519,14 @@ bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr& instr) !instr_info.can_use_output_modifiers[(int)instr->opcode]) return false; - if (!can_use_VOP3(ctx, instr)) + bool can_vop3 = can_use_VOP3(ctx, instr); + if (!instr->isSDWA() && !can_vop3) return false; /* omod has no effect if denormals are enabled */ bool can_use_omod = (instr->definitions[0].bytes() == 4 ? block.fp_mode.denorm32 : block.fp_mode.denorm16_64) == 0; + can_use_omod = can_use_omod && (can_vop3 || ctx.program->chip_class >= GFX9); /* SDWA omod is GFX9+ */ + ssa_info& def_info = ctx.info[instr->definitions[0].tempId()]; uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5; @@ -2506,8 +2540,14 @@ bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr& instr) /* MADs/FMAs are created later, so we don't have to update the original add */ assert(!ctx.info[instr->definitions[0].tempId()].is_mad()); - if (!apply_omod_clamp_helper(ctx, instr, def_info)) - return false; + if (instr->isSDWA()) { + if (!apply_omod_clamp_helper(ctx, static_cast(instr.get()), def_info)) + return false; + } else { + to_VOP3(ctx, instr); + if (!apply_omod_clamp_helper(ctx, static_cast(instr.get()), def_info)) + return false; + } std::swap(instr->definitions[0], def_info.instr->definitions[0]); ctx.info[instr->definitions[0].tempId()].label &= label_clamp; @@ -2525,7 +2565,7 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr return; if (instr->isVALU()) { - if (can_apply_sgprs(instr)) + if (can_apply_sgprs(ctx, instr)) apply_sgprs(ctx, instr); while (apply_omod_clamp(ctx, block, instr)) ; } @@ -2534,6 +2574,9 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr instr->definitions[0].setHint(vcc); } + if (instr->isSDWA()) + return; + /* TODO: There are still some peephole optimizations that could be done: * - abs(a - b) -> s_absdiff_i32 * - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32 @@ -2557,6 +2600,8 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr return; if (mul_instr->isVOP3() && static_cast(mul_instr)->clamp) return; + if (mul_instr->isSDWA()) + return; /* convert to mul(neg(a), b) */ ctx.uses[mul_instr->definitions[0].tempId()]--; @@ -2639,6 +2684,8 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr // TODO: would be better to check this before selecting a mul instr? if (!check_vop3_operands(ctx, 3, op)) return; + if (mul_instr->isSDWA()) + return; if (mul_instr->isVOP3()) { VOP3A_instruction* vop3 = static_cast (mul_instr);