radv/ac: Add core Float64 support.
Signed-off-by: Bas Nieuwenhuizen <basni@google.com> Reviewed-by: Dave Airlie <airlied@redhat.com>
This commit is contained in:
parent
01e18b21d1
commit
29577b2123
|
@ -119,6 +119,7 @@ struct nir_to_llvm_context {
|
|||
LLVMTypeRef v3i32;
|
||||
LLVMTypeRef v4i32;
|
||||
LLVMTypeRef v8i32;
|
||||
LLVMTypeRef f64;
|
||||
LLVMTypeRef f32;
|
||||
LLVMTypeRef f16;
|
||||
LLVMTypeRef v2f32;
|
||||
|
@ -313,34 +314,78 @@ static LLVMValueRef get_shared_memory_ptr(struct nir_to_llvm_context *ctx,
|
|||
return ptr;
|
||||
}
|
||||
|
||||
static LLVMTypeRef to_integer_type_scalar(struct nir_to_llvm_context *ctx, LLVMTypeRef t)
|
||||
{
|
||||
if (t == ctx->f16 || t == ctx->i16)
|
||||
return ctx->i16;
|
||||
else if (t == ctx->f32 || t == ctx->i32)
|
||||
return ctx->i32;
|
||||
else if (t == ctx->f64 || t == ctx->i64)
|
||||
return ctx->i64;
|
||||
else
|
||||
unreachable("Unhandled integer size");
|
||||
}
|
||||
|
||||
static LLVMTypeRef to_integer_type(struct nir_to_llvm_context *ctx, LLVMTypeRef t)
|
||||
{
|
||||
if (LLVMGetTypeKind(t) == LLVMVectorTypeKind) {
|
||||
LLVMTypeRef elem_type = LLVMGetElementType(t);
|
||||
return LLVMVectorType(to_integer_type_scalar(ctx, elem_type),
|
||||
LLVMGetVectorSize(t));
|
||||
}
|
||||
return to_integer_type_scalar(ctx, t);
|
||||
}
|
||||
|
||||
static LLVMValueRef to_integer(struct nir_to_llvm_context *ctx, LLVMValueRef v)
|
||||
{
|
||||
LLVMTypeRef type = LLVMTypeOf(v);
|
||||
if (type == ctx->f32) {
|
||||
return LLVMBuildBitCast(ctx->builder, v, ctx->i32, "");
|
||||
} else if (LLVMGetTypeKind(type) == LLVMVectorTypeKind) {
|
||||
LLVMTypeRef elem_type = LLVMGetElementType(type);
|
||||
if (elem_type == ctx->f32) {
|
||||
LLVMTypeRef nt = LLVMVectorType(ctx->i32, LLVMGetVectorSize(type));
|
||||
return LLVMBuildBitCast(ctx->builder, v, nt, "");
|
||||
}
|
||||
return LLVMBuildBitCast(ctx->builder, v, to_integer_type(ctx, type), "");
|
||||
}
|
||||
|
||||
static LLVMTypeRef to_float_type_scalar(struct nir_to_llvm_context *ctx, LLVMTypeRef t)
|
||||
{
|
||||
if (t == ctx->i16 || t == ctx->f16)
|
||||
return ctx->f16;
|
||||
else if (t == ctx->i32 || t == ctx->f32)
|
||||
return ctx->f32;
|
||||
else if (t == ctx->i64 || t == ctx->f64)
|
||||
return ctx->f64;
|
||||
else
|
||||
unreachable("Unhandled float size");
|
||||
}
|
||||
|
||||
static LLVMTypeRef to_float_type(struct nir_to_llvm_context *ctx, LLVMTypeRef t)
|
||||
{
|
||||
if (LLVMGetTypeKind(t) == LLVMVectorTypeKind) {
|
||||
LLVMTypeRef elem_type = LLVMGetElementType(t);
|
||||
return LLVMVectorType(to_float_type_scalar(ctx, elem_type),
|
||||
LLVMGetVectorSize(t));
|
||||
}
|
||||
return v;
|
||||
return to_float_type_scalar(ctx, t);
|
||||
}
|
||||
|
||||
static LLVMValueRef to_float(struct nir_to_llvm_context *ctx, LLVMValueRef v)
|
||||
{
|
||||
LLVMTypeRef type = LLVMTypeOf(v);
|
||||
if (type == ctx->i32) {
|
||||
return LLVMBuildBitCast(ctx->builder, v, ctx->f32, "");
|
||||
} else if (LLVMGetTypeKind(type) == LLVMVectorTypeKind) {
|
||||
LLVMTypeRef elem_type = LLVMGetElementType(type);
|
||||
if (elem_type == ctx->i32) {
|
||||
LLVMTypeRef nt = LLVMVectorType(ctx->f32, LLVMGetVectorSize(type));
|
||||
return LLVMBuildBitCast(ctx->builder, v, nt, "");
|
||||
}
|
||||
}
|
||||
return v;
|
||||
return LLVMBuildBitCast(ctx->builder, v, to_float_type(ctx, type), "");
|
||||
}
|
||||
|
||||
static int get_elem_bits(struct nir_to_llvm_context *ctx, LLVMTypeRef type)
|
||||
{
|
||||
if (LLVMGetTypeKind(type) == LLVMVectorTypeKind)
|
||||
type = LLVMGetElementType(type);
|
||||
|
||||
if (LLVMGetTypeKind(type) == LLVMIntegerTypeKind)
|
||||
return LLVMGetIntTypeWidth(type);
|
||||
|
||||
if (type == ctx->f16)
|
||||
return 16;
|
||||
if (type == ctx->f32)
|
||||
return 32;
|
||||
if (type == ctx->f64)
|
||||
return 64;
|
||||
|
||||
unreachable("Unhandled type kind in get_elem_bits");
|
||||
}
|
||||
|
||||
static LLVMValueRef unpack_param(struct nir_to_llvm_context *ctx,
|
||||
|
@ -710,6 +755,7 @@ static void setup_types(struct nir_to_llvm_context *ctx)
|
|||
ctx->v8i32 = LLVMVectorType(ctx->i32, 8);
|
||||
ctx->f32 = LLVMFloatTypeInContext(ctx->context);
|
||||
ctx->f16 = LLVMHalfTypeInContext(ctx->context);
|
||||
ctx->f64 = LLVMDoubleTypeInContext(ctx->context);
|
||||
ctx->v2f32 = LLVMVectorType(ctx->f32, 2);
|
||||
ctx->v4f32 = LLVMVectorType(ctx->f32, 4);
|
||||
ctx->v16i8 = LLVMVectorType(ctx->i8, 16);
|
||||
|
@ -894,35 +940,47 @@ static LLVMValueRef emit_float_cmp(struct nir_to_llvm_context *ctx,
|
|||
|
||||
static LLVMValueRef emit_intrin_1f_param(struct nir_to_llvm_context *ctx,
|
||||
const char *intrin,
|
||||
LLVMTypeRef result_type,
|
||||
LLVMValueRef src0)
|
||||
{
|
||||
char name[64];
|
||||
LLVMValueRef params[] = {
|
||||
to_float(ctx, src0),
|
||||
};
|
||||
return ac_emit_llvm_intrinsic(&ctx->ac, intrin, ctx->f32, params, 1, AC_FUNC_ATTR_READNONE);
|
||||
|
||||
sprintf(name, "%s.f%d", intrin, get_elem_bits(ctx, result_type));
|
||||
return ac_emit_llvm_intrinsic(&ctx->ac, name, result_type, params, 1, AC_FUNC_ATTR_READNONE);
|
||||
}
|
||||
|
||||
static LLVMValueRef emit_intrin_2f_param(struct nir_to_llvm_context *ctx,
|
||||
const char *intrin,
|
||||
LLVMTypeRef result_type,
|
||||
LLVMValueRef src0, LLVMValueRef src1)
|
||||
{
|
||||
char name[64];
|
||||
LLVMValueRef params[] = {
|
||||
to_float(ctx, src0),
|
||||
to_float(ctx, src1),
|
||||
};
|
||||
return ac_emit_llvm_intrinsic(&ctx->ac, intrin, ctx->f32, params, 2, AC_FUNC_ATTR_READNONE);
|
||||
|
||||
sprintf(name, "%s.f%d", intrin, get_elem_bits(ctx, result_type));
|
||||
return ac_emit_llvm_intrinsic(&ctx->ac, name, result_type, params, 2, AC_FUNC_ATTR_READNONE);
|
||||
}
|
||||
|
||||
static LLVMValueRef emit_intrin_3f_param(struct nir_to_llvm_context *ctx,
|
||||
const char *intrin,
|
||||
LLVMTypeRef result_type,
|
||||
LLVMValueRef src0, LLVMValueRef src1, LLVMValueRef src2)
|
||||
{
|
||||
char name[64];
|
||||
LLVMValueRef params[] = {
|
||||
to_float(ctx, src0),
|
||||
to_float(ctx, src1),
|
||||
to_float(ctx, src2),
|
||||
};
|
||||
return ac_emit_llvm_intrinsic(&ctx->ac, intrin, ctx->f32, params, 3, AC_FUNC_ATTR_READNONE);
|
||||
|
||||
sprintf(name, "%s.f%d", intrin, get_elem_bits(ctx, result_type));
|
||||
return ac_emit_llvm_intrinsic(&ctx->ac, name, result_type, params, 3, AC_FUNC_ATTR_READNONE);
|
||||
}
|
||||
|
||||
static LLVMValueRef emit_bcsel(struct nir_to_llvm_context *ctx,
|
||||
|
@ -1345,6 +1403,7 @@ static void visit_alu(struct nir_to_llvm_context *ctx, nir_alu_instr *instr)
|
|||
LLVMValueRef src[4], result = NULL;
|
||||
unsigned num_components = instr->dest.dest.ssa.num_components;
|
||||
unsigned src_components;
|
||||
LLVMTypeRef def_type = get_def_type(ctx, &instr->dest.dest.ssa);
|
||||
|
||||
assert(nir_op_infos[instr->op].num_inputs <= ARRAY_SIZE(src));
|
||||
switch (instr->op) {
|
||||
|
@ -1410,7 +1469,8 @@ static void visit_alu(struct nir_to_llvm_context *ctx, nir_alu_instr *instr)
|
|||
src[0] = to_float(ctx, src[0]);
|
||||
src[1] = to_float(ctx, src[1]);
|
||||
result = ac_emit_fdiv(&ctx->ac, src[0], src[1]);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.floor.f32", result);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.floor",
|
||||
to_float_type(ctx, def_type), result);
|
||||
result = LLVMBuildFMul(ctx->builder, src[1] , result, "");
|
||||
result = LLVMBuildFSub(ctx->builder, src[0], result, "");
|
||||
break;
|
||||
|
@ -1491,7 +1551,8 @@ static void visit_alu(struct nir_to_llvm_context *ctx, nir_alu_instr *instr)
|
|||
result = emit_float_cmp(ctx, LLVMRealUGE, src[0], src[1]);
|
||||
break;
|
||||
case nir_op_fabs:
|
||||
result = emit_intrin_1f_param(ctx, "llvm.fabs.f32", src[0]);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.fabs",
|
||||
to_float_type(ctx, def_type), src[0]);
|
||||
break;
|
||||
case nir_op_iabs:
|
||||
result = emit_iabs(ctx, src[0]);
|
||||
|
@ -1516,50 +1577,64 @@ static void visit_alu(struct nir_to_llvm_context *ctx, nir_alu_instr *instr)
|
|||
result = emit_fsign(ctx, src[0]);
|
||||
break;
|
||||
case nir_op_ffloor:
|
||||
result = emit_intrin_1f_param(ctx, "llvm.floor.f32", src[0]);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.floor",
|
||||
to_float_type(ctx, def_type), src[0]);
|
||||
break;
|
||||
case nir_op_ftrunc:
|
||||
result = emit_intrin_1f_param(ctx, "llvm.trunc.f32", src[0]);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.trunc",
|
||||
to_float_type(ctx, def_type), src[0]);
|
||||
break;
|
||||
case nir_op_fceil:
|
||||
result = emit_intrin_1f_param(ctx, "llvm.ceil.f32", src[0]);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.ceil",
|
||||
to_float_type(ctx, def_type), src[0]);
|
||||
break;
|
||||
case nir_op_fround_even:
|
||||
result = emit_intrin_1f_param(ctx, "llvm.rint.f32", src[0]);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.rint",
|
||||
to_float_type(ctx, def_type),src[0]);
|
||||
break;
|
||||
case nir_op_ffract:
|
||||
result = emit_ffract(ctx, src[0]);
|
||||
break;
|
||||
case nir_op_fsin:
|
||||
result = emit_intrin_1f_param(ctx, "llvm.sin.f32", src[0]);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.sin",
|
||||
to_float_type(ctx, def_type), src[0]);
|
||||
break;
|
||||
case nir_op_fcos:
|
||||
result = emit_intrin_1f_param(ctx, "llvm.cos.f32", src[0]);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.cos",
|
||||
to_float_type(ctx, def_type), src[0]);
|
||||
break;
|
||||
case nir_op_fsqrt:
|
||||
result = emit_intrin_1f_param(ctx, "llvm.sqrt.f32", src[0]);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.sqrt",
|
||||
to_float_type(ctx, def_type), src[0]);
|
||||
break;
|
||||
case nir_op_fexp2:
|
||||
result = emit_intrin_1f_param(ctx, "llvm.exp2.f32", src[0]);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.exp2",
|
||||
to_float_type(ctx, def_type), src[0]);
|
||||
break;
|
||||
case nir_op_flog2:
|
||||
result = emit_intrin_1f_param(ctx, "llvm.log2.f32", src[0]);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.log2",
|
||||
to_float_type(ctx, def_type), src[0]);
|
||||
break;
|
||||
case nir_op_frsq:
|
||||
result = emit_intrin_1f_param(ctx, "llvm.sqrt.f32", src[0]);
|
||||
result = emit_intrin_1f_param(ctx, "llvm.sqrt",
|
||||
to_float_type(ctx, def_type), src[0]);
|
||||
result = ac_emit_fdiv(&ctx->ac, ctx->f32one, result);
|
||||
break;
|
||||
case nir_op_fpow:
|
||||
result = emit_intrin_2f_param(ctx, "llvm.pow.f32", src[0], src[1]);
|
||||
result = emit_intrin_2f_param(ctx, "llvm.pow",
|
||||
to_float_type(ctx, def_type), src[0], src[1]);
|
||||
break;
|
||||
case nir_op_fmax:
|
||||
result = emit_intrin_2f_param(ctx, "llvm.maxnum.f32", src[0], src[1]);
|
||||
result = emit_intrin_2f_param(ctx, "llvm.maxnum",
|
||||
to_float_type(ctx, def_type), src[0], src[1]);
|
||||
break;
|
||||
case nir_op_fmin:
|
||||
result = emit_intrin_2f_param(ctx, "llvm.minnum.f32", src[0], src[1]);
|
||||
result = emit_intrin_2f_param(ctx, "llvm.minnum",
|
||||
to_float_type(ctx, def_type), src[0], src[1]);
|
||||
break;
|
||||
case nir_op_ffma:
|
||||
result = emit_intrin_3f_param(ctx, "llvm.fma.f32", src[0], src[1], src[2]);
|
||||
result = emit_intrin_3f_param(ctx, "llvm.fma",
|
||||
to_float_type(ctx, def_type), src[0], src[1], src[2]);
|
||||
break;
|
||||
case nir_op_ibitfield_extract:
|
||||
result = emit_bitfield_extract(ctx, "llvm.AMDGPU.bfe.i32", src);
|
||||
|
@ -1583,19 +1658,29 @@ static void visit_alu(struct nir_to_llvm_context *ctx, nir_alu_instr *instr)
|
|||
src[i] = to_integer(ctx, src[i]);
|
||||
result = ac_build_gather_values(&ctx->ac, src, num_components);
|
||||
break;
|
||||
case nir_op_d2i:
|
||||
case nir_op_f2i:
|
||||
src[0] = to_float(ctx, src[0]);
|
||||
result = LLVMBuildFPToSI(ctx->builder, src[0], ctx->i32, "");
|
||||
result = LLVMBuildFPToSI(ctx->builder, src[0], def_type, "");
|
||||
break;
|
||||
case nir_op_d2u:
|
||||
case nir_op_f2u:
|
||||
src[0] = to_float(ctx, src[0]);
|
||||
result = LLVMBuildFPToUI(ctx->builder, src[0], ctx->i32, "");
|
||||
result = LLVMBuildFPToUI(ctx->builder, src[0], def_type, "");
|
||||
break;
|
||||
case nir_op_i2d:
|
||||
case nir_op_i2f:
|
||||
result = LLVMBuildSIToFP(ctx->builder, src[0], ctx->f32, "");
|
||||
result = LLVMBuildSIToFP(ctx->builder, src[0], to_float_type(ctx, def_type), "");
|
||||
break;
|
||||
case nir_op_u2d:
|
||||
case nir_op_u2f:
|
||||
result = LLVMBuildUIToFP(ctx->builder, src[0], ctx->f32, "");
|
||||
result = LLVMBuildUIToFP(ctx->builder, src[0], to_float_type(ctx, def_type), "");
|
||||
break;
|
||||
case nir_op_f2d:
|
||||
result = LLVMBuildFPExt(ctx->builder, src[0], to_float_type(ctx, def_type), "");
|
||||
break;
|
||||
case nir_op_d2f:
|
||||
result = LLVMBuildFPTrunc(ctx->builder, src[0], to_float_type(ctx, def_type), "");
|
||||
break;
|
||||
case nir_op_bcsel:
|
||||
result = emit_bcsel(ctx, src[0], src[1], src[2]);
|
||||
|
@ -4249,8 +4334,8 @@ static LLVMValueRef
|
|||
emit_float_saturate(struct nir_to_llvm_context *ctx, LLVMValueRef v, float lo, float hi)
|
||||
{
|
||||
v = to_float(ctx, v);
|
||||
v = emit_intrin_2f_param(ctx, "llvm.maxnum.f32", v, LLVMConstReal(ctx->f32, lo));
|
||||
return emit_intrin_2f_param(ctx, "llvm.minnum.f32", v, LLVMConstReal(ctx->f32, hi));
|
||||
v = emit_intrin_2f_param(ctx, "llvm.maxnum.f32", ctx->f32, v, LLVMConstReal(ctx->f32, lo));
|
||||
return emit_intrin_2f_param(ctx, "llvm.minnum.f32", ctx->f32, v, LLVMConstReal(ctx->f32, hi));
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue