vkd3d-shader: Verify that we compile expected shader stage.

Signed-off-by: Hans-Kristian Arntzen <post@arntzen-software.no>
This commit is contained in:
Hans-Kristian Arntzen 2021-04-15 11:42:26 +02:00
parent 8f17fdd1fa
commit 744497274c
5 changed files with 69 additions and 2 deletions

View File

@ -212,6 +212,8 @@ struct vkd3d_shader_interface_info
const struct vkd3d_shader_descriptor_binding *push_constant_ubo_binding;
/* Ignored unless VKD3D_SHADER_INTERFACE_SSBO_OFFSET_BUFFER or TYPED_OFFSET_BUFFER is set */
const struct vkd3d_shader_descriptor_binding *offset_buffer_binding;
VkShaderStageFlagBits stage;
};
struct vkd3d_shader_descriptor_table

View File

@ -449,6 +449,30 @@ static void vkd3d_dxil_log_callback(void *userdata, dxil_spv_log_level level, co
}
}
static bool dxil_match_shader_stage(dxil_spv_shader_stage blob_stage, VkShaderStageFlagBits expected)
{
VkShaderStageFlagBits stage;
switch (blob_stage)
{
case DXIL_SPV_STAGE_VERTEX: stage = VK_SHADER_STAGE_VERTEX_BIT; break;
case DXIL_SPV_STAGE_HULL: stage = VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT; break;
case DXIL_SPV_STAGE_DOMAIN: stage = VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT; break;
case DXIL_SPV_STAGE_GEOMETRY: stage = VK_SHADER_STAGE_GEOMETRY_BIT; break;
case DXIL_SPV_STAGE_PIXEL: stage = VK_SHADER_STAGE_FRAGMENT_BIT; break;
case DXIL_SPV_STAGE_COMPUTE: stage = VK_SHADER_STAGE_COMPUTE_BIT; break;
default: return false;
}
if (stage != expected)
{
ERR("Expected VkShaderStage #%x, but got VkShaderStage #%x.\n", expected, stage);
return false;
}
return true;
}
int vkd3d_shader_compile_dxil(const struct vkd3d_shader_code *dxbc,
struct vkd3d_shader_code *spirv,
const struct vkd3d_shader_interface_info *shader_interface_info,
@ -463,6 +487,7 @@ int vkd3d_shader_compile_dxil(const struct vkd3d_shader_code *dxbc,
dxil_spv_converter converter = NULL;
dxil_spv_parsed_blob blob = NULL;
dxil_spv_compiled_spirv compiled;
dxil_spv_shader_stage stage;
unsigned int i, max_size;
vkd3d_shader_hash_t hash;
int ret = VKD3D_OK;
@ -489,6 +514,13 @@ int vkd3d_shader_compile_dxil(const struct vkd3d_shader_code *dxbc,
goto end;
}
stage = dxil_spv_parsed_blob_get_shader_stage(blob);
if (!dxil_match_shader_stage(stage, shader_interface_info->stage))
{
ret = VKD3D_ERROR_INVALID_ARGUMENT;
goto end;
}
if (dxil_spv_create_converter(blob, &converter) != DXIL_SPV_SUCCESS)
{
ret = VKD3D_ERROR_INVALID_SHADER;

View File

@ -300,6 +300,29 @@ static void vkd3d_shader_scan_destroy(struct vkd3d_shader_scan_info *scan_info)
hash_map_clear(&scan_info->register_map);
}
static int vkd3d_shader_validate_shader_type(enum vkd3d_shader_type type, VkShaderStageFlagBits stages)
{
static const VkShaderStageFlagBits table[VKD3D_SHADER_TYPE_COUNT] = {
VK_SHADER_STAGE_FRAGMENT_BIT,
VK_SHADER_STAGE_VERTEX_BIT,
VK_SHADER_STAGE_GEOMETRY_BIT,
VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
VK_SHADER_STAGE_COMPUTE_BIT,
};
if (type >= VKD3D_SHADER_TYPE_COUNT)
return VKD3D_ERROR_INVALID_ARGUMENT;
if (table[type] != stages)
{
ERR("Expected VkShaderStage #%x, but got VkShaderStage #%x.\n", stages, table[type]);
return VKD3D_ERROR_INVALID_ARGUMENT;
}
return 0;
}
int vkd3d_shader_compile_dxbc(const struct vkd3d_shader_code *dxbc,
struct vkd3d_shader_code *spirv, unsigned int compiler_options,
const struct vkd3d_shader_interface_info *shader_interface_info,
@ -353,6 +376,12 @@ int vkd3d_shader_compile_dxbc(const struct vkd3d_shader_code *dxbc,
return ret;
}
if ((ret = vkd3d_shader_validate_shader_type(parser.shader_version.type, shader_interface_info->stage)) < 0)
{
vkd3d_shader_scan_destroy(&scan_info);
return ret;
}
vkd3d_shader_dump_shader(hash, dxbc, "dxbc");
if (TRACE_ON())

View File

@ -875,6 +875,9 @@ static HRESULT d3d12_state_object_compile_pipeline(struct d3d12_state_object *ob
shader_interface_info.type = VKD3D_SHADER_STRUCTURE_TYPE_SHADER_INTERFACE_INFO;
shader_interface_info.min_ssbo_alignment = d3d12_device_get_ssbo_alignment(object->device);
/* Effectively ignored. */
shader_interface_info.stage = VK_SHADER_STAGE_ALL;
global_signature = unsafe_impl_from_ID3D12RootSignature(data->global_root_signature);
if (global_signature)

View File

@ -2021,7 +2021,7 @@ struct d3d12_pipeline_state *unsafe_impl_from_ID3D12PipelineState(ID3D12Pipeline
}
static HRESULT create_shader_stage(struct d3d12_device *device,
struct VkPipelineShaderStageCreateInfo *stage_desc, enum VkShaderStageFlagBits stage,
struct VkPipelineShaderStageCreateInfo *stage_desc, VkShaderStageFlagBits stage,
const D3D12_SHADER_BYTECODE *code, const struct vkd3d_shader_interface_info *shader_interface,
const struct vkd3d_shader_compile_arguments *compile_args, struct vkd3d_shader_meta *meta)
{
@ -2143,6 +2143,7 @@ static HRESULT d3d12_pipeline_state_init_compute(struct d3d12_pipeline_state *st
shader_interface.push_constant_buffer_count = root_signature->root_constant_count;
shader_interface.push_constant_ubo_binding = &root_signature->push_constant_ubo_binding;
shader_interface.offset_buffer_binding = &root_signature->offset_buffer_binding;
shader_interface.stage = VK_SHADER_STAGE_COMPUTE_BIT;
if ((hr = vkd3d_create_pipeline_cache_from_d3d12_desc(device, &desc->cached_pso, &state->vk_pso_cache)) < 0)
{
@ -3073,7 +3074,7 @@ static HRESULT d3d12_pipeline_state_init_graphics(struct d3d12_pipeline_state *s
}
shader_interface.next = shader_stages[i].stage == xfb_stage ? &xfb_info : NULL;
shader_interface.stage = shader_stages[i].stage;
if (FAILED(hr = create_shader_stage(device, &graphics->stages[graphics->stage_count],
shader_stages[i].stage, b, &shader_interface,
shader_stages[i].stage == VK_SHADER_STAGE_FRAGMENT_BIT ? &ps_compile_args : &compile_args,