aco: add p_extract/p_insert

These will let us make the SDWA optimizer much simpler than if we were to
recognize combinations of shift/and/bfe.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/3151>
This commit is contained in:
Rhys Perry 2020-08-12 14:35:15 +01:00 committed by Marge Bot
parent e9d1643288
commit 2f94353735
4 changed files with 207 additions and 7 deletions

View File

@ -1994,6 +1994,103 @@ void lower_to_hw_instr(Program* program)
Operand(reg.advance(4), s1), Operand(0u), Operand(scc, s1));
break;
}
case aco_opcode::p_extract:
{
assert(instr->operands[1].isConstant());
assert(instr->operands[2].isConstant());
assert(instr->operands[3].isConstant());
if (instr->definitions[0].regClass() == s1)
assert(instr->definitions.size() >= 2 && instr->definitions[1].physReg() == scc);
Definition dst = instr->definitions[0];
Operand op = instr->operands[0];
unsigned bits = instr->operands[2].constantValue();
unsigned index = instr->operands[1].constantValue();
unsigned offset = index * bits;
bool signext = !instr->operands[3].constantEquals(0);
if (dst.regClass() == s1) {
if (offset == (32 - bits)) {
bld.sop2(signext ? aco_opcode::s_ashr_i32 : aco_opcode::s_lshr_b32,
dst, bld.def(s1, scc), op, Operand(offset));
} else if (offset == 0 && signext && (bits == 8 || bits == 16)) {
bld.sop1(bits == 8 ? aco_opcode::s_sext_i32_i8 : aco_opcode::s_sext_i32_i16, dst, op);
} else {
bld.sop2(signext ? aco_opcode::s_bfe_i32 : aco_opcode::s_bfe_u32,
dst, bld.def(s1, scc), op, Operand((bits << 16) | offset));
}
} else if (dst.regClass() == v1 || ctx.program->chip_class <= GFX7) {
assert(op.physReg().byte() == 0 && dst.physReg().byte() == 0);
if (offset == (32 - bits) && op.regClass() != s1) {
bld.vop2(signext ? aco_opcode::v_ashrrev_i32 : aco_opcode::v_lshrrev_b32,
dst, Operand(offset), op);
} else {
bld.vop3(signext ? aco_opcode::v_bfe_i32 : aco_opcode::v_bfe_u32,
dst, op, Operand(offset), Operand(bits));
}
} else if (dst.regClass() == v2b) {
aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(
aco_opcode::v_mov_b32, (Format)((uint16_t)Format::VOP1|(uint16_t)Format::SDWA), 1, 1)};
sdwa->operands[0] = Operand(op.physReg().advance(-op.physReg().byte()),
RegClass::get(op.regClass().type(), 4));
sdwa->definitions[0] = dst;
sdwa->sel[0] = sdwa_ubyte0 + op.physReg().byte() + index;
if (signext)
sdwa->sel[0] |= sdwa_sext;
sdwa->dst_sel = sdwa_uword;
bld.insert(std::move(sdwa));
}
break;
}
case aco_opcode::p_insert:
{
assert(instr->operands[1].isConstant());
assert(instr->operands[2].isConstant());
if (instr->definitions[0].regClass() == s1)
assert(instr->definitions.size() >= 2 && instr->definitions[1].physReg() == scc);
Definition dst = instr->definitions[0];
Operand op = instr->operands[0];
unsigned bits = instr->operands[2].constantValue();
unsigned index = instr->operands[1].constantValue();
unsigned offset = index * bits;
if (dst.regClass() == s1) {
if (offset == (32 - bits)) {
bld.sop2(aco_opcode::s_lshl_b32, dst, bld.def(s1, scc), op, Operand(offset));
} else if (offset == 0) {
bld.sop2(aco_opcode::s_bfe_u32, dst, bld.def(s1, scc), op, Operand(bits << 16));
} else {
bld.sop2(aco_opcode::s_bfe_u32, dst, bld.def(s1, scc), op, Operand(bits << 16));
bld.sop2(aco_opcode::s_lshl_b32, dst, bld.def(s1, scc), Operand(dst.physReg(), s1), Operand(offset));
}
} else if (dst.regClass() == v1 || ctx.program->chip_class <= GFX7) {
if (offset == (dst.bytes() * 8u - bits)) {
bld.vop2(aco_opcode::v_lshlrev_b32, dst, Operand(offset), op);
} else if (offset == 0) {
bld.vop3(aco_opcode::v_bfe_u32, dst, op, Operand(0u), Operand(bits));
} else if (program->chip_class >= GFX9 || (op.regClass() != s1 && program->chip_class >= GFX8)) {
aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, (Format)((uint16_t)Format::VOP1|(uint16_t)Format::SDWA), 1, 1)};
sdwa->operands[0] = op;
sdwa->definitions[0] = dst;
sdwa->sel[0] = sdwa_udword;
sdwa->dst_sel = (bits == 8 ? sdwa_ubyte0 : sdwa_uword0) + (offset / bits);
bld.insert(std::move(sdwa));
} else {
bld.vop3(aco_opcode::v_bfe_u32, dst, op, Operand(0u), Operand(bits));
bld.vop2(aco_opcode::v_lshlrev_b32, dst, Operand(offset), Operand(dst.physReg(), v1));
}
} else {
assert(dst.regClass() == v2b);
aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(
aco_opcode::v_mov_b32, (Format)((uint16_t)Format::VOP1|(uint16_t)Format::SDWA), 1, 1)};
sdwa->operands[0] = op;
sdwa->definitions[0] = Definition(dst.physReg().advance(-dst.physReg().byte()), v1);
sdwa->sel[0] = sdwa_uword;
sdwa->dst_sel = sdwa_ubyte0 + dst.physReg().byte() + index;
sdwa->dst_preserve = 1;
bld.insert(std::move(sdwa));
}
break;
}
default:
break;
}

View File

@ -320,6 +320,14 @@ opcode("p_bpermute")
opcode("p_constaddr")
# These don't have to be pseudo-ops, but it makes optimization easier to only
# have to consider two instructions.
# (src0 >> (index * bits)) & ((1 << bits) - 1) with optional sign extension
opcode("p_extract") # src1=index, src2=bits, src3=signext
# (src0 & ((1 << bits) - 1)) << (index * bits)
opcode("p_insert") # src1=index, src2=bits
# SOP2 instructions: 2 scalar inputs, 1 scalar output (+optional scc)
SOP2 = {
# GFX6, GFX7, GFX8, GFX9, GFX10, name

View File

@ -763,6 +763,8 @@ bool alu_can_accept_constant(aco_opcode opcode, unsigned operand)
case aco_opcode::v_readlane_b32:
case aco_opcode::v_readlane_b32_e64:
case aco_opcode::v_readfirstlane_b32:
case aco_opcode::p_extract:
case aco_opcode::p_insert:
return operand != 0;
default:
return true;
@ -1610,6 +1612,16 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
if (instr->operands[0].constantEquals(0x3f800000u))
ctx.info[instr->definitions[0].tempId()].set_canonicalized();
break;
case aco_opcode::p_extract: {
if (instr->operands[0].isTemp())
ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
break;
}
case aco_opcode::p_insert: {
if (instr->operands[0].isTemp())
ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
break;
}
default:
break;
}
@ -2210,6 +2222,70 @@ bool combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode
return false;
}
/* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */
bool combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
bool is_or = instr->opcode == aco_opcode::v_or_b32;
aco_opcode new_op_lshl = is_or ? aco_opcode::v_lshl_or_b32 : aco_opcode::v_lshl_add_u32;
if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2))
return true;
if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2))
return true;
if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, new_op_lshl, "120", 1 | 2))
return true;
if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, new_op_lshl, "210", 1 | 2))
return true;
if (instr->isSDWA())
return false;
/* v_or_b32(p_extract(a, 0, 8/16, 0), b) -> v_and_or_b32(a, 0xff/0xffff, b)
* v_or_b32(p_insert(a, 0, 8/16), b) -> v_and_or_b32(a, 0xff/0xffff, b)
* v_or_b32(p_insert(a, 24/16, 8/16), b) -> v_lshl_or_b32(a, 24/16, b)
* v_add_u32(p_insert(a, 24/16, 8/16), b) -> v_lshl_add_b32(a, 24/16, b)
*/
for (unsigned i = 0; i < 2; i++) {
Instruction *extins = follow_operand(ctx, instr->operands[i]);
if (!extins)
continue;
aco_opcode op;
Operand operands[3];
if (extins->opcode == aco_opcode::p_insert &&
(extins->operands[1].constantValue() + 1) * extins->operands[2].constantValue() == 32) {
op = new_op_lshl;
operands[1] = Operand(extins->operands[1].constantValue() * extins->operands[2].constantValue());
} else if (is_or && (extins->opcode == aco_opcode::p_insert ||
(extins->opcode == aco_opcode::p_extract && extins->operands[3].constantEquals(0))) &&
extins->operands[1].constantEquals(0)) {
op = aco_opcode::v_and_or_b32;
operands[1] = Operand(extins->operands[2].constantEquals(8) ? 0xffu : 0xffffu);
} else {
continue;
}
operands[0] = extins->operands[0];
operands[2] = instr->operands[!i];
if (!check_vop3_operands(ctx, 3, operands))
continue;
bool neg[3] = {}, abs[3] = {};
uint8_t opsel = 0, omod = 0;
bool clamp = false;
if (instr->isVOP3())
clamp = instr->vop3().clamp;
ctx.uses[instr->operands[i].tempId()]--;
create_vop3_for_op3(ctx, op, instr, operands, neg, abs, opsel, clamp, omod);
return true;
}
return false;
}
bool combine_minmax(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode opposite, aco_opcode minmax3)
{
if (combine_three_valu_op(ctx, instr, instr->opcode, minmax3, "012", 1 | 2))
@ -3198,10 +3274,7 @@ void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
} else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->chip_class >= GFX9) {
if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012", 1 | 2)) ;
else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32, "012", 1 | 2)) ;
else if (combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2)) ;
else if (combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2)) ;
else if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, aco_opcode::v_lshl_or_b32, "120", 1 | 2)) ;
else combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, aco_opcode::v_lshl_or_b32, "210", 1 | 2);
else combine_add_or_then_and_lshl(ctx, instr) ;
} else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->chip_class >= GFX10) {
if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012", 1 | 2)) ;
else combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xor3_b32, "012", 1 | 2);
@ -3215,9 +3288,8 @@ void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
else if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, aco_opcode::v_lshl_add_u32, "120", 1 | 2)) ;
else if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, aco_opcode::v_lshl_add_u32, "210", 1 | 2)) ;
else combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16, aco_opcode::v_mad_u32_u16, "120", 1 | 2) ;
else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16, aco_opcode::v_mad_u32_u16, "120", 1 | 2)) ;
else combine_add_or_then_and_lshl(ctx, instr) ;
}
} else if (instr->opcode == aco_opcode::v_add_co_u32 ||
instr->opcode == aco_opcode::v_add_co_u32_e64) {

View File

@ -376,6 +376,29 @@ bool validate_ir(Program* program)
check(instr->definitions[0].size() == op.size(), "Operand sizes must match Definition size", instr.get());
}
check(instr->operands.size() == block.linear_preds.size(), "Number of Operands does not match number of predecessors", instr.get());
} else if (instr->opcode == aco_opcode::p_extract || instr->opcode == aco_opcode::p_insert) {
check(instr->operands[0].isTemp(),
"Data operand must be temporary", instr.get());
check(instr->operands[1].isConstant(), "Index must be constant", instr.get());
if (instr->opcode == aco_opcode::p_extract)
check(instr->operands[3].isConstant(), "Sign-extend flag must be constant", instr.get());
check(instr->definitions[0].getTemp().type() != RegType::sgpr ||
instr->operands[0].getTemp().type() == RegType::sgpr,
"Can't extract/insert VGPR to SGPR", instr.get());
if (instr->operands[0].getTemp().type() == RegType::vgpr)
check(instr->operands[0].bytes() == instr->definitions[0].bytes(),
"Sizes of operand and definition must match", instr.get());
if (instr->definitions[0].getTemp().type() == RegType::sgpr)
check(instr->definitions.size() >= 2 && instr->definitions[1].isFixed() && instr->definitions[1].physReg() == scc, "SGPR extract/insert needs a SCC definition", instr.get());
check(instr->operands[2].constantEquals(8) || instr->operands[2].constantEquals(16), "Size must be 8 or 16", instr.get());
check(instr->operands[2].constantValue() < instr->operands[0].getTemp().bytes() * 8u, "Size must be smaller than source", instr.get());
unsigned comp = instr->operands[0].bytes() * 8u / MAX2(instr->operands[2].constantValue(), 1);
check(instr->operands[1].constantValue() < comp, "Index must be in-bounds", instr.get());
}
break;
}