diff --git a/libs/vkd3d/raytracing_pipeline.c b/libs/vkd3d/raytracing_pipeline.c index 05e4f947..1b88a600 100644 --- a/libs/vkd3d/raytracing_pipeline.c +++ b/libs/vkd3d/raytracing_pipeline.c @@ -200,6 +200,12 @@ static HRESULT STDMETHODCALLTYPE d3d12_state_object_GetDevice(ID3D12StateObject return d3d12_device_query_interface(state_object->device, iid, device); } +static bool vkd3d_export_equal(LPCWSTR export, const struct vkd3d_shader_library_entry_point *entry) +{ + return vkd3d_export_strequal(export, entry->mangled_entry_point) || + vkd3d_export_strequal(export, entry->plain_entry_point); +} + static uint32_t d3d12_state_object_get_export_index(struct d3d12_state_object *object, const WCHAR *export_name, const WCHAR **out_subtype) { @@ -604,11 +610,8 @@ static uint32_t d3d12_state_object_pipeline_data_find_entry_inner( return VK_SHADER_UNUSED_KHR; for (i = 0; i < count; i++) - { - if (vkd3d_export_strequal(import, entry_points[i].mangled_entry_point) || - vkd3d_export_strequal(import, entry_points[i].plain_entry_point)) + if (vkd3d_export_equal(import, &entry_points[i])) return i; - } return VK_SHADER_UNUSED_KHR; } @@ -755,27 +758,63 @@ static VkDeviceSize d3d12_state_object_pipeline_data_compute_default_stack_size( return pipeline_stack_size; } -static struct d3d12_root_signature *d3d12_state_object_pipeline_data_get_local_root_signature( +static struct d3d12_root_signature *d3d12_state_object_find_associated_root_signature_entry( struct d3d12_state_object_pipeline_data *data, const struct vkd3d_shader_library_entry_point *entry) { size_t i; - for (i = 0; i < data->associations_count; i++) - { - if (vkd3d_export_strequal(data->associations[i].export, entry->mangled_entry_point) || - vkd3d_export_strequal(data->associations[i].export, entry->plain_entry_point)) - { + if (vkd3d_export_equal(data->associations[i].export, entry)) return data->associations[i].root_signature; + return NULL; +} + +static struct d3d12_root_signature *d3d12_state_object_find_associated_root_signature_export( + struct d3d12_state_object_pipeline_data *data, LPCWSTR export) +{ + size_t i; + for (i = 0; i < data->associations_count; i++) + if (vkd3d_export_strequal(data->associations[i].export, export)) + return data->associations[i].root_signature; + return NULL; +} + +static struct d3d12_root_signature *d3d12_state_object_pipeline_data_get_local_root_signature( + struct d3d12_state_object_pipeline_data *data, + const struct vkd3d_shader_library_entry_point *entry) +{ + const D3D12_HIT_GROUP_DESC *hit_group; + struct d3d12_root_signature *rs; + size_t i; + + rs = d3d12_state_object_find_associated_root_signature_entry(data, entry); + + /* If we didn't find an association for this entry point, we might have an association + * in a hit group export. + * FIXME: Is it possible to have multiple hit groups, all referring to same entry point, while using + * different root signatures for the different instances of the entry point? :| */ + for (i = 0; i < data->hit_groups_count && !rs; i++) + { + hit_group = data->hit_groups[i]; + if (vkd3d_export_equal(hit_group->ClosestHitShaderImport, entry) || + vkd3d_export_equal(hit_group->AnyHitShaderImport, entry) || + vkd3d_export_equal(hit_group->IntersectionShaderImport, entry)) + { + rs = d3d12_state_object_find_associated_root_signature_export(data, hit_group->HitGroupExport); } } - if (data->high_priority_local_root_signature) - return impl_from_ID3D12RootSignature(data->high_priority_local_root_signature); - else if (data->low_priority_local_root_signature) - return impl_from_ID3D12RootSignature(data->low_priority_local_root_signature); - else - return NULL; + if (!rs) + { + if (data->high_priority_local_root_signature) + rs = impl_from_ID3D12RootSignature(data->high_priority_local_root_signature); + else if (data->low_priority_local_root_signature) + rs = impl_from_ID3D12RootSignature(data->low_priority_local_root_signature); + else + rs = NULL; + } + + return rs; } static HRESULT d3d12_state_object_get_group_handles(struct d3d12_state_object *object,