ac: add LLVM build functions for subgroup instrinsics

Co-authored-by: Connor Abbott <cwabbott0@gmail.com>
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
This commit is contained in:
Daniel Schürmann 2018-03-06 15:03:36 +01:00 committed by Bas Nieuwenhuizen
parent d19f20e793
commit d5f7ebda3e
2 changed files with 485 additions and 1 deletions

View File

@ -2507,3 +2507,459 @@ void ac_apply_fmask_to_sample(struct ac_llvm_context *ac, LLVMValueRef fmask,
addr[sample_chan] = LLVMBuildSelect(ac->builder, tmp, final_sample,
addr[sample_chan], "");
}
static LLVMValueRef
_ac_build_readlane(struct ac_llvm_context *ctx, LLVMValueRef src, LLVMValueRef lane)
{
ac_build_optimization_barrier(ctx, &src);
return ac_build_intrinsic(ctx,
lane == NULL ? "llvm.amdgcn.readfirstlane" : "llvm.amdgcn.readlane",
LLVMTypeOf(src), (LLVMValueRef []) {
src, lane },
lane == NULL ? 1 : 2,
AC_FUNC_ATTR_READNONE |
AC_FUNC_ATTR_CONVERGENT);
}
/**
* Builds the "llvm.amdgcn.readlane" or "llvm.amdgcn.readfirstlane" intrinsic.
* @param ctx
* @param src
* @param lane - id of the lane or NULL for the first active lane
* @return value of the lane
*/
LLVMValueRef
ac_build_readlane(struct ac_llvm_context *ctx, LLVMValueRef src, LLVMValueRef lane)
{
LLVMTypeRef src_type = LLVMTypeOf(src);
src = ac_to_integer(ctx, src);
unsigned bits = LLVMGetIntTypeWidth(LLVMTypeOf(src));
LLVMValueRef ret;
if (bits == 32) {
ret = _ac_build_readlane(ctx, src, lane);
} else {
assert(bits % 32 == 0);
LLVMTypeRef vec_type = LLVMVectorType(ctx->i32, bits / 32);
LLVMValueRef src_vector =
LLVMBuildBitCast(ctx->builder, src, vec_type, "");
ret = LLVMGetUndef(vec_type);
for (unsigned i = 0; i < bits / 32; i++) {
src = LLVMBuildExtractElement(ctx->builder, src_vector,
LLVMConstInt(ctx->i32, i, 0), "");
LLVMValueRef ret_comp = _ac_build_readlane(ctx, src, lane);
ret = LLVMBuildInsertElement(ctx->builder, ret, ret_comp,
LLVMConstInt(ctx->i32, i, 0), "");
}
}
return LLVMBuildBitCast(ctx->builder, ret, src_type, "");
}
LLVMValueRef
ac_build_writelane(struct ac_llvm_context *ctx, LLVMValueRef src, LLVMValueRef value, LLVMValueRef lane)
{
/* TODO: Use the actual instruction when LLVM adds an intrinsic for it.
*/
LLVMValueRef pred = LLVMBuildICmp(ctx->builder, LLVMIntEQ, lane,
ac_get_thread_id(ctx), "");
return LLVMBuildSelect(ctx->builder, pred, value, src, "");
}
LLVMValueRef
ac_build_mbcnt(struct ac_llvm_context *ctx, LLVMValueRef mask)
{
LLVMValueRef mask_vec = LLVMBuildBitCast(ctx->builder, mask,
LLVMVectorType(ctx->i32, 2),
"");
LLVMValueRef mask_lo = LLVMBuildExtractElement(ctx->builder, mask_vec,
ctx->i32_0, "");
LLVMValueRef mask_hi = LLVMBuildExtractElement(ctx->builder, mask_vec,
ctx->i32_1, "");
LLVMValueRef val =
ac_build_intrinsic(ctx, "llvm.amdgcn.mbcnt.lo", ctx->i32,
(LLVMValueRef []) { mask_lo, ctx->i32_0 },
2, AC_FUNC_ATTR_READNONE);
val = ac_build_intrinsic(ctx, "llvm.amdgcn.mbcnt.hi", ctx->i32,
(LLVMValueRef []) { mask_hi, val },
2, AC_FUNC_ATTR_READNONE);
return val;
}
enum dpp_ctrl {
_dpp_quad_perm = 0x000,
_dpp_row_sl = 0x100,
_dpp_row_sr = 0x110,
_dpp_row_rr = 0x120,
dpp_wf_sl1 = 0x130,
dpp_wf_rl1 = 0x134,
dpp_wf_sr1 = 0x138,
dpp_wf_rr1 = 0x13C,
dpp_row_mirror = 0x140,
dpp_row_half_mirror = 0x141,
dpp_row_bcast15 = 0x142,
dpp_row_bcast31 = 0x143
};
static inline enum dpp_ctrl
dpp_quad_perm(unsigned lane0, unsigned lane1, unsigned lane2, unsigned lane3)
{
assert(lane0 < 4 && lane1 < 4 && lane2 < 4 && lane3 < 4);
return _dpp_quad_perm | lane0 | (lane1 << 2) | (lane2 << 4) | (lane3 << 6);
}
static inline enum dpp_ctrl
dpp_row_sl(unsigned amount)
{
assert(amount > 0 && amount < 16);
return _dpp_row_sl | amount;
}
static inline enum dpp_ctrl
dpp_row_sr(unsigned amount)
{
assert(amount > 0 && amount < 16);
return _dpp_row_sr | amount;
}
static LLVMValueRef
_ac_build_dpp(struct ac_llvm_context *ctx, LLVMValueRef old, LLVMValueRef src,
enum dpp_ctrl dpp_ctrl, unsigned row_mask, unsigned bank_mask,
bool bound_ctrl)
{
return ac_build_intrinsic(ctx, "llvm.amdgcn.update.dpp.i32",
LLVMTypeOf(old),
(LLVMValueRef[]) {
old, src,
LLVMConstInt(ctx->i32, dpp_ctrl, 0),
LLVMConstInt(ctx->i32, row_mask, 0),
LLVMConstInt(ctx->i32, bank_mask, 0),
LLVMConstInt(ctx->i1, bound_ctrl, 0) },
6, AC_FUNC_ATTR_READNONE | AC_FUNC_ATTR_CONVERGENT);
}
static LLVMValueRef
ac_build_dpp(struct ac_llvm_context *ctx, LLVMValueRef old, LLVMValueRef src,
enum dpp_ctrl dpp_ctrl, unsigned row_mask, unsigned bank_mask,
bool bound_ctrl)
{
LLVMTypeRef src_type = LLVMTypeOf(src);
src = ac_to_integer(ctx, src);
old = ac_to_integer(ctx, old);
unsigned bits = LLVMGetIntTypeWidth(LLVMTypeOf(src));
LLVMValueRef ret;
if (bits == 32) {
ret = _ac_build_dpp(ctx, old, src, dpp_ctrl, row_mask,
bank_mask, bound_ctrl);
} else {
assert(bits % 32 == 0);
LLVMTypeRef vec_type = LLVMVectorType(ctx->i32, bits / 32);
LLVMValueRef src_vector =
LLVMBuildBitCast(ctx->builder, src, vec_type, "");
LLVMValueRef old_vector =
LLVMBuildBitCast(ctx->builder, old, vec_type, "");
ret = LLVMGetUndef(vec_type);
for (unsigned i = 0; i < bits / 32; i++) {
src = LLVMBuildExtractElement(ctx->builder, src_vector,
LLVMConstInt(ctx->i32, i,
0), "");
old = LLVMBuildExtractElement(ctx->builder, old_vector,
LLVMConstInt(ctx->i32, i,
0), "");
LLVMValueRef ret_comp = _ac_build_dpp(ctx, old, src,
dpp_ctrl,
row_mask,
bank_mask,
bound_ctrl);
ret = LLVMBuildInsertElement(ctx->builder, ret,
ret_comp,
LLVMConstInt(ctx->i32, i,
0), "");
}
}
return LLVMBuildBitCast(ctx->builder, ret, src_type, "");
}
static inline unsigned
ds_pattern_bitmode(unsigned and_mask, unsigned or_mask, unsigned xor_mask)
{
assert(and_mask < 32 && or_mask < 32 && xor_mask < 32);
return and_mask | (or_mask << 5) | (xor_mask << 10);
}
static LLVMValueRef
_ac_build_ds_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src, unsigned mask)
{
return ac_build_intrinsic(ctx, "llvm.amdgcn.ds.swizzle",
LLVMTypeOf(src), (LLVMValueRef []) {
src, LLVMConstInt(ctx->i32, mask, 0) },
2, AC_FUNC_ATTR_READNONE | AC_FUNC_ATTR_CONVERGENT);
}
LLVMValueRef
ac_build_ds_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src, unsigned mask)
{
LLVMTypeRef src_type = LLVMTypeOf(src);
src = ac_to_integer(ctx, src);
unsigned bits = LLVMGetIntTypeWidth(LLVMTypeOf(src));
LLVMValueRef ret;
if (bits == 32) {
ret = _ac_build_ds_swizzle(ctx, src, mask);
} else {
assert(bits % 32 == 0);
LLVMTypeRef vec_type = LLVMVectorType(ctx->i32, bits / 32);
LLVMValueRef src_vector =
LLVMBuildBitCast(ctx->builder, src, vec_type, "");
ret = LLVMGetUndef(vec_type);
for (unsigned i = 0; i < bits / 32; i++) {
src = LLVMBuildExtractElement(ctx->builder, src_vector,
LLVMConstInt(ctx->i32, i,
0), "");
LLVMValueRef ret_comp = _ac_build_ds_swizzle(ctx, src,
mask);
ret = LLVMBuildInsertElement(ctx->builder, ret,
ret_comp,
LLVMConstInt(ctx->i32, i,
0), "");
}
}
return LLVMBuildBitCast(ctx->builder, ret, src_type, "");
}
static LLVMValueRef
ac_build_wwm(struct ac_llvm_context *ctx, LLVMValueRef src)
{
char name[32], type[8];
ac_build_type_name_for_intr(LLVMTypeOf(src), type, sizeof(type));
snprintf(name, sizeof(name), "llvm.amdgcn.wwm.%s", type);
return ac_build_intrinsic(ctx, name, LLVMTypeOf(src),
(LLVMValueRef []) { src }, 1,
AC_FUNC_ATTR_READNONE);
}
static LLVMValueRef
ac_build_set_inactive(struct ac_llvm_context *ctx, LLVMValueRef src,
LLVMValueRef inactive)
{
char name[32], type[8];
LLVMTypeRef src_type = LLVMTypeOf(src);
src = ac_to_integer(ctx, src);
inactive = ac_to_integer(ctx, inactive);
ac_build_type_name_for_intr(LLVMTypeOf(src), type, sizeof(type));
snprintf(name, sizeof(name), "llvm.amdgcn.set.inactive.%s", type);
LLVMValueRef ret =
ac_build_intrinsic(ctx, name,
LLVMTypeOf(src), (LLVMValueRef []) {
src, inactive }, 2,
AC_FUNC_ATTR_READNONE |
AC_FUNC_ATTR_CONVERGENT);
return LLVMBuildBitCast(ctx->builder, ret, src_type, "");
}
static LLVMValueRef
get_reduction_identity(struct ac_llvm_context *ctx, nir_op op, unsigned type_size)
{
if (type_size == 4) {
switch (op) {
case nir_op_iadd: return ctx->i32_0;
case nir_op_fadd: return ctx->f32_0;
case nir_op_imul: return ctx->i32_1;
case nir_op_fmul: return ctx->f32_1;
case nir_op_imin: return LLVMConstInt(ctx->i32, INT32_MAX, 0);
case nir_op_umin: return LLVMConstInt(ctx->i32, UINT32_MAX, 0);
case nir_op_fmin: return LLVMConstReal(ctx->f32, INFINITY);
case nir_op_imax: return LLVMConstInt(ctx->i32, INT32_MIN, 0);
case nir_op_umax: return ctx->i32_0;
case nir_op_fmax: return LLVMConstReal(ctx->f32, -INFINITY);
case nir_op_iand: return LLVMConstInt(ctx->i32, -1, 0);
case nir_op_ior: return ctx->i32_0;
case nir_op_ixor: return ctx->i32_0;
default:
unreachable("bad reduction intrinsic");
}
} else { /* type_size == 64bit */
switch (op) {
case nir_op_iadd: return ctx->i64_0;
case nir_op_fadd: return ctx->f64_0;
case nir_op_imul: return ctx->i64_1;
case nir_op_fmul: return ctx->f64_1;
case nir_op_imin: return LLVMConstInt(ctx->i64, INT64_MAX, 0);
case nir_op_umin: return LLVMConstInt(ctx->i64, UINT64_MAX, 0);
case nir_op_fmin: return LLVMConstReal(ctx->f64, INFINITY);
case nir_op_imax: return LLVMConstInt(ctx->i64, INT64_MIN, 0);
case nir_op_umax: return ctx->i64_0;
case nir_op_fmax: return LLVMConstReal(ctx->f64, -INFINITY);
case nir_op_iand: return LLVMConstInt(ctx->i64, -1, 0);
case nir_op_ior: return ctx->i64_0;
case nir_op_ixor: return ctx->i64_0;
default:
unreachable("bad reduction intrinsic");
}
}
}
static LLVMValueRef
ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs, nir_op op)
{
bool _64bit = ac_get_type_size(LLVMTypeOf(lhs)) == 8;
switch (op) {
case nir_op_iadd: return LLVMBuildAdd(ctx->builder, lhs, rhs, "");
case nir_op_fadd: return LLVMBuildFAdd(ctx->builder, lhs, rhs, "");
case nir_op_imul: return LLVMBuildMul(ctx->builder, lhs, rhs, "");
case nir_op_fmul: return LLVMBuildFMul(ctx->builder, lhs, rhs, "");
case nir_op_imin: return LLVMBuildSelect(ctx->builder,
LLVMBuildICmp(ctx->builder, LLVMIntSLT, lhs, rhs, ""),
lhs, rhs, "");
case nir_op_umin: return LLVMBuildSelect(ctx->builder,
LLVMBuildICmp(ctx->builder, LLVMIntULT, lhs, rhs, ""),
lhs, rhs, "");
case nir_op_fmin: return ac_build_intrinsic(ctx,
_64bit ? "llvm.minnum.f64" : "llvm.minnum.f32",
_64bit ? ctx->f64 : ctx->f32,
(LLVMValueRef[]){lhs, rhs}, 2, AC_FUNC_ATTR_READNONE);
case nir_op_imax: return LLVMBuildSelect(ctx->builder,
LLVMBuildICmp(ctx->builder, LLVMIntSGT, lhs, rhs, ""),
lhs, rhs, "");
case nir_op_umax: return LLVMBuildSelect(ctx->builder,
LLVMBuildICmp(ctx->builder, LLVMIntUGT, lhs, rhs, ""),
lhs, rhs, "");
case nir_op_fmax: return ac_build_intrinsic(ctx,
_64bit ? "llvm.maxnum.f64" : "llvm.maxnum.f32",
_64bit ? ctx->f64 : ctx->f32,
(LLVMValueRef[]){lhs, rhs}, 2, AC_FUNC_ATTR_READNONE);
case nir_op_iand: return LLVMBuildAnd(ctx->builder, lhs, rhs, "");
case nir_op_ior: return LLVMBuildOr(ctx->builder, lhs, rhs, "");
case nir_op_ixor: return LLVMBuildXor(ctx->builder, lhs, rhs, "");
default:
unreachable("bad reduction intrinsic");
}
}
/* TODO: add inclusive and excluse scan functions for SI chip class. */
static LLVMValueRef
ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity)
{
LLVMValueRef result, tmp;
result = src;
tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false);
result = ac_build_alu_op(ctx, result, tmp, op);
tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(2), 0xf, 0xf, false);
result = ac_build_alu_op(ctx, result, tmp, op);
tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(3), 0xf, 0xf, false);
result = ac_build_alu_op(ctx, result, tmp, op);
tmp = ac_build_dpp(ctx, identity, result, dpp_row_sr(4), 0xf, 0xe, false);
result = ac_build_alu_op(ctx, result, tmp, op);
tmp = ac_build_dpp(ctx, identity, result, dpp_row_sr(8), 0xf, 0xc, false);
result = ac_build_alu_op(ctx, result, tmp, op);
tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast15, 0xa, 0xf, false);
result = ac_build_alu_op(ctx, result, tmp, op);
tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false);
result = ac_build_alu_op(ctx, result, tmp, op);
return result;
}
LLVMValueRef
ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op)
{
ac_build_optimization_barrier(ctx, &src);
LLVMValueRef result;
LLVMValueRef identity = get_reduction_identity(ctx, op,
ac_get_type_size(LLVMTypeOf(src)));
result = LLVMBuildBitCast(ctx->builder,
ac_build_set_inactive(ctx, src, identity),
LLVMTypeOf(identity), "");
result = ac_build_scan(ctx, op, result, identity);
return ac_build_wwm(ctx, result);
}
LLVMValueRef
ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op)
{
ac_build_optimization_barrier(ctx, &src);
LLVMValueRef result;
LLVMValueRef identity = get_reduction_identity(ctx, op,
ac_get_type_size(LLVMTypeOf(src)));
result = LLVMBuildBitCast(ctx->builder,
ac_build_set_inactive(ctx, src, identity),
LLVMTypeOf(identity), "");
result = ac_build_dpp(ctx, identity, result, dpp_wf_sr1, 0xf, 0xf, false);
result = ac_build_scan(ctx, op, result, identity);
return ac_build_wwm(ctx, result);
}
LLVMValueRef
ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsigned cluster_size)
{
if (cluster_size == 1) return src;
ac_build_optimization_barrier(ctx, &src);
LLVMValueRef result, swap;
LLVMValueRef identity = get_reduction_identity(ctx, op,
ac_get_type_size(LLVMTypeOf(src)));
result = LLVMBuildBitCast(ctx->builder,
ac_build_set_inactive(ctx, src, identity),
LLVMTypeOf(identity), "");
swap = ac_build_quad_swizzle(ctx, result, 1, 0, 3, 2);
result = ac_build_alu_op(ctx, result, swap, op);
if (cluster_size == 2) return ac_build_wwm(ctx, result);
swap = ac_build_quad_swizzle(ctx, result, 2, 3, 0, 1);
result = ac_build_alu_op(ctx, result, swap, op);
if (cluster_size == 4) return ac_build_wwm(ctx, result);
if (ctx->chip_class >= VI)
swap = ac_build_dpp(ctx, identity, result, dpp_row_half_mirror, 0xf, 0xf, false);
else
swap = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x1f, 0, 0x04));
result = ac_build_alu_op(ctx, result, swap, op);
if (cluster_size == 8) return ac_build_wwm(ctx, result);
if (ctx->chip_class >= VI)
swap = ac_build_dpp(ctx, identity, result, dpp_row_mirror, 0xf, 0xf, false);
else
swap = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x1f, 0, 0x08));
result = ac_build_alu_op(ctx, result, swap, op);
if (cluster_size == 16) return ac_build_wwm(ctx, result);
if (ctx->chip_class >= VI && cluster_size != 32)
swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast15, 0xa, 0xf, false);
else
swap = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x1f, 0, 0x10));
result = ac_build_alu_op(ctx, result, swap, op);
if (cluster_size == 32) return ac_build_wwm(ctx, result);
if (ctx->chip_class >= VI) {
swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false);
result = ac_build_alu_op(ctx, result, swap, op);
result = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 63, 0));
return ac_build_wwm(ctx, result);
} else {
swap = ac_build_readlane(ctx, result, ctx->i32_0);
result = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 32, 0));
result = ac_build_alu_op(ctx, result, swap, op);
return ac_build_wwm(ctx, result);
}
}
LLVMValueRef
ac_build_quad_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src,
unsigned lane0, unsigned lane1, unsigned lane2, unsigned lane3)
{
unsigned mask = dpp_quad_perm(lane0, lane1, lane2, lane3);
if (ctx->chip_class >= VI && HAVE_LLVM >= 0x0600) {
return ac_build_dpp(ctx, src, src, mask, 0xf, 0xf, false);
} else {
return ac_build_ds_swizzle(ctx, src, (1 << 15) | mask);
}
}
LLVMValueRef
ac_build_shuffle(struct ac_llvm_context *ctx, LLVMValueRef src, LLVMValueRef index)
{
index = LLVMBuildMul(ctx->builder, index, LLVMConstInt(ctx->i32, 4, 0), "");
return ac_build_intrinsic(ctx,
"llvm.amdgcn.ds.bpermute", ctx->i32,
(LLVMValueRef []) {index, src}, 2,
AC_FUNC_ATTR_READNONE |
AC_FUNC_ATTR_CONVERGENT);
}

