aco: add SDWA_instruction

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Reviewed-By: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4002>
This commit is contained in:
Rhys Perry 2019-12-04 20:18:05 +00:00 committed by Daniel Schürmann
parent 00312f3c95
commit b84d59af50
5 changed files with 207 additions and 7 deletions

View File

@ -547,7 +547,7 @@ void emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction*
/* first emit the instruction without the DPP operand */
Operand dpp_op = instr->operands[0];
instr->operands[0] = Operand(PhysReg{250}, v1);
instr->format = (Format) ((uint32_t) instr->format & ~(1 << 14));
instr->format = (Format) ((uint16_t) instr->format & ~(uint16_t)Format::DPP);
emit_instruction(ctx, out, instr);
DPP_instruction* dpp = static_cast<DPP_instruction*>(instr);
uint32_t encoding = (0xF & dpp->row_mask) << 28;
@ -561,6 +561,47 @@ void emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction*
encoding |= (0xFF) & dpp_op.physReg();
out.push_back(encoding);
return;
} else if (instr->isSDWA()) {
/* first emit the instruction without the SDWA operand */
Operand sdwa_op = instr->operands[0];
instr->operands[0] = Operand(PhysReg{249}, v1);
instr->format = (Format) ((uint16_t) instr->format & ~(uint16_t)Format::SDWA);
emit_instruction(ctx, out, instr);
SDWA_instruction* sdwa = static_cast<SDWA_instruction*>(instr);
uint32_t encoding = 0;
if ((uint16_t)instr->format & (uint16_t)Format::VOPC) {
if (instr->definitions[0].physReg() != vcc) {
encoding |= instr->definitions[0].physReg() << 8;
encoding |= 1 << 15;
}
encoding |= (sdwa->clamp ? 1 : 0) << 13;
} else {
encoding |= (uint32_t)(sdwa->dst_sel & sdwa_asuint) << 8;
uint32_t dst_u = sdwa->dst_sel & sdwa_sext ? 1 : 0;
encoding |= dst_u << 11;
encoding |= (sdwa->clamp ? 1 : 0) << 13;
encoding |= sdwa->omod << 14;
}
encoding |= (uint32_t)(sdwa->sel[0] & sdwa_asuint) << 16;
encoding |= sdwa->sel[0] & sdwa_sext ? 1 << 19 : 0;
encoding |= sdwa->abs[0] << 21;
encoding |= sdwa->neg[0] << 20;
if (instr->operands.size() >= 2) {
encoding |= (uint32_t)(sdwa->sel[1] & sdwa_asuint) << 24;
encoding |= sdwa->sel[1] & sdwa_sext ? 1 << 27 : 0;
encoding |= sdwa->abs[1] << 29;
encoding |= sdwa->neg[1] << 28;
}
encoding |= 0xFF & sdwa_op.physReg();
encoding |= (sdwa_op.physReg() < 256) << 23;
if (instr->operands.size() >= 2)
encoding |= (instr->operands[1].physReg() < 256) << 31;
out.push_back(encoding);
} else {
unreachable("unimplemented instruction format");
}

View File

