spirv: Pass SSA values through functions

Previously, we would create temporary variables and fill them out.
Instead, we create as many function parameters as we need and pass them
through as SSA defs.

Reviewed-by: Iago Toral Quiroga <itoral@igalia.com>
This commit is contained in:
Jason Ekstrand 2018-09-22 09:46:26 -05:00
parent bfe0e32913
commit a45b6fb452
1 changed files with 139 additions and 41 deletions

View File

@ -42,6 +42,135 @@ vtn_load_param_pointer(struct vtn_builder *b,
return vtn_pointer_from_ssa(b, nir_load_param(&b->nb, param_idx), ptr_type);
}
static unsigned
vtn_type_count_function_params(struct vtn_type *type)
{
switch (type->base_type) {
case vtn_base_type_array:
return type->length * vtn_type_count_function_params(type->array_element);
case vtn_base_type_struct: {
unsigned count = 0;
for (unsigned i = 0; i < type->length; i++)
count += vtn_type_count_function_params(type->members[i]);
return count;
}
case vtn_base_type_sampled_image:
return 2;
default:
return 1;
}
}
static void
vtn_type_add_to_function_params(struct vtn_type *type,
nir_function *func,
unsigned *param_idx)
{
static const nir_parameter nir_deref_param = {
.num_components = 1,
.bit_size = 32,
};
switch (type->base_type) {
case vtn_base_type_array:
for (unsigned i = 0; i < type->length; i++)
vtn_type_add_to_function_params(type->array_element, func, param_idx);
break;
case vtn_base_type_struct:
for (unsigned i = 0; i < type->length; i++)
vtn_type_add_to_function_params(type->members[i], func, param_idx);
break;
case vtn_base_type_sampled_image:
func->params[(*param_idx)++] = nir_deref_param;
func->params[(*param_idx)++] = nir_deref_param;
break;
case vtn_base_type_image:
case vtn_base_type_sampler:
func->params[(*param_idx)++] = nir_deref_param;
break;
case vtn_base_type_pointer:
if (type->type) {
func->params[(*param_idx)++] = (nir_parameter) {
.num_components = glsl_get_vector_elements(type->type),
.bit_size = glsl_get_bit_size(type->type),
};
} else {
func->params[(*param_idx)++] = nir_deref_param;
}
break;
default:
func->params[(*param_idx)++] = (nir_parameter) {
.num_components = glsl_get_vector_elements(type->type),
.bit_size = glsl_get_bit_size(type->type),
};
}
}
static void
vtn_ssa_value_add_to_call_params(struct vtn_builder *b,
struct vtn_ssa_value *value,
struct vtn_type *type,
nir_call_instr *call,
unsigned *param_idx)
{
switch (type->base_type) {
case vtn_base_type_array:
for (unsigned i = 0; i < type->length; i++) {
vtn_ssa_value_add_to_call_params(b, value->elems[i],
type->array_element,
call, param_idx);
}
break;
case vtn_base_type_struct:
for (unsigned i = 0; i < type->length; i++) {
vtn_ssa_value_add_to_call_params(b, value->elems[i],
type->members[i],
call, param_idx);
}
break;
default:
call->params[(*param_idx)++] = nir_src_for_ssa(value->def);
break;
}
}
static void
vtn_ssa_value_load_function_param(struct vtn_builder *b,
struct vtn_ssa_value *value,
struct vtn_type *type,
unsigned *param_idx)
{
switch (type->base_type) {
case vtn_base_type_array:
for (unsigned i = 0; i < type->length; i++) {
vtn_ssa_value_load_function_param(b, value->elems[i],
type->array_element, param_idx);
}
break;
case vtn_base_type_struct:
for (unsigned i = 0; i < type->length; i++) {
vtn_ssa_value_load_function_param(b, value->elems[i],
type->members[i], param_idx);
}
break;
default:
value->def = nir_load_param(&b->nb, (*param_idx)++);
break;
}
}
void
vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
@ -86,12 +215,8 @@ vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
call->params[param_idx++] =
nir_src_for_ssa(vtn_pointer_to_ssa(b, pointer));
} else {
/* This is a regular SSA value and we need a temporary */
nir_variable *tmp =
nir_local_variable_create(b->nb.impl, arg_type->type, "arg_tmp");
nir_deref_instr *tmp_deref = nir_build_deref_var(&b->nb, tmp);
vtn_local_store(b, vtn_ssa_value(b, arg_id), tmp_deref);
call->params[param_idx++] = nir_src_for_ssa(&tmp_deref->dest.ssa);
vtn_ssa_value_add_to_call_params(b, vtn_ssa_value(b, arg_id),
arg_type, call, &param_idx);
}
}
assert(param_idx == call->num_params);
@ -130,12 +255,9 @@ vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
nir_function *func =
nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
unsigned num_params = func_type->length;
for (unsigned i = 0; i < func_type->length; i++) {
/* Sampled images are actually two parameters */
if (func_type->params[i]->base_type == vtn_base_type_sampled_image)
num_params++;
}
unsigned num_params = 0;
for (unsigned i = 0; i < func_type->length; i++)
num_params += vtn_type_count_function_params(func_type->params[i]);
/* Add one parameter for the function return value */
if (func_type->return_type->base_type != vtn_base_type_void)
@ -152,31 +274,8 @@ vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
};
}
for (unsigned i = 0; i < func_type->length; i++) {
if (func_type->params[i]->base_type == vtn_base_type_sampled_image) {
/* Sampled images are two pointer parameters */
func->params[idx++] = (nir_parameter) {
.num_components = 1, .bit_size = 32,
};
func->params[idx++] = (nir_parameter) {
.num_components = 1, .bit_size = 32,
};
} else if (func_type->params[i]->base_type == vtn_base_type_pointer &&
func_type->params[i]->type != NULL) {
/* Pointers with as storage class get passed by-value */
assert(glsl_type_is_vector_or_scalar(func_type->params[i]->type));
func->params[idx++] = (nir_parameter) {
.num_components =
glsl_get_vector_elements(func_type->params[i]->type),
.bit_size = glsl_get_bit_size(func_type->params[i]->type),
};
} else {
/* Everything else is a regular pointer */
func->params[idx++] = (nir_parameter) {
.num_components = 1, .bit_size = 32,
};
}
}
for (unsigned i = 0; i < func_type->length; i++)
vtn_type_add_to_function_params(func_type->params[i], func, &idx);
assert(idx == num_params);
b->func->impl = nir_function_impl_create(func);
@ -235,10 +334,9 @@ vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
vtn_load_param_pointer(b, type, b->func_param_idx++);
} else {
/* We're a regular SSA value. */
nir_ssa_def *param_val = nir_load_param(&b->nb, b->func_param_idx++);
nir_deref_instr *deref =
nir_build_deref_cast(&b->nb, param_val, nir_var_local, type->type);
vtn_push_ssa(b, w[2], type, vtn_local_load(b, deref));
struct vtn_ssa_value *value = vtn_create_ssa_value(b, type->type);
vtn_ssa_value_load_function_param(b, value, type, &b->func_param_idx);
vtn_push_ssa(b, w[2], type, value);
}
break;
}