diff --git a/src/glsl/nir/nir.c b/src/glsl/nir/nir.c index 3b74e424bbe..4100f9770cc 100644 --- a/src/glsl/nir/nir.c +++ b/src/glsl/nir/nir.c @@ -1605,6 +1605,60 @@ nir_ssa_def_init(nir_function_impl *impl, nir_instr *instr, nir_ssa_def *def, def->num_components = num_components; } +struct ssa_def_rewrite_state { + void *mem_ctx; + nir_ssa_def *old; + nir_src new_src; +}; + +static bool +ssa_def_rewrite_uses_src(nir_src *src, void *void_state) +{ + struct ssa_def_rewrite_state *state = void_state; + + if (src->is_ssa && src->ssa == state->old) + *src = nir_src_copy(state->new_src, state->mem_ctx); + + return true; +} + +void +nir_ssa_def_rewrite_uses(nir_ssa_def *def, nir_src new_src, void *mem_ctx) +{ + struct ssa_def_rewrite_state state; + state.mem_ctx = mem_ctx; + state.old = def; + state.new_src = new_src; + + assert(!new_src.is_ssa || def != new_src.ssa); + + struct set *new_uses, *new_if_uses; + if (new_src.is_ssa) { + new_uses = new_src.ssa->uses; + new_if_uses = new_src.ssa->if_uses; + } else { + new_uses = new_src.reg.reg->uses; + new_if_uses = new_src.reg.reg->if_uses; + } + + struct set_entry *entry; + set_foreach(def->uses, entry) { + nir_instr *instr = (nir_instr *)entry->key; + + _mesa_set_remove(def->uses, entry); + nir_foreach_src(instr, ssa_def_rewrite_uses_src, &state); + _mesa_set_add(new_uses, _mesa_hash_pointer(instr), instr); + } + + set_foreach(def->if_uses, entry) { + nir_if *if_use = (nir_if *)entry->key; + + _mesa_set_remove(def->if_uses, entry); + if_use->condition = nir_src_copy(new_src, mem_ctx); + _mesa_set_add(new_if_uses, _mesa_hash_pointer(if_use), if_use); + } +} + static bool foreach_cf_node(nir_cf_node *node, nir_foreach_block_cb cb, bool reverse, void *state); diff --git a/src/glsl/nir/nir.h b/src/glsl/nir/nir.h index 558ec914044..5933b5dc447 100644 --- a/src/glsl/nir/nir.h +++ b/src/glsl/nir/nir.h @@ -1281,6 +1281,7 @@ bool nir_foreach_src(nir_instr *instr, nir_foreach_src_cb cb, void *state); void nir_ssa_def_init(nir_function_impl *impl, nir_instr *instr, nir_ssa_def *def, unsigned num_components, const char *name); +void nir_ssa_def_rewrite_uses(nir_ssa_def *def, nir_src new_src, void *mem_ctx); /* visits basic blocks in source-code order */ typedef bool (*nir_foreach_block_cb)(nir_block *block, void *state);