amd/llvm: Add Subgroup Scan functions for SI
The idea of this implementation is taken from the ROCm Device Libs: https://github.com/RadeonOpenCompute/ROCm-Device-Libs/blob/master/ockl/src/wfredscan.cl Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
This commit is contained in:
parent
fca2d3ce3f
commit
0cbcfc071e
|
@ -4042,8 +4042,6 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs,
|
|||
/**
|
||||
* \param maxprefix specifies that the result only needs to be correct for a
|
||||
* prefix of this many threads
|
||||
*
|
||||
* TODO: add inclusive and excluse scan functions for GFX6.
|
||||
*/
|
||||
static LLVMValueRef
|
||||
ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
|
||||
|
@ -4051,13 +4049,84 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu
|
|||
{
|
||||
LLVMValueRef result, tmp;
|
||||
|
||||
if (ctx->chip_class >= GFX10) {
|
||||
result = inclusive ? src : identity;
|
||||
} else {
|
||||
if (!inclusive)
|
||||
src = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
|
||||
if (inclusive) {
|
||||
result = src;
|
||||
} else if (ctx->chip_class >= GFX10) {
|
||||
result = identity;
|
||||
} else if (ctx->chip_class >= GFX8) {
|
||||
src = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
|
||||
result = src;
|
||||
} else {
|
||||
/* wavefront shift_right by 1 on SI/CI */
|
||||
LLVMValueRef active, tmp1, tmp2;
|
||||
LLVMValueRef tid = ac_get_thread_id(ctx);
|
||||
tmp1 = ac_build_ds_swizzle(ctx, src, (1 << 15) | dpp_quad_perm(0, 0, 1, 2));
|
||||
tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x18, 0x03, 0x00));
|
||||
active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
|
||||
LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x7, 0), ""),
|
||||
LLVMConstInt(ctx->i32, 0x4, 0), "");
|
||||
tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
|
||||
tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x10, 0x07, 0x00));
|
||||
active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
|
||||
LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0xf, 0), ""),
|
||||
LLVMConstInt(ctx->i32, 0x8, 0), "");
|
||||
tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
|
||||
tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x00, 0x0f, 0x00));
|
||||
active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
|
||||
LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, 0), ""),
|
||||
LLVMConstInt(ctx->i32, 0x10, 0), "");
|
||||
tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
|
||||
tmp2 = ac_build_readlane(ctx, src, LLVMConstInt(ctx->i32, 31, 0));
|
||||
active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 32, 0), "");
|
||||
tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
|
||||
active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 0, 0), "");
|
||||
src = LLVMBuildSelect(ctx->builder, active, identity, tmp1, "");
|
||||
result = src;
|
||||
}
|
||||
|
||||
if (ctx->chip_class <= GFX7) {
|
||||
assert(maxprefix == 64);
|
||||
LLVMValueRef tid = ac_get_thread_id(ctx);
|
||||
LLVMValueRef active;
|
||||
tmp = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x1e, 0x00, 0x00));
|
||||
active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
|
||||
LLVMBuildAnd(ctx->builder, tid, ctx->i32_1, ""),
|
||||
ctx->i32_0, "");
|
||||
tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x1c, 0x01, 0x00));
|
||||
active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
|
||||
LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 2, 0), ""),
|
||||
ctx->i32_0, "");
|
||||
tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x18, 0x03, 0x00));
|
||||
active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
|
||||
LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 4, 0), ""),
|
||||
ctx->i32_0, "");
|
||||
tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x10, 0x07, 0x00));
|
||||
active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
|
||||
LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 8, 0), ""),
|
||||
ctx->i32_0, "");
|
||||
tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x00, 0x0f, 0x00));
|
||||
active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
|
||||
LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 16, 0), ""),
|
||||
ctx->i32_0, "");
|
||||
tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
tmp = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 31, 0));
|
||||
active = LLVMBuildICmp(ctx->builder, LLVMIntNE,
|
||||
LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 32, 0), ""),
|
||||
ctx->i32_0, "");
|
||||
tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, "");
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
return result;
|
||||
}
|
||||
|
||||
if (maxprefix <= 1)
|
||||
return result;
|
||||
tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false);
|
||||
|
|
Loading…
Reference in New Issue