nir,amd: remove trinary_minmax opcodes

These consist of the variations nir_op_{i|u|f}{min|max|med}3 which are either
lowered in the backend (LLVM) anyway or can be recombined by the backend (ACO).

Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6421>
This commit is contained in:
Daniel Schürmann 2020-06-18 15:14:20 +01:00 committed by Marge Bot
parent 1fa43a4a8e
commit a79dad950b
9 changed files with 20 additions and 239 deletions

View File

@ -1793,84 +1793,6 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
}
break;
}
case nir_op_fmax3: {
if (dst.regClass() == v2b) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f16, dst, false);
} else if (dst.regClass() == v1) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f32, dst, ctx->block->fp_mode.must_flush_denorms32);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
case nir_op_fmin3: {
if (dst.regClass() == v2b) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f16, dst, false);
} else if (dst.regClass() == v1) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f32, dst, ctx->block->fp_mode.must_flush_denorms32);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
case nir_op_fmed3: {
if (dst.regClass() == v2b) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f16, dst, false);
} else if (dst.regClass() == v1) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f32, dst, ctx->block->fp_mode.must_flush_denorms32);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
case nir_op_umax3: {
if (dst.size() == 1) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_u32, dst);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
case nir_op_umin3: {
if (dst.size() == 1) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_u32, dst);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
case nir_op_umed3: {
if (dst.size() == 1) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_u32, dst);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
case nir_op_imax3: {
if (dst.size() == 1) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_i32, dst);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
case nir_op_imin3: {
if (dst.size() == 1) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_i32, dst);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
case nir_op_imed3: {
if (dst.size() == 1) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_i32, dst);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
case nir_op_cube_face_coord: {
Temp in = get_alu_src(ctx, instr->src[0], 3);
Temp src[3] = { emit_extract_vector(ctx, in, 0, v1),

View File

@ -600,9 +600,6 @@ void init_context(isel_context *ctx, nir_shader *shader)
case nir_op_fsub:
case nir_op_fmax:
case nir_op_fmin:
case nir_op_fmax3:
case nir_op_fmin3:
case nir_op_fmed3:
case nir_op_fneg:
case nir_op_fabs:
case nir_op_fsat:

View File

@ -2727,54 +2727,6 @@ void ac_build_waitcnt(struct ac_llvm_context *ctx, unsigned wait_flags)
ctx->voidt, args, 1, 0);
}
LLVMValueRef ac_build_fmed3(struct ac_llvm_context *ctx, LLVMValueRef src0,
LLVMValueRef src1, LLVMValueRef src2,
unsigned bitsize)
{
LLVMValueRef result;
if (bitsize == 64 || (bitsize == 16 && ctx->chip_class <= GFX8)) {
/* Lower 64-bit fmed because LLVM doesn't expose an intrinsic,
* or lower 16-bit fmed because it's only supported on GFX9+.
*/
LLVMValueRef min1, min2, max1;
min1 = ac_build_fmin(ctx, src0, src1);
max1 = ac_build_fmax(ctx, src0, src1);
min2 = ac_build_fmin(ctx, max1, src2);
result = ac_build_fmax(ctx, min2, min1);
} else {
LLVMTypeRef type;
char *intr;
if (bitsize == 16) {
intr = "llvm.amdgcn.fmed3.f16";
type = ctx->f16;
} else {
assert(bitsize == 32);
intr = "llvm.amdgcn.fmed3.f32";
type = ctx->f32;
}
LLVMValueRef params[] = {
src0,
src1,
src2,
};
result = ac_build_intrinsic(ctx, intr, type, params, 3,
AC_FUNC_ATTR_READNONE);
}
if (ctx->chip_class < GFX9 && bitsize == 32) {
/* Only pre-GFX9 chips do not flush denorms. */
result = ac_build_canonicalize(ctx, result, bitsize);
}
return result;
}
LLVMValueRef ac_build_fract(struct ac_llvm_context *ctx, LLVMValueRef src0,
unsigned bitsize)
{

View File

@ -1174,57 +1174,6 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
break;
}
case nir_op_fmin3:
result = emit_intrin_2f_param(&ctx->ac, "llvm.minnum",
ac_to_float_type(&ctx->ac, def_type), src[0], src[1]);
result = emit_intrin_2f_param(&ctx->ac, "llvm.minnum",
ac_to_float_type(&ctx->ac, def_type), result, src[2]);
break;
case nir_op_umin3:
result = ac_build_umin(&ctx->ac, src[0], src[1]);
result = ac_build_umin(&ctx->ac, result, src[2]);
break;
case nir_op_imin3:
result = ac_build_imin(&ctx->ac, src[0], src[1]);
result = ac_build_imin(&ctx->ac, result, src[2]);
break;
case nir_op_fmax3:
result = emit_intrin_2f_param(&ctx->ac, "llvm.maxnum",
ac_to_float_type(&ctx->ac, def_type), src[0], src[1]);
result = emit_intrin_2f_param(&ctx->ac, "llvm.maxnum",
ac_to_float_type(&ctx->ac, def_type), result, src[2]);
break;
case nir_op_umax3:
result = ac_build_umax(&ctx->ac, src[0], src[1]);
result = ac_build_umax(&ctx->ac, result, src[2]);
break;
case nir_op_imax3:
result = ac_build_imax(&ctx->ac, src[0], src[1]);
result = ac_build_imax(&ctx->ac, result, src[2]);
break;
case nir_op_fmed3: {
src[0] = ac_to_float(&ctx->ac, src[0]);
src[1] = ac_to_float(&ctx->ac, src[1]);
src[2] = ac_to_float(&ctx->ac, src[2]);
result = ac_build_fmed3(&ctx->ac, src[0], src[1], src[2],
instr->dest.dest.ssa.bit_size);
break;
}
case nir_op_imed3: {
LLVMValueRef tmp1 = ac_build_imin(&ctx->ac, src[0], src[1]);
LLVMValueRef tmp2 = ac_build_imax(&ctx->ac, src[0], src[1]);
tmp2 = ac_build_imin(&ctx->ac, tmp2, src[2]);
result = ac_build_imax(&ctx->ac, tmp1, tmp2);
break;
}
case nir_op_umed3: {
LLVMValueRef tmp1 = ac_build_umin(&ctx->ac, src[0], src[1]);
LLVMValueRef tmp2 = ac_build_umax(&ctx->ac, src[0], src[1]);
tmp2 = ac_build_umin(&ctx->ac, tmp2, src[2]);
result = ac_build_umax(&ctx->ac, tmp1, tmp2);
break;
}
default:
fprintf(stderr, "Unknown NIR alu instr: ");
nir_print_instr(&instr->instr, stderr);

View File

@ -838,12 +838,6 @@ nir_lower_int64_op_to_options_mask(nir_op opcode)
case nir_op_imax:
case nir_op_umin:
case nir_op_umax:
case nir_op_imin3:
case nir_op_imax3:
case nir_op_umin3:
case nir_op_umax3:
case nir_op_imed3:
case nir_op_umed3:
return nir_lower_minmax64;
case nir_op_iabs:
return nir_lower_iabs64;
@ -944,18 +938,6 @@ lower_int64_alu_instr(nir_builder *b, nir_instr *instr, void *_state)
return lower_umin64(b, src[0], src[1]);
case nir_op_umax:
return lower_umax64(b, src[0], src[1]);
case nir_op_imin3:
return lower_imin64(b, src[0], lower_imin64(b, src[1], src[2]));
case nir_op_imax3:
return lower_imax64(b, src[0], lower_imax64(b, src[1], src[2]));
case nir_op_umin3:
return lower_umin64(b, src[0], lower_umin64(b, src[1], src[2]));
case nir_op_umax3:
return lower_umax64(b, src[0], lower_umax64(b, src[1], src[2]));
case nir_op_imed3:
return lower_imax64(b, lower_imin64(b, lower_imax64(b, src[0], src[1]), src[2]), lower_imin64(b, src[0], src[1]));
case nir_op_umed3:
return lower_umax64(b, lower_umin64(b, lower_umax64(b, src[0], src[1]), src[2]), lower_umin64(b, src[0], src[1]));
case nir_op_iabs:
return lower_iabs64(b, src[0]);
case nir_op_ineg:

View File

@ -950,22 +950,8 @@ triop("flrp", tfloat, "", "src0 * (1 - src2) + src1 * src2")
# component on vectors). There are two versions, one for floating point
# bools (0.0 vs 1.0) and one for integer bools (0 vs ~0).
triop("fcsel", tfloat32, "", "(src0 != 0.0f) ? src1 : src2")
# 3 way min/max/med
triop("fmin3", tfloat, "", "fminf(src0, fminf(src1, src2))")
triop("imin3", tint, "", "MIN2(src0, MIN2(src1, src2))")
triop("umin3", tuint, "", "MIN2(src0, MIN2(src1, src2))")
triop("fmax3", tfloat, "", "fmaxf(src0, fmaxf(src1, src2))")
triop("imax3", tint, "", "MAX2(src0, MAX2(src1, src2))")
triop("umax3", tuint, "", "MAX2(src0, MAX2(src1, src2))")
triop("fmed3", tfloat, "", "fmaxf(fminf(fmaxf(src0, src1), src2), fminf(src0, src1))")
triop("imed3", tint, "", "MAX2(MIN2(MAX2(src0, src1), src2), MIN2(src0, src1))")
triop("umed3", tuint, "", "MAX2(MIN2(MAX2(src0, src1), src2), MIN2(src0, src1))")
opcode("bcsel", 0, tuint, [0, 0, 0],
[tbool1, tuint, tuint], False, "", "src0 ? src1 : src2")
opcode("b8csel", 0, tuint, [0, 0, 0],

View File

@ -1153,10 +1153,6 @@ optimizations.extend([
(('bcsel', a, ('bcsel', b, c, d), d), ('bcsel', ('iand', a, b), c, d)),
(('bcsel', a, b, ('bcsel', c, b, d)), ('bcsel', ('ior', a, c), b, d)),
(('fmin3@64', a, b, c), ('fmin@64', a, ('fmin@64', b, c))),
(('fmax3@64', a, b, c), ('fmax@64', a, ('fmax@64', b, c))),
(('fmed3@64', a, b, c), ('fmax@64', ('fmin@64', ('fmax@64', a, b), c), ('fmin@64', a, b))),
# Misc. lowering
(('fmod', a, b), ('fsub', a, ('fmul', b, ('ffloor', ('fdiv', a, b)))), 'options->lower_fmod'),
(('frem', a, b), ('fsub', a, ('fmul', b, ('ftrunc', ('fdiv', a, b)))), 'options->lower_fmod'),

View File

@ -1319,10 +1319,6 @@ nir_unsigned_upper_bound(nir_shader *shader, struct hash_table *range_ht,
case nir_op_udiv:
case nir_op_bcsel:
case nir_op_b32csel:
case nir_op_imax3:
case nir_op_imin3:
case nir_op_umax3:
case nir_op_umin3:
case nir_op_ubfe:
case nir_op_bfm:
case nir_op_f2u32:
@ -1405,16 +1401,6 @@ nir_unsigned_upper_bound(nir_shader *shader, struct hash_table *range_ht,
case nir_op_b32csel:
res = src1 > src2 ? src1 : src2;
break;
case nir_op_imax3:
case nir_op_imin3:
case nir_op_umax3:
src0 = src0 > src1 ? src0 : src1;
res = src0 > src2 ? src0 : src2;
break;
case nir_op_umin3:
src0 = src0 < src1 ? src0 : src1;
res = src0 < src2 ? src0 : src2;
break;
case nir_op_ubfe:
res = bitmask(MIN2(src2, scalar.def->bit_size));
break;

View File

@ -126,34 +126,45 @@ vtn_handle_amd_shader_trinary_minmax_instruction(struct vtn_builder *b, SpvOp ex
for (unsigned i = 0; i < num_inputs; i++)
src[i] = vtn_get_nir_ssa(b, w[i + 5]);
/* place constants at src[1-2] for easier constant-folding */
for (unsigned i = 1; i <= 2; i++) {
if (nir_src_as_const_value(nir_src_for_ssa(src[0]))) {
nir_ssa_def* tmp = src[i];
src[i] = src[0];
src[0] = tmp;
}
}
nir_ssa_def *def;
switch ((enum ShaderTrinaryMinMaxAMD)ext_opcode) {
case FMin3AMD:
def = nir_fmin3(nb, src[0], src[1], src[2]);
def = nir_fmin(nb, src[0], nir_fmin(nb, src[1], src[2]));
break;
case UMin3AMD:
def = nir_umin3(nb, src[0], src[1], src[2]);
def = nir_umin(nb, src[0], nir_umin(nb, src[1], src[2]));
break;
case SMin3AMD:
def = nir_imin3(nb, src[0], src[1], src[2]);
def = nir_imin(nb, src[0], nir_imin(nb, src[1], src[2]));
break;
case FMax3AMD:
def = nir_fmax3(nb, src[0], src[1], src[2]);
def = nir_fmax(nb, src[0], nir_fmax(nb, src[1], src[2]));
break;
case UMax3AMD:
def = nir_umax3(nb, src[0], src[1], src[2]);
def = nir_umax(nb, src[0], nir_umax(nb, src[1], src[2]));
break;
case SMax3AMD:
def = nir_imax3(nb, src[0], src[1], src[2]);
def = nir_imax(nb, src[0], nir_imax(nb, src[1], src[2]));
break;
case FMid3AMD:
def = nir_fmed3(nb, src[0], src[1], src[2]);
def = nir_fmin(nb, nir_fmax(nb, src[0], nir_fmin(nb, src[1], src[2])),
nir_fmax(nb, src[1], src[2]));
break;
case UMid3AMD:
def = nir_umed3(nb, src[0], src[1], src[2]);
def = nir_umin(nb, nir_umax(nb, src[0], nir_umin(nb, src[1], src[2])),
nir_umax(nb, src[1], src[2]));
break;
case SMid3AMD:
def = nir_imed3(nb, src[0], src[1], src[2]);
def = nir_imin(nb, nir_imax(nb, src[0], nir_imin(nb, src[1], src[2])),
nir_imax(nb, src[1], src[2]));
break;
default:
unreachable("unknown opcode\n");