aco: add SDWA_instruction
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com> Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Reviewed-By: Timur Kristóf <timur.kristof@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4002>
This commit is contained in:
parent
00312f3c95
commit
b84d59af50
|
@ -547,7 +547,7 @@ void emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction*
|
||||||
/* first emit the instruction without the DPP operand */
|
/* first emit the instruction without the DPP operand */
|
||||||
Operand dpp_op = instr->operands[0];
|
Operand dpp_op = instr->operands[0];
|
||||||
instr->operands[0] = Operand(PhysReg{250}, v1);
|
instr->operands[0] = Operand(PhysReg{250}, v1);
|
||||||
instr->format = (Format) ((uint32_t) instr->format & ~(1 << 14));
|
instr->format = (Format) ((uint16_t) instr->format & ~(uint16_t)Format::DPP);
|
||||||
emit_instruction(ctx, out, instr);
|
emit_instruction(ctx, out, instr);
|
||||||
DPP_instruction* dpp = static_cast<DPP_instruction*>(instr);
|
DPP_instruction* dpp = static_cast<DPP_instruction*>(instr);
|
||||||
uint32_t encoding = (0xF & dpp->row_mask) << 28;
|
uint32_t encoding = (0xF & dpp->row_mask) << 28;
|
||||||
|
@ -561,6 +561,47 @@ void emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction*
|
||||||
encoding |= (0xFF) & dpp_op.physReg();
|
encoding |= (0xFF) & dpp_op.physReg();
|
||||||
out.push_back(encoding);
|
out.push_back(encoding);
|
||||||
return;
|
return;
|
||||||
|
} else if (instr->isSDWA()) {
|
||||||
|
/* first emit the instruction without the SDWA operand */
|
||||||
|
Operand sdwa_op = instr->operands[0];
|
||||||
|
instr->operands[0] = Operand(PhysReg{249}, v1);
|
||||||
|
instr->format = (Format) ((uint16_t) instr->format & ~(uint16_t)Format::SDWA);
|
||||||
|
emit_instruction(ctx, out, instr);
|
||||||
|
|
||||||
|
SDWA_instruction* sdwa = static_cast<SDWA_instruction*>(instr);
|
||||||
|
uint32_t encoding = 0;
|
||||||
|
|
||||||
|
if ((uint16_t)instr->format & (uint16_t)Format::VOPC) {
|
||||||
|
if (instr->definitions[0].physReg() != vcc) {
|
||||||
|
encoding |= instr->definitions[0].physReg() << 8;
|
||||||
|
encoding |= 1 << 15;
|
||||||
|
}
|
||||||
|
encoding |= (sdwa->clamp ? 1 : 0) << 13;
|
||||||
|
} else {
|
||||||
|
encoding |= (uint32_t)(sdwa->dst_sel & sdwa_asuint) << 8;
|
||||||
|
uint32_t dst_u = sdwa->dst_sel & sdwa_sext ? 1 : 0;
|
||||||
|
encoding |= dst_u << 11;
|
||||||
|
encoding |= (sdwa->clamp ? 1 : 0) << 13;
|
||||||
|
encoding |= sdwa->omod << 14;
|
||||||
|
}
|
||||||
|
|
||||||
|
encoding |= (uint32_t)(sdwa->sel[0] & sdwa_asuint) << 16;
|
||||||
|
encoding |= sdwa->sel[0] & sdwa_sext ? 1 << 19 : 0;
|
||||||
|
encoding |= sdwa->abs[0] << 21;
|
||||||
|
encoding |= sdwa->neg[0] << 20;
|
||||||
|
|
||||||
|
if (instr->operands.size() >= 2) {
|
||||||
|
encoding |= (uint32_t)(sdwa->sel[1] & sdwa_asuint) << 24;
|
||||||
|
encoding |= sdwa->sel[1] & sdwa_sext ? 1 << 27 : 0;
|
||||||
|
encoding |= sdwa->abs[1] << 29;
|
||||||
|
encoding |= sdwa->neg[1] << 28;
|
||||||
|
}
|
||||||
|
|
||||||
|
encoding |= 0xFF & sdwa_op.physReg();
|
||||||
|
encoding |= (sdwa_op.physReg() < 256) << 23;
|
||||||
|
if (instr->operands.size() >= 2)
|
||||||
|
encoding |= (instr->operands[1].physReg() < 256) << 31;
|
||||||
|
out.push_back(encoding);
|
||||||
} else {
|
} else {
|
||||||
unreachable("unimplemented instruction format");
|
unreachable("unimplemented instruction format");
|
||||||
}
|
}
|
||||||
|
|
|
@ -169,6 +169,11 @@ constexpr Format asVOP3(Format format) {
|
||||||
return (Format) ((uint32_t) Format::VOP3 | (uint32_t) format);
|
return (Format) ((uint32_t) Format::VOP3 | (uint32_t) format);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
constexpr Format asSDWA(Format format) {
|
||||||
|
assert(format == Format::VOP1 || format == Format::VOP2 || format == Format::VOPC);
|
||||||
|
return (Format) ((uint32_t) Format::SDWA | (uint32_t) format);
|
||||||
|
}
|
||||||
|
|
||||||
enum class RegType {
|
enum class RegType {
|
||||||
none = 0,
|
none = 0,
|
||||||
sgpr,
|
sgpr,
|
||||||
|
@ -841,6 +846,55 @@ struct DPP_instruction : public Instruction {
|
||||||
bool bound_ctrl : 1;
|
bool bound_ctrl : 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum sdwa_sel : uint8_t {
|
||||||
|
/* masks */
|
||||||
|
sdwa_wordnum = 0x1,
|
||||||
|
sdwa_bytenum = 0x3,
|
||||||
|
sdwa_asuint = 0x7,
|
||||||
|
|
||||||
|
/* flags */
|
||||||
|
sdwa_isword = 0x4,
|
||||||
|
sdwa_sext = 0x8,
|
||||||
|
|
||||||
|
/* specific values */
|
||||||
|
sdwa_ubyte0 = 0,
|
||||||
|
sdwa_ubyte1 = 1,
|
||||||
|
sdwa_ubyte2 = 2,
|
||||||
|
sdwa_ubyte3 = 3,
|
||||||
|
sdwa_uword0 = sdwa_isword | 0,
|
||||||
|
sdwa_uword1 = sdwa_isword | 1,
|
||||||
|
sdwa_udword = 6,
|
||||||
|
|
||||||
|
sdwa_sbyte0 = sdwa_ubyte0 | sdwa_sext,
|
||||||
|
sdwa_sbyte1 = sdwa_ubyte1 | sdwa_sext,
|
||||||
|
sdwa_sbyte2 = sdwa_ubyte2 | sdwa_sext,
|
||||||
|
sdwa_sbyte3 = sdwa_ubyte3 | sdwa_sext,
|
||||||
|
sdwa_sword0 = sdwa_uword0 | sdwa_sext,
|
||||||
|
sdwa_sword1 = sdwa_uword1 | sdwa_sext,
|
||||||
|
sdwa_sdword = sdwa_udword | sdwa_sext,
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sub-Dword Addressing Format:
|
||||||
|
* This format can be used for VOP1, VOP2 or VOPC instructions.
|
||||||
|
*
|
||||||
|
* omod and SGPR/constant operands are only available on GFX9+. For VOPC,
|
||||||
|
* the definition doesn't have to be VCC on GFX9+.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
struct SDWA_instruction : public Instruction {
|
||||||
|
/* these destination modifiers aren't available with VOPC except for
|
||||||
|
* clamp on GFX8 */
|
||||||
|
unsigned dst_sel:4;
|
||||||
|
bool dst_preserve:1;
|
||||||
|
bool clamp:1;
|
||||||
|
unsigned omod:2; /* GFX9+ */
|
||||||
|
|
||||||
|
unsigned sel[2];
|
||||||
|
bool neg[2];
|
||||||
|
bool abs[2];
|
||||||
|
};
|
||||||
|
|
||||||
struct Interp_instruction : public Instruction {
|
struct Interp_instruction : public Instruction {
|
||||||
uint8_t attribute;
|
uint8_t attribute;
|
||||||
uint8_t component;
|
uint8_t component;
|
||||||
|
|
|
@ -140,6 +140,9 @@ void print_asm(Program *program, std::vector<uint32_t>& binary,
|
||||||
if (!l && program->chip_class == GFX9 && ((binary[pos] & 0xffff8000) == 0xd1348000)) { /* not actually an invalid instruction */
|
if (!l && program->chip_class == GFX9 && ((binary[pos] & 0xffff8000) == 0xd1348000)) { /* not actually an invalid instruction */
|
||||||
out << std::left << std::setw(align_width) << std::setfill(' ') << "\tv_add_u32_e64 + clamp";
|
out << std::left << std::setw(align_width) << std::setfill(' ') << "\tv_add_u32_e64 + clamp";
|
||||||
new_pos = pos + 2;
|
new_pos = pos + 2;
|
||||||
|
} else if (program->chip_class == GFX10 && l == 4 && ((binary[pos] & 0xfe0001ff) == 0x020000f9)) {
|
||||||
|
out << std::left << std::setw(align_width) << std::setfill(' ') << "\tv_cndmask_b32 + sdwa";
|
||||||
|
new_pos = pos + 2;
|
||||||
} else if (!l) {
|
} else if (!l) {
|
||||||
out << std::left << std::setw(align_width) << std::setfill(' ') << "(invalid instruction)";
|
out << std::left << std::setw(align_width) << std::setfill(' ') << "(invalid instruction)";
|
||||||
new_pos = pos + 1;
|
new_pos = pos + 1;
|
||||||
|
|
|
@ -528,7 +528,38 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
|
||||||
if (dpp->bound_ctrl)
|
if (dpp->bound_ctrl)
|
||||||
fprintf(output, " bound_ctrl:1");
|
fprintf(output, " bound_ctrl:1");
|
||||||
} else if ((int)instr->format & (int)Format::SDWA) {
|
} else if ((int)instr->format & (int)Format::SDWA) {
|
||||||
fprintf(output, " (printing unimplemented)");
|
SDWA_instruction* sdwa = static_cast<SDWA_instruction*>(instr);
|
||||||
|
switch (sdwa->omod) {
|
||||||
|
case 1:
|
||||||
|
fprintf(output, " *2");
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
fprintf(output, " *4");
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
fprintf(output, " *0.5");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (sdwa->clamp)
|
||||||
|
fprintf(output, " clamp");
|
||||||
|
switch (sdwa->dst_sel & sdwa_asuint) {
|
||||||
|
case sdwa_udword:
|
||||||
|
break;
|
||||||
|
case sdwa_ubyte0:
|
||||||
|
case sdwa_ubyte1:
|
||||||
|
case sdwa_ubyte2:
|
||||||
|
case sdwa_ubyte3:
|
||||||
|
fprintf(output, " dst_sel:%sbyte%u", sdwa->dst_sel & sdwa_sext ? "s" : "u",
|
||||||
|
sdwa->dst_sel & sdwa_bytenum);
|
||||||
|
break;
|
||||||
|
case sdwa_uword0:
|
||||||
|
case sdwa_uword1:
|
||||||
|
fprintf(output, " dst_sel:%sword%u", sdwa->dst_sel & sdwa_sext ? "s" : "u",
|
||||||
|
sdwa->dst_sel & sdwa_wordnum);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (sdwa->dst_preserve)
|
||||||
|
fprintf(output, " dst_preserve");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -546,23 +577,33 @@ void aco_print_instr(struct Instruction *instr, FILE *output)
|
||||||
if (instr->operands.size()) {
|
if (instr->operands.size()) {
|
||||||
bool abs[instr->operands.size()];
|
bool abs[instr->operands.size()];
|
||||||
bool neg[instr->operands.size()];
|
bool neg[instr->operands.size()];
|
||||||
|
uint8_t sel[instr->operands.size()];
|
||||||
if ((int)instr->format & (int)Format::VOP3A) {
|
if ((int)instr->format & (int)Format::VOP3A) {
|
||||||
VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*>(instr);
|
VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*>(instr);
|
||||||
for (unsigned i = 0; i < instr->operands.size(); ++i) {
|
for (unsigned i = 0; i < instr->operands.size(); ++i) {
|
||||||
abs[i] = vop3->abs[i];
|
abs[i] = vop3->abs[i];
|
||||||
neg[i] = vop3->neg[i];
|
neg[i] = vop3->neg[i];
|
||||||
|
sel[i] = sdwa_udword;
|
||||||
}
|
}
|
||||||
} else if (instr->isDPP()) {
|
} else if (instr->isDPP()) {
|
||||||
DPP_instruction* dpp = static_cast<DPP_instruction*>(instr);
|
DPP_instruction* dpp = static_cast<DPP_instruction*>(instr);
|
||||||
assert(instr->operands.size() <= 2);
|
|
||||||
for (unsigned i = 0; i < instr->operands.size(); ++i) {
|
for (unsigned i = 0; i < instr->operands.size(); ++i) {
|
||||||
abs[i] = dpp->abs[i];
|
abs[i] = i < 2 ? dpp->abs[i] : false;
|
||||||
neg[i] = dpp->neg[i];
|
neg[i] = i < 2 ? dpp->neg[i] : false;
|
||||||
|
sel[i] = sdwa_udword;
|
||||||
|
}
|
||||||
|
} else if (instr->isSDWA()) {
|
||||||
|
SDWA_instruction* sdwa = static_cast<SDWA_instruction*>(instr);
|
||||||
|
for (unsigned i = 0; i < instr->operands.size(); ++i) {
|
||||||
|
abs[i] = i < 2 ? sdwa->abs[i] : false;
|
||||||
|
neg[i] = i < 2 ? sdwa->neg[i] : false;
|
||||||
|
sel[i] = i < 2 ? sdwa->sel[i] : sdwa_udword;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < instr->operands.size(); ++i) {
|
for (unsigned i = 0; i < instr->operands.size(); ++i) {
|
||||||
abs[i] = false;
|
abs[i] = false;
|
||||||
neg[i] = false;
|
neg[i] = false;
|
||||||
|
sel[i] = sdwa_udword;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (unsigned i = 0; i < instr->operands.size(); ++i) {
|
for (unsigned i = 0; i < instr->operands.size(); ++i) {
|
||||||
|
@ -575,7 +616,20 @@ void aco_print_instr(struct Instruction *instr, FILE *output)
|
||||||
fprintf(output, "-");
|
fprintf(output, "-");
|
||||||
if (abs[i])
|
if (abs[i])
|
||||||
fprintf(output, "|");
|
fprintf(output, "|");
|
||||||
|
if (sel[i] & sdwa_sext)
|
||||||
|
fprintf(output, "sext(");
|
||||||
print_operand(&instr->operands[i], output);
|
print_operand(&instr->operands[i], output);
|
||||||
|
if (sel[i] & sdwa_sext)
|
||||||
|
fprintf(output, ")");
|
||||||
|
if ((sel[i] & sdwa_asuint) == sdwa_udword) {
|
||||||
|
/* print nothing */
|
||||||
|
} else if (sel[i] & sdwa_isword) {
|
||||||
|
unsigned index = sel[i] & sdwa_wordnum;
|
||||||
|
fprintf(output, "[%u:%u]", index * 16, index * 16 + 15);
|
||||||
|
} else {
|
||||||
|
unsigned index = sel[i] & sdwa_bytenum;
|
||||||
|
fprintf(output, "[%u:%u]", index * 8, index * 8 + 7);
|
||||||
|
}
|
||||||
if (abs[i])
|
if (abs[i])
|
||||||
fprintf(output, "|");
|
fprintf(output, "|");
|
||||||
}
|
}
|
||||||
|
|
|
@ -93,6 +93,50 @@ void validate(Program* program, FILE * output)
|
||||||
"Format cannot have VOP3A/VOP3B applied", instr.get());
|
"Format cannot have VOP3A/VOP3B applied", instr.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* check SDWA */
|
||||||
|
if (instr->isSDWA()) {
|
||||||
|
check(base_format == Format::VOP2 ||
|
||||||
|
base_format == Format::VOP1 ||
|
||||||
|
base_format == Format::VOPC,
|
||||||
|
"Format cannot have SDWA applied", instr.get());
|
||||||
|
|
||||||
|
check(program->chip_class >= GFX8, "SDWA is GFX8+ only", instr.get());
|
||||||
|
|
||||||
|
SDWA_instruction *sdwa = static_cast<SDWA_instruction*>(instr.get());
|
||||||
|
check(sdwa->omod == 0 || program->chip_class >= GFX9, "SDWA omod only supported on GFX9+", instr.get());
|
||||||
|
if (base_format == Format::VOPC) {
|
||||||
|
check(sdwa->clamp == false || program->chip_class == GFX8, "SDWA VOPC clamp only supported on GFX8", instr.get());
|
||||||
|
check((instr->definitions[0].isFixed() && instr->definitions[0].physReg() == vcc) ||
|
||||||
|
program->chip_class >= GFX9,
|
||||||
|
"SDWA+VOPC definition must be fixed to vcc on GFX8", instr.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (instr->operands.size() >= 3) {
|
||||||
|
check(instr->operands[2].isFixed() && instr->operands[2].physReg() == vcc,
|
||||||
|
"3rd operand must be fixed to vcc with SDWA", instr.get());
|
||||||
|
}
|
||||||
|
if (instr->definitions.size() >= 2) {
|
||||||
|
check(instr->definitions[1].isFixed() && instr->definitions[1].physReg() == vcc,
|
||||||
|
"2nd definition must be fixed to vcc with SDWA", instr.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
check(instr->opcode != aco_opcode::v_madmk_f32 &&
|
||||||
|
instr->opcode != aco_opcode::v_madak_f32 &&
|
||||||
|
instr->opcode != aco_opcode::v_madmk_f16 &&
|
||||||
|
instr->opcode != aco_opcode::v_madak_f16 &&
|
||||||
|
instr->opcode != aco_opcode::v_readfirstlane_b32 &&
|
||||||
|
instr->opcode != aco_opcode::v_clrexcp &&
|
||||||
|
instr->opcode != aco_opcode::v_swap_b32,
|
||||||
|
"SDWA can't be used with this opcode", instr.get());
|
||||||
|
if (program->chip_class != GFX8) {
|
||||||
|
check(instr->opcode != aco_opcode::v_mac_f32 &&
|
||||||
|
instr->opcode != aco_opcode::v_mac_f16 &&
|
||||||
|
instr->opcode != aco_opcode::v_fmac_f32 &&
|
||||||
|
instr->opcode != aco_opcode::v_fmac_f16,
|
||||||
|
"SDWA can't be used with this opcode", instr.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* check for undefs */
|
/* check for undefs */
|
||||||
for (unsigned i = 0; i < instr->operands.size(); i++) {
|
for (unsigned i = 0; i < instr->operands.size(); i++) {
|
||||||
if (instr->operands[i].isUndefined()) {
|
if (instr->operands[i].isUndefined()) {
|
||||||
|
@ -137,6 +181,10 @@ void validate(Program* program, FILE * output)
|
||||||
if (program->chip_class >= GFX10 && !is_shift64)
|
if (program->chip_class >= GFX10 && !is_shift64)
|
||||||
const_bus_limit = 2;
|
const_bus_limit = 2;
|
||||||
|
|
||||||
|
uint32_t scalar_mask = instr->isVOP3() ? 0x7 : 0x5;
|
||||||
|
if (instr->isSDWA())
|
||||||
|
scalar_mask = program->chip_class >= GFX9 ? 0x7 : 0x4;
|
||||||
|
|
||||||
check(instr->definitions[0].getTemp().type() == RegType::vgpr ||
|
check(instr->definitions[0].getTemp().type() == RegType::vgpr ||
|
||||||
(int) instr->format & (int) Format::VOPC ||
|
(int) instr->format & (int) Format::VOPC ||
|
||||||
instr->opcode == aco_opcode::v_readfirstlane_b32 ||
|
instr->opcode == aco_opcode::v_readfirstlane_b32 ||
|
||||||
|
@ -158,7 +206,7 @@ void validate(Program* program, FILE * output)
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (op.isTemp() && instr->operands[i].regClass().type() == RegType::sgpr) {
|
if (op.isTemp() && instr->operands[i].regClass().type() == RegType::sgpr) {
|
||||||
check(i != 1 || instr->isVOP3(), "Wrong source position for SGPR argument", instr.get());
|
check(scalar_mask & (1 << i), "Wrong source position for SGPR argument", instr.get());
|
||||||
|
|
||||||
if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
|
if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
|
||||||
if (num_sgprs < 2)
|
if (num_sgprs < 2)
|
||||||
|
@ -167,7 +215,7 @@ void validate(Program* program, FILE * output)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op.isConstant() && !op.isLiteral())
|
if (op.isConstant() && !op.isLiteral())
|
||||||
check(i == 0 || instr->isVOP3(), "Wrong source position for constant argument", instr.get());
|
check(scalar_mask & (1 << i), "Wrong source position for constant argument", instr.get());
|
||||||
}
|
}
|
||||||
check(num_sgprs + (literal.isUndefined() ? 0 : 1) <= const_bus_limit, "Too many SGPRs/literals", instr.get());
|
check(num_sgprs + (literal.isUndefined() ? 0 : 1) <= const_bus_limit, "Too many SGPRs/literals", instr.get());
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue