spirv/nir: add support for AMD_shader_ballot and Groups capability

This commit also renames existing AMD capabilities:
 - gcn_shader -> amd_gcn_shader
 - trinary_minmax -> amd_trinary_minmax

Reviewed-by: Connor Abbott <cwabbott0@gmail.com>
This commit is contained in:
Daniel Schürmann 2018-05-09 20:41:23 +02:00 committed by Connor Abbott
parent ea51275e07
commit 7a858f274c
6 changed files with 139 additions and 11 deletions

View File

@ -245,6 +245,9 @@ radv_shader_compile_to_nir(struct radv_device *device,
const struct spirv_to_nir_options spirv_options = {
.lower_ubo_ssbo_access_to_offsets = true,
.caps = {
.amd_gcn_shader = true,
.amd_shader_ballot = false,
.amd_trinary_minmax = true,
.derivative_group = true,
.descriptor_array_dynamic_indexing = true,
.descriptor_array_non_uniform_indexing = true,
@ -253,7 +256,6 @@ radv_shader_compile_to_nir(struct radv_device *device,
.draw_parameters = true,
.float16 = true,
.float64 = true,
.gcn_shader = true,
.geometry_streams = true,
.image_read_without_format = true,
.image_write_without_format = true,
@ -277,7 +279,6 @@ radv_shader_compile_to_nir(struct radv_device *device,
.subgroup_vote = true,
.tessellation = true,
.transform_feedback = true,
.trinary_minmax = true,
.variable_pointers = true,
},
.ubo_addr_format = nir_address_format_32bit_index_offset,

View File

@ -45,7 +45,6 @@ struct spirv_supported_capabilities {
bool fragment_shader_sample_interlock;
bool fragment_shader_pixel_interlock;
bool geometry_streams;
bool gcn_shader;
bool image_ms_array;
bool image_read_without_format;
bool image_write_without_format;
@ -72,9 +71,11 @@ struct spirv_supported_capabilities {
bool subgroup_vote;
bool tessellation;
bool transform_feedback;
bool trinary_minmax;
bool variable_pointers;
bool float16;
bool amd_gcn_shader;
bool amd_shader_ballot;
bool amd_trinary_minmax;
};
typedef struct shader_info {

View File

@ -394,10 +394,13 @@ vtn_handle_extension(struct vtn_builder *b, SpvOp opcode,
if (strcmp(ext, "GLSL.std.450") == 0) {
val->ext_handler = vtn_handle_glsl450_instruction;
} else if ((strcmp(ext, "SPV_AMD_gcn_shader") == 0)
&& (b->options && b->options->caps.gcn_shader)) {
&& (b->options && b->options->caps.amd_gcn_shader)) {
val->ext_handler = vtn_handle_amd_gcn_shader_instruction;
} else if ((strcmp(ext, "SPV_AMD_shader_ballot") == 0)
&& (b->options && b->options->caps.amd_shader_ballot)) {
val->ext_handler = vtn_handle_amd_shader_ballot_instruction;
} else if ((strcmp(ext, "SPV_AMD_shader_trinary_minmax") == 0)
&& (b->options && b->options->caps.trinary_minmax)) {
&& (b->options && b->options->caps.amd_trinary_minmax)) {
val->ext_handler = vtn_handle_amd_shader_trinary_minmax_instruction;
} else if (strcmp(ext, "OpenCL.std") == 0) {
val->ext_handler = vtn_handle_opencl_instruction;
@ -3612,7 +3615,6 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvCapabilityImageReadWrite:
case SpvCapabilityImageMipmap:
case SpvCapabilityPipes:
case SpvCapabilityGroups:
case SpvCapabilityDeviceEnqueue:
case SpvCapabilityLiteralSampler:
case SpvCapabilityGenericPointer:
@ -3677,6 +3679,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
spv_check_supported(subgroup_arithmetic, cap);
break;
case SpvCapabilityGroups:
spv_check_supported(amd_shader_ballot, cap);
break;
case SpvCapabilityVariablePointersStorageBuffer:
case SpvCapabilityVariablePointers:
spv_check_supported(variable_pointers, cap);
@ -4525,12 +4531,31 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvOpGroupNonUniformLogicalXor:
case SpvOpGroupNonUniformQuadBroadcast:
case SpvOpGroupNonUniformQuadSwap:
case SpvOpGroupAll:
case SpvOpGroupAny:
case SpvOpGroupBroadcast:
case SpvOpGroupIAdd:
case SpvOpGroupFAdd:
case SpvOpGroupFMin:
case SpvOpGroupUMin:
case SpvOpGroupSMin:
case SpvOpGroupFMax:
case SpvOpGroupUMax:
case SpvOpGroupSMax:
case SpvOpSubgroupBallotKHR:
case SpvOpSubgroupFirstInvocationKHR:
case SpvOpSubgroupReadInvocationKHR:
case SpvOpSubgroupAllKHR:
case SpvOpSubgroupAnyKHR:
case SpvOpSubgroupAllEqualKHR:
case SpvOpGroupIAddNonUniformAMD:
case SpvOpGroupFAddNonUniformAMD:
case SpvOpGroupFMinNonUniformAMD:
case SpvOpGroupUMinNonUniformAMD:
case SpvOpGroupSMinNonUniformAMD:
case SpvOpGroupFMaxNonUniformAMD:
case SpvOpGroupUMaxNonUniformAMD:
case SpvOpGroupSMaxNonUniformAMD:
vtn_handle_subgroup(b, opcode, w, count);
break;

View File

@ -56,6 +56,67 @@ vtn_handle_amd_gcn_shader_instruction(struct vtn_builder *b, SpvOp ext_opcode,
return true;
}
bool
vtn_handle_amd_shader_ballot_instruction(struct vtn_builder *b, SpvOp ext_opcode,
const uint32_t *w, unsigned count)
{
const struct glsl_type *dest_type =
vtn_value(b, w[1], vtn_value_type_type)->type->type;
struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
val->ssa = vtn_create_ssa_value(b, dest_type);
unsigned num_args;
nir_intrinsic_op op;
switch ((enum ShaderBallotAMD)ext_opcode) {
case SwizzleInvocationsAMD:
num_args = 1;
op = nir_intrinsic_quad_swizzle_amd;
break;
case SwizzleInvocationsMaskedAMD:
num_args = 1;
op = nir_intrinsic_masked_swizzle_amd;
break;
case WriteInvocationAMD:
num_args = 3;
op = nir_intrinsic_write_invocation_amd;
break;
case MbcntAMD:
num_args = 1;
op = nir_intrinsic_mbcnt_amd;
break;
default:
unreachable("Invalid opcode");
}
nir_intrinsic_instr *intrin = nir_intrinsic_instr_create(b->nb.shader, op);
nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, dest_type, NULL);
intrin->num_components = intrin->dest.ssa.num_components;
for (unsigned i = 0; i < num_args; i++)
intrin->src[i] = nir_src_for_ssa(vtn_ssa_value(b, w[i + 5])->def);
if (intrin->intrinsic == nir_intrinsic_quad_swizzle_amd) {
struct vtn_value *val = vtn_value(b, w[6], vtn_value_type_constant);
unsigned mask = val->constant->values[0][0].u32 |
val->constant->values[0][1].u32 << 2 |
val->constant->values[0][2].u32 << 4 |
val->constant->values[0][3].u32 << 6;
nir_intrinsic_set_swizzle_mask(intrin, mask);
} else if (intrin->intrinsic == nir_intrinsic_masked_swizzle_amd) {
struct vtn_value *val = vtn_value(b, w[6], vtn_value_type_constant);
unsigned mask = val->constant->values[0][0].u32 |
val->constant->values[0][1].u32 << 5 |
val->constant->values[0][2].u32 << 10;
nir_intrinsic_set_swizzle_mask(intrin, mask);
}
nir_builder_instr_insert(&b->nb, &intrin->instr);
val->ssa->def = &intrin->dest.ssa;
return true;
}
bool
vtn_handle_amd_shader_trinary_minmax_instruction(struct vtn_builder *b, SpvOp ext_opcode,
const uint32_t *w, unsigned count)

View File

@ -833,6 +833,9 @@ vtn_u64_literal(const uint32_t *w)
bool vtn_handle_amd_gcn_shader_instruction(struct vtn_builder *b, SpvOp ext_opcode,
const uint32_t *words, unsigned count);
bool vtn_handle_amd_shader_ballot_instruction(struct vtn_builder *b, SpvOp ext_opcode,
const uint32_t *w, unsigned count);
bool vtn_handle_amd_shader_trinary_minmax_instruction(struct vtn_builder *b, SpvOp ext_opcode,
const uint32_t *words, unsigned count);
#endif /* _VTN_PRIVATE_H_ */

View File

@ -183,7 +183,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
val->ssa, vtn_ssa_value(b, w[3]), NULL, 0, 0);
break;
case SpvOpGroupNonUniformBroadcast: ++w;
case SpvOpGroupNonUniformBroadcast:
case SpvOpGroupBroadcast: ++w;
case SpvOpSubgroupReadInvocationKHR:
vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
val->ssa, vtn_ssa_value(b, w[3]),
@ -193,6 +194,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
case SpvOpGroupNonUniformAll:
case SpvOpGroupNonUniformAny:
case SpvOpGroupNonUniformAllEqual:
case SpvOpGroupAll:
case SpvOpGroupAny:
case SpvOpSubgroupAllKHR:
case SpvOpSubgroupAnyKHR:
case SpvOpSubgroupAllEqualKHR: {
@ -201,10 +204,12 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
nir_intrinsic_op op;
switch (opcode) {
case SpvOpGroupNonUniformAll:
case SpvOpGroupAll:
case SpvOpSubgroupAllKHR:
op = nir_intrinsic_vote_all;
break;
case SpvOpGroupNonUniformAny:
case SpvOpGroupAny:
case SpvOpSubgroupAnyKHR:
op = nir_intrinsic_vote_any;
break;
@ -232,8 +237,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
}
nir_ssa_def *src0;
if (opcode == SpvOpGroupNonUniformAll ||
opcode == SpvOpGroupNonUniformAny ||
if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
opcode == SpvOpGroupNonUniformAllEqual) {
src0 = vtn_ssa_value(b, w[4])->def;
} else {
@ -319,13 +324,33 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
case SpvOpGroupNonUniformBitwiseXor:
case SpvOpGroupNonUniformLogicalAnd:
case SpvOpGroupNonUniformLogicalOr:
case SpvOpGroupNonUniformLogicalXor: {
case SpvOpGroupNonUniformLogicalXor:
case SpvOpGroupIAdd:
case SpvOpGroupFAdd:
case SpvOpGroupFMin:
case SpvOpGroupUMin:
case SpvOpGroupSMin:
case SpvOpGroupFMax:
case SpvOpGroupUMax:
case SpvOpGroupSMax:
case SpvOpGroupIAddNonUniformAMD:
case SpvOpGroupFAddNonUniformAMD:
case SpvOpGroupFMinNonUniformAMD:
case SpvOpGroupUMinNonUniformAMD:
case SpvOpGroupSMinNonUniformAMD:
case SpvOpGroupFMaxNonUniformAMD:
case SpvOpGroupUMaxNonUniformAMD:
case SpvOpGroupSMaxNonUniformAMD: {
nir_op reduction_op;
switch (opcode) {
case SpvOpGroupNonUniformIAdd:
case SpvOpGroupIAdd:
case SpvOpGroupIAddNonUniformAMD:
reduction_op = nir_op_iadd;
break;
case SpvOpGroupNonUniformFAdd:
case SpvOpGroupFAdd:
case SpvOpGroupFAddNonUniformAMD:
reduction_op = nir_op_fadd;
break;
case SpvOpGroupNonUniformIMul:
@ -335,21 +360,33 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
reduction_op = nir_op_fmul;
break;
case SpvOpGroupNonUniformSMin:
case SpvOpGroupSMin:
case SpvOpGroupSMinNonUniformAMD:
reduction_op = nir_op_imin;
break;
case SpvOpGroupNonUniformUMin:
case SpvOpGroupUMin:
case SpvOpGroupUMinNonUniformAMD:
reduction_op = nir_op_umin;
break;
case SpvOpGroupNonUniformFMin:
case SpvOpGroupFMin:
case SpvOpGroupFMinNonUniformAMD:
reduction_op = nir_op_fmin;
break;
case SpvOpGroupNonUniformSMax:
case SpvOpGroupSMax:
case SpvOpGroupSMaxNonUniformAMD:
reduction_op = nir_op_imax;
break;
case SpvOpGroupNonUniformUMax:
case SpvOpGroupUMax:
case SpvOpGroupUMaxNonUniformAMD:
reduction_op = nir_op_umax;
break;
case SpvOpGroupNonUniformFMax:
case SpvOpGroupFMax:
case SpvOpGroupFMaxNonUniformAMD:
reduction_op = nir_op_fmax;
break;
case SpvOpGroupNonUniformBitwiseAnd: