amd/common/gfx10: implement scan & reduce operations

Acked-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
This commit is contained in:
Nicolai Hähnle 2018-05-23 22:08:22 +02:00 committed by Marek Olšák
parent 7ba80c1d19
commit 227c29a80d
1 changed files with 104 additions and 8 deletions

View File

@ -3874,6 +3874,58 @@ ac_build_dpp(struct ac_llvm_context *ctx, LLVMValueRef old, LLVMValueRef src,
return LLVMBuildBitCast(ctx->builder, ret, src_type, "");
}
static LLVMValueRef
_ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel,
bool exchange_rows, bool bound_ctrl)
{
LLVMValueRef args[6] = {
src,
src,
LLVMConstInt(ctx->i32, sel, false),
LLVMConstInt(ctx->i32, sel >> 32, false),
ctx->i1true, /* fi */
bound_ctrl ? ctx->i1true : ctx->i1false,
};
return ac_build_intrinsic(ctx, exchange_rows ? "llvm.amdgcn.permlanex16"
: "llvm.amdgcn.permlane16",
ctx->i32, args, 6,
AC_FUNC_ATTR_READNONE | AC_FUNC_ATTR_CONVERGENT);
}
static LLVMValueRef
ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel,
bool exchange_rows, bool bound_ctrl)
{
LLVMTypeRef src_type = LLVMTypeOf(src);
src = ac_to_integer(ctx, src);
unsigned bits = LLVMGetIntTypeWidth(LLVMTypeOf(src));
LLVMValueRef ret;
if (bits == 32) {
ret = _ac_build_permlane16(ctx, src, sel, exchange_rows,
bound_ctrl);
} else {
assert(bits % 32 == 0);
LLVMTypeRef vec_type = LLVMVectorType(ctx->i32, bits / 32);
LLVMValueRef src_vector =
LLVMBuildBitCast(ctx->builder, src, vec_type, "");
ret = LLVMGetUndef(vec_type);
for (unsigned i = 0; i < bits / 32; i++) {
src = LLVMBuildExtractElement(ctx->builder, src_vector,
LLVMConstInt(ctx->i32, i,
0), "");
LLVMValueRef ret_comp =
_ac_build_permlane16(ctx, src, sel,
exchange_rows,
bound_ctrl);
ret = LLVMBuildInsertElement(ctx->builder, ret,
ret_comp,
LLVMConstInt(ctx->i32, i,
0), "");
}
}
return LLVMBuildBitCast(ctx->builder, ret, src_type, "");
}
static inline unsigned
ds_pattern_bitmode(unsigned and_mask, unsigned or_mask, unsigned xor_mask)
{
@ -4037,10 +4089,18 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs,
*/
static LLVMValueRef
ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
unsigned maxprefix)
unsigned maxprefix, bool inclusive)
{
LLVMValueRef result, tmp;
result = src;
if (ctx->chip_class >= GFX10) {
result = inclusive ? src : identity;
} else {
if (inclusive)
result = src;
else
result = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
}
if (maxprefix <= 1)
return result;
tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false);
@ -4063,6 +4123,38 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu
result = ac_build_alu_op(ctx, result, tmp, op);
if (maxprefix <= 16)
return result;
if (ctx->chip_class >= GFX10) {
/* dpp_row_bcast{15,31} are not supported on gfx10. */
LLVMBuilderRef builder = ctx->builder;
LLVMValueRef tid = ac_get_thread_id(ctx);
LLVMValueRef cc;
/* TODO-GFX10: Can we get better code-gen by putting this into
* a branch so that LLVM generates EXEC mask manipulations? */
if (inclusive)
tmp = result;
else
tmp = ac_build_alu_op(ctx, result, src, op);
tmp = ac_build_permlane16(ctx, tmp, ~(uint64_t)0, true, false);
tmp = ac_build_alu_op(ctx, result, tmp, op);
cc = LLVMBuildAnd(builder, tid, LLVMConstInt(ctx->i32, 16, false), "");
cc = LLVMBuildICmp(builder, LLVMIntNE, cc, ctx->i32_0, "");
result = LLVMBuildSelect(builder, cc, tmp, result, "");
if (maxprefix <= 32)
return result;
if (inclusive)
tmp = result;
else
tmp = ac_build_alu_op(ctx, result, src, op);
tmp = ac_build_readlane(ctx, tmp, LLVMConstInt(ctx->i32, 31, false));
tmp = ac_build_alu_op(ctx, result, tmp, op);
cc = LLVMBuildICmp(builder, LLVMIntUGE, tid,
LLVMConstInt(ctx->i32, 32, false), "");
result = LLVMBuildSelect(builder, cc, tmp, result, "");
return result;
}
tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast15, 0xa, 0xf, false);
result = ac_build_alu_op(ctx, result, tmp, op);
if (maxprefix <= 32)
@ -4092,7 +4184,7 @@ ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op
get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src)));
result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity),
LLVMTypeOf(identity), "");
result = ac_build_scan(ctx, op, result, identity, 64);
result = ac_build_scan(ctx, op, result, identity, 64, true);
return ac_build_wwm(ctx, result);
}
@ -4116,8 +4208,7 @@ ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op
get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src)));
result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity),
LLVMTypeOf(identity), "");
result = ac_build_dpp(ctx, identity, result, dpp_wf_sr1, 0xf, 0xf, false);
result = ac_build_scan(ctx, op, result, identity, 64);
result = ac_build_scan(ctx, op, result, identity, 64, false);
return ac_build_wwm(ctx, result);
}
@ -4155,7 +4246,9 @@ ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsign
result = ac_build_alu_op(ctx, result, swap, op);
if (cluster_size == 16) return ac_build_wwm(ctx, result);
if (ctx->chip_class >= GFX8 && cluster_size != 32)
if (ctx->chip_class >= GFX10)
swap = ac_build_permlane16(ctx, result, 0, true, false);
else if (ctx->chip_class >= GFX8 && cluster_size != 32)
swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast15, 0xa, 0xf, false);
else
swap = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x1f, 0, 0x10));
@ -4163,7 +4256,10 @@ ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsign
if (cluster_size == 32) return ac_build_wwm(ctx, result);
if (ctx->chip_class >= GFX8) {
swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false);
if (ctx->chip_class >= GFX10)
swap = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 31, false));
else
swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false);
result = ac_build_alu_op(ctx, result, swap, op);
result = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 63, 0));
return ac_build_wwm(ctx, result);
@ -4242,7 +4338,7 @@ ac_build_wg_wavescan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
ac_build_optimization_barrier(ctx, &tmp);
bbs[1] = LLVMGetInsertBlock(builder);
phivalues_scan[1] = ac_build_scan(ctx, ws->op, tmp, identity, ws->maxwaves);
phivalues_scan[1] = ac_build_scan(ctx, ws->op, tmp, identity, ws->maxwaves, true);
}
ac_build_endif(ctx, 1001);