diff --git a/src/amd/common/ac_nir_to_llvm.c b/src/amd/common/ac_nir_to_llvm.c index 41b67f1fd36..5c59d472b20 100644 --- a/src/amd/common/ac_nir_to_llvm.c +++ b/src/amd/common/ac_nir_to_llvm.c @@ -451,20 +451,6 @@ static void set_userdata_location_indirect(struct ac_userdata_info *ud_info, uin } #endif -#define AC_USERDATA_DESCRIPTOR_SET_0 0 -#define AC_USERDATA_DESCRIPTOR_SET_1 2 -#define AC_USERDATA_DESCRIPTOR_SET_2 4 -#define AC_USERDATA_DESCRIPTOR_SET_3 6 -#define AC_USERDATA_PUSH_CONST_DYN 8 - -#define AC_USERDATA_VS_VERTEX_BUFFERS 10 -#define AC_USERDATA_VS_BASE_VERTEX 12 -#define AC_USERDATA_VS_START_INSTANCE 13 - -#define AC_USERDATA_PS_SAMPLE_POS 10 - -#define AC_USERDATA_CS_GRID_SIZE 10 - static void create_function(struct nir_to_llvm_context *ctx, struct nir_shader *nir) { @@ -473,10 +459,15 @@ static void create_function(struct nir_to_llvm_context *ctx, unsigned array_count = 0; unsigned sgpr_count = 0, user_sgpr_count; unsigned i; + unsigned num_sets = ctx->options->layout ? ctx->options->layout->num_sets : 0; + unsigned user_sgpr_idx; /* 1 for each descriptor set */ - for (unsigned i = 0; i < 4; ++i) - arg_types[arg_idx++] = const_array(ctx->i8, 1024 * 1024); + for (unsigned i = 0; i < num_sets; ++i) { + if (ctx->options->layout->set[i].layout->shader_stages & (1 << ctx->stage)) { + arg_types[arg_idx++] = const_array(ctx->i8, 1024 * 1024); + } + } /* 1 for push constants and dynamic descriptors */ arg_types[arg_idx++] = const_array(ctx->i8, 1024 * 1024); @@ -549,18 +540,25 @@ static void create_function(struct nir_to_llvm_context *ctx, ctx->shader_info->num_input_vgprs += llvm_get_type_size(arg_types[i]) / 4; arg_idx = 0; - for (unsigned i = 0; i < 4; ++i) { - set_userdata_location(&ctx->shader_info->user_sgprs_locs.descriptor_sets[i], i * 2, 2); - ctx->descriptor_sets[i] = - LLVMGetParam(ctx->main_function, arg_idx++); + user_sgpr_idx = 0; + for (unsigned i = 0; i < num_sets; ++i) { + if (ctx->options->layout->set[i].layout->shader_stages & (1 << ctx->stage)) { + set_userdata_location(&ctx->shader_info->user_sgprs_locs.descriptor_sets[i], user_sgpr_idx, 2); + user_sgpr_idx += 2; + ctx->descriptor_sets[i] = + LLVMGetParam(ctx->main_function, arg_idx++); + } else + ctx->descriptor_sets[i] = NULL; } ctx->push_constants = LLVMGetParam(ctx->main_function, arg_idx++); - set_userdata_location_shader(ctx, AC_UD_PUSH_CONSTANTS, AC_USERDATA_PUSH_CONST_DYN, 2); + set_userdata_location_shader(ctx, AC_UD_PUSH_CONSTANTS, user_sgpr_idx, 2); + user_sgpr_idx += 2; switch (nir->stage) { case MESA_SHADER_COMPUTE: - set_userdata_location_shader(ctx, AC_UD_CS_GRID_SIZE, AC_USERDATA_CS_GRID_SIZE, 3); + set_userdata_location_shader(ctx, AC_UD_CS_GRID_SIZE, user_sgpr_idx, 3); + user_sgpr_idx += 3; ctx->num_work_groups = LLVMGetParam(ctx->main_function, arg_idx++); ctx->workgroup_ids = @@ -571,9 +569,11 @@ static void create_function(struct nir_to_llvm_context *ctx, LLVMGetParam(ctx->main_function, arg_idx++); break; case MESA_SHADER_VERTEX: - set_userdata_location_shader(ctx, AC_UD_VS_VERTEX_BUFFERS, AC_USERDATA_VS_VERTEX_BUFFERS, 2); + set_userdata_location_shader(ctx, AC_UD_VS_VERTEX_BUFFERS, user_sgpr_idx, 2); + user_sgpr_idx += 2; ctx->vertex_buffers = LLVMGetParam(ctx->main_function, arg_idx++); - set_userdata_location_shader(ctx, AC_UD_VS_BASE_VERTEX_START_INSTANCE, AC_USERDATA_VS_BASE_VERTEX, 2); + set_userdata_location_shader(ctx, AC_UD_VS_BASE_VERTEX_START_INSTANCE, user_sgpr_idx, 2); + user_sgpr_idx += 2; ctx->base_vertex = LLVMGetParam(ctx->main_function, arg_idx++); ctx->start_instance = LLVMGetParam(ctx->main_function, arg_idx++); ctx->vertex_id = LLVMGetParam(ctx->main_function, arg_idx++); @@ -582,7 +582,8 @@ static void create_function(struct nir_to_llvm_context *ctx, ctx->instance_id = LLVMGetParam(ctx->main_function, arg_idx++); break; case MESA_SHADER_FRAGMENT: - set_userdata_location_shader(ctx, AC_UD_PS_SAMPLE_POS, AC_USERDATA_PS_SAMPLE_POS, 2); + set_userdata_location_shader(ctx, AC_UD_PS_SAMPLE_POS, user_sgpr_idx, 2); + user_sgpr_idx += 2; ctx->sample_positions = LLVMGetParam(ctx->main_function, arg_idx++); ctx->prim_mask = LLVMGetParam(ctx->main_function, arg_idx++); ctx->persp_sample = LLVMGetParam(ctx->main_function, arg_idx++);