View File

@ -27,7 +27,7 @@
#include <stdbool.h>
#include <llvm-c/TargetMachine.h>
#include "compiler/nir/nir.h"
#include "amd_family.h"
#ifdef __cplusplus
@ -417,6 +417,34 @@ LLVMValueRef ac_unpack_param(struct ac_llvm_context *ctx, LLVMValueRef param,
void ac_apply_fmask_to_sample(struct ac_llvm_context *ac, LLVMValueRef fmask,
LLVMValueRef *addr, bool is_array_tex);
LLVMValueRef
ac_build_ds_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src, unsigned mask);
LLVMValueRef
ac_build_readlane(struct ac_llvm_context *ctx, LLVMValueRef src, LLVMValueRef lane);
LLVMValueRef
ac_build_writelane(struct ac_llvm_context *ctx, LLVMValueRef src, LLVMValueRef value, LLVMValueRef lane);
LLVMValueRef
ac_build_mbcnt(struct ac_llvm_context *ctx, LLVMValueRef mask);
LLVMValueRef
ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op);
LLVMValueRef
ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op);
LLVMValueRef
ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsigned cluster_size);
LLVMValueRef
ac_build_quad_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src,
unsigned lane0, unsigned lane1, unsigned lane2, unsigned lane3);
LLVMValueRef
ac_build_shuffle(struct ac_llvm_context *ctx, LLVMValueRef src, LLVMValueRef index);
#ifdef __cplusplus
}
#endif