radv/rt: use derefs for the traversal stack

Reviewed-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17301>
This commit is contained in:
Daniel Schürmann 2022-04-22 00:29:02 +02:00 committed by Marge Bot
parent 076ea8b35a
commit 3750663c72
1 changed files with 22 additions and 24 deletions

View File

@ -1435,15 +1435,15 @@ build_traversal_shader(struct radv_device *device,
struct rt_variables vars = create_rt_variables(b.shader, dst_vars->stack_sizes);
map_rt_variables(var_remap, &vars, dst_vars);
unsigned stack_entry_size = 4;
unsigned lanes = device->physical_device->rt_wave_size;
unsigned stack_entry_stride = stack_entry_size * lanes;
nir_ssa_def *stack_entry_stride_def = nir_imm_int(&b, stack_entry_stride);
nir_ssa_def *stack_base =
nir_iadd_imm(&b, nir_imul_imm(&b, nir_load_local_invocation_index(&b), stack_entry_size),
b.shader->info.shared_size);
b.shader->info.shared_size += stack_entry_stride * MAX_STACK_ENTRY_COUNT;
unsigned elements = lanes * MAX_STACK_ENTRY_COUNT;
nir_variable *stack_var = nir_variable_create(b.shader, nir_var_mem_shared,
glsl_array_type(glsl_uint_type(), elements, 0),
"trav_stack");
nir_deref_instr *stack_deref = nir_build_deref_var(&b, stack_var);
nir_deref_instr *stack;
nir_ssa_def *stack_idx_stride = nir_imm_int(&b, lanes);
nir_ssa_def *stack_idx_base = nir_load_local_invocation_index(&b);
nir_ssa_def *accel_struct = nir_load_var(&b, vars.accel_struct);
@ -1471,14 +1471,15 @@ build_traversal_shader(struct radv_device *device,
nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_imm_int(&b, 0), 1);
nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1);
nir_store_var(&b, trav_vars.stack, nir_iadd(&b, stack_base, stack_entry_stride_def), 1);
nir_store_shared(&b, bvh_root, stack_base, .base = 0, .align_mul = stack_entry_size);
nir_store_var(&b, trav_vars.stack, nir_iadd(&b, stack_idx_base, stack_idx_stride), 1);
stack = nir_build_deref_array(&b, stack_deref, stack_idx_base);
nir_store_deref(&b, stack, bvh_root, 0x1);
nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, 0), 1);
nir_push_loop(&b);
nir_push_if(&b, nir_ieq(&b, nir_load_var(&b, trav_vars.stack), stack_base));
nir_push_if(&b, nir_ieq(&b, nir_load_var(&b, trav_vars.stack), stack_idx_base));
nir_jump(&b, nir_jump_break);
nir_pop_if(&b, NULL);
@ -1495,10 +1496,10 @@ build_traversal_shader(struct radv_device *device,
nir_pop_if(&b, NULL);
nir_store_var(&b, trav_vars.stack,
nir_isub(&b, nir_load_var(&b, trav_vars.stack), stack_entry_stride_def), 1);
nir_isub(&b, nir_load_var(&b, trav_vars.stack), stack_idx_stride), 1);
nir_ssa_def *bvh_node = nir_load_shared(&b, 1, 32, nir_load_var(&b, trav_vars.stack), .base = 0,
.align_mul = stack_entry_size);
stack = nir_build_deref_array(&b, stack_deref, nir_load_var(&b, trav_vars.stack));
nir_ssa_def *bvh_node = nir_load_deref(&b, stack);
nir_ssa_def *bvh_node_type = nir_iand_imm(&b, bvh_node, 7);
bvh_node = nir_iadd(&b, nir_load_var(&b, trav_vars.bvh_base), nir_u2u(&b, bvh_node, 64));
@ -1548,12 +1549,11 @@ build_traversal_shader(struct radv_device *device,
build_addr_to_node(
&b, nir_pack_64_2x32(&b, nir_channels(&b, instance_data, 0x3))),
1);
nir_store_shared(&b, nir_iand_imm(&b, nir_channel(&b, instance_data, 0), 63),
nir_load_var(&b, trav_vars.stack), .base = 0,
.align_mul = stack_entry_size);
stack = nir_build_deref_array(&b, stack_deref, nir_load_var(&b, trav_vars.stack));
nir_store_deref(&b, stack, nir_iand_imm(&b, nir_channel(&b, instance_data, 0), 63), 0x1);
nir_store_var(&b, trav_vars.stack,
nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_entry_stride_def),
1);
nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_idx_stride), 1);
nir_store_var(
&b, trav_vars.origin,
@ -1588,11 +1588,11 @@ build_traversal_shader(struct radv_device *device,
nir_ssa_def *new_node = nir_channel(&b, result, i);
nir_push_if(&b, nir_ine_imm(&b, new_node, 0xffffffff));
{
nir_store_shared(&b, new_node, nir_load_var(&b, trav_vars.stack), .base = 0,
.align_mul = stack_entry_size);
stack = nir_build_deref_array(&b, stack_deref, nir_load_var(&b, trav_vars.stack));
nir_store_deref(&b, stack, new_node, 0x1);
nir_store_var(
&b, trav_vars.stack,
nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_entry_stride_def), 1);
nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_idx_stride), 1);
}
nir_pop_if(&b, NULL);
}
@ -1646,8 +1646,6 @@ insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInf
{
struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
nir_shader *shader = build_traversal_shader(device, pCreateInfo, vars, var_remap);
b->shader->info.shared_size += shader->info.shared_size;
assert(b->shader->info.shared_size <= 32768);
/* For now, just inline the traversal shader */
nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->idx), 1));