diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index ca2c507a4a0..82602fa0ab3 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -21,6 +21,7 @@ * IN THE SOFTWARE. */ +#include "radv_acceleration_structure.h" #include "radv_private.h" #include "radv_shader.h" @@ -303,6 +304,473 @@ create_inner_vars(nir_builder *b, const struct rt_variables *vars) return inner_vars; } +/* The hit attributes are stored on the stack. This is the offset compared to the current stack + * pointer of where the hit attrib is stored. */ +const uint32_t RADV_HIT_ATTRIB_OFFSET = -(16 + RADV_MAX_HIT_ATTRIB_SIZE); + +static void +insert_rt_return(nir_builder *b, const struct rt_variables *vars) +{ + nir_store_var(b, vars->stack_ptr, + nir_iadd(b, nir_load_var(b, vars->stack_ptr), nir_imm_int(b, -16)), 1); + nir_store_var(b, vars->idx, + nir_load_scratch(b, 1, 32, nir_load_var(b, vars->stack_ptr), .align_mul = 16), 1); +} + +enum sbt_type { + SBT_RAYGEN, + SBT_MISS, + SBT_HIT, + SBT_CALLABLE, +}; + +static nir_ssa_def * +get_sbt_ptr(nir_builder *b, nir_ssa_def *idx, enum sbt_type binding) +{ + nir_ssa_def *desc = nir_load_sbt_amd(b, 4, .binding = binding); + nir_ssa_def *base_addr = nir_pack_64_2x32(b, nir_channels(b, desc, 0x3)); + nir_ssa_def *stride = nir_channel(b, desc, 2); + + nir_ssa_def *ret = nir_imul(b, idx, stride); + ret = nir_iadd(b, base_addr, nir_u2u64(b, ret)); + + return ret; +} + +static void +load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_ssa_def *idx, + enum sbt_type binding, unsigned offset) +{ + nir_ssa_def *addr = get_sbt_ptr(b, idx, binding); + + nir_ssa_def *load_addr = addr; + if (offset) + load_addr = nir_iadd(b, load_addr, nir_imm_int64(b, offset)); + nir_ssa_def *v_idx = + nir_build_load_global(b, 1, 32, load_addr, .align_mul = 4, .align_offset = 0); + + nir_store_var(b, vars->idx, v_idx, 1); + + nir_ssa_def *record_addr = nir_iadd(b, addr, nir_imm_int64(b, RADV_RT_HANDLE_SIZE)); + nir_store_var(b, vars->shader_record_ptr, record_addr, 1); +} + +static nir_ssa_def * +nir_build_vec3_mat_mult(nir_builder *b, nir_ssa_def *vec, nir_ssa_def *matrix[], bool translation) +{ + nir_ssa_def *result_components[3] = { + nir_channel(b, matrix[0], 3), + nir_channel(b, matrix[1], 3), + nir_channel(b, matrix[2], 3), + }; + for (unsigned i = 0; i < 3; ++i) { + for (unsigned j = 0; j < 3; ++j) { + nir_ssa_def *v = + nir_fmul(b, nir_channels(b, vec, 1 << j), nir_channels(b, matrix[i], 1 << j)); + result_components[i] = (translation || j) ? nir_fadd(b, result_components[i], v) : v; + } + } + return nir_vec(b, result_components, 3); +} + +static nir_ssa_def * +nir_build_vec3_mat_mult_pre(nir_builder *b, nir_ssa_def *vec, nir_ssa_def *matrix[]) +{ + nir_ssa_def *result_components[3] = { + nir_channel(b, matrix[0], 3), + nir_channel(b, matrix[1], 3), + nir_channel(b, matrix[2], 3), + }; + return nir_build_vec3_mat_mult(b, nir_fsub(b, vec, nir_vec(b, result_components, 3)), matrix, + false); +} + +static void +nir_build_wto_matrix_load(nir_builder *b, nir_ssa_def *instance_addr, nir_ssa_def **out) +{ + unsigned offset = offsetof(struct radv_bvh_instance_node, wto_matrix); + for (unsigned i = 0; i < 3; ++i) { + out[i] = nir_build_load_global(b, 4, 32, + nir_iadd(b, instance_addr, nir_imm_int64(b, offset + i * 16)), + .align_mul = 64, .align_offset = offset + i * 16); + } +} + +/* This lowers all the RT instructions that we do not want to pass on to the combined shader and + * that we can implement using the variables from the shader we are going to inline into. */ +static void +lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned call_idx_base) +{ + nir_builder b_shader; + nir_builder_init(&b_shader, nir_shader_get_entrypoint(shader)); + + nir_foreach_block (block, nir_shader_get_entrypoint(shader)) { + nir_foreach_instr_safe (instr, block) { + switch (instr->type) { + case nir_instr_type_intrinsic: { + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + switch (intr->intrinsic) { + case nir_intrinsic_rt_execute_callable: { + uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE; + uint32_t ret = call_idx_base + nir_intrinsic_call_idx(intr) + 1; + b_shader.cursor = nir_instr_remove(instr); + + nir_store_var(&b_shader, vars->stack_ptr, + nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), + nir_imm_int(&b_shader, size)), + 1); + nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret), + nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16, + .write_mask = 1); + + nir_store_var(&b_shader, vars->stack_ptr, + nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), + nir_imm_int(&b_shader, 16)), + 1); + load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_CALLABLE, 0); + + nir_store_var( + &b_shader, vars->arg, + nir_isub(&b_shader, intr->src[1].ssa, nir_imm_int(&b_shader, size + 16)), 1); + + vars->stack_sizes[vars->group_idx].recursive_size = + MAX2(vars->stack_sizes[vars->group_idx].recursive_size, size + 16); + break; + } + case nir_intrinsic_rt_trace_ray: { + uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE; + uint32_t ret = call_idx_base + nir_intrinsic_call_idx(intr) + 1; + b_shader.cursor = nir_instr_remove(instr); + + nir_store_var(&b_shader, vars->stack_ptr, + nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), + nir_imm_int(&b_shader, size)), + 1); + nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret), + nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16, + .write_mask = 1); + + nir_store_var(&b_shader, vars->stack_ptr, + nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), + nir_imm_int(&b_shader, 16)), + 1); + + nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 1), 1); + nir_store_var( + &b_shader, vars->arg, + nir_isub(&b_shader, intr->src[10].ssa, nir_imm_int(&b_shader, size + 16)), 1); + + vars->stack_sizes[vars->group_idx].recursive_size = + MAX2(vars->stack_sizes[vars->group_idx].recursive_size, size + 16); + + /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */ + nir_store_var(&b_shader, vars->accel_struct, intr->src[0].ssa, 0x1); + nir_store_var(&b_shader, vars->flags, intr->src[1].ssa, 0x1); + nir_store_var(&b_shader, vars->cull_mask, + nir_iand(&b_shader, intr->src[2].ssa, nir_imm_int(&b_shader, 0xff)), + 0x1); + nir_store_var(&b_shader, vars->sbt_offset, + nir_iand(&b_shader, intr->src[3].ssa, nir_imm_int(&b_shader, 0xf)), + 0x1); + nir_store_var(&b_shader, vars->sbt_stride, + nir_iand(&b_shader, intr->src[4].ssa, nir_imm_int(&b_shader, 0xf)), + 0x1); + nir_store_var(&b_shader, vars->miss_index, + nir_iand(&b_shader, intr->src[5].ssa, nir_imm_int(&b_shader, 0xffff)), + 0x1); + nir_store_var(&b_shader, vars->origin, intr->src[6].ssa, 0x7); + nir_store_var(&b_shader, vars->tmin, intr->src[7].ssa, 0x1); + nir_store_var(&b_shader, vars->direction, intr->src[8].ssa, 0x7); + nir_store_var(&b_shader, vars->tmax, intr->src[9].ssa, 0x1); + break; + } + case nir_intrinsic_rt_resume: { + uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE; + b_shader.cursor = nir_instr_remove(instr); + + nir_store_var(&b_shader, vars->stack_ptr, + nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), + nir_imm_int(&b_shader, -size)), + 1); + break; + } + case nir_intrinsic_rt_return_amd: { + b_shader.cursor = nir_instr_remove(instr); + + if (shader->info.stage == MESA_SHADER_RAYGEN) { + nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 0), 1); + break; + } + insert_rt_return(&b_shader, vars); + break; + } + case nir_intrinsic_load_scratch: { + b_shader.cursor = nir_before_instr(instr); + nir_instr_rewrite_src_ssa( + instr, &intr->src[0], + nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[0].ssa)); + break; + } + case nir_intrinsic_store_scratch: { + b_shader.cursor = nir_before_instr(instr); + nir_instr_rewrite_src_ssa( + instr, &intr->src[1], + nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[1].ssa)); + break; + } + case nir_intrinsic_load_rt_arg_scratch_offset_amd: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->arg); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_shader_record_ptr: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->shader_record_ptr); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_ray_launch_id: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_global_invocation_id(&b_shader, 32); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_ray_t_min: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->tmin); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_ray_t_max: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->tmax); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_ray_world_origin: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->origin); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_ray_world_direction: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->direction); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_ray_instance_custom_index: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->custom_instance_and_mask); + ret = nir_iand(&b_shader, ret, nir_imm_int(&b_shader, 0xFFFFFF)); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_primitive_id: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->primitive_id); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_ray_geometry_index: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->geometry_id_and_flags); + ret = nir_iand(&b_shader, ret, nir_imm_int(&b_shader, 0xFFFFFFF)); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_instance_id: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->instance_id); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_ray_flags: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->flags); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_ray_hit_kind: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->hit_kind); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_load_ray_world_to_object: { + unsigned c = nir_intrinsic_column(intr); + nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr); + nir_ssa_def *wto_matrix[3]; + nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix); + + nir_ssa_def *vals[3]; + for (unsigned i = 0; i < 3; ++i) + vals[i] = nir_channel(&b_shader, wto_matrix[i], c); + + nir_ssa_def *val = nir_vec(&b_shader, vals, 3); + if (c == 3) + val = nir_fneg(&b_shader, + nir_build_vec3_mat_mult(&b_shader, val, wto_matrix, false)); + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, val); + break; + } + case nir_intrinsic_load_ray_object_to_world: { + unsigned c = nir_intrinsic_column(intr); + nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr); + nir_ssa_def *val; + if (c == 3) { + nir_ssa_def *wto_matrix[3]; + nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix); + + nir_ssa_def *vals[3]; + for (unsigned i = 0; i < 3; ++i) + vals[i] = nir_channel(&b_shader, wto_matrix[i], c); + + val = nir_vec(&b_shader, vals, 3); + } else { + val = nir_build_load_global( + &b_shader, 3, 32, + nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 92 + c * 12)), + .align_mul = 4, .align_offset = 0); + } + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, val); + break; + } + case nir_intrinsic_load_ray_object_origin: { + nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr); + nir_ssa_def *wto_matrix[] = { + nir_build_load_global( + &b_shader, 4, 32, + nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 16)), + .align_mul = 64, .align_offset = 16), + nir_build_load_global( + &b_shader, 4, 32, + nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 32)), + .align_mul = 64, .align_offset = 32), + nir_build_load_global( + &b_shader, 4, 32, + nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 48)), + .align_mul = 64, .align_offset = 48)}; + nir_ssa_def *val = nir_build_vec3_mat_mult_pre( + &b_shader, nir_load_var(&b_shader, vars->origin), wto_matrix); + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, val); + break; + } + case nir_intrinsic_load_ray_object_direction: { + nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr); + nir_ssa_def *wto_matrix[3]; + nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix); + nir_ssa_def *val = nir_build_vec3_mat_mult( + &b_shader, nir_load_var(&b_shader, vars->direction), wto_matrix, false); + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, val); + break; + } + case nir_intrinsic_load_intersection_opaque_amd: { + b_shader.cursor = nir_instr_remove(instr); + nir_ssa_def *ret = nir_load_var(&b_shader, vars->opaque); + nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); + break; + } + case nir_intrinsic_ignore_ray_intersection: { + b_shader.cursor = nir_instr_remove(instr); + nir_store_var(&b_shader, vars->ahit_status, nir_imm_int(&b_shader, 1), 1); + + /* The if is a workaround to avoid having to fix up control flow manually */ + nir_push_if(&b_shader, nir_imm_true(&b_shader)); + nir_jump(&b_shader, nir_jump_return); + nir_pop_if(&b_shader, NULL); + break; + } + case nir_intrinsic_terminate_ray: { + b_shader.cursor = nir_instr_remove(instr); + nir_store_var(&b_shader, vars->ahit_status, nir_imm_int(&b_shader, 2), 1); + + /* The if is a workaround to avoid having to fix up control flow manually */ + nir_push_if(&b_shader, nir_imm_true(&b_shader)); + nir_jump(&b_shader, nir_jump_return); + nir_pop_if(&b_shader, NULL); + break; + } + case nir_intrinsic_report_ray_intersection: { + b_shader.cursor = nir_instr_remove(instr); + nir_push_if( + &b_shader, + nir_iand( + &b_shader, + nir_flt(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmax)), + nir_fge(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmin)))); + { + nir_store_var(&b_shader, vars->ahit_status, nir_imm_int(&b_shader, 0), 1); + nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 1); + nir_store_var(&b_shader, vars->hit_kind, intr->src[1].ssa, 1); + } + nir_pop_if(&b_shader, NULL); + break; + } + default: + break; + } + break; + } + case nir_instr_type_jump: { + nir_jump_instr *jump = nir_instr_as_jump(instr); + if (jump->type == nir_jump_halt) { + b_shader.cursor = nir_instr_remove(instr); + nir_jump(&b_shader, nir_jump_return); + } + break; + } + default: + break; + } + } + } + + nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_none); +} + +static void +insert_rt_case(nir_builder *b, nir_shader *shader, const struct rt_variables *vars, + nir_ssa_def *idx, uint32_t call_idx_base, uint32_t call_idx) +{ + struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL); + + nir_opt_dead_cf(shader); + + struct rt_variables src_vars = create_rt_variables(shader, vars->stack_sizes); + map_rt_variables(var_remap, &src_vars, vars); + + NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base); + + NIR_PASS_V(shader, nir_opt_remove_phis); + NIR_PASS_V(shader, nir_lower_returns); + NIR_PASS_V(shader, nir_opt_dce); + + if (b->shader->info.stage == MESA_SHADER_ANY_HIT || + b->shader->info.stage == MESA_SHADER_INTERSECTION) { + src_vars.stack_sizes[src_vars.group_idx].non_recursive_size = + MAX2(src_vars.stack_sizes[src_vars.group_idx].non_recursive_size, shader->scratch_size); + } else { + src_vars.stack_sizes[src_vars.group_idx].recursive_size = + MAX2(src_vars.stack_sizes[src_vars.group_idx].recursive_size, shader->scratch_size); + } + + nir_push_if(b, nir_ieq(b, idx, nir_imm_int(b, call_idx))); + nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1); + nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap); + nir_pop_if(b, NULL); + + /* Adopt the instructions from the source shader, since they are merely moved, not cloned. */ + ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader)); + + ralloc_free(var_remap); +} + static nir_shader * create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_pipeline_shader_stack_size *stack_sizes)