From faa2a894876a387c8945bb46f6ce71f495db1d44 Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Fri, 19 Nov 2021 16:28:52 +0100 Subject: [PATCH] aco: Implement usub_sat. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Georg Lehmann Reviewed-by: Timur Kristóf Part-of: --- .../compiler/aco_instruction_selection.cpp | 96 +++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index e09c2c281c9..b98fcaf85b6 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -1345,6 +1345,25 @@ uadd32_sat(Builder& bld, Definition dst, Temp src0, Temp src1) return dst.getTemp(); } +Temp +usub32_sat(Builder& bld, Definition dst, Temp src0, Temp src1) +{ + if (bld.program->gfx_level < GFX8) { + Builder::Result sub = bld.vsub32(bld.def(v1), src0, src1, true); + return bld.vop2_e64(aco_opcode::v_cndmask_b32, dst, sub.def(0).getTemp(), Operand::c32(0u), + sub.def(1).getTemp()); + } + + Builder::Result sub(NULL); + if (bld.program->gfx_level >= GFX9) { + sub = bld.vop2_e64(aco_opcode::v_sub_u32, dst, src0, src1); + } else { + sub = bld.vop2_e64(aco_opcode::v_sub_co_u32, dst, bld.def(bld.lm), src0, src1); + } + sub.instr->vop3().clamp = 1; + return dst.getTemp(); +} + void visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) { @@ -2082,6 +2101,83 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) } break; } + case nir_op_usub_sat: { + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = get_alu_src(ctx, instr->src[1]); + if (dst.regClass() == s1) { + Temp tmp = bld.tmp(s1), carry = bld.tmp(s1); + bld.sop2(aco_opcode::s_sub_u32, Definition(tmp), bld.scc(Definition(carry)), src0, src1); + bld.sop2(aco_opcode::s_cselect_b32, Definition(dst), Operand::c32(0), tmp, bld.scc(carry)); + break; + } else if (dst.regClass() == v2b) { + Instruction* sub_instr; + if (ctx->program->gfx_level >= GFX10) { + sub_instr = bld.vop3(aco_opcode::v_sub_u16_e64, Definition(dst), src0, src1).instr; + } else { + aco_opcode op = aco_opcode::v_sub_u16; + if (src1.type() == RegType::sgpr) { + std::swap(src0, src1); + op = aco_opcode::v_subrev_u16; + } + sub_instr = bld.vop2_e64(op, Definition(dst), src0, as_vgpr(ctx, src1)).instr; + } + sub_instr->vop3().clamp = 1; + break; + } else if (dst.regClass() == v1) { + usub32_sat(bld, Definition(dst), src0, as_vgpr(ctx, src1)); + break; + } + + assert(src0.size() == 2 && src1.size() == 2); + Temp src00 = bld.tmp(src0.type(), 1); + Temp src01 = bld.tmp(src0.type(), 1); + bld.pseudo(aco_opcode::p_split_vector, Definition(src00), Definition(src01), src0); + Temp src10 = bld.tmp(src1.type(), 1); + Temp src11 = bld.tmp(src1.type(), 1); + bld.pseudo(aco_opcode::p_split_vector, Definition(src10), Definition(src11), src1); + + if (dst.regClass() == s2) { + Temp carry0 = bld.tmp(s1); + Temp carry1 = bld.tmp(s1); + + Temp no_sat0 = + bld.sop2(aco_opcode::s_sub_u32, bld.def(s1), bld.scc(Definition(carry0)), src00, src10); + Temp no_sat1 = bld.sop2(aco_opcode::s_subb_u32, bld.def(s1), bld.scc(Definition(carry1)), + src01, src11, bld.scc(carry0)); + + Temp no_sat = bld.pseudo(aco_opcode::p_create_vector, bld.def(s2), no_sat0, no_sat1); + + bld.sop2(aco_opcode::s_cselect_b64, Definition(dst), Operand::c64(0ull), no_sat, + bld.scc(carry1)); + } else if (dst.regClass() == v2) { + Temp no_sat0 = bld.tmp(v1); + Temp dst0 = bld.tmp(v1); + Temp dst1 = bld.tmp(v1); + + Temp carry0 = bld.vsub32(Definition(no_sat0), src00, src10, true).def(1).getTemp(); + Temp carry1; + + if (ctx->program->gfx_level >= GFX8) { + carry1 = bld.tmp(bld.lm); + bld.vop2_e64(aco_opcode::v_subb_co_u32, Definition(dst1), Definition(carry1), + as_vgpr(ctx, src01), as_vgpr(ctx, src11), carry0) + .instr->vop3() + .clamp = 1; + } else { + Temp no_sat1 = bld.tmp(v1); + carry1 = bld.vsub32(Definition(no_sat1), src01, src11, true, carry0).def(1).getTemp(); + bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst1), no_sat1, Operand::c32(0u), + carry1); + } + + bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst0), no_sat0, Operand::c32(0u), + carry1); + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), dst0, dst1); + } else { + isel_err(&instr->instr, "Unimplemented NIR instr bit size"); + } + break; + } case nir_op_imul: { if (dst.bytes() <= 2 && ctx->program->gfx_level >= GFX10) { emit_vop3a_instruction(ctx, instr, aco_opcode::v_mul_lo_u16_e64, dst);