vkd3d: Sink shader interface struct build to where we need it.

Signed-off-by: Hans-Kristian Arntzen <post@arntzen-software.no>
This commit is contained in:
Hans-Kristian Arntzen 2022-03-18 17:05:43 +01:00
parent dc45142b93
commit 73fa8b9588
1 changed files with 61 additions and 68 deletions

View File

@ -2121,14 +2121,61 @@ static HRESULT vkd3d_load_spirv_from_cached_state(struct d3d12_device *device,
return hr;
}
static HRESULT vkd3d_create_shader_stage(struct d3d12_device *device,
static void d3d12_pipeline_state_init_shader_interface(struct d3d12_pipeline_state *state,
struct d3d12_device *device,
VkShaderStageFlagBits stage,
struct vkd3d_shader_interface_info *shader_interface)
{
const struct d3d12_root_signature *root_signature = state->root_signature;
shader_interface->flags = d3d12_root_signature_get_shader_interface_flags(root_signature);
shader_interface->min_ssbo_alignment = d3d12_device_get_ssbo_alignment(device);
shader_interface->descriptor_tables.offset = root_signature->descriptor_table_offset;
shader_interface->descriptor_tables.count = root_signature->descriptor_table_count;
shader_interface->bindings = root_signature->bindings;
shader_interface->binding_count = root_signature->binding_count;
shader_interface->push_constant_buffers = root_signature->root_constants;
shader_interface->push_constant_buffer_count = root_signature->root_constant_count;
shader_interface->push_constant_ubo_binding = &root_signature->push_constant_ubo_binding;
shader_interface->offset_buffer_binding = &root_signature->offset_buffer_binding;
shader_interface->stage = stage;
shader_interface->xfb_info =
(stage != VK_SHADER_STAGE_COMPUTE_BIT && stage == state->graphics.cached_desc.xfb_stage) ?
state->graphics.cached_desc.xfb_info : NULL;
#ifdef VKD3D_ENABLE_DESCRIPTOR_QA
shader_interface->descriptor_qa_global_binding = &root_signature->descriptor_qa_global_info;
shader_interface->descriptor_qa_heap_binding = &root_signature->descriptor_qa_heap_binding;
#endif
}
static void d3d12_pipeline_state_init_compile_arguments(struct d3d12_pipeline_state *state,
struct d3d12_device *device, VkShaderStageFlagBits stage,
struct vkd3d_shader_compile_arguments *compile_arguments)
{
memset(compile_arguments, 0, sizeof(*compile_arguments));
compile_arguments->target = VKD3D_SHADER_TARGET_SPIRV_VULKAN_1_0;
compile_arguments->target_extension_count = device->vk_info.shader_extension_count;
compile_arguments->target_extensions = device->vk_info.shader_extensions;
compile_arguments->quirks = &vkd3d_shader_quirk_info;
if (stage == VK_SHADER_STAGE_FRAGMENT_BIT)
{
/* Options which are exclusive to PS. Especially output swizzles must only be used in PS. */
compile_arguments->parameter_count = ARRAY_SIZE(state->graphics.cached_desc.ps_shader_parameters);
compile_arguments->parameters = state->graphics.cached_desc.ps_shader_parameters;
compile_arguments->dual_source_blending = state->graphics.cached_desc.is_dual_source_blending;
compile_arguments->output_swizzles = state->graphics.cached_desc.ps_output_swizzle;
compile_arguments->output_swizzle_count = state->graphics.rt_count;
}
}
static HRESULT vkd3d_create_shader_stage(struct d3d12_pipeline_state *state, struct d3d12_device *device,
VkPipelineShaderStageCreateInfo *stage_desc, VkShaderStageFlagBits stage,
VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT *required_subgroup_size_info,
const D3D12_SHADER_BYTECODE *code,
const struct vkd3d_shader_interface_info *shader_interface,
const struct vkd3d_shader_compile_arguments *compile_args, struct vkd3d_shader_code *spirv_code)
const D3D12_SHADER_BYTECODE *code, struct vkd3d_shader_code *spirv_code)
{
struct vkd3d_shader_code dxbc = {code->pShaderBytecode, code->BytecodeLength};
struct vkd3d_shader_interface_info shader_interface;
struct vkd3d_shader_compile_arguments compile_args;
vkd3d_shader_hash_t recovered_hash = 0;
vkd3d_shader_hash_t compiled_hash = 0;
int ret;
@ -2150,7 +2197,11 @@ static HRESULT vkd3d_create_shader_stage(struct d3d12_device *device,
if (!spirv_code->code)
{
TRACE("Calling vkd3d_shader_compile_dxbc.\n");
if ((ret = vkd3d_shader_compile_dxbc(&dxbc, spirv_code, 0, shader_interface, compile_args)) < 0)
d3d12_pipeline_state_init_shader_interface(state, device, stage, &shader_interface);
d3d12_pipeline_state_init_compile_arguments(state, device, stage, &compile_args);
if ((ret = vkd3d_shader_compile_dxbc(&dxbc, spirv_code, 0, &shader_interface, &compile_args)) < 0)
{
WARN("Failed to compile shader, vkd3d result %d.\n", ret);
return hresult_from_vkd3d_result(ret);
@ -2259,31 +2310,9 @@ static void vkd3d_report_pipeline_creation_feedback_results(const VkPipelineCrea
}
}
static void d3d12_pipeline_state_init_compile_arguments(struct d3d12_pipeline_state *state,
struct d3d12_device *device, VkShaderStageFlagBits stage,
struct vkd3d_shader_compile_arguments *compile_arguments)
{
memset(compile_arguments, 0, sizeof(*compile_arguments));
compile_arguments->target = VKD3D_SHADER_TARGET_SPIRV_VULKAN_1_0;
compile_arguments->target_extension_count = device->vk_info.shader_extension_count;
compile_arguments->target_extensions = device->vk_info.shader_extensions;
compile_arguments->quirks = &vkd3d_shader_quirk_info;
if (stage == VK_SHADER_STAGE_FRAGMENT_BIT)
{
/* Options which are exclusive to PS. Especially output swizzles must only be used in PS. */
compile_arguments->parameter_count = ARRAY_SIZE(state->graphics.cached_desc.ps_shader_parameters);
compile_arguments->parameters = state->graphics.cached_desc.ps_shader_parameters;
compile_arguments->dual_source_blending = state->graphics.cached_desc.is_dual_source_blending;
compile_arguments->output_swizzles = state->graphics.cached_desc.ps_output_swizzle;
compile_arguments->output_swizzle_count = state->graphics.rt_count;
}
}
static HRESULT vkd3d_create_compute_pipeline(struct d3d12_pipeline_state *state,
struct d3d12_device *device,
const D3D12_SHADER_BYTECODE *code,
const struct vkd3d_shader_interface_info *shader_interface,
VkPipelineLayout vk_pipeline_layout, VkPipelineCache vk_cache, VkPipeline *vk_pipeline,
struct vkd3d_shader_code *spirv_code)
{
@ -2291,22 +2320,19 @@ static HRESULT vkd3d_create_compute_pipeline(struct d3d12_pipeline_state *state,
const struct vkd3d_vk_device_procs *vk_procs = &device->vk_procs;
VkPipelineCreationFeedbackCreateInfoEXT feedback_info;
struct vkd3d_shader_debug_ring_spec_info spec_info;
struct vkd3d_shader_compile_arguments compile_args;
VkPipelineCreationFeedbackEXT feedbacks[1];
VkComputePipelineCreateInfo pipeline_info;
VkPipelineCreationFeedbackEXT feedback;
VkResult vr;
HRESULT hr;
d3d12_pipeline_state_init_compile_arguments(state, device, VK_SHADER_STAGE_COMPUTE_BIT, &compile_args);
pipeline_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
pipeline_info.pNext = NULL;
pipeline_info.flags = 0;
if (FAILED(hr = vkd3d_create_shader_stage(device,
if (FAILED(hr = vkd3d_create_shader_stage(state, device,
&pipeline_info.stage,
VK_SHADER_STAGE_COMPUTE_BIT, &required_subgroup_size_info,
code, shader_interface, &compile_args, spirv_code)))
code, spirv_code)))
return hr;
pipeline_info.layout = vk_pipeline_layout;
pipeline_info.basePipelineHandle = VK_NULL_HANDLE;
@ -2350,30 +2376,6 @@ static HRESULT vkd3d_create_compute_pipeline(struct d3d12_pipeline_state *state,
return S_OK;
}
static void d3d12_pipeline_state_init_shader_interface(struct d3d12_pipeline_state *state,
struct d3d12_device *device,
VkShaderStageFlagBits stage,
struct vkd3d_shader_interface_info *shader_interface)
{
const struct d3d12_root_signature *root_signature = state->root_signature;
shader_interface->flags = d3d12_root_signature_get_shader_interface_flags(root_signature);
shader_interface->min_ssbo_alignment = d3d12_device_get_ssbo_alignment(device);
shader_interface->descriptor_tables.offset = root_signature->descriptor_table_offset;
shader_interface->descriptor_tables.count = root_signature->descriptor_table_count;
shader_interface->bindings = root_signature->bindings;
shader_interface->binding_count = root_signature->binding_count;
shader_interface->push_constant_buffers = root_signature->root_constants;
shader_interface->push_constant_buffer_count = root_signature->root_constant_count;
shader_interface->push_constant_ubo_binding = &root_signature->push_constant_ubo_binding;
shader_interface->offset_buffer_binding = &root_signature->offset_buffer_binding;
shader_interface->stage = stage;
shader_interface->xfb_info = NULL;
#ifdef VKD3D_ENABLE_DESCRIPTOR_QA
shader_interface->descriptor_qa_global_binding = &root_signature->descriptor_qa_global_info;
shader_interface->descriptor_qa_heap_binding = &root_signature->descriptor_qa_heap_binding;
#endif
}
static HRESULT d3d12_pipeline_state_init_compute(struct d3d12_pipeline_state *state,
struct d3d12_device *device, const struct d3d12_pipeline_state_desc *desc,
const struct d3d12_cached_pipeline_state *cached_pso)
@ -2400,7 +2402,7 @@ static HRESULT d3d12_pipeline_state_init_compute(struct d3d12_pipeline_state *st
VK_SHADER_STAGE_COMPUTE_BIT, &state->compute.code);
hr = vkd3d_create_compute_pipeline(state, device,
&desc->cs, &shader_interface,
&desc->cs,
state->root_signature->compute.vk_pipeline_layout,
state->vk_pso_cache,
&state->compute.vk_pipeline,
@ -3395,22 +3397,13 @@ static HRESULT d3d12_pipeline_state_init_graphics(struct d3d12_pipeline_state *s
for (i = 0; i < ARRAY_SIZE(shader_stages); i++)
{
const D3D12_SHADER_BYTECODE *b = (const void *)((uintptr_t)desc + shader_stages[i].offset);
struct vkd3d_shader_interface_info shader_interface;
struct vkd3d_shader_compile_arguments compile_args;
if (!b->pShaderBytecode)
continue;
/* TODO: Move this to vkd3d_create_shader_stage itself. */
d3d12_pipeline_state_init_shader_interface(state, device, shader_stages[i].stage, &shader_interface);
shader_interface.xfb_info = shader_stages[i].stage == graphics->cached_desc.xfb_stage ?
graphics->cached_desc.xfb_info : NULL;
d3d12_pipeline_state_init_compile_arguments(state, device, shader_interface.stage, &compile_args);
if (FAILED(hr = vkd3d_create_shader_stage(device,
if (FAILED(hr = vkd3d_create_shader_stage(state, device,
&graphics->stages[stage_count],
shader_stages[i].stage, NULL, b, &shader_interface,
&compile_args,
shader_stages[i].stage, NULL, b,
&graphics->code[stage_count])))
goto fail;