radv: create RT traversal as separate shader

This will help in future to keep the main shader slim
when we have actual function calls.

Reviewed-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17301>
This commit is contained in:
Daniel Schürmann 2022-04-21 21:33:10 +02:00 committed by Marge Bot
parent 8e056af399
commit 076ea8b35a
1 changed files with 137 additions and 109 deletions

View File

@ -1422,150 +1422,157 @@ insert_traversal_aabb_case(struct radv_device *device,
nir_pop_if(b, NULL);
}
static void
insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
nir_builder *b, const struct rt_variables *vars)
static nir_shader *
build_traversal_shader(struct radv_device *device,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
const struct rt_variables *dst_vars,
struct hash_table *var_remap)
{
nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_traversal");
b.shader->info.internal = false;
b.shader->info.workgroup_size[0] = 8;
b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
struct rt_variables vars = create_rt_variables(b.shader, dst_vars->stack_sizes);
map_rt_variables(var_remap, &vars, dst_vars);
unsigned stack_entry_size = 4;
unsigned lanes = b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] *
b->shader->info.workgroup_size[2];
unsigned lanes = device->physical_device->rt_wave_size;
unsigned stack_entry_stride = stack_entry_size * lanes;
nir_ssa_def *stack_entry_stride_def = nir_imm_int(b, stack_entry_stride);
nir_ssa_def *stack_entry_stride_def = nir_imm_int(&b, stack_entry_stride);
nir_ssa_def *stack_base =
nir_iadd_imm(b, nir_imul_imm(b, nir_load_local_invocation_index(b), stack_entry_size),
b->shader->info.shared_size);
nir_iadd_imm(&b, nir_imul_imm(&b, nir_load_local_invocation_index(&b), stack_entry_size),
b.shader->info.shared_size);
b->shader->info.shared_size += stack_entry_stride * MAX_STACK_ENTRY_COUNT;
assert(b->shader->info.shared_size <= 32768);
b.shader->info.shared_size += stack_entry_stride * MAX_STACK_ENTRY_COUNT;
nir_ssa_def *accel_struct = nir_load_var(b, vars->accel_struct);
nir_ssa_def *accel_struct = nir_load_var(&b, vars.accel_struct);
struct rt_traversal_vars trav_vars = init_traversal_vars(b);
struct rt_traversal_vars trav_vars = init_traversal_vars(&b);
/* Initialize the follow-up shader idx to 0, to be replaced by the miss shader
* if we actually miss. */
nir_store_var(b, vars->idx, nir_imm_int(b, 0), 1);
nir_store_var(&b, vars.idx, nir_imm_int(&b, 0), 1);
nir_store_var(b, trav_vars.should_return, nir_imm_bool(b, false), 1);
nir_store_var(&b, trav_vars.should_return, nir_imm_bool(&b, false), 1);
nir_push_if(b, nir_ine_imm(b, accel_struct, 0));
nir_push_if(&b, nir_ine_imm(&b, accel_struct, 0));
{
nir_store_var(b, trav_vars.bvh_base, build_addr_to_node(b, accel_struct), 1);
nir_store_var(&b, trav_vars.bvh_base, build_addr_to_node(&b, accel_struct), 1);
nir_ssa_def *bvh_root = nir_build_load_global(
b, 1, 32, accel_struct, .access = ACCESS_NON_WRITEABLE, .align_mul = 64);
&b, 1, 32, accel_struct, .access = ACCESS_NON_WRITEABLE, .align_mul = 64);
nir_ssa_def *desc = create_bvh_descriptor(b);
nir_ssa_def *vec3ones = nir_channels(b, nir_imm_vec4(b, 1.0, 1.0, 1.0, 1.0), 0x7);
nir_ssa_def *desc = create_bvh_descriptor(&b);
nir_ssa_def *vec3ones = nir_channels(&b, nir_imm_vec4(&b, 1.0, 1.0, 1.0, 1.0), 0x7);
nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7);
nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7);
nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_imm_int(b, 0), 1);
nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1);
nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7);
nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7);
nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7);
nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_imm_int(&b, 0), 1);
nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1);
nir_store_var(b, trav_vars.stack, nir_iadd(b, stack_base, stack_entry_stride_def), 1);
nir_store_shared(b, bvh_root, stack_base, .base = 0, .align_mul = stack_entry_size);
nir_store_var(&b, trav_vars.stack, nir_iadd(&b, stack_base, stack_entry_stride_def), 1);
nir_store_shared(&b, bvh_root, stack_base, .base = 0, .align_mul = stack_entry_size);
nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, 0), 1);
nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, 0), 1);
nir_push_loop(b);
nir_push_loop(&b);
nir_push_if(b, nir_ieq(b, nir_load_var(b, trav_vars.stack), stack_base));
nir_jump(b, nir_jump_break);
nir_pop_if(b, NULL);
nir_push_if(&b, nir_ieq(&b, nir_load_var(&b, trav_vars.stack), stack_base));
nir_jump(&b, nir_jump_break);
nir_pop_if(&b, NULL);
nir_push_if(
b, nir_uge(b, nir_load_var(b, trav_vars.top_stack), nir_load_var(b, trav_vars.stack)));
nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, 0), 1);
nir_store_var(b, trav_vars.bvh_base,
build_addr_to_node(b, nir_load_var(b, vars->accel_struct)), 1);
nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7);
nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7);
nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1);
&b, nir_uge(&b, nir_load_var(&b, trav_vars.top_stack), nir_load_var(&b, trav_vars.stack)));
nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, 0), 1);
nir_store_var(&b, trav_vars.bvh_base,
build_addr_to_node(&b, nir_load_var(&b, vars.accel_struct)), 1);
nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7);
nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7);
nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7);
nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1);
nir_pop_if(b, NULL);
nir_pop_if(&b, NULL);
nir_store_var(b, trav_vars.stack,
nir_isub(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1);
nir_store_var(&b, trav_vars.stack,
nir_isub(&b, nir_load_var(&b, trav_vars.stack), stack_entry_stride_def), 1);
nir_ssa_def *bvh_node = nir_load_shared(b, 1, 32, nir_load_var(b, trav_vars.stack), .base = 0,
nir_ssa_def *bvh_node = nir_load_shared(&b, 1, 32, nir_load_var(&b, trav_vars.stack), .base = 0,
.align_mul = stack_entry_size);
nir_ssa_def *bvh_node_type = nir_iand_imm(b, bvh_node, 7);
nir_ssa_def *bvh_node_type = nir_iand_imm(&b, bvh_node, 7);
bvh_node = nir_iadd(b, nir_load_var(b, trav_vars.bvh_base), nir_u2u(b, bvh_node, 64));
bvh_node = nir_iadd(&b, nir_load_var(&b, trav_vars.bvh_base), nir_u2u(&b, bvh_node, 64));
nir_ssa_def *intrinsic_result = NULL;
if (!radv_emulate_rt(device->physical_device)) {
intrinsic_result = nir_bvh64_intersect_ray_amd(
b, 32, desc, nir_unpack_64_2x32(b, bvh_node), nir_load_var(b, vars->tmax),
nir_load_var(b, trav_vars.origin), nir_load_var(b, trav_vars.dir),
nir_load_var(b, trav_vars.inv_dir));
&b, 32, desc, nir_unpack_64_2x32(&b, bvh_node), nir_load_var(&b, vars.tmax),
nir_load_var(&b, trav_vars.origin), nir_load_var(&b, trav_vars.dir),
nir_load_var(&b, trav_vars.inv_dir));
}
nir_push_if(b, nir_ine_imm(b, nir_iand_imm(b, bvh_node_type, 4), 0));
nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 4), 0));
{
nir_push_if(b, nir_ine_imm(b, nir_iand_imm(b, bvh_node_type, 2), 0));
nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 2), 0));
{
/* custom */
nir_push_if(b, nir_ine_imm(b, nir_iand_imm(b, bvh_node_type, 1), 0));
nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 1), 0));
if (!(pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR)) {
insert_traversal_aabb_case(device, pCreateInfo, b, vars, &trav_vars, bvh_node);
insert_traversal_aabb_case(device, pCreateInfo, &b, &vars, &trav_vars, bvh_node);
}
nir_push_else(b, NULL);
nir_push_else(&b, NULL);
{
/* instance */
nir_ssa_def *instance_node_addr = build_node_to_addr(device, b, bvh_node);
nir_ssa_def *instance_node_addr = build_node_to_addr(device, &b, bvh_node);
nir_ssa_def *instance_data =
nir_build_load_global(b, 4, 32, instance_node_addr, .align_mul = 64);
nir_build_load_global(&b, 4, 32, instance_node_addr, .align_mul = 64);
nir_ssa_def *wto_matrix[] = {
nir_build_load_global(b, 4, 32, nir_iadd_imm(b, instance_node_addr, 16),
nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 16),
.align_mul = 64, .align_offset = 16),
nir_build_load_global(b, 4, 32, nir_iadd_imm(b, instance_node_addr, 32),
nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 32),
.align_mul = 64, .align_offset = 32),
nir_build_load_global(b, 4, 32, nir_iadd_imm(b, instance_node_addr, 48),
nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 48),
.align_mul = 64, .align_offset = 48)};
nir_ssa_def *instance_id =
nir_build_load_global(b, 1, 32, nir_iadd_imm(b, instance_node_addr, 88));
nir_ssa_def *instance_and_mask = nir_channel(b, instance_data, 2);
nir_ssa_def *instance_mask = nir_ushr_imm(b, instance_and_mask, 24);
nir_build_load_global(&b, 1, 32, nir_iadd_imm(&b, instance_node_addr, 88));
nir_ssa_def *instance_and_mask = nir_channel(&b, instance_data, 2);
nir_ssa_def *instance_mask = nir_ushr_imm(&b, instance_and_mask, 24);
nir_push_if(
b,
nir_ieq_imm(b, nir_iand(b, instance_mask, nir_load_var(b, vars->cull_mask)), 0));
nir_jump(b, nir_jump_continue);
nir_pop_if(b, NULL);
&b,
nir_ieq_imm(&b, nir_iand(&b, instance_mask, nir_load_var(&b, vars.cull_mask)), 0));
nir_jump(&b, nir_jump_continue);
nir_pop_if(&b, NULL);
nir_store_var(b, trav_vars.top_stack, nir_load_var(b, trav_vars.stack), 1);
nir_store_var(b, trav_vars.bvh_base,
nir_store_var(&b, trav_vars.top_stack, nir_load_var(&b, trav_vars.stack), 1);
nir_store_var(&b, trav_vars.bvh_base,
build_addr_to_node(
b, nir_pack_64_2x32(b, nir_channels(b, instance_data, 0x3))),
&b, nir_pack_64_2x32(&b, nir_channels(&b, instance_data, 0x3))),
1);
nir_store_shared(b, nir_iand_imm(b, nir_channel(b, instance_data, 0), 63),
nir_load_var(b, trav_vars.stack), .base = 0,
nir_store_shared(&b, nir_iand_imm(&b, nir_channel(&b, instance_data, 0), 63),
nir_load_var(&b, trav_vars.stack), .base = 0,
.align_mul = stack_entry_size);
nir_store_var(b, trav_vars.stack,
nir_iadd(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def),
nir_store_var(&b, trav_vars.stack,
nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_entry_stride_def),
1);
nir_store_var(
b, trav_vars.origin,
nir_build_vec3_mat_mult_pre(b, nir_load_var(b, vars->origin), wto_matrix), 7);
&b, trav_vars.origin,
nir_build_vec3_mat_mult_pre(&b, nir_load_var(&b, vars.origin), wto_matrix), 7);
nir_store_var(
b, trav_vars.dir,
nir_build_vec3_mat_mult(b, nir_load_var(b, vars->direction), wto_matrix, false),
&b, trav_vars.dir,
nir_build_vec3_mat_mult(&b, nir_load_var(&b, vars.direction), wto_matrix, false),
7);
nir_store_var(b, trav_vars.inv_dir,
nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
nir_store_var(b, trav_vars.custom_instance_and_mask, instance_and_mask, 1);
nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_channel(b, instance_data, 3),
nir_store_var(&b, trav_vars.inv_dir,
nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7);
nir_store_var(&b, trav_vars.custom_instance_and_mask, instance_and_mask, 1);
nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_channel(&b, instance_data, 3),
1);
nir_store_var(b, trav_vars.instance_id, instance_id, 1);
nir_store_var(b, trav_vars.instance_addr, instance_node_addr, 1);
nir_store_var(&b, trav_vars.instance_id, instance_id, 1);
nir_store_var(&b, trav_vars.instance_addr, instance_node_addr, 1);
}
nir_pop_if(b, NULL);
nir_pop_if(&b, NULL);
}
nir_push_else(b, NULL);
nir_push_else(&b, NULL);
{
/* box */
nir_ssa_def *result = intrinsic_result;
@ -1573,61 +1580,85 @@ insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInf
/* If we didn't run the intrinsic cause the hardware didn't support it,
* emulate ray/box intersection here */
result = intersect_ray_amd_software_box(device,
b, bvh_node, nir_load_var(b, vars->tmax), nir_load_var(b, trav_vars.origin),
nir_load_var(b, trav_vars.dir), nir_load_var(b, trav_vars.inv_dir));
&b, bvh_node, nir_load_var(&b, vars.tmax), nir_load_var(&b, trav_vars.origin),
nir_load_var(&b, trav_vars.dir), nir_load_var(&b, trav_vars.inv_dir));
}
for (unsigned i = 4; i-- > 0; ) {
nir_ssa_def *new_node = nir_channel(b, result, i);
nir_push_if(b, nir_ine_imm(b, new_node, 0xffffffff));
nir_ssa_def *new_node = nir_channel(&b, result, i);
nir_push_if(&b, nir_ine_imm(&b, new_node, 0xffffffff));
{
nir_store_shared(b, new_node, nir_load_var(b, trav_vars.stack), .base = 0,
nir_store_shared(&b, new_node, nir_load_var(&b, trav_vars.stack), .base = 0,
.align_mul = stack_entry_size);
nir_store_var(
b, trav_vars.stack,
nir_iadd(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1);
&b, trav_vars.stack,
nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_entry_stride_def), 1);
}
nir_pop_if(b, NULL);
nir_pop_if(&b, NULL);
}
}
nir_pop_if(b, NULL);
nir_pop_if(&b, NULL);
}
nir_push_else(b, NULL);
nir_push_else(&b, NULL);
if (!(pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR)) {
nir_ssa_def *result = intrinsic_result;
if (!result) {
/* If we didn't run the intrinsic cause the hardware didn't support it,
* emulate ray/tri intersection here */
result = intersect_ray_amd_software_tri(device,
b, bvh_node, nir_load_var(b, vars->tmax), nir_load_var(b, trav_vars.origin),
nir_load_var(b, trav_vars.dir), nir_load_var(b, trav_vars.inv_dir));
&b, bvh_node, nir_load_var(&b, vars.tmax), nir_load_var(&b, trav_vars.origin),
nir_load_var(&b, trav_vars.dir), nir_load_var(&b, trav_vars.inv_dir));
}
insert_traversal_triangle_case(device, pCreateInfo, b, result, vars, &trav_vars, bvh_node);
insert_traversal_triangle_case(device, pCreateInfo, &b, result, &vars, &trav_vars, bvh_node);
}
nir_pop_if(b, NULL);
nir_pop_if(&b, NULL);
nir_pop_loop(b, NULL);
nir_pop_loop(&b, NULL);
}
nir_pop_if(b, NULL);
nir_pop_if(&b, NULL);
/* should_return is set if we had a hit but we won't be calling the closest hit shader and hence
* need to return immediately to the calling shader. */
nir_push_if(b, nir_load_var(b, trav_vars.should_return));
nir_push_if(&b, nir_load_var(&b, trav_vars.should_return));
{
insert_rt_return(b, vars);
insert_rt_return(&b, &vars);
}
nir_push_else(b, NULL);
nir_push_else(&b, NULL);
{
/* Only load the miss shader if we actually miss, which we determining by not having set
* a closest hit shader. It is valid to not specify an SBT pointer for miss shaders if none
* of the rays miss. */
nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->idx), 0));
nir_push_if(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0));
{
load_sbt_entry(b, vars, nir_load_var(b, vars->miss_index), SBT_MISS, 0);
load_sbt_entry(&b, &vars, nir_load_var(&b, vars.miss_index), SBT_MISS, 0);
}
nir_pop_if(b, NULL);
nir_pop_if(&b, NULL);
}
nir_pop_if(&b, NULL);
return b.shader;
}
static void
insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
nir_builder *b, const struct rt_variables *vars)
{
struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
nir_shader *shader = build_traversal_shader(device, pCreateInfo, vars, var_remap);
b->shader->info.shared_size += shader->info.shared_size;
assert(b->shader->info.shared_size <= 32768);
/* For now, just inline the traversal shader */
nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->idx), 1));
nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1);
nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
nir_pop_if(b, NULL);
/* Adopt the instructions from the source shader, since they are merely moved, not cloned. */
ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader));
ralloc_free(var_remap);
}
static unsigned
@ -1770,10 +1801,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, false), 1);
nir_push_if(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 1));
nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, true), 1);
insert_traversal(device, pCreateInfo, &b, &vars);
nir_pop_if(&b, NULL);
nir_ssa_def *idx = nir_load_var(&b, vars.idx);