From 64bc58a88ee3c0131a7d540b2ff61a0c707563e4 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Wed, 6 May 2015 12:37:10 -0700 Subject: [PATCH] nir/spirv: Handle control-flow with loops --- src/glsl/nir/spirv_to_nir.c | 168 ++++++++++++++++++++++++---- src/glsl/nir/spirv_to_nir_private.h | 4 +- 2 files changed, 151 insertions(+), 21 deletions(-) diff --git a/src/glsl/nir/spirv_to_nir.c b/src/glsl/nir/spirv_to_nir.c index 3bbf91453fd..a4f13603dac 100644 --- a/src/glsl/nir/spirv_to_nir.c +++ b/src/glsl/nir/spirv_to_nir.c @@ -1000,6 +1000,13 @@ vtn_handle_first_cfg_pass_instruction(struct vtn_builder *b, SpvOp opcode, b->block = NULL; break; + case SpvOpSelectionMerge: + case SpvOpLoopMerge: + assert(b->block && b->block->merge_op == SpvOpNop); + b->block->merge_op = opcode; + b->block->merge_block_id = w[1]; + break; + default: /* Continue on as per normal */ return true; @@ -1015,19 +1022,20 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, switch (opcode) { case SpvOpLabel: { struct vtn_block *block = vtn_value(b, w[1], vtn_value_type_block)->block; + assert(block->block == NULL); + struct exec_node *list_tail = exec_list_get_tail(b->nb.cf_node_list); nir_cf_node *tail_node = exec_node_data(nir_cf_node, list_tail, node); assert(tail_node->type == nir_cf_node_block); block->block = nir_cf_node_as_block(tail_node); + assert(exec_list_is_empty(&block->block->instr_list)); break; } case SpvOpLoopMerge: case SpvOpSelectionMerge: - assert(b->merge_block == NULL); - /* TODO: Selection Control */ - b->merge_block = vtn_value(b, w[1], vtn_value_type_block)->block; + /* This is handled by cfg pre-pass and walk_blocks */ break; case SpvOpUndef: @@ -1186,19 +1194,68 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, static void vtn_walk_blocks(struct vtn_builder *b, struct vtn_block *start, - struct vtn_block *end) + struct vtn_block *break_block, struct vtn_block *cont_block, + struct vtn_block *end_block) { struct vtn_block *block = start; - while (block != end) { + while (block != end_block) { + const uint32_t *w = block->branch; + SpvOp branch_op = w[0] & SpvOpCodeMask; + + if (block->block != NULL) { + /* We've already visited this block once before so this is a + * back-edge. Back-edges are only allowed to point to a loop + * merge. + */ + assert(block == cont_block); + return; + } + + b->block = block; vtn_foreach_instruction(b, block->label, block->branch, vtn_handle_body_instruction); - const uint32_t *w = block->branch; - SpvOp branch_op = w[0] & SpvOpCodeMask; switch (branch_op) { case SpvOpBranch: { - assert(vtn_value(b, w[1], vtn_value_type_block)->block == end); - return; + struct vtn_block *branch_block = + vtn_value(b, w[1], vtn_value_type_block)->block; + + if (branch_block == break_block) { + nir_jump_instr *jump = nir_jump_instr_create(b->shader, + nir_jump_break); + nir_builder_instr_insert(&b->nb, &jump->instr); + + return; + } else if (branch_block == cont_block) { + nir_jump_instr *jump = nir_jump_instr_create(b->shader, + nir_jump_continue); + nir_builder_instr_insert(&b->nb, &jump->instr); + + return; + } else if (branch_block == end_block) { + return; + } else if (branch_block->merge_op == SpvOpLoopMerge) { + /* This is the jump into a loop. */ + cont_block = branch_block; + break_block = vtn_value(b, branch_block->merge_block_id, + vtn_value_type_block)->block; + + nir_loop *loop = nir_loop_create(b->shader); + nir_cf_node_insert_end(b->nb.cf_node_list, &loop->cf_node); + + struct exec_list *old_list = b->nb.cf_node_list; + + nir_builder_insert_after_cf_list(&b->nb, &loop->body); + vtn_walk_blocks(b, branch_block, break_block, cont_block, NULL); + + nir_builder_insert_after_cf_list(&b->nb, old_list); + block = break_block; + continue; + } else { + /* TODO: Can this ever happen? */ + block = branch_block; + continue; + } } case SpvOpBranchConditional: { @@ -1207,28 +1264,99 @@ vtn_walk_blocks(struct vtn_builder *b, struct vtn_block *start, vtn_value(b, w[2], vtn_value_type_block)->block; struct vtn_block *else_block = vtn_value(b, w[3], vtn_value_type_block)->block; - struct vtn_block *merge_block = b->merge_block; nir_if *if_stmt = nir_if_create(b->shader); if_stmt->condition = nir_src_for_ssa(vtn_ssa_value(b, w[1])); nir_cf_node_insert_end(b->nb.cf_node_list, &if_stmt->cf_node); - struct exec_list *old_list = b->nb.cf_node_list; + if (then_block == break_block) { + nir_jump_instr *jump = nir_jump_instr_create(b->shader, + nir_jump_break); + nir_instr_insert_after_cf_list(&if_stmt->then_list, + &jump->instr); + block = else_block; + } else if (else_block == break_block) { + nir_jump_instr *jump = nir_jump_instr_create(b->shader, + nir_jump_break); + nir_instr_insert_after_cf_list(&if_stmt->else_list, + &jump->instr); + block = then_block; + } else if (then_block == cont_block) { + nir_jump_instr *jump = nir_jump_instr_create(b->shader, + nir_jump_continue); + nir_instr_insert_after_cf_list(&if_stmt->then_list, + &jump->instr); + block = else_block; + } else if (else_block == cont_block) { + nir_jump_instr *jump = nir_jump_instr_create(b->shader, + nir_jump_continue); + nir_instr_insert_after_cf_list(&if_stmt->else_list, + &jump->instr); + block = then_block; + } else { + /* Conventional if statement */ + assert(block->merge_op == SpvOpSelectionMerge); + struct vtn_block *merge_block = + vtn_value(b, block->merge_block_id, vtn_value_type_block)->block; - nir_builder_insert_after_cf_list(&b->nb, &if_stmt->then_list); - vtn_walk_blocks(b, then_block, merge_block); + struct exec_list *old_list = b->nb.cf_node_list; - nir_builder_insert_after_cf_list(&b->nb, &if_stmt->else_list); - vtn_walk_blocks(b, else_block, merge_block); + nir_builder_insert_after_cf_list(&b->nb, &if_stmt->then_list); + vtn_walk_blocks(b, then_block, break_block, cont_block, merge_block); - nir_builder_insert_after_cf_list(&b->nb, old_list); - block = merge_block; + nir_builder_insert_after_cf_list(&b->nb, &if_stmt->else_list); + vtn_walk_blocks(b, else_block, break_block, cont_block, merge_block); + + nir_builder_insert_after_cf_list(&b->nb, old_list); + block = merge_block; + continue; + } + + /* If we got here then we inserted a predicated break or continue + * above and we need to handle the other case. We already set + * `block` above to indicate what block to visit after the + * predicated break. + */ + + /* It's possible that the other branch is also a break/continue. + * If it is, we handle that here. + */ + if (block == break_block) { + nir_jump_instr *jump = nir_jump_instr_create(b->shader, + nir_jump_break); + nir_builder_instr_insert(&b->nb, &jump->instr); + + return; + } else if (block == cont_block) { + nir_jump_instr *jump = nir_jump_instr_create(b->shader, + nir_jump_continue); + nir_builder_instr_insert(&b->nb, &jump->instr); + + return; + } + + /* If we got here then there was a predicated break/continue but + * the other half of the if has stuff in it. `block` was already + * set above so there is nothing left for us to do. + */ continue; } + case SpvOpReturn: { + nir_jump_instr *jump = nir_jump_instr_create(b->shader, + nir_jump_return); + nir_builder_instr_insert(&b->nb, &jump->instr); + return; + } + + case SpvOpKill: { + nir_intrinsic_instr *discard = + nir_intrinsic_instr_create(b->shader, nir_intrinsic_discard); + nir_builder_instr_insert(&b->nb, &discard->instr); + return; + } + case SpvOpSwitch: - case SpvOpKill: - case SpvOpReturn: case SpvOpReturnValue: case SpvOpUnreachable: default: @@ -1275,7 +1403,7 @@ spirv_to_nir(const uint32_t *words, size_t word_count, b->impl = nir_function_impl_create(func->overload); nir_builder_init(&b->nb, b->impl); nir_builder_insert_after_cf_list(&b->nb, &b->impl->body); - vtn_walk_blocks(b, func->start_block, NULL); + vtn_walk_blocks(b, func->start_block, NULL, NULL, NULL); } ralloc_free(b); diff --git a/src/glsl/nir/spirv_to_nir_private.h b/src/glsl/nir/spirv_to_nir_private.h index fd80dd4e161..d2b364bdfeb 100644 --- a/src/glsl/nir/spirv_to_nir_private.h +++ b/src/glsl/nir/spirv_to_nir_private.h @@ -47,6 +47,9 @@ enum vtn_value_type { }; struct vtn_block { + /* Merge opcode if this block contains a merge; SpvOpNop otherwise. */ + SpvOp merge_op; + uint32_t merge_block_id; const uint32_t *label; const uint32_t *branch; nir_block *block; @@ -92,7 +95,6 @@ struct vtn_builder { nir_shader *shader; nir_function_impl *impl; struct vtn_block *block; - struct vtn_block *merge_block; unsigned value_id_bound; struct vtn_value *values;