aco: implement some 16-bit arithmetic instead of lowering

fossil-db (parallel-rdp, Navi):
Totals from 210 (30.75% of 683) affected shaders:
SGPRs: 9704 -> 10248 (+5.61%)
VGPRs: 5884 -> 5368 (-8.77%)
CodeSize: 1155564 -> 1098752 (-4.92%)
Instrs: 199927 -> 189940 (-5.00%)
Cycles: 20438392 -> 19860124 (-2.83%)

v2: use divergence analysis to determine which instructions to lower.

Co-Authored-by: Daniel Schürmann <daniel@schuermann.dev>
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4791>
This commit is contained in:
Rhys Perry 2020-04-27 21:17:26 +01:00 committed by Marge Bot
parent 8ed7cad75f
commit ef95ba8cdd
3 changed files with 92 additions and 27 deletions

View File

@ -391,7 +391,8 @@ public:
Result vadd32(Definition dst, Op a, Op b, bool carry_out=false, Op carry_in=Op(Operand(s2)), bool post_ra=false) {
if (!b.op.isTemp() || b.op.regClass().type() != RegType::vgpr)
std::swap(a, b);
assert((post_ra || b.op.hasRegClass()) && b.op.regClass().type() == RegType::vgpr);
if (!post_ra && (!b.op.hasRegClass() || b.op.regClass().type() == RegType::sgpr))
b = copy(def(v1), b);
if (!carry_in.op.isUndefined())
return vop2(aco_opcode::v_addc_co_u32, Definition(dst), hint_vcc(def(lm)), a, b, carry_in);
@ -411,7 +412,8 @@ public:
bool reverse = !b.op.isTemp() || b.op.regClass().type() != RegType::vgpr;
if (reverse)
std::swap(a, b);
assert(b.op.isTemp() && b.op.regClass().type() == RegType::vgpr);
if (!b.op.hasRegClass() || b.op.regClass().type() == RegType::sgpr)
b = copy(def(v1), b);
aco_opcode op;
Temp carry;

View File

@ -784,13 +784,13 @@ void emit_vop2_instruction_logic64(isel_context *ctx, nir_alu_instr *instr,
}
void emit_vop3a_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst,
bool flush_denorms = false, unsigned num_sources = 2)
bool flush_denorms = false, unsigned num_sources = 2, bool swap_srcs = false)
{
assert(num_sources == 2 || num_sources == 3);
Temp src[3] = { Temp(0, v1), Temp(0, v1), Temp(0, v1) };
bool has_sgpr = false;
for (unsigned i = 0; i < num_sources; i++) {
src[i] = get_alu_src(ctx, instr->src[i]);
src[i] = get_alu_src(ctx, instr->src[swap_srcs ? 1 - i : i]);
if (has_sgpr)
src[i] = as_vgpr(ctx, src[i]);
else
@ -1307,7 +1307,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
break;
}
case nir_op_imax: {
if (dst.regClass() == v1) {
if (dst.regClass() == v2b && ctx->program->chip_class >= GFX10) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_max_i16_e64, dst);
} else if (dst.regClass() == v2b) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_max_i16, dst, true);
} else if (dst.regClass() == v1) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_max_i32, dst, true);
} else if (dst.regClass() == s1) {
emit_sop2_instruction(ctx, instr, aco_opcode::s_max_i32, dst, true);
@ -1317,7 +1321,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
break;
}
case nir_op_umax: {
if (dst.regClass() == v1) {
if (dst.regClass() == v2b && ctx->program->chip_class >= GFX10) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_max_u16_e64, dst);
} else if (dst.regClass() == v2b) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_max_u16, dst, true);
} else if (dst.regClass() == v1) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_max_u32, dst, true);
} else if (dst.regClass() == s1) {
emit_sop2_instruction(ctx, instr, aco_opcode::s_max_u32, dst, true);
@ -1327,7 +1335,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
break;
}
case nir_op_imin: {
if (dst.regClass() == v1) {
if (dst.regClass() == v2b && ctx->program->chip_class >= GFX10) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_min_i16_e64, dst);
} else if (dst.regClass() == v2b) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_min_i16, dst, true);
} else if (dst.regClass() == v1) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_min_i32, dst, true);
} else if (dst.regClass() == s1) {
emit_sop2_instruction(ctx, instr, aco_opcode::s_min_i32, dst, true);
@ -1337,7 +1349,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
break;
}
case nir_op_umin: {
if (dst.regClass() == v1) {
if (dst.regClass() == v2b && ctx->program->chip_class >= GFX10) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_min_u16_e64, dst);
} else if (dst.regClass() == v2b) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_min_u16, dst, true);
} else if (dst.regClass() == v1) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_min_u32, dst, true);
} else if (dst.regClass() == s1) {
emit_sop2_instruction(ctx, instr, aco_opcode::s_min_u32, dst, true);
@ -1395,7 +1411,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
break;
}
case nir_op_ushr: {
if (dst.regClass() == v1) {
if (dst.regClass() == v2b && ctx->program->chip_class >= GFX10) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_lshrrev_b16_e64, dst, false, 2, true);
} else if (dst.regClass() == v2b) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_lshrrev_b16, dst, false, true);
} else if (dst.regClass() == v1) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_lshrrev_b32, dst, false, true);
} else if (dst.regClass() == v2 && ctx->program->chip_class >= GFX8) {
bld.vop3(aco_opcode::v_lshrrev_b64, Definition(dst),
@ -1412,7 +1432,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
break;
}
case nir_op_ishl: {
if (dst.regClass() == v1) {
if (dst.regClass() == v2b && ctx->program->chip_class >= GFX10) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_lshlrev_b16_e64, dst, false, 2, true);
} else if (dst.regClass() == v2b) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_lshlrev_b16, dst, false, true);
} else if (dst.regClass() == v1) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_lshlrev_b32, dst, false, true);
} else if (dst.regClass() == v2 && ctx->program->chip_class >= GFX8) {
bld.vop3(aco_opcode::v_lshlrev_b64, Definition(dst),
@ -1429,7 +1453,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
break;
}
case nir_op_ishr: {
if (dst.regClass() == v1) {
if (dst.regClass() == v2b && ctx->program->chip_class >= GFX10) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_ashrrev_i16_e64, dst, false, 2, true);
} else if (dst.regClass() == v2b) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_ashrrev_i16, dst, false, true);
} else if (dst.regClass() == v1) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_ashrrev_i32, dst, false, true);
} else if (dst.regClass() == v2 && ctx->program->chip_class >= GFX8) {
bld.vop3(aco_opcode::v_ashrrev_i64, Definition(dst),
@ -1499,6 +1527,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
if (dst.regClass() == s1) {
emit_sop2_instruction(ctx, instr, aco_opcode::s_add_u32, dst, true);
break;
} else if (dst.regClass() == v2b && ctx->program->chip_class < GFX10) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_add_u16, dst, true);
break;
} else if (dst.regClass() == v2b) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_add_u16_e64, dst);
break;
}
Temp src0 = get_alu_src(ctx, instr->src[0]);
@ -1539,6 +1573,16 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
bld.sop2(aco_opcode::s_add_u32, Definition(tmp), bld.scc(Definition(carry)),
src0, src1);
bld.sop2(aco_opcode::s_cselect_b32, Definition(dst), Operand((uint32_t) -1), tmp, bld.scc(carry));
} else if (dst.regClass() == v2b) {
Instruction *instr;
if (ctx->program->chip_class >= GFX10) {
instr = bld.vop3(aco_opcode::v_add_u16_e64, Definition(dst), src0, src1).instr;
} else {
if (src1.type() == RegType::sgpr)
std::swap(src0, src1);
instr = bld.vop2_e64(aco_opcode::v_add_u16, Definition(dst), src0, as_vgpr(ctx, src1)).instr;
}
static_cast<VOP3A_instruction*>(instr)->clamp = 1;
} else if (dst.regClass() == v1) {
if (ctx->options->chip_class >= GFX9) {
aco_ptr<VOP3A_instruction> add{create_instruction<VOP3A_instruction>(aco_opcode::v_add_u32, asVOP3(Format::VOP2), 2, 1)};
@ -1605,6 +1649,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
if (dst.regClass() == v1) {
bld.vsub32(Definition(dst), src0, src1);
break;
} else if (dst.regClass() == v2b) {
if (ctx->program->chip_class >= GFX10)
bld.vop3(aco_opcode::v_sub_u16_e64, Definition(dst), src0, src1);
else if (src1.type() == RegType::sgpr)
bld.vop2(aco_opcode::v_subrev_u16, Definition(dst), src1, as_vgpr(ctx, src0));
else
bld.vop2(aco_opcode::v_sub_u16, Definition(dst), src0, as_vgpr(ctx, src1));
break;
}
Temp src00 = bld.tmp(src0.type(), 1);
@ -1671,6 +1723,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
} else {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_lo_u32, dst);
}
} else if (dst.regClass() == v2b && ctx->program->chip_class >= GFX10) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_lo_u16_e64, dst);
} else if (dst.regClass() == v2b) {
emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_lo_u16, dst, true);
} else if (dst.regClass() == s1) {
emit_sop2_instruction(ctx, instr, aco_opcode::s_mul_i32, dst, false);
} else {

View File

@ -2954,51 +2954,58 @@ lower_bit_size_callback(const nir_alu_instr *alu, void *_)
enum chip_class chip = device->physical_device->rad_info.chip_class;
if (alu->dest.dest.ssa.bit_size & (8 | 16)) {
unsigned bit_size = alu->dest.dest.ssa.bit_size;
switch (alu->op) {
case nir_op_iabs:
case nir_op_iadd:
case nir_op_iand:
case nir_op_bitfield_select:
case nir_op_udiv:
case nir_op_idiv:
case nir_op_imax:
case nir_op_umax:
case nir_op_imin:
case nir_op_umin:
case nir_op_umod:
case nir_op_imod:
case nir_op_imul:
case nir_op_imul_high:
case nir_op_umul_high:
case nir_op_ineg:
case nir_op_inot:
case nir_op_ior:
case nir_op_irem:
case nir_op_ishl:
case nir_op_ishr:
case nir_op_ushr:
case nir_op_isign:
case nir_op_isub:
case nir_op_ixor:
return 32;
case nir_op_imax:
case nir_op_umax:
case nir_op_imin:
case nir_op_umin:
case nir_op_ishr:
case nir_op_ushr:
case nir_op_ishl:
case nir_op_iadd:
case nir_op_uadd_sat:
case nir_op_isub:
case nir_op_imul:
return (bit_size == 8 ||
!(chip >= GFX8 && nir_dest_is_divergent(alu->dest.dest))) ? 32 : 0;
default:
return 0;
}
}
if (nir_src_bit_size(alu->src[0].src) & (8 | 16)) {
unsigned bit_size = nir_src_bit_size(alu->src[0].src);
switch (alu->op) {
case nir_op_bit_count:
case nir_op_find_lsb:
case nir_op_ufind_msb:
case nir_op_ieq:
case nir_op_ige:
case nir_op_uge:
case nir_op_ilt:
case nir_op_ult:
case nir_op_ine:
case nir_op_i2b1:
return 32;
case nir_op_ilt:
case nir_op_ige:
case nir_op_ieq:
case nir_op_ine:
case nir_op_ult:
case nir_op_uge:
return (bit_size == 8 ||
!(chip >= GFX8 && nir_dest_is_divergent(alu->dest.dest))) ? 32 : 0;
default:
return 0;
}