radv: Add helper to inline shaders into the main shader.

Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12592>
This commit is contained in:
Bas Nieuwenhuizen 2021-08-27 02:07:57 +02:00 committed by Marge Bot
parent dcb02dbe73
commit 207ce6d658
1 changed files with 468 additions and 0 deletions

View File

@ -21,6 +21,7 @@
* IN THE SOFTWARE.
*/
#include "radv_acceleration_structure.h"
#include "radv_private.h"
#include "radv_shader.h"
@ -303,6 +304,473 @@ create_inner_vars(nir_builder *b, const struct rt_variables *vars)
return inner_vars;
}
/* The hit attributes are stored on the stack. This is the offset compared to the current stack
* pointer of where the hit attrib is stored. */
const uint32_t RADV_HIT_ATTRIB_OFFSET = -(16 + RADV_MAX_HIT_ATTRIB_SIZE);
static void
insert_rt_return(nir_builder *b, const struct rt_variables *vars)
{
nir_store_var(b, vars->stack_ptr,
nir_iadd(b, nir_load_var(b, vars->stack_ptr), nir_imm_int(b, -16)), 1);
nir_store_var(b, vars->idx,
nir_load_scratch(b, 1, 32, nir_load_var(b, vars->stack_ptr), .align_mul = 16), 1);
}
enum sbt_type {
SBT_RAYGEN,
SBT_MISS,
SBT_HIT,
SBT_CALLABLE,
};
static nir_ssa_def *
get_sbt_ptr(nir_builder *b, nir_ssa_def *idx, enum sbt_type binding)
{
nir_ssa_def *desc = nir_load_sbt_amd(b, 4, .binding = binding);
nir_ssa_def *base_addr = nir_pack_64_2x32(b, nir_channels(b, desc, 0x3));
nir_ssa_def *stride = nir_channel(b, desc, 2);
nir_ssa_def *ret = nir_imul(b, idx, stride);
ret = nir_iadd(b, base_addr, nir_u2u64(b, ret));
return ret;
}
static void
load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_ssa_def *idx,
enum sbt_type binding, unsigned offset)
{
nir_ssa_def *addr = get_sbt_ptr(b, idx, binding);
nir_ssa_def *load_addr = addr;
if (offset)
load_addr = nir_iadd(b, load_addr, nir_imm_int64(b, offset));
nir_ssa_def *v_idx =
nir_build_load_global(b, 1, 32, load_addr, .align_mul = 4, .align_offset = 0);
nir_store_var(b, vars->idx, v_idx, 1);
nir_ssa_def *record_addr = nir_iadd(b, addr, nir_imm_int64(b, RADV_RT_HANDLE_SIZE));
nir_store_var(b, vars->shader_record_ptr, record_addr, 1);
}
static nir_ssa_def *
nir_build_vec3_mat_mult(nir_builder *b, nir_ssa_def *vec, nir_ssa_def *matrix[], bool translation)
{
nir_ssa_def *result_components[3] = {
nir_channel(b, matrix[0], 3),
nir_channel(b, matrix[1], 3),
nir_channel(b, matrix[2], 3),
};
for (unsigned i = 0; i < 3; ++i) {
for (unsigned j = 0; j < 3; ++j) {
nir_ssa_def *v =
nir_fmul(b, nir_channels(b, vec, 1 << j), nir_channels(b, matrix[i], 1 << j));
result_components[i] = (translation || j) ? nir_fadd(b, result_components[i], v) : v;
}
}
return nir_vec(b, result_components, 3);
}
static nir_ssa_def *
nir_build_vec3_mat_mult_pre(nir_builder *b, nir_ssa_def *vec, nir_ssa_def *matrix[])
{
nir_ssa_def *result_components[3] = {
nir_channel(b, matrix[0], 3),
nir_channel(b, matrix[1], 3),
nir_channel(b, matrix[2], 3),
};
return nir_build_vec3_mat_mult(b, nir_fsub(b, vec, nir_vec(b, result_components, 3)), matrix,
false);
}
static void
nir_build_wto_matrix_load(nir_builder *b, nir_ssa_def *instance_addr, nir_ssa_def **out)
{
unsigned offset = offsetof(struct radv_bvh_instance_node, wto_matrix);
for (unsigned i = 0; i < 3; ++i) {
out[i] = nir_build_load_global(b, 4, 32,
nir_iadd(b, instance_addr, nir_imm_int64(b, offset + i * 16)),
.align_mul = 64, .align_offset = offset + i * 16);
}
}
/* This lowers all the RT instructions that we do not want to pass on to the combined shader and
* that we can implement using the variables from the shader we are going to inline into. */
static void
lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned call_idx_base)
{
nir_builder b_shader;
nir_builder_init(&b_shader, nir_shader_get_entrypoint(shader));
nir_foreach_block (block, nir_shader_get_entrypoint(shader)) {
nir_foreach_instr_safe (instr, block) {
switch (instr->type) {
case nir_instr_type_intrinsic: {
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
switch (intr->intrinsic) {
case nir_intrinsic_rt_execute_callable: {
uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
uint32_t ret = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
b_shader.cursor = nir_instr_remove(instr);
nir_store_var(&b_shader, vars->stack_ptr,
nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr),
nir_imm_int(&b_shader, size)),
1);
nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret),
nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16,
.write_mask = 1);
nir_store_var(&b_shader, vars->stack_ptr,
nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr),
nir_imm_int(&b_shader, 16)),
1);
load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_CALLABLE, 0);
nir_store_var(
&b_shader, vars->arg,
nir_isub(&b_shader, intr->src[1].ssa, nir_imm_int(&b_shader, size + 16)), 1);
vars->stack_sizes[vars->group_idx].recursive_size =
MAX2(vars->stack_sizes[vars->group_idx].recursive_size, size + 16);
break;
}
case nir_intrinsic_rt_trace_ray: {
uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
uint32_t ret = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
b_shader.cursor = nir_instr_remove(instr);
nir_store_var(&b_shader, vars->stack_ptr,
nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr),
nir_imm_int(&b_shader, size)),
1);
nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret),
nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16,
.write_mask = 1);
nir_store_var(&b_shader, vars->stack_ptr,
nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr),
nir_imm_int(&b_shader, 16)),
1);
nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 1), 1);
nir_store_var(
&b_shader, vars->arg,
nir_isub(&b_shader, intr->src[10].ssa, nir_imm_int(&b_shader, size + 16)), 1);
vars->stack_sizes[vars->group_idx].recursive_size =
MAX2(vars->stack_sizes[vars->group_idx].recursive_size, size + 16);
/* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
nir_store_var(&b_shader, vars->accel_struct, intr->src[0].ssa, 0x1);
nir_store_var(&b_shader, vars->flags, intr->src[1].ssa, 0x1);
nir_store_var(&b_shader, vars->cull_mask,
nir_iand(&b_shader, intr->src[2].ssa, nir_imm_int(&b_shader, 0xff)),
0x1);
nir_store_var(&b_shader, vars->sbt_offset,
nir_iand(&b_shader, intr->src[3].ssa, nir_imm_int(&b_shader, 0xf)),
0x1);
nir_store_var(&b_shader, vars->sbt_stride,
nir_iand(&b_shader, intr->src[4].ssa, nir_imm_int(&b_shader, 0xf)),
0x1);
nir_store_var(&b_shader, vars->miss_index,
nir_iand(&b_shader, intr->src[5].ssa, nir_imm_int(&b_shader, 0xffff)),
0x1);
nir_store_var(&b_shader, vars->origin, intr->src[6].ssa, 0x7);
nir_store_var(&b_shader, vars->tmin, intr->src[7].ssa, 0x1);
nir_store_var(&b_shader, vars->direction, intr->src[8].ssa, 0x7);
nir_store_var(&b_shader, vars->tmax, intr->src[9].ssa, 0x1);
break;
}
case nir_intrinsic_rt_resume: {
uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
b_shader.cursor = nir_instr_remove(instr);
nir_store_var(&b_shader, vars->stack_ptr,
nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr),
nir_imm_int(&b_shader, -size)),
1);
break;
}
case nir_intrinsic_rt_return_amd: {
b_shader.cursor = nir_instr_remove(instr);
if (shader->info.stage == MESA_SHADER_RAYGEN) {
nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 0), 1);
break;
}
insert_rt_return(&b_shader, vars);
break;
}
case nir_intrinsic_load_scratch: {
b_shader.cursor = nir_before_instr(instr);
nir_instr_rewrite_src_ssa(
instr, &intr->src[0],
nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[0].ssa));
break;
}
case nir_intrinsic_store_scratch: {
b_shader.cursor = nir_before_instr(instr);
nir_instr_rewrite_src_ssa(
instr, &intr->src[1],
nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[1].ssa));
break;
}
case nir_intrinsic_load_rt_arg_scratch_offset_amd: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->arg);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_shader_record_ptr: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->shader_record_ptr);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_ray_launch_id: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_global_invocation_id(&b_shader, 32);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_ray_t_min: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->tmin);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_ray_t_max: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->tmax);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_ray_world_origin: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->origin);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_ray_world_direction: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->direction);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_ray_instance_custom_index: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->custom_instance_and_mask);
ret = nir_iand(&b_shader, ret, nir_imm_int(&b_shader, 0xFFFFFF));
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_primitive_id: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->primitive_id);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_ray_geometry_index: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->geometry_id_and_flags);
ret = nir_iand(&b_shader, ret, nir_imm_int(&b_shader, 0xFFFFFFF));
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_instance_id: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->instance_id);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_ray_flags: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->flags);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_ray_hit_kind: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->hit_kind);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_load_ray_world_to_object: {
unsigned c = nir_intrinsic_column(intr);
nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
nir_ssa_def *wto_matrix[3];
nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
nir_ssa_def *vals[3];
for (unsigned i = 0; i < 3; ++i)
vals[i] = nir_channel(&b_shader, wto_matrix[i], c);
nir_ssa_def *val = nir_vec(&b_shader, vals, 3);
if (c == 3)
val = nir_fneg(&b_shader,
nir_build_vec3_mat_mult(&b_shader, val, wto_matrix, false));
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, val);
break;
}
case nir_intrinsic_load_ray_object_to_world: {
unsigned c = nir_intrinsic_column(intr);
nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
nir_ssa_def *val;
if (c == 3) {
nir_ssa_def *wto_matrix[3];
nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
nir_ssa_def *vals[3];
for (unsigned i = 0; i < 3; ++i)
vals[i] = nir_channel(&b_shader, wto_matrix[i], c);
val = nir_vec(&b_shader, vals, 3);
} else {
val = nir_build_load_global(
&b_shader, 3, 32,
nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 92 + c * 12)),
.align_mul = 4, .align_offset = 0);
}
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, val);
break;
}
case nir_intrinsic_load_ray_object_origin: {
nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
nir_ssa_def *wto_matrix[] = {
nir_build_load_global(
&b_shader, 4, 32,
nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 16)),
.align_mul = 64, .align_offset = 16),
nir_build_load_global(
&b_shader, 4, 32,
nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 32)),
.align_mul = 64, .align_offset = 32),
nir_build_load_global(
&b_shader, 4, 32,
nir_iadd(&b_shader, instance_node_addr, nir_imm_int64(&b_shader, 48)),
.align_mul = 64, .align_offset = 48)};
nir_ssa_def *val = nir_build_vec3_mat_mult_pre(
&b_shader, nir_load_var(&b_shader, vars->origin), wto_matrix);
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, val);
break;
}
case nir_intrinsic_load_ray_object_direction: {
nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
nir_ssa_def *wto_matrix[3];
nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
nir_ssa_def *val = nir_build_vec3_mat_mult(
&b_shader, nir_load_var(&b_shader, vars->direction), wto_matrix, false);
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, val);
break;
}
case nir_intrinsic_load_intersection_opaque_amd: {
b_shader.cursor = nir_instr_remove(instr);
nir_ssa_def *ret = nir_load_var(&b_shader, vars->opaque);
nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
break;
}
case nir_intrinsic_ignore_ray_intersection: {
b_shader.cursor = nir_instr_remove(instr);
nir_store_var(&b_shader, vars->ahit_status, nir_imm_int(&b_shader, 1), 1);
/* The if is a workaround to avoid having to fix up control flow manually */
nir_push_if(&b_shader, nir_imm_true(&b_shader));
nir_jump(&b_shader, nir_jump_return);
nir_pop_if(&b_shader, NULL);
break;
}
case nir_intrinsic_terminate_ray: {
b_shader.cursor = nir_instr_remove(instr);
nir_store_var(&b_shader, vars->ahit_status, nir_imm_int(&b_shader, 2), 1);
/* The if is a workaround to avoid having to fix up control flow manually */
nir_push_if(&b_shader, nir_imm_true(&b_shader));
nir_jump(&b_shader, nir_jump_return);
nir_pop_if(&b_shader, NULL);
break;
}
case nir_intrinsic_report_ray_intersection: {
b_shader.cursor = nir_instr_remove(instr);
nir_push_if(
&b_shader,
nir_iand(
&b_shader,
nir_flt(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmax)),
nir_fge(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmin))));
{
nir_store_var(&b_shader, vars->ahit_status, nir_imm_int(&b_shader, 0), 1);
nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 1);
nir_store_var(&b_shader, vars->hit_kind, intr->src[1].ssa, 1);
}
nir_pop_if(&b_shader, NULL);
break;
}
default:
break;
}
break;
}
case nir_instr_type_jump: {
nir_jump_instr *jump = nir_instr_as_jump(instr);
if (jump->type == nir_jump_halt) {
b_shader.cursor = nir_instr_remove(instr);
nir_jump(&b_shader, nir_jump_return);
}
break;
}
default:
break;
}
}
}
nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_none);
}
static void
insert_rt_case(nir_builder *b, nir_shader *shader, const struct rt_variables *vars,
nir_ssa_def *idx, uint32_t call_idx_base, uint32_t call_idx)
{
struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
nir_opt_dead_cf(shader);
struct rt_variables src_vars = create_rt_variables(shader, vars->stack_sizes);
map_rt_variables(var_remap, &src_vars, vars);
NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base);
NIR_PASS_V(shader, nir_opt_remove_phis);
NIR_PASS_V(shader, nir_lower_returns);
NIR_PASS_V(shader, nir_opt_dce);
if (b->shader->info.stage == MESA_SHADER_ANY_HIT ||
b->shader->info.stage == MESA_SHADER_INTERSECTION) {
src_vars.stack_sizes[src_vars.group_idx].non_recursive_size =
MAX2(src_vars.stack_sizes[src_vars.group_idx].non_recursive_size, shader->scratch_size);
} else {
src_vars.stack_sizes[src_vars.group_idx].recursive_size =
MAX2(src_vars.stack_sizes[src_vars.group_idx].recursive_size, shader->scratch_size);
}
nir_push_if(b, nir_ieq(b, idx, nir_imm_int(b, call_idx)));
nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1);
nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
nir_pop_if(b, NULL);
/* Adopt the instructions from the source shader, since they are merely moved, not cloned. */
ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader));
ralloc_free(var_remap);
}
static nir_shader *
create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_pipeline_shader_stack_size *stack_sizes)