mesa/src/compiler/nir/nir_builder.c

451 lines
14 KiB
C

/*
* Copyright © 2014-2015 Broadcom
* Copyright © 2021 Google
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice (including the next
* paragraph) shall be included in all copies or substantial portions of the
* Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*/
#include "nir_builder.h"
void
nir_builder_init(nir_builder *build, nir_function_impl *impl)
{
memset(build, 0, sizeof(*build));
build->exact = false;
build->impl = impl;
build->shader = impl->function->shader;
}
nir_builder MUST_CHECK PRINTFLIKE(3, 4)
nir_builder_init_simple_shader(gl_shader_stage stage,
const nir_shader_compiler_options *options,
const char *name, ...)
{
nir_builder b;
memset(&b, 0, sizeof(b));
b.shader = nir_shader_create(NULL, stage, options, NULL);
if (name) {
va_list args;
va_start(args, name);
b.shader->info.name = ralloc_vasprintf(b.shader, name, args);
va_end(args);
}
nir_function *func = nir_function_create(b.shader, "main");
func->is_entrypoint = true;
b.exact = false;
b.impl = nir_function_impl_create(func);
b.cursor = nir_after_cf_list(&b.impl->body);
/* Simple shaders are typically internal, e.g. blit shaders */
b.shader->info.internal = true;
return b;
}
nir_ssa_def *
nir_builder_alu_instr_finish_and_insert(nir_builder *build, nir_alu_instr *instr)
{
const nir_op_info *op_info = &nir_op_infos[instr->op];
instr->exact = build->exact;
/* Guess the number of components the destination temporary should have
* based on our input sizes, if it's not fixed for the op.
*/
unsigned num_components = op_info->output_size;
if (num_components == 0) {
for (unsigned i = 0; i < op_info->num_inputs; i++) {
if (op_info->input_sizes[i] == 0)
num_components = MAX2(num_components,
instr->src[i].src.ssa->num_components);
}
}
assert(num_components != 0);
/* Figure out the bitwidth based on the source bitwidth if the instruction
* is variable-width.
*/
unsigned bit_size = nir_alu_type_get_type_size(op_info->output_type);
if (bit_size == 0) {
for (unsigned i = 0; i < op_info->num_inputs; i++) {
unsigned src_bit_size = instr->src[i].src.ssa->bit_size;
if (nir_alu_type_get_type_size(op_info->input_types[i]) == 0) {
if (bit_size)
assert(src_bit_size == bit_size);
else
bit_size = src_bit_size;
} else {
assert(src_bit_size ==
nir_alu_type_get_type_size(op_info->input_types[i]));
}
}
}
/* When in doubt, assume 32. */
if (bit_size == 0)
bit_size = 32;
/* Make sure we don't swizzle from outside of our source vector (like if a
* scalar value was passed into a multiply with a vector).
*/
for (unsigned i = 0; i < op_info->num_inputs; i++) {
for (unsigned j = instr->src[i].src.ssa->num_components;
j < NIR_MAX_VEC_COMPONENTS; j++) {
instr->src[i].swizzle[j] = instr->src[i].src.ssa->num_components - 1;
}
}
nir_ssa_dest_init(&instr->instr, &instr->dest.dest, num_components,
bit_size, NULL);
instr->dest.write_mask = (1 << num_components) - 1;
nir_builder_instr_insert(build, &instr->instr);
return &instr->dest.dest.ssa;
}
nir_ssa_def *
nir_build_alu(nir_builder *build, nir_op op, nir_ssa_def *src0,
nir_ssa_def *src1, nir_ssa_def *src2, nir_ssa_def *src3)
{
nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
if (!instr)
return NULL;
instr->src[0].src = nir_src_for_ssa(src0);
if (src1)
instr->src[1].src = nir_src_for_ssa(src1);
if (src2)
instr->src[2].src = nir_src_for_ssa(src2);
if (src3)
instr->src[3].src = nir_src_for_ssa(src3);
return nir_builder_alu_instr_finish_and_insert(build, instr);
}
nir_ssa_def *
nir_build_alu1(nir_builder *build, nir_op op, nir_ssa_def *src0)
{
nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
if (!instr)
return NULL;
instr->src[0].src = nir_src_for_ssa(src0);
return nir_builder_alu_instr_finish_and_insert(build, instr);
}
nir_ssa_def *
nir_build_alu2(nir_builder *build, nir_op op, nir_ssa_def *src0,
nir_ssa_def *src1)
{
nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
if (!instr)
return NULL;
instr->src[0].src = nir_src_for_ssa(src0);
instr->src[1].src = nir_src_for_ssa(src1);
return nir_builder_alu_instr_finish_and_insert(build, instr);
}
nir_ssa_def *
nir_build_alu3(nir_builder *build, nir_op op, nir_ssa_def *src0,
nir_ssa_def *src1, nir_ssa_def *src2)
{
nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
if (!instr)
return NULL;
instr->src[0].src = nir_src_for_ssa(src0);
instr->src[1].src = nir_src_for_ssa(src1);
instr->src[2].src = nir_src_for_ssa(src2);
return nir_builder_alu_instr_finish_and_insert(build, instr);
}
nir_ssa_def *
nir_build_alu4(nir_builder *build, nir_op op, nir_ssa_def *src0,
nir_ssa_def *src1, nir_ssa_def *src2, nir_ssa_def *src3)
{
nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
if (!instr)
return NULL;
instr->src[0].src = nir_src_for_ssa(src0);
instr->src[1].src = nir_src_for_ssa(src1);
instr->src[2].src = nir_src_for_ssa(src2);
instr->src[3].src = nir_src_for_ssa(src3);
return nir_builder_alu_instr_finish_and_insert(build, instr);
}
/* for the couple special cases with more than 4 src args: */
nir_ssa_def *
nir_build_alu_src_arr(nir_builder *build, nir_op op, nir_ssa_def **srcs)
{
const nir_op_info *op_info = &nir_op_infos[op];
nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
if (!instr)
return NULL;
for (unsigned i = 0; i < op_info->num_inputs; i++)
instr->src[i].src = nir_src_for_ssa(srcs[i]);
return nir_builder_alu_instr_finish_and_insert(build, instr);
}
nir_ssa_def *
nir_vec_scalars(nir_builder *build, nir_ssa_scalar *comp, unsigned num_components)
{
nir_op op = nir_op_vec(num_components);
nir_alu_instr *instr = nir_alu_instr_create(build->shader, op);
if (!instr)
return NULL;
for (unsigned i = 0; i < num_components; i++) {
instr->src[i].src = nir_src_for_ssa(comp[i].def);
instr->src[i].swizzle[0] = comp[i].comp;
}
instr->exact = build->exact;
/* Note: not reusing nir_builder_alu_instr_finish_and_insert() because it
* can't re-guess the num_components when num_components == 1 (nir_op_mov).
*/
nir_ssa_dest_init(&instr->instr, &instr->dest.dest, num_components,
comp[0].def->bit_size, NULL);
instr->dest.write_mask = (1 << num_components) - 1;
nir_builder_instr_insert(build, &instr->instr);
return &instr->dest.dest.ssa;
}
/**
* Turns a nir_src into a nir_ssa_def * so it can be passed to
* nir_build_alu()-based builder calls.
*
* See nir_ssa_for_alu_src() for alu instructions.
*/
nir_ssa_def *
nir_ssa_for_src(nir_builder *build, nir_src src, int num_components)
{
if (src.is_ssa && src.ssa->num_components == num_components)
return src.ssa;
assert((unsigned)num_components <= nir_src_num_components(src));
nir_alu_src alu = { NIR_SRC_INIT };
alu.src = src;
for (int j = 0; j < NIR_MAX_VEC_COMPONENTS; j++)
alu.swizzle[j] = j;
return nir_mov_alu(build, alu, num_components);
}
/**
* Similar to nir_ssa_for_src(), but for alu srcs, respecting the
* nir_alu_src's swizzle.
*/
nir_ssa_def *
nir_ssa_for_alu_src(nir_builder *build, nir_alu_instr *instr, unsigned srcn)
{
if (nir_alu_src_is_trivial_ssa(instr, srcn))
return instr->src[srcn].src.ssa;
nir_alu_src *src = &instr->src[srcn];
unsigned num_components = nir_ssa_alu_instr_src_components(instr, srcn);
return nir_mov_alu(build, *src, num_components);
}
/* Generic builder for system values. */
nir_ssa_def *
nir_load_system_value(nir_builder *build, nir_intrinsic_op op, int index,
unsigned num_components, unsigned bit_size)
{
nir_intrinsic_instr *load = nir_intrinsic_instr_create(build->shader, op);
if (nir_intrinsic_infos[op].dest_components > 0)
assert(num_components == nir_intrinsic_infos[op].dest_components);
else
load->num_components = num_components;
load->const_index[0] = index;
nir_ssa_dest_init(&load->instr, &load->dest,
num_components, bit_size, NULL);
nir_builder_instr_insert(build, &load->instr);
return &load->dest.ssa;
}
void
nir_builder_instr_insert(nir_builder *build, nir_instr *instr)
{
nir_instr_insert(build->cursor, instr);
if (build->update_divergence)
nir_update_instr_divergence(build->shader, instr);
/* Move the cursor forward. */
build->cursor = nir_after_instr(instr);
}
void
nir_builder_cf_insert(nir_builder *build, nir_cf_node *cf)
{
nir_cf_node_insert(build->cursor, cf);
}
bool
nir_builder_is_inside_cf(nir_builder *build, nir_cf_node *cf_node)
{
nir_block *block = nir_cursor_current_block(build->cursor);
for (nir_cf_node *n = &block->cf_node; n; n = n->parent) {
if (n == cf_node)
return true;
}
return false;
}
nir_if *
nir_push_if_src(nir_builder *build, nir_src condition)
{
nir_if *nif = nir_if_create(build->shader);
nif->condition = condition;
nir_builder_cf_insert(build, &nif->cf_node);
build->cursor = nir_before_cf_list(&nif->then_list);
return nif;
}
nir_if *
nir_push_if(nir_builder *build, nir_ssa_def *condition)
{
return nir_push_if_src(build, nir_src_for_ssa(condition));
}
nir_if *
nir_push_else(nir_builder *build, nir_if *nif)
{
if (nif) {
assert(nir_builder_is_inside_cf(build, &nif->cf_node));
} else {
nir_block *block = nir_cursor_current_block(build->cursor);
nif = nir_cf_node_as_if(block->cf_node.parent);
}
build->cursor = nir_before_cf_list(&nif->else_list);
return nif;
}
void
nir_pop_if(nir_builder *build, nir_if *nif)
{
if (nif) {
assert(nir_builder_is_inside_cf(build, &nif->cf_node));
} else {
nir_block *block = nir_cursor_current_block(build->cursor);
nif = nir_cf_node_as_if(block->cf_node.parent);
}
build->cursor = nir_after_cf_node(&nif->cf_node);
}
nir_ssa_def *
nir_if_phi(nir_builder *build, nir_ssa_def *then_def, nir_ssa_def *else_def)
{
nir_block *block = nir_cursor_current_block(build->cursor);
nir_if *nif = nir_cf_node_as_if(nir_cf_node_prev(&block->cf_node));
nir_phi_instr *phi = nir_phi_instr_create(build->shader);
nir_phi_instr_add_src(phi, nir_if_last_then_block(nif), nir_src_for_ssa(then_def));
nir_phi_instr_add_src(phi, nir_if_last_else_block(nif), nir_src_for_ssa(else_def));
assert(then_def->num_components == else_def->num_components);
assert(then_def->bit_size == else_def->bit_size);
nir_ssa_dest_init(&phi->instr, &phi->dest,
then_def->num_components, then_def->bit_size, NULL);
nir_builder_instr_insert(build, &phi->instr);
return &phi->dest.ssa;
}
nir_loop *
nir_push_loop(nir_builder *build)
{
nir_loop *loop = nir_loop_create(build->shader);
nir_builder_cf_insert(build, &loop->cf_node);
build->cursor = nir_before_cf_list(&loop->body);
return loop;
}
void
nir_pop_loop(nir_builder *build, nir_loop *loop)
{
if (loop) {
assert(nir_builder_is_inside_cf(build, &loop->cf_node));
} else {
nir_block *block = nir_cursor_current_block(build->cursor);
loop = nir_cf_node_as_loop(block->cf_node.parent);
}
build->cursor = nir_after_cf_node(&loop->cf_node);
}
nir_ssa_def *
nir_compare_func(nir_builder *b, enum compare_func func,
nir_ssa_def *src0, nir_ssa_def *src1)
{
switch (func) {
case COMPARE_FUNC_NEVER:
return nir_imm_int(b, 0);
case COMPARE_FUNC_ALWAYS:
return nir_imm_int(b, ~0);
case COMPARE_FUNC_EQUAL:
return nir_feq(b, src0, src1);
case COMPARE_FUNC_NOTEQUAL:
return nir_fneu(b, src0, src1);
case COMPARE_FUNC_GREATER:
return nir_flt(b, src1, src0);
case COMPARE_FUNC_GEQUAL:
return nir_fge(b, src0, src1);
case COMPARE_FUNC_LESS:
return nir_flt(b, src0, src1);
case COMPARE_FUNC_LEQUAL:
return nir_fge(b, src1, src0);
}
unreachable("bad compare func");
}
nir_ssa_def *
nir_type_convert(nir_builder *b,
nir_ssa_def *src,
nir_alu_type src_type,
nir_alu_type dest_type)
{
assert(nir_alu_type_get_type_size(src_type) == 0 ||
nir_alu_type_get_type_size(src_type) == src->bit_size);
src_type = (nir_alu_type) (src_type | src->bit_size);
nir_op opcode =
nir_type_conversion_op(src_type, dest_type, nir_rounding_mode_undef);
return nir_build_alu(b, opcode, src, NULL, NULL, NULL);
}