gallivm/nir: handle subgroup reduction across all types

Reviewed-by: Roland Scheidegger <sroland@vmware.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/11816>
This commit is contained in:
Dave Airlie 2021-09-07 14:51:48 +10:00 committed by Marge Bot
parent 3a27e406ed
commit 143167f2a0
1 changed files with 78 additions and 8 deletions

View File

@ -2030,36 +2030,106 @@ static void emit_reduce(struct lp_build_nir_context *bld_base, LLVMValueRef src,
switch (reduction_op) {
case nir_op_fmin: {
LLVMValueRef flt_max = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), INFINITY) :
lp_build_const_float(gallivm, INFINITY);
(bit_size == 16 ? LLVMConstReal(LLVMHalfTypeInContext(gallivm->context), INFINITY) : lp_build_const_float(gallivm, INFINITY));
store_val = LLVMBuildBitCast(builder, flt_max, int_bld->elem_type, "");
break;
}
case nir_op_fmax: {
LLVMValueRef flt_min = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), -INFINITY) :
lp_build_const_float(gallivm, -INFINITY);
(bit_size == 16 ? LLVMConstReal(LLVMHalfTypeInContext(gallivm->context), -INFINITY) : lp_build_const_float(gallivm, -INFINITY));
store_val = LLVMBuildBitCast(builder, flt_min, int_bld->elem_type, "");
break;
}
case nir_op_fmul: {
LLVMValueRef flt_one = bit_size == 64 ? LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), 1.0) :
lp_build_const_float(gallivm, 1.0);
(bit_size == 16 ? LLVMConstReal(LLVMHalfTypeInContext(gallivm->context), 1.0) : lp_build_const_float(gallivm, 1.0));
store_val = LLVMBuildBitCast(builder, flt_one, int_bld->elem_type, "");
break;
}
case nir_op_umin:
store_val = lp_build_const_int32(gallivm, UINT_MAX);
switch (bit_size) {
case 8:
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), UINT8_MAX, 0);
break;
case 16:
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), UINT16_MAX, 0);
break;
case 32:
default:
store_val = lp_build_const_int32(gallivm, UINT_MAX);
break;
case 64:
store_val = lp_build_const_int64(gallivm, UINT64_MAX);
break;
}
break;
case nir_op_imin:
store_val = lp_build_const_int32(gallivm, INT_MAX);
switch (bit_size) {
case 8:
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), INT8_MAX, 0);
break;
case 16:
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), INT16_MAX, 0);
break;
case 32:
default:
store_val = lp_build_const_int32(gallivm, INT_MAX);
break;
case 64:
store_val = lp_build_const_int64(gallivm, INT64_MAX);
break;
}
break;
case nir_op_imax:
store_val = lp_build_const_int32(gallivm, INT_MIN);
switch (bit_size) {
case 8:
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), INT8_MIN, 0);
break;
case 16:
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), INT16_MIN, 0);
break;
case 32:
default:
store_val = lp_build_const_int32(gallivm, INT_MIN);
break;
case 64:
store_val = lp_build_const_int64(gallivm, INT64_MIN);
break;
}
break;
case nir_op_imul:
store_val = lp_build_const_int32(gallivm, 1);
switch (bit_size) {
case 8:
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), 1, 0);
break;
case 16:
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), 1, 0);
break;
case 32:
default:
store_val = lp_build_const_int32(gallivm, 1);
break;
case 64:
store_val = lp_build_const_int64(gallivm, 1);
break;
}
break;
case nir_op_iand:
store_val = lp_build_const_int32(gallivm, 0xffffffff);
switch (bit_size) {
case 8:
store_val = LLVMConstInt(LLVMInt8TypeInContext(gallivm->context), 0xff, 0);
break;
case 16:
store_val = LLVMConstInt(LLVMInt16TypeInContext(gallivm->context), 0xffff, 0);
break;
case 32:
default:
store_val = lp_build_const_int32(gallivm, 0xffffffff);
break;
case 64:
store_val = lp_build_const_int64(gallivm, 0xffffffffffffffffLL);
break;
}
break;
default:
break;