amd/common/gfx10: implement scan & reduce operations
Acked-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
This commit is contained in:
parent
7ba80c1d19
commit
227c29a80d
|
@ -3874,6 +3874,58 @@ ac_build_dpp(struct ac_llvm_context *ctx, LLVMValueRef old, LLVMValueRef src,
|
|||
return LLVMBuildBitCast(ctx->builder, ret, src_type, "");
|
||||
}
|
||||
|
||||
static LLVMValueRef
|
||||
_ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel,
|
||||
bool exchange_rows, bool bound_ctrl)
|
||||
{
|
||||
LLVMValueRef args[6] = {
|
||||
src,
|
||||
src,
|
||||
LLVMConstInt(ctx->i32, sel, false),
|
||||
LLVMConstInt(ctx->i32, sel >> 32, false),
|
||||
ctx->i1true, /* fi */
|
||||
bound_ctrl ? ctx->i1true : ctx->i1false,
|
||||
};
|
||||
return ac_build_intrinsic(ctx, exchange_rows ? "llvm.amdgcn.permlanex16"
|
||||
: "llvm.amdgcn.permlane16",
|
||||
ctx->i32, args, 6,
|
||||
AC_FUNC_ATTR_READNONE | AC_FUNC_ATTR_CONVERGENT);
|
||||
}
|
||||
|
||||
static LLVMValueRef
|
||||
ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel,
|
||||
bool exchange_rows, bool bound_ctrl)
|
||||
{
|
||||
LLVMTypeRef src_type = LLVMTypeOf(src);
|
||||
src = ac_to_integer(ctx, src);
|
||||
unsigned bits = LLVMGetIntTypeWidth(LLVMTypeOf(src));
|
||||
LLVMValueRef ret;
|
||||
if (bits == 32) {
|
||||
ret = _ac_build_permlane16(ctx, src, sel, exchange_rows,
|
||||
bound_ctrl);
|
||||
} else {
|
||||
assert(bits % 32 == 0);
|
||||
LLVMTypeRef vec_type = LLVMVectorType(ctx->i32, bits / 32);
|
||||
LLVMValueRef src_vector =
|
||||
LLVMBuildBitCast(ctx->builder, src, vec_type, "");
|
||||
ret = LLVMGetUndef(vec_type);
|
||||
for (unsigned i = 0; i < bits / 32; i++) {
|
||||
src = LLVMBuildExtractElement(ctx->builder, src_vector,
|
||||
LLVMConstInt(ctx->i32, i,
|
||||
0), "");
|
||||
LLVMValueRef ret_comp =
|
||||
_ac_build_permlane16(ctx, src, sel,
|
||||
exchange_rows,
|
||||
bound_ctrl);
|
||||
ret = LLVMBuildInsertElement(ctx->builder, ret,
|
||||
ret_comp,
|
||||
LLVMConstInt(ctx->i32, i,
|
||||
0), "");
|
||||
}
|
||||
}
|
||||
return LLVMBuildBitCast(ctx->builder, ret, src_type, "");
|
||||
}
|
||||
|
||||
static inline unsigned
|
||||
ds_pattern_bitmode(unsigned and_mask, unsigned or_mask, unsigned xor_mask)
|
||||
{
|
||||
|
@ -4037,10 +4089,18 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs,
|
|||
*/
|
||||
static LLVMValueRef
|
||||
ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
|
||||
unsigned maxprefix)
|
||||
unsigned maxprefix, bool inclusive)
|
||||
{
|
||||
LLVMValueRef result, tmp;
|
||||
result = src;
|
||||
|
||||
if (ctx->chip_class >= GFX10) {
|
||||
result = inclusive ? src : identity;
|
||||
} else {
|
||||
if (inclusive)
|
||||
result = src;
|
||||
else
|
||||
result = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
|
||||
}
|
||||
if (maxprefix <= 1)
|
||||
return result;
|
||||
tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false);
|
||||
|
@ -4063,6 +4123,38 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu
|
|||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
if (maxprefix <= 16)
|
||||
return result;
|
||||
|
||||
if (ctx->chip_class >= GFX10) {
|
||||
/* dpp_row_bcast{15,31} are not supported on gfx10. */
|
||||
LLVMBuilderRef builder = ctx->builder;
|
||||
LLVMValueRef tid = ac_get_thread_id(ctx);
|
||||
LLVMValueRef cc;
|
||||
/* TODO-GFX10: Can we get better code-gen by putting this into
|
||||
* a branch so that LLVM generates EXEC mask manipulations? */
|
||||
if (inclusive)
|
||||
tmp = result;
|
||||
else
|
||||
tmp = ac_build_alu_op(ctx, result, src, op);
|
||||
tmp = ac_build_permlane16(ctx, tmp, ~(uint64_t)0, true, false);
|
||||
tmp = ac_build_alu_op(ctx, result, tmp, op);
|
||||
cc = LLVMBuildAnd(builder, tid, LLVMConstInt(ctx->i32, 16, false), "");
|
||||
cc = LLVMBuildICmp(builder, LLVMIntNE, cc, ctx->i32_0, "");
|
||||
result = LLVMBuildSelect(builder, cc, tmp, result, "");
|
||||
if (maxprefix <= 32)
|
||||
return result;
|
||||
|
||||
if (inclusive)
|
||||
tmp = result;
|
||||
else
|
||||
tmp = ac_build_alu_op(ctx, result, src, op);
|
||||
tmp = ac_build_readlane(ctx, tmp, LLVMConstInt(ctx->i32, 31, false));
|
||||
tmp = ac_build_alu_op(ctx, result, tmp, op);
|
||||
cc = LLVMBuildICmp(builder, LLVMIntUGE, tid,
|
||||
LLVMConstInt(ctx->i32, 32, false), "");
|
||||
result = LLVMBuildSelect(builder, cc, tmp, result, "");
|
||||
return result;
|
||||
}
|
||||
|
||||
tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast15, 0xa, 0xf, false);
|
||||
result = ac_build_alu_op(ctx, result, tmp, op);
|
||||
if (maxprefix <= 32)
|
||||
|
@ -4092,7 +4184,7 @@ ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op
|
|||
get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src)));
|
||||
result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity),
|
||||
LLVMTypeOf(identity), "");
|
||||
result = ac_build_scan(ctx, op, result, identity, 64);
|
||||
result = ac_build_scan(ctx, op, result, identity, 64, true);
|
||||
|
||||
return ac_build_wwm(ctx, result);
|
||||
}
|
||||
|
@ -4116,8 +4208,7 @@ ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op
|
|||
get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src)));
|
||||
result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity),
|
||||
LLVMTypeOf(identity), "");
|
||||
result = ac_build_dpp(ctx, identity, result, dpp_wf_sr1, 0xf, 0xf, false);
|
||||
result = ac_build_scan(ctx, op, result, identity, 64);
|
||||
result = ac_build_scan(ctx, op, result, identity, 64, false);
|
||||
|
||||
return ac_build_wwm(ctx, result);
|
||||
}
|
||||
|
@ -4155,7 +4246,9 @@ ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsign
|
|||
result = ac_build_alu_op(ctx, result, swap, op);
|
||||
if (cluster_size == 16) return ac_build_wwm(ctx, result);
|
||||
|
||||
if (ctx->chip_class >= GFX8 && cluster_size != 32)
|
||||
if (ctx->chip_class >= GFX10)
|
||||
swap = ac_build_permlane16(ctx, result, 0, true, false);
|
||||
else if (ctx->chip_class >= GFX8 && cluster_size != 32)
|
||||
swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast15, 0xa, 0xf, false);
|
||||
else
|
||||
swap = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x1f, 0, 0x10));
|
||||
|
@ -4163,7 +4256,10 @@ ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsign
|
|||
if (cluster_size == 32) return ac_build_wwm(ctx, result);
|
||||
|
||||
if (ctx->chip_class >= GFX8) {
|
||||
swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false);
|
||||
if (ctx->chip_class >= GFX10)
|
||||
swap = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 31, false));
|
||||
else
|
||||
swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false);
|
||||
result = ac_build_alu_op(ctx, result, swap, op);
|
||||
result = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 63, 0));
|
||||
return ac_build_wwm(ctx, result);
|
||||
|
@ -4242,7 +4338,7 @@ ac_build_wg_wavescan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
|
|||
ac_build_optimization_barrier(ctx, &tmp);
|
||||
|
||||
bbs[1] = LLVMGetInsertBlock(builder);
|
||||
phivalues_scan[1] = ac_build_scan(ctx, ws->op, tmp, identity, ws->maxwaves);
|
||||
phivalues_scan[1] = ac_build_scan(ctx, ws->op, tmp, identity, ws->maxwaves, true);
|
||||
}
|
||||
ac_build_endif(ctx, 1001);
|
||||
|
||||
|
|
Loading…
Reference in New Issue