diff --git a/libs/vkd3d/command.c b/libs/vkd3d/command.c index 11d06a39..8e80587c 100644 --- a/libs/vkd3d/command.c +++ b/libs/vkd3d/command.c @@ -4069,6 +4069,35 @@ static bool d3d12_command_list_update_compute_pipeline(struct d3d12_command_list return true; } +static bool d3d12_command_list_update_raygen_pipeline(struct d3d12_command_list *list) +{ + const struct vkd3d_vk_device_procs *vk_procs = &list->device->vk_procs; + + if (list->current_pipeline != VK_NULL_HANDLE) + return true; + + if (!list->rt_state) + { + WARN("Pipeline state %p is not a raygen pipeline.\n", list->rt_state); + return false; + } + + if (list->command_buffer_pipeline != list->rt_state->pipeline) + { + VK_CALL(vkCmdBindPipeline(list->vk_command_buffer, + VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, + list->rt_state->pipeline)); + list->command_buffer_pipeline = list->rt_state->pipeline; + } + + /* Pipeline stack size is part of the PSO, not any command buffer state + * for some reason ... */ + VK_CALL(vkCmdSetRayTracingPipelineStackSizeKHR(list->vk_command_buffer, + list->rt_state->pipeline_stack_size)); + + return true; +} + static void d3d12_command_list_check_vbo_alignment(struct d3d12_command_list *list) { const uint32_t *stride_masks; @@ -4523,6 +4552,20 @@ static bool d3d12_command_list_update_compute_state(struct d3d12_command_list *l return true; } +static bool d3d12_command_list_update_raygen_state(struct d3d12_command_list *list) +{ + d3d12_command_list_end_current_render_pass(list, false); + + if (!d3d12_command_list_update_raygen_pipeline(list)) + return false; + + /* DXR uses compute bind point for descriptors, we will redirect internally to + * raygen bind point in Vulkan. */ + d3d12_command_list_update_descriptors(list, VK_PIPELINE_BIND_POINT_COMPUTE); + + return true; +} + static void d3d12_command_list_update_dynamic_state(struct d3d12_command_list *list) { const struct vkd3d_vk_device_procs *vk_procs = &list->device->vk_procs; @@ -8310,10 +8353,51 @@ static void STDMETHODCALLTYPE d3d12_command_list_SetPipelineState1(d3d12_command } } +static VkStridedDeviceAddressRegionKHR convert_strided_range( + const D3D12_GPU_VIRTUAL_ADDRESS_RANGE_AND_STRIDE *region) +{ + VkStridedDeviceAddressRegionKHR table; + table.deviceAddress = region->StartAddress; + table.size = region->SizeInBytes; + table.stride = region->StrideInBytes; + return table; +} + static void STDMETHODCALLTYPE d3d12_command_list_DispatchRays(d3d12_command_list_iface *iface, const D3D12_DISPATCH_RAYS_DESC *desc) { - FIXME("iface %p, desc %p stub!\n", iface, desc); + struct d3d12_command_list *list = impl_from_ID3D12GraphicsCommandList(iface); + const struct vkd3d_vk_device_procs *vk_procs = &list->device->vk_procs; + VkStridedDeviceAddressRegionKHR callable_table; + VkStridedDeviceAddressRegionKHR raygen_table; + VkStridedDeviceAddressRegionKHR miss_table; + VkStridedDeviceAddressRegionKHR hit_table; + + TRACE("iface %p, desc %p\n", iface, desc); + + if (!d3d12_device_supports_ray_tracing_tier_1_0(list->device)) + { + WARN("Ray tracing is not supported. Calling this is invalid.\n"); + return; + } + + raygen_table.deviceAddress = desc->RayGenerationShaderRecord.StartAddress; + raygen_table.size = desc->RayGenerationShaderRecord.SizeInBytes; + raygen_table.stride = raygen_table.size; + miss_table = convert_strided_range(&desc->MissShaderTable); + hit_table = convert_strided_range(&desc->HitGroupTable); + callable_table = convert_strided_range(&desc->CallableShaderTable); + + if (!d3d12_command_list_update_raygen_state(list)) + { + WARN("Failed to update raygen state, ignoring dispatch.\n"); + return; + } + + /* TODO: Is DispatchRays predicated? */ + VK_CALL(vkCmdTraceRaysKHR(list->vk_command_buffer, + &raygen_table, &miss_table, &hit_table, &callable_table, + desc->Width, desc->Height, desc->Depth)); } static VkFragmentShadingRateCombinerOpKHR vk_shading_rate_combiner_from_d3d12(D3D12_SHADING_RATE_COMBINER combiner)