radv: vkCmdTraceRaysIndirect2KHR

This changes the trace rays logic to always use
VkTraceRaysIndirectCommand2KHR and implements
vkCmdTraceRaysIndirect2KHR. I renamed the
load_sbt_amd to sbt_base_amd and moved the SBT
load lowering from ACO to NIR.

Note that we can not just upload one pointer to
all the trace parameters because that would
be incompatible with traceRaysIndirect.

Signed-off-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16430>
This commit is contained in:
Konstantin Seurer 2022-05-12 20:22:59 +02:00 committed by Marge Bot
parent 3aa0ea8279
commit 16585664cd
8 changed files with 99 additions and 87 deletions

View File

@ -5754,19 +5754,6 @@ visit_load_ubo(isel_context* ctx, nir_intrinsic_instr* instr)
nir_intrinsic_align_mul(instr), nir_intrinsic_align_offset(instr));
}
void
visit_load_sbt_amd(isel_context* ctx, nir_intrinsic_instr* instr)
{
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
unsigned binding = nir_intrinsic_binding(instr);
Builder bld(ctx->program, ctx->block);
Temp desc_base = convert_pointer_to_64_bit(ctx, get_arg(ctx, ctx->args->ac.sbt_descriptors));
Operand desc_off = bld.copy(bld.def(s1), Operand::c32(binding * 16u));
bld.smem(aco_opcode::s_load_dwordx4, Definition(dst), desc_base, desc_off);
emit_split_vector(ctx, dst, 4);
}
void
visit_load_push_constant(isel_context* ctx, nir_intrinsic_instr* instr)
{
@ -9082,7 +9069,12 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
true);
break;
}
case nir_intrinsic_load_sbt_amd: visit_load_sbt_amd(ctx, instr); break;
case nir_intrinsic_load_sbt_base_amd: {
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
Temp addr = convert_pointer_to_64_bit(ctx, get_arg(ctx, ctx->args->ac.sbt_descriptors));
bld.copy(Definition(dst), Operand(addr));
break;
}
case nir_intrinsic_bvh64_intersect_ray_amd: visit_bvh64_intersect_ray_amd(ctx, instr); break;
case nir_intrinsic_overwrite_vs_arguments_amd: {
ctx->arg_temps[ctx->args->ac.vertex_id.arg_index] = get_ssa_temp(ctx, instr->src[0].ssa);

View File

@ -603,6 +603,7 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_intrinsic_load_workgroup_id:
case nir_intrinsic_load_num_workgroups:
case nir_intrinsic_load_ray_launch_size_addr_amd:
case nir_intrinsic_load_sbt_base_amd:
case nir_intrinsic_load_subgroup_id:
case nir_intrinsic_load_num_subgroups:
case nir_intrinsic_load_first_vertex:
@ -725,7 +726,6 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_intrinsic_inclusive_scan:
case nir_intrinsic_exclusive_scan:
case nir_intrinsic_reduce:
case nir_intrinsic_load_sbt_amd:
case nir_intrinsic_load_ubo:
case nir_intrinsic_load_ssbo:
case nir_intrinsic_load_global_amd:

View File

@ -7683,62 +7683,67 @@ radv_indirect_dispatch(struct radv_cmd_buffer *cmd_buffer, struct radeon_winsys_
radv_compute_dispatch(cmd_buffer, &info);
}
static void
radv_rt_dispatch(struct radv_cmd_buffer *cmd_buffer, const struct radv_dispatch_info *info)
{
radv_dispatch(cmd_buffer, info, cmd_buffer->state.rt_pipeline,
VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR);
}
enum radv_rt_mode {
radv_rt_mode_direct,
radv_rt_mode_indirect,
radv_rt_mode_indirect2,
};
static bool
radv_rt_set_args(struct radv_cmd_buffer *cmd_buffer,
const VkStridedDeviceAddressRegionKHR *tables, uint64_t launch_size_va,
struct radv_dispatch_info *info)
static void
radv_trace_rays(struct radv_cmd_buffer *cmd_buffer, const VkTraceRaysIndirectCommand2KHR *tables,
uint64_t indirect_va, enum radv_rt_mode mode)
{
struct radv_compute_pipeline *pipeline = cmd_buffer->state.rt_pipeline;
uint32_t base_reg = pipeline->base.user_data_0[MESA_SHADER_COMPUTE];
void *ptr;
uint32_t *write_ptr;
uint32_t offset;
info->unaligned = true;
struct radv_dispatch_info info = {0};
info.unaligned = true;
if (!radv_cmd_buffer_upload_alloc(cmd_buffer, 64 + (launch_size_va ? 0 : 12), &offset, &ptr))
return false;
uint64_t launch_size_va;
uint64_t sbt_va;
write_ptr = ptr;
for (unsigned i = 0; i < 4; ++i, write_ptr += 4) {
write_ptr[0] = tables[i].deviceAddress;
write_ptr[1] = tables[i].deviceAddress >> 32;
write_ptr[2] = tables[i].stride;
write_ptr[3] = 0;
}
if (mode != radv_rt_mode_indirect2) {
uint32_t upload_size = mode == radv_rt_mode_direct
? sizeof(VkTraceRaysIndirectCommand2KHR)
: offsetof(VkTraceRaysIndirectCommand2KHR, width);
if (!launch_size_va) {
write_ptr[0] = info->blocks[0];
write_ptr[1] = info->blocks[1];
write_ptr[2] = info->blocks[2];
uint32_t offset;
if (!radv_cmd_buffer_upload_data(cmd_buffer, upload_size, tables, &offset))
return;
uint64_t upload_va = radv_buffer_get_va(cmd_buffer->upload.upload_bo) + offset;
launch_size_va = (mode == radv_rt_mode_direct)
? upload_va + offsetof(VkTraceRaysIndirectCommand2KHR, width)
: indirect_va;
sbt_va = upload_va;
} else {
info->va = launch_size_va;
launch_size_va = indirect_va + offsetof(VkTraceRaysIndirectCommand2KHR, width);
sbt_va = indirect_va;
}
uint64_t va = radv_buffer_get_va(cmd_buffer->upload.upload_bo) + offset;
if (mode == radv_rt_mode_direct) {
info.blocks[0] = tables->width;
info.blocks[1] = tables->height;
info.blocks[2] = tables->depth;
} else
info.va = launch_size_va;
struct radv_userdata_info *desc_loc =
radv_lookup_user_sgpr(&pipeline->base, MESA_SHADER_COMPUTE, AC_UD_CS_SBT_DESCRIPTORS);
if (desc_loc->sgpr_idx != -1) {
radv_emit_shader_pointer(cmd_buffer->device, cmd_buffer->cs,
base_reg + desc_loc->sgpr_idx * 4, va, false);
base_reg + desc_loc->sgpr_idx * 4, sbt_va, false);
}
struct radv_userdata_info *size_loc =
radv_lookup_user_sgpr(&pipeline->base, MESA_SHADER_COMPUTE, AC_UD_CS_RAY_LAUNCH_SIZE_ADDR);
if (size_loc->sgpr_idx != -1) {
radv_emit_shader_pointer(cmd_buffer->device, cmd_buffer->cs,
base_reg + size_loc->sgpr_idx * 4, launch_size_va ? launch_size_va : (va + 64), false);
base_reg + size_loc->sgpr_idx * 4, launch_size_va, false);
}
return true;
radv_dispatch(cmd_buffer, &info, pipeline, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR);
}
VKAPI_ATTR void VKAPI_CALL
@ -7750,23 +7755,25 @@ radv_CmdTraceRaysKHR(VkCommandBuffer commandBuffer,
uint32_t width, uint32_t height, uint32_t depth)
{
RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
struct radv_dispatch_info info = {0};
info.blocks[0] = width;
info.blocks[1] = height;
info.blocks[2] = depth;
const VkStridedDeviceAddressRegionKHR tables[] = {
*pRaygenShaderBindingTable,
*pMissShaderBindingTable,
*pHitShaderBindingTable,
*pCallableShaderBindingTable,
VkTraceRaysIndirectCommand2KHR tables = {
.raygenShaderRecordAddress = pRaygenShaderBindingTable->deviceAddress,
.raygenShaderRecordSize = pRaygenShaderBindingTable->size,
.missShaderBindingTableAddress = pMissShaderBindingTable->deviceAddress,
.missShaderBindingTableSize = pMissShaderBindingTable->size,
.missShaderBindingTableStride = pMissShaderBindingTable->stride,
.hitShaderBindingTableAddress = pHitShaderBindingTable->deviceAddress,
.hitShaderBindingTableSize = pHitShaderBindingTable->size,
.hitShaderBindingTableStride = pHitShaderBindingTable->stride,
.callableShaderBindingTableAddress = pCallableShaderBindingTable->deviceAddress,
.callableShaderBindingTableSize = pCallableShaderBindingTable->size,
.callableShaderBindingTableStride = pCallableShaderBindingTable->stride,
.width = width,
.height = height,
.depth = depth,
};
if (!radv_rt_set_args(cmd_buffer, tables, 0, &info))
return;
radv_rt_dispatch(cmd_buffer, &info);
radv_trace_rays(cmd_buffer, &tables, 0, radv_rt_mode_direct);
}
VKAPI_ATTR void VKAPI_CALL
@ -7781,18 +7788,31 @@ radv_CmdTraceRaysIndirectKHR(VkCommandBuffer commandBuffer,
assert(cmd_buffer->device->use_global_bo_list);
const VkStridedDeviceAddressRegionKHR tables[] = {
*pRaygenShaderBindingTable,
*pMissShaderBindingTable,
*pHitShaderBindingTable,
*pCallableShaderBindingTable,
VkTraceRaysIndirectCommand2KHR tables = {
.raygenShaderRecordAddress = pRaygenShaderBindingTable->deviceAddress,
.raygenShaderRecordSize = pRaygenShaderBindingTable->size,
.missShaderBindingTableAddress = pMissShaderBindingTable->deviceAddress,
.missShaderBindingTableSize = pMissShaderBindingTable->size,
.missShaderBindingTableStride = pMissShaderBindingTable->stride,
.hitShaderBindingTableAddress = pHitShaderBindingTable->deviceAddress,
.hitShaderBindingTableSize = pHitShaderBindingTable->size,
.hitShaderBindingTableStride = pHitShaderBindingTable->stride,
.callableShaderBindingTableAddress = pCallableShaderBindingTable->deviceAddress,
.callableShaderBindingTableSize = pCallableShaderBindingTable->size,
.callableShaderBindingTableStride = pCallableShaderBindingTable->stride,
};
struct radv_dispatch_info info = {0};
if (!radv_rt_set_args(cmd_buffer, tables, indirectDeviceAddress, &info))
return;
radv_trace_rays(cmd_buffer, &tables, indirectDeviceAddress, radv_rt_mode_indirect);
}
radv_rt_dispatch(cmd_buffer, &info);
VKAPI_ATTR void VKAPI_CALL
radv_CmdTraceRaysIndirect2KHR(VkCommandBuffer commandBuffer, VkDeviceAddress indirectDeviceAddress)
{
RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
assert(cmd_buffer->device->use_global_bo_list);
radv_trace_rays(cmd_buffer, NULL, indirectDeviceAddress, radv_rt_mode_indirect2);
}
static void

View File

@ -333,23 +333,25 @@ insert_rt_return(nir_builder *b, const struct rt_variables *vars)
}
enum sbt_type {
SBT_RAYGEN,
SBT_MISS,
SBT_HIT,
SBT_CALLABLE,
SBT_RAYGEN = offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress),
SBT_MISS = offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress),
SBT_HIT = offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
SBT_CALLABLE = offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress),
};
static nir_ssa_def *
get_sbt_ptr(nir_builder *b, nir_ssa_def *idx, enum sbt_type binding)
{
nir_ssa_def *desc = nir_load_sbt_amd(b, 4, .binding = binding);
nir_ssa_def *base_addr = nir_pack_64_2x32(b, nir_channels(b, desc, 0x3));
nir_ssa_def *stride = nir_channel(b, desc, 2);
nir_ssa_def *desc_base_addr = nir_load_sbt_base_amd(b);
nir_ssa_def *ret = nir_imul(b, idx, stride);
ret = nir_iadd(b, base_addr, nir_u2u64(b, ret));
nir_ssa_def *desc =
nir_pack_64_2x32(b, nir_build_load_smem_amd(b, 2, desc_base_addr, nir_imm_int(b, binding)));
return ret;
nir_ssa_def *stride_offset = nir_imm_int(b, binding + (binding == SBT_RAYGEN ? 8 : 16));
nir_ssa_def *stride =
nir_pack_64_2x32(b, nir_build_load_smem_amd(b, 2, desc_base_addr, stride_offset));
return nir_iadd(b, desc, nir_imul(b, nir_u2u64(b, idx), stride));
}
static void

View File

@ -197,7 +197,7 @@ gather_intrinsic_info(const nir_shader *nir, const nir_intrinsic_instr *instr,
case nir_intrinsic_store_output:
gather_intrinsic_store_output_info(nir, instr, info);
break;
case nir_intrinsic_load_sbt_amd:
case nir_intrinsic_load_sbt_base_amd:
info->cs.uses_sbt = true;
break;
case nir_intrinsic_load_force_vrs_rates_amd:

View File

@ -107,6 +107,7 @@ visit_intrinsic(nir_shader *shader, nir_intrinsic_instr *instr)
case nir_intrinsic_load_num_subgroups:
case nir_intrinsic_load_ray_launch_size:
case nir_intrinsic_load_ray_launch_size_addr_amd:
case nir_intrinsic_load_sbt_base_amd:
case nir_intrinsic_load_subgroup_size:
case nir_intrinsic_load_subgroup_eq_mask:
case nir_intrinsic_load_subgroup_ge_mask:
@ -373,7 +374,6 @@ visit_intrinsic(nir_shader *shader, nir_intrinsic_instr *instr)
case nir_intrinsic_masked_swizzle_amd:
case nir_intrinsic_is_sparse_texels_resident:
case nir_intrinsic_sparse_residency_code_and:
case nir_intrinsic_load_sbt_amd:
case nir_intrinsic_bvh64_intersect_ray_amd:
case nir_intrinsic_image_deref_load_param_intel:
case nir_intrinsic_image_load_raw_intel:

View File

@ -1356,10 +1356,8 @@ intrinsic("overwrite_vs_arguments_amd", src_comp=[1, 1], indices=[])
# Overwrites TES input registers, for use with vertex compaction after culling. src = {tes_u, tes_v, rel_patch_id, patch_id}.
intrinsic("overwrite_tes_arguments_amd", src_comp=[1, 1, 1, 1], indices=[])
# loads a descriptor for an sbt.
# src = [index] BINDING = which table
intrinsic("load_sbt_amd", dest_comp=4, bit_sizes=[32], indices=[BINDING],
flags=[CAN_ELIMINATE, CAN_REORDER])
# The address of the sbt descriptors.
system_value("sbt_base_amd", 1, bit_sizes=[64])
# 1. HW descriptor
# 2. BVH node(64-bit pointer as 2x32 ...)

View File

@ -119,6 +119,7 @@ can_move_intrinsic(nir_intrinsic_instr *instr, opt_preamble_ctx *ctx)
case nir_intrinsic_load_workgroup_size:
case nir_intrinsic_load_ray_launch_size:
case nir_intrinsic_load_ray_launch_size_addr_amd:
case nir_intrinsic_load_sbt_base_amd:
case nir_intrinsic_load_is_indexed_draw:
case nir_intrinsic_load_viewport_scale:
case nir_intrinsic_load_user_clip_plane:
@ -188,7 +189,6 @@ can_move_intrinsic(nir_intrinsic_instr *instr, opt_preamble_ctx *ctx)
case nir_intrinsic_load_vulkan_descriptor:
case nir_intrinsic_quad_swizzle_amd:
case nir_intrinsic_masked_swizzle_amd:
case nir_intrinsic_load_sbt_amd:
case nir_intrinsic_load_ssbo_address:
case nir_intrinsic_bindless_resource_ir3:
return can_move_srcs(&instr->instr, ctx);