intel/fs: fix metadata preserve on trace_ray intrinsic

c78be5da30 ("intel/fs: lower ray query intrinsics") introduced a
helper function using nir_(push|pop)_if which invalidated dominance &
block_index for the replacement of nir_intrinsic_rt_trace_ray.

We can still keep dominance/block_index metadata for the lowering of
nir_intrinsic_rt_execute_callable though.

This change uses 2 different lowering function with correct metadata
preservation.

Signed-off-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Fixes: c78be5da30 ("intel/fs: lower ray query intrinsics")
Reviewed-by: Marcin Ślusarz <marcin.slusarz@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15910>
This commit is contained in:
Lionel Landwerlin 2022-04-12 21:59:58 +03:00 committed by Marge Bot
parent fcd6b2a47a
commit 9c0805ef91
1 changed files with 119 additions and 106 deletions

View File

@ -125,7 +125,7 @@ store_resume_addr(nir_builder *b, nir_intrinsic_instr *call)
}
static bool
lower_shader_calls_instr(struct nir_builder *b, nir_instr *instr, void *data)
lower_shader_trace_ray_instr(struct nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
@ -134,117 +134,130 @@ lower_shader_calls_instr(struct nir_builder *b, nir_instr *instr, void *data)
* brw_nir_lower_rt_intrinsics()
*/
nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
switch (call->intrinsic) {
case nir_intrinsic_rt_trace_ray: {
b->cursor = nir_instr_remove(instr);
store_resume_addr(b, call);
nir_ssa_def *as_addr = call->src[0].ssa;
nir_ssa_def *ray_flags = call->src[1].ssa;
/* From the SPIR-V spec:
*
* "Only the 8 least-significant bits of Cull Mask are used by this
* instruction - other bits are ignored.
*
* Only the 4 least-significant bits of SBT Offset and SBT Stride are
* used by this instruction - other bits are ignored.
*
* Only the 16 least-significant bits of Miss Index are used by this
* instruction - other bits are ignored."
*/
nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff);
nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf);
nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf);
nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff);
nir_ssa_def *ray_orig = call->src[6].ssa;
nir_ssa_def *ray_t_min = call->src[7].ssa;
nir_ssa_def *ray_dir = call->src[8].ssa;
nir_ssa_def *ray_t_max = call->src[9].ssa;
nir_ssa_def *root_node_ptr =
brw_nir_rt_acceleration_structure_to_root_node(b, as_addr);
/* The hardware packet requires an address to the first element of the
* hit SBT.
*
* In order to calculate this, we must multiply the "SBT Offset"
* provided to OpTraceRay by the SBT stride provided for the hit SBT in
* the call to vkCmdTraceRay() and add that to the base address of the
* hit SBT. This stride is not to be confused with the "SBT Stride"
* provided to OpTraceRay which is in units of this stride. It's a
* rather terrible overload of the word "stride". The hardware docs
* calls the SPIR-V stride value the "shader index multiplier" which is
* a much more sane name.
*/
nir_ssa_def *hit_sbt_stride_B =
nir_load_ray_hit_sbt_stride_intel(b);
nir_ssa_def *hit_sbt_offset_B =
nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B));
nir_ssa_def *hit_sbt_addr =
nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b),
nir_u2u64(b, hit_sbt_offset_B));
/* The hardware packet takes an address to the miss BSR. */
nir_ssa_def *miss_sbt_stride_B =
nir_load_ray_miss_sbt_stride_intel(b);
nir_ssa_def *miss_sbt_offset_B =
nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B));
nir_ssa_def *miss_sbt_addr =
nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b),
nir_u2u64(b, miss_sbt_offset_B));
struct brw_nir_rt_mem_ray_defs ray_defs = {
.root_node_ptr = root_node_ptr,
.ray_flags = nir_u2u16(b, ray_flags),
.ray_mask = cull_mask,
.hit_group_sr_base_ptr = hit_sbt_addr,
.hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
.miss_sr_ptr = miss_sbt_addr,
.orig = ray_orig,
.t_near = ray_t_min,
.dir = ray_dir,
.t_far = ray_t_max,
.shader_index_multiplier = sbt_stride,
};
brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
nir_trace_ray_intel(b,
nir_load_btd_global_arg_addr_intel(b),
nir_imm_int(b, BRW_RT_BVH_LEVEL_WORLD),
nir_imm_int(b, GEN_RT_TRACE_RAY_INITAL),
.synchronous = false);
return true;
}
case nir_intrinsic_rt_execute_callable: {
b->cursor = nir_instr_remove(instr);
store_resume_addr(b, call);
nir_ssa_def *sbt_offset32 =
nir_imul(b, call->src[0].ssa,
nir_u2u32(b, nir_load_callable_sbt_stride_intel(b)));
nir_ssa_def *sbt_addr =
nir_iadd(b, nir_load_callable_sbt_addr_intel(b),
nir_u2u64(b, sbt_offset32));
brw_nir_btd_spawn(b, sbt_addr);
return true;
}
default:
if (call->intrinsic != nir_intrinsic_rt_trace_ray)
return false;
}
b->cursor = nir_instr_remove(instr);
store_resume_addr(b, call);
nir_ssa_def *as_addr = call->src[0].ssa;
nir_ssa_def *ray_flags = call->src[1].ssa;
/* From the SPIR-V spec:
*
* "Only the 8 least-significant bits of Cull Mask are used by this
* instruction - other bits are ignored.
*
* Only the 4 least-significant bits of SBT Offset and SBT Stride are
* used by this instruction - other bits are ignored.
*
* Only the 16 least-significant bits of Miss Index are used by this
* instruction - other bits are ignored."
*/
nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff);
nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf);
nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf);
nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff);
nir_ssa_def *ray_orig = call->src[6].ssa;
nir_ssa_def *ray_t_min = call->src[7].ssa;
nir_ssa_def *ray_dir = call->src[8].ssa;
nir_ssa_def *ray_t_max = call->src[9].ssa;
nir_ssa_def *root_node_ptr =
brw_nir_rt_acceleration_structure_to_root_node(b, as_addr);
/* The hardware packet requires an address to the first element of the
* hit SBT.
*
* In order to calculate this, we must multiply the "SBT Offset"
* provided to OpTraceRay by the SBT stride provided for the hit SBT in
* the call to vkCmdTraceRay() and add that to the base address of the
* hit SBT. This stride is not to be confused with the "SBT Stride"
* provided to OpTraceRay which is in units of this stride. It's a
* rather terrible overload of the word "stride". The hardware docs
* calls the SPIR-V stride value the "shader index multiplier" which is
* a much more sane name.
*/
nir_ssa_def *hit_sbt_stride_B =
nir_load_ray_hit_sbt_stride_intel(b);
nir_ssa_def *hit_sbt_offset_B =
nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B));
nir_ssa_def *hit_sbt_addr =
nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b),
nir_u2u64(b, hit_sbt_offset_B));
/* The hardware packet takes an address to the miss BSR. */
nir_ssa_def *miss_sbt_stride_B =
nir_load_ray_miss_sbt_stride_intel(b);
nir_ssa_def *miss_sbt_offset_B =
nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B));
nir_ssa_def *miss_sbt_addr =
nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b),
nir_u2u64(b, miss_sbt_offset_B));
struct brw_nir_rt_mem_ray_defs ray_defs = {
.root_node_ptr = root_node_ptr,
.ray_flags = nir_u2u16(b, ray_flags),
.ray_mask = cull_mask,
.hit_group_sr_base_ptr = hit_sbt_addr,
.hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
.miss_sr_ptr = miss_sbt_addr,
.orig = ray_orig,
.t_near = ray_t_min,
.dir = ray_dir,
.t_far = ray_t_max,
.shader_index_multiplier = sbt_stride,
};
brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
nir_trace_ray_intel(b,
nir_load_btd_global_arg_addr_intel(b),
nir_imm_int(b, BRW_RT_BVH_LEVEL_WORLD),
nir_imm_int(b, GEN_RT_TRACE_RAY_INITAL),
.synchronous = false);
return true;
}
static bool
lower_shader_call_instr(struct nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
/* Leave nir_intrinsic_rt_resume to be lowered by
* brw_nir_lower_rt_intrinsics()
*/
nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
if (call->intrinsic != nir_intrinsic_rt_execute_callable)
return false;
b->cursor = nir_instr_remove(instr);
store_resume_addr(b, call);
nir_ssa_def *sbt_offset32 =
nir_imul(b, call->src[0].ssa,
nir_u2u32(b, nir_load_callable_sbt_stride_intel(b)));
nir_ssa_def *sbt_addr =
nir_iadd(b, nir_load_callable_sbt_addr_intel(b),
nir_u2u64(b, sbt_offset32));
brw_nir_btd_spawn(b, sbt_addr);
return true;
}
bool
brw_nir_lower_shader_calls(nir_shader *shader)
{
return nir_shader_instructions_pass(shader,
lower_shader_calls_instr,
nir_metadata_block_index |
nir_metadata_dominance,
NULL);
return
nir_shader_instructions_pass(shader,
lower_shader_trace_ray_instr,
nir_metadata_none,
NULL) |
nir_shader_instructions_pass(shader,
lower_shader_call_instr,
nir_metadata_block_index |
nir_metadata_dominance,
NULL);
}
/** Creates a trivial return shader