@ -169,6 +169,11 @@ constexpr Format asVOP3(Format format) {
return (Format) ((uint32_t) Format::VOP3 | (uint32_t) format);
};
constexpr Format asSDWA(Format format) {
assert(format == Format::VOP1 || format == Format::VOP2 || format == Format::VOPC);
return (Format) ((uint32_t) Format::SDWA | (uint32_t) format);
}
enum class RegType {
none = 0,
sgpr,
@ -841,6 +846,55 @@ struct DPP_instruction : public Instruction {
bool bound_ctrl : 1;
};
enum sdwa_sel : uint8_t {
/* masks */
sdwa_wordnum = 0x1,
sdwa_bytenum = 0x3,
sdwa_asuint = 0x7,
/* flags */
sdwa_isword = 0x4,
sdwa_sext = 0x8,
/* specific values */
sdwa_ubyte0 = 0,
sdwa_ubyte1 = 1,
sdwa_ubyte2 = 2,
sdwa_ubyte3 = 3,
sdwa_uword0 = sdwa_isword | 0,
sdwa_uword1 = sdwa_isword | 1,
sdwa_udword = 6,
sdwa_sbyte0 = sdwa_ubyte0 | sdwa_sext,
sdwa_sbyte1 = sdwa_ubyte1 | sdwa_sext,
sdwa_sbyte2 = sdwa_ubyte2 | sdwa_sext,
sdwa_sbyte3 = sdwa_ubyte3 | sdwa_sext,
sdwa_sword0 = sdwa_uword0 | sdwa_sext,
sdwa_sword1 = sdwa_uword1 | sdwa_sext,
sdwa_sdword = sdwa_udword | sdwa_sext,
};
/**
* Sub-Dword Addressing Format:
* This format can be used for VOP1, VOP2 or VOPC instructions.
*
* omod and SGPR/constant operands are only available on GFX9+. For VOPC,
* the definition doesn't have to be VCC on GFX9+.
*
*/
struct SDWA_instruction : public Instruction {
/* these destination modifiers aren't available with VOPC except for
* clamp on GFX8 */
unsigned dst_sel:4;
bool dst_preserve:1;
bool clamp:1;
unsigned omod:2; /* GFX9+ */
unsigned sel[2];
bool neg[2];
bool abs[2];
};
struct Interp_instruction : public Instruction {
uint8_t attribute;
uint8_t component;

View File

@ -140,6 +140,9 @@ void print_asm(Program *program, std::vector<uint32_t>& binary,
if (!l && program->chip_class == GFX9 && ((binary[pos] & 0xffff8000) == 0xd1348000)) { /* not actually an invalid instruction */
out << std::left << std::setw(align_width) << std::setfill(' ') << "\tv_add_u32_e64 + clamp";
new_pos = pos + 2;
} else if (program->chip_class == GFX10 && l == 4 && ((binary[pos] & 0xfe0001ff) == 0x020000f9)) {
out << std::left << std::setw(align_width) << std::setfill(' ') << "\tv_cndmask_b32 + sdwa";
new_pos = pos + 2;
} else if (!l) {
out << std::left << std::setw(align_width) << std::setfill(' ') << "(invalid instruction)";
new_pos = pos + 1;

View File

@ -528,7 +528,38 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
if (dpp->bound_ctrl)
fprintf(output, " bound_ctrl:1");
} else if ((int)instr->format & (int)Format::SDWA) {
fprintf(output, " (printing unimplemented)");
SDWA_instruction* sdwa = static_cast<SDWA_instruction*>(instr);
switch (sdwa->omod) {
case 1:
fprintf(output, " *2");
break;
case 2:
fprintf(output, " *4");
break;
case 3:
fprintf(output, " *0.5");
break;
}
if (sdwa->clamp)
fprintf(output, " clamp");
switch (sdwa->dst_sel & sdwa_asuint) {
case sdwa_udword:
break;
case sdwa_ubyte0:
case sdwa_ubyte1:
case sdwa_ubyte2:
case sdwa_ubyte3:
fprintf(output, " dst_sel:%sbyte%u", sdwa->dst_sel & sdwa_sext ? "s" : "u",
sdwa->dst_sel & sdwa_bytenum);
break;
case sdwa_uword0:
case sdwa_uword1:
fprintf(output, " dst_sel:%sword%u", sdwa->dst_sel & sdwa_sext ? "s" : "u",
sdwa->dst_sel & sdwa_wordnum);
break;
}
if (sdwa->dst_preserve)
fprintf(output, " dst_preserve");
}
}
@ -546,23 +577,33 @@ void aco_print_instr(struct Instruction *instr, FILE *output)
if (instr->operands.size()) {
bool abs[instr->operands.size()];
bool neg[instr->operands.size()];
uint8_t sel[instr->operands.size()];
if ((int)instr->format & (int)Format::VOP3A) {
VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*>(instr);
for (unsigned i = 0; i < instr->operands.size(); ++i) {
abs[i] = vop3->abs[i];
neg[i] = vop3->neg[i];
sel[i] = sdwa_udword;
}
} else if (instr->isDPP()) {
DPP_instruction* dpp = static_cast<DPP_instruction*>(instr);
assert(instr->operands.size() <= 2);
for (unsigned i = 0; i < instr->operands.size(); ++i) {
abs[i] = dpp->abs[i];
neg[i] = dpp->neg[i];
abs[i] = i < 2 ? dpp->abs[i] : false;
neg[i] = i < 2 ? dpp->neg[i] : false;
sel[i] = sdwa_udword;
}
} else if (instr->isSDWA()) {
SDWA_instruction* sdwa = static_cast<SDWA_instruction*>(instr);
for (unsigned i = 0; i < instr->operands.size(); ++i) {
abs[i] = i < 2 ? sdwa->abs[i] : false;
neg[i] = i < 2 ? sdwa->neg[i] : false;
sel[i] = i < 2 ? sdwa->sel[i] : sdwa_udword;
}
} else {
for (unsigned i = 0; i < instr->operands.size(); ++i) {
abs[i] = false;
neg[i] = false;
sel[i] = sdwa_udword;
}
}
for (unsigned i = 0; i < instr->operands.size(); ++i) {
@ -575,7 +616,20 @@ void aco_print_instr(struct Instruction *instr, FILE *output)
fprintf(output, "-");
if (abs[i])
fprintf(output, "|");
if (sel[i] & sdwa_sext)
fprintf(output, "sext(");
print_operand(&instr->operands[i], output);
if (sel[i] & sdwa_sext)
fprintf(output, ")");
if ((sel[i] & sdwa_asuint) == sdwa_udword) {
/* print nothing */
} else if (sel[i] & sdwa_isword) {
unsigned index = sel[i] & sdwa_wordnum;
fprintf(output, "[%u:%u]", index * 16, index * 16 + 15);
} else {
unsigned index = sel[i] & sdwa_bytenum;
fprintf(output, "[%u:%u]", index * 8, index * 8 + 7);
}
if (abs[i])
fprintf(output, "|");
}

View File

@ -93,6 +93,50 @@ void validate(Program* program, FILE * output)
"Format cannot have VOP3A/VOP3B applied", instr.get());
}
/* check SDWA */
if (instr->isSDWA()) {
check(base_format == Format::VOP2 ||
base_format == Format::VOP1 ||
base_format == Format::VOPC,
"Format cannot have SDWA applied", instr.get());
check(program->chip_class >= GFX8, "SDWA is GFX8+ only", instr.get());
SDWA_instruction *sdwa = static_cast<SDWA_instruction*>(instr.get());
check(sdwa->omod == 0 || program->chip_class >= GFX9, "SDWA omod only supported on GFX9+", instr.get());
if (base_format == Format::VOPC) {
check(sdwa->clamp == false || program->chip_class == GFX8, "SDWA VOPC clamp only supported on GFX8", instr.get());
check((instr->definitions[0].isFixed() && instr->definitions[0].physReg() == vcc) ||
program->chip_class >= GFX9,
"SDWA+VOPC definition must be fixed to vcc on GFX8", instr.get());
}
if (instr->operands.size() >= 3) {
check(instr->operands[2].isFixed() && instr->operands[2].physReg() == vcc,
"3rd operand must be fixed to vcc with SDWA", instr.get());
}
if (instr->definitions.size() >= 2) {
check(instr->definitions[1].isFixed() && instr->definitions[1].physReg() == vcc,
"2nd definition must be fixed to vcc with SDWA", instr.get());
}
check(instr->opcode != aco_opcode::v_madmk_f32 &&
instr->opcode != aco_opcode::v_madak_f32 &&
instr->opcode != aco_opcode::v_madmk_f16 &&
instr->opcode != aco_opcode::v_madak_f16 &&
instr->opcode != aco_opcode::v_readfirstlane_b32 &&
instr->opcode != aco_opcode::v_clrexcp &&
instr->opcode != aco_opcode::v_swap_b32,
"SDWA can't be used with this opcode", instr.get());
if (program->chip_class != GFX8) {
check(instr->opcode != aco_opcode::v_mac_f32 &&
instr->opcode != aco_opcode::v_mac_f16 &&
instr->opcode != aco_opcode::v_fmac_f32 &&
instr->opcode != aco_opcode::v_fmac_f16,
"SDWA can't be used with this opcode", instr.get());
}
}
/* check for undefs */
for (unsigned i = 0; i < instr->operands.size(); i++) {
if (instr->operands[i].isUndefined()) {
@ -137,6 +181,10 @@ void validate(Program* program, FILE * output)
if (program->chip_class >= GFX10 && !is_shift64)
const_bus_limit = 2;
uint32_t scalar_mask = instr->isVOP3() ? 0x7 : 0x5;
if (instr->isSDWA())
scalar_mask = program->chip_class >= GFX9 ? 0x7 : 0x4;
check(instr->definitions[0].getTemp().type() == RegType::vgpr ||
(int) instr->format & (int) Format::VOPC ||
instr->opcode == aco_opcode::v_readfirstlane_b32 ||
@ -158,7 +206,7 @@ void validate(Program* program, FILE * output)
continue;
}
if (op.isTemp() && instr->operands[i].regClass().type() == RegType::sgpr) {
check(i != 1 || instr->isVOP3(), "Wrong source position for SGPR argument", instr.get());
check(scalar_mask & (1 << i), "Wrong source position for SGPR argument", instr.get());
if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
if (num_sgprs < 2)
@ -167,7 +215,7 @@ void validate(Program* program, FILE * output)
}
if (op.isConstant() && !op.isLiteral())
check(i == 0 || instr->isVOP3(), "Wrong source position for constant argument", instr.get());
check(scalar_mask & (1 << i), "Wrong source position for constant argument", instr.get());
}
check(num_sgprs + (literal.isUndefined() ? 0 : 1) <= const_bus_limit, "Too many SGPRs/literals", instr.get());
}