From 0bc1b0fd2362df78f47d208e0fde4ff2183b5521 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Wed, 23 Dec 2015 23:45:47 -0800 Subject: [PATCH] nir/lower_return: Do it for real this time --- src/glsl/nir/nir_lower_returns.c | 205 ++++++++++++++++++++++++++++--- 1 file changed, 185 insertions(+), 20 deletions(-) diff --git a/src/glsl/nir/nir_lower_returns.c b/src/glsl/nir/nir_lower_returns.c index a15d0332ed4..f36fc9dd613 100644 --- a/src/glsl/nir/nir_lower_returns.c +++ b/src/glsl/nir/nir_lower_returns.c @@ -22,44 +22,209 @@ */ #include "nir.h" +#include "nir_builder.h" +#include "nir_control_flow.h" + +struct lower_returns_state { + nir_builder builder; + struct exec_list *parent_cf_list; + struct exec_list *cf_list; + nir_loop *loop; + nir_if *if_stmt; + nir_variable *return_flag; +}; + +static bool lower_returns_in_cf_list(struct exec_list *cf_list, + struct lower_returns_state *state); static bool -assert_no_returns_block(nir_block *block, void *state) +lower_returns_in_loop(nir_loop *loop, struct lower_returns_state *state) { - (void)state; + nir_loop *parent = state->loop; + state->loop = loop; + bool progress = lower_returns_in_cf_list(&loop->body, state); + state->loop = parent; - nir_foreach_instr(block, instr) { - if (instr->type != nir_instr_type_jump) - continue; + /* Nothing interesting */ + if (!progress) + return false; - nir_jump_instr *jump = nir_instr_as_jump(instr); - assert(jump->type != nir_jump_return); + /* In this case, there was a return somewhere inside of the loop. That + * return would have been turned into a write to the return_flag + * variable and a break. We need to insert a predicated return right + * after the loop ends. + */ + + assert(state->return_flag); + + nir_intrinsic_instr *load = + nir_intrinsic_instr_create(state->builder.shader, nir_intrinsic_load_var); + load->num_components = 1; + load->variables[0] = nir_deref_var_create(load, state->return_flag); + nir_ssa_dest_init(&load->instr, &load->dest, 1, "return"); + nir_instr_insert(nir_after_cf_node(&loop->cf_node), &load->instr); + + nir_if *if_stmt = nir_if_create(state->builder.shader); + if_stmt->condition = nir_src_for_ssa(&load->dest.ssa); + nir_cf_node_insert(nir_after_instr(&load->instr), &if_stmt->cf_node); + + nir_jump_instr *ret = + nir_jump_instr_create(state->builder.shader, nir_jump_return); + nir_instr_insert(nir_before_cf_list(&if_stmt->then_list), &ret->instr); + + return true; +} + +static bool +lower_returns_in_if(nir_if *if_stmt, struct lower_returns_state *state) +{ + bool progress; + + nir_if *parent = state->if_stmt; + state->if_stmt = if_stmt; + progress = lower_returns_in_cf_list(&if_stmt->then_list, state); + progress = lower_returns_in_cf_list(&if_stmt->else_list, state) || progress; + state->if_stmt = parent; + + return progress; +} + +static bool +lower_returns_in_block(nir_block *block, struct lower_returns_state *state) +{ + if (block->predecessors->entries == 0 && + block != nir_start_block(state->builder.impl)) { + /* This block is unreachable. Delete it and everything after it. */ + nir_cf_list list; + nir_cf_extract(&list, nir_before_cf_node(&block->cf_node), + nir_after_cf_list(state->cf_list)); + + if (exec_list_is_empty(&list.list)) { + /* There's nothing here, which also means there's nothing in this + * block so we have nothing to do. + */ + return false; + } else { + nir_cf_delete(&list); + return true; + } + } + + nir_instr *last_instr = nir_block_last_instr(block); + if (last_instr == NULL) + return false; + + if (last_instr->type != nir_instr_type_jump) + return false; + + nir_jump_instr *jump = nir_instr_as_jump(last_instr); + if (jump->type != nir_jump_return) + return false; + + if (state->loop) { + /* We're in a loop. Just set the return flag to true and break. + * lower_returns_in_loop will do the rest. + */ + nir_builder *b = &state->builder; + b->cursor = nir_before_instr(&jump->instr); + + if (state->return_flag == NULL) { + state->return_flag = + nir_local_variable_create(b->impl, glsl_bool_type(), "return"); + + /* Set a default value of false */ + state->return_flag->constant_initializer = + rzalloc(state->return_flag, nir_constant); + } + + nir_store_var(b, state->return_flag, nir_imm_int(b, NIR_TRUE)); + jump->type = nir_jump_return; + } else if (state->if_stmt) { + /* If we're not in a loop but in an if, just move the rest of the CF + * list into the the other case of the if. + */ + nir_cf_list list; + nir_cf_extract(&list, nir_after_cf_node(&state->if_stmt->cf_node), + nir_after_cf_list(state->parent_cf_list)); + + nir_instr_remove(&jump->instr); + + if (state->cf_list == &state->if_stmt->then_list) { + nir_cf_reinsert(&list, + nir_after_cf_list(&state->if_stmt->else_list)); + } else if (state->cf_list == &state->if_stmt->else_list) { + nir_cf_reinsert(&list, + nir_after_cf_list(&state->if_stmt->then_list)); + } else { + unreachable("Invalid CF list"); + } + } else { + nir_instr_remove(&jump->instr); + + /* No if, no nothing. Just delete the return and whatever follows. */ + nir_cf_list list; + nir_cf_extract(&list, nir_after_cf_node(&block->cf_node), + nir_after_cf_list(state->parent_cf_list)); + nir_cf_delete(&list); } return true; } -bool -nir_lower_returns_impl(nir_function_impl *impl) +static bool +lower_returns_in_cf_list(struct exec_list *cf_list, + struct lower_returns_state *state) { bool progress = false; - assert(impl->end_block->predecessors->entries == 1); + struct exec_list *prev_parent_list = state->parent_cf_list; + state->parent_cf_list = state->cf_list; + state->cf_list = cf_list; - struct set_entry *entry = - _mesa_set_next_entry(impl->end_block->predecessors, NULL); + foreach_list_typed_reverse_safe(nir_cf_node, node, node, cf_list) { + switch (node->type) { + case nir_cf_node_block: + if (lower_returns_in_block(nir_cf_node_as_block(node), state)) + progress = true; + break; - nir_block *last_block = (nir_block *)entry->key; + case nir_cf_node_if: + if (lower_returns_in_if(nir_cf_node_as_if(node), state)) + progress = true; + break; - nir_instr *last_instr = nir_block_last_instr(last_block); - if (last_instr && last_instr->type == nir_instr_type_jump) { - nir_jump_instr *jump = nir_instr_as_jump(last_instr); - assert(jump->type == nir_jump_return); - nir_instr_remove(&jump->instr); - progress = true; + case nir_cf_node_loop: + if (lower_returns_in_loop(nir_cf_node_as_loop(node), state)) + progress = true; + break; + + default: + unreachable("Invalid inner CF node type"); + } } - nir_foreach_block(impl, assert_no_returns_block, NULL); + state->cf_list = state->parent_cf_list; + state->parent_cf_list = prev_parent_list; + + return progress; +} + +bool +nir_lower_returns_impl(nir_function_impl *impl) +{ + struct lower_returns_state state; + + state.parent_cf_list = NULL; + state.cf_list = &impl->body; + state.loop = NULL; + state.if_stmt = NULL; + state.return_flag = NULL; + nir_builder_init(&state.builder, impl); + + bool progress = lower_returns_in_cf_list(&impl->body, &state); + + if (progress) + nir_metadata_preserve(impl, nir_metadata_none); return progress; }