mirror of https://gitlab.freedesktop.org/mesa/mesa
nak: Rewrite union_find and use it in repair_ssa
The new UnionFind is safe code, is generic over the element type, and uses constant stack space. Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27454>
This commit is contained in:
parent
b5f4c54d0d
commit
b681677f7d
|
@ -119,6 +119,10 @@ _libnak_rs = static_library(
|
||||||
link_with: [_libbitview_rs, libnak_bindings_gen, _libnak_ir_proc_rs],
|
link_with: [_libbitview_rs, libnak_bindings_gen, _libnak_ir_proc_rs],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if with_tests
|
||||||
|
rust.test('nak', _libnak_rs, suite : ['nouveau'])
|
||||||
|
endif
|
||||||
|
|
||||||
nak_nir_algebraic_c = custom_target(
|
nak_nir_algebraic_c = custom_target(
|
||||||
'nak_nir_algebraic.c',
|
'nak_nir_algebraic.c',
|
||||||
input : 'nak_nir_algebraic.py',
|
input : 'nak_nir_algebraic.py',
|
||||||
|
|
|
@ -27,3 +27,4 @@ mod repair_ssa;
|
||||||
mod sph;
|
mod sph;
|
||||||
mod spill_values;
|
mod spill_values;
|
||||||
mod to_cssa;
|
mod to_cssa;
|
||||||
|
mod union_find;
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
use crate::bitset::BitSet;
|
use crate::bitset::BitSet;
|
||||||
use crate::ir::*;
|
use crate::ir::*;
|
||||||
|
use crate::union_find::UnionFind;
|
||||||
|
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
@ -233,7 +234,7 @@ impl Function {
|
||||||
|
|
||||||
// For loop back-edges, we inserted a phi whether we need one or not.
|
// For loop back-edges, we inserted a phi whether we need one or not.
|
||||||
// We want to eliminate any redundant phis.
|
// We want to eliminate any redundant phis.
|
||||||
let mut ssa_map = HashMap::new();
|
let mut ssa_map = UnionFind::new();
|
||||||
if cfg.has_loop() {
|
if cfg.has_loop() {
|
||||||
let mut to_do = true;
|
let mut to_do = true;
|
||||||
while to_do {
|
while to_do {
|
||||||
|
@ -245,9 +246,7 @@ impl Function {
|
||||||
for (_, p_ssa) in phi.srcs.iter_mut() {
|
for (_, p_ssa) in phi.srcs.iter_mut() {
|
||||||
// Apply the remap to the phi sources so that we
|
// Apply the remap to the phi sources so that we
|
||||||
// pick up any remaps from previous loop iterations.
|
// pick up any remaps from previous loop iterations.
|
||||||
while let Some(new_ssa) = ssa_map.get(p_ssa) {
|
*p_ssa = ssa_map.find(*p_ssa);
|
||||||
*p_ssa = *new_ssa;
|
|
||||||
}
|
|
||||||
|
|
||||||
if *p_ssa == phi.dst {
|
if *p_ssa == phi.dst {
|
||||||
continue;
|
continue;
|
||||||
|
@ -261,7 +260,10 @@ impl Function {
|
||||||
// All sources are identical or the phi destination so
|
// All sources are identical or the phi destination so
|
||||||
// we can delete this phi and add it to the remap
|
// we can delete this phi and add it to the remap
|
||||||
let ssa = ssa.expect("Circular SSA def");
|
let ssa = ssa.expect("Circular SSA def");
|
||||||
ssa_map.insert(phi.dst, ssa);
|
// union(a, b) ensures that the representative is the representative
|
||||||
|
// for a. This means union(ssa, phi.dst) ensures that phi.dst gets
|
||||||
|
// mapped to ssa, not the other way around.
|
||||||
|
ssa_map.union(ssa, phi.dst);
|
||||||
to_do = true;
|
to_do = true;
|
||||||
false
|
false
|
||||||
});
|
});
|
||||||
|
@ -300,9 +302,7 @@ impl Function {
|
||||||
if !ssa_map.is_empty() {
|
if !ssa_map.is_empty() {
|
||||||
for instr in &mut bb.instrs {
|
for instr in &mut bb.instrs {
|
||||||
instr.for_each_ssa_use_mut(|ssa| {
|
instr.for_each_ssa_use_mut(|ssa| {
|
||||||
while let Some(new_ssa) = ssa_map.get(ssa) {
|
*ssa = ssa_map.find(*ssa);
|
||||||
*ssa = *new_ssa;
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -312,11 +312,9 @@ impl Function {
|
||||||
if !s_phis.is_empty() {
|
if !s_phis.is_empty() {
|
||||||
let phi_src = get_or_insert_phi_srcs(bb);
|
let phi_src = get_or_insert_phi_srcs(bb);
|
||||||
for phi in s_phis.iter() {
|
for phi in s_phis.iter() {
|
||||||
let mut ssa = phi.srcs.get(&b_idx).unwrap();
|
let mut ssa = *phi.srcs.get(&b_idx).unwrap();
|
||||||
while let Some(new_ssa) = ssa_map.get(ssa) {
|
ssa = ssa_map.find(ssa);
|
||||||
ssa = new_ssa;
|
phi_src.srcs.push(phi.idx, ssa.into());
|
||||||
}
|
|
||||||
phi_src.srcs.push(phi.idx, (*ssa).into());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,291 @@
|
||||||
|
// Copyright © 2024 Mel Henning
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::hash::Hash;
|
||||||
|
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
struct Root<X: Copy> {
|
||||||
|
size: usize,
|
||||||
|
representative: X,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
enum Node<X: Copy> {
|
||||||
|
Child { parent_idx: usize },
|
||||||
|
Root(Root<X>),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Union-find structure
|
||||||
|
///
|
||||||
|
/// This implementation follows Tarjan and van Leeuwen - specifically the
|
||||||
|
/// "link by size" and "halving" variant.
|
||||||
|
///
|
||||||
|
/// Robert E. Tarjan and Jan van Leeuwen. 1984. Worst-case Analysis of Set
|
||||||
|
/// Union Algorithms. J. ACM 31, 2 (April 1984), 245–281.
|
||||||
|
/// https://doi.org/10.1145/62.2160
|
||||||
|
pub struct UnionFind<X: Copy + Hash + Eq> {
|
||||||
|
idx_map: HashMap<X, usize>,
|
||||||
|
nodes: Vec<Node<X>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<X: Copy + Hash + Eq> UnionFind<X> {
|
||||||
|
/// Create a new union-find structure
|
||||||
|
///
|
||||||
|
/// At initialization, each possible value is in its own set
|
||||||
|
pub fn new() -> Self {
|
||||||
|
UnionFind {
|
||||||
|
idx_map: HashMap::new(),
|
||||||
|
nodes: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_root(&mut self, mut idx: usize) -> (usize, Root<X>) {
|
||||||
|
loop {
|
||||||
|
match self.nodes[idx] {
|
||||||
|
Node::Child { parent_idx } => {
|
||||||
|
match self.nodes[parent_idx] {
|
||||||
|
Node::Child {
|
||||||
|
parent_idx: grandparent_idx,
|
||||||
|
} => {
|
||||||
|
// "Halving" in Tarjan and van Leeuwen
|
||||||
|
self.nodes[idx] = Node::Child {
|
||||||
|
parent_idx: grandparent_idx,
|
||||||
|
};
|
||||||
|
idx = grandparent_idx;
|
||||||
|
}
|
||||||
|
Node::Root(parent_root) => {
|
||||||
|
return (parent_idx, parent_root)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Node::Root(root) => return (idx, root),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the representative element for x
|
||||||
|
pub fn find(&mut self, x: X) -> X {
|
||||||
|
match self.idx_map.get(&x) {
|
||||||
|
Some(&idx) => {
|
||||||
|
let (_, Root { representative, .. }) = self.find_root(idx);
|
||||||
|
representative
|
||||||
|
}
|
||||||
|
None => x,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_or_create(&mut self, x: X) -> usize {
|
||||||
|
*self.idx_map.entry(x).or_insert_with(|| {
|
||||||
|
self.nodes.push(Node::Root(Root {
|
||||||
|
size: 1,
|
||||||
|
representative: x,
|
||||||
|
}));
|
||||||
|
self.nodes.len() - 1
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Union the sets containing a and b
|
||||||
|
///
|
||||||
|
/// The representative for a will become the representative of
|
||||||
|
/// the combined set
|
||||||
|
pub fn union(&mut self, a: X, b: X) {
|
||||||
|
if a == b {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let a_idx = self.map_or_create(a);
|
||||||
|
let b_idx = self.map_or_create(b);
|
||||||
|
let (a_root_idx, a_root) = self.find_root(a_idx);
|
||||||
|
let (b_root_idx, b_root) = self.find_root(b_idx);
|
||||||
|
|
||||||
|
if a_root_idx != b_root_idx {
|
||||||
|
// Keep the tree balanced
|
||||||
|
let (new_root_idx, new_child_idx) = if a_root.size >= b_root.size {
|
||||||
|
(a_root_idx, b_root_idx)
|
||||||
|
} else {
|
||||||
|
(b_root_idx, a_root_idx)
|
||||||
|
};
|
||||||
|
|
||||||
|
self.nodes[new_root_idx] = Node::Root(Root {
|
||||||
|
size: a_root.size + b_root.size,
|
||||||
|
representative: a_root.representative,
|
||||||
|
});
|
||||||
|
self.nodes[new_child_idx] = Node::Child {
|
||||||
|
parent_idx: new_root_idx,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return true if find() is the identity mapping
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.nodes.is_empty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::union_find::Node;
|
||||||
|
use crate::union_find::UnionFind;
|
||||||
|
use std::cmp::max;
|
||||||
|
use std::hash::Hash;
|
||||||
|
|
||||||
|
fn ceil_log2(x: usize) -> u32 {
|
||||||
|
assert!(x > 0);
|
||||||
|
usize::BITS - (x - 1).leading_zeros()
|
||||||
|
}
|
||||||
|
|
||||||
|
struct HeightInfo {
|
||||||
|
height: u32,
|
||||||
|
size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct HeightCalc<'a, X: Copy + Hash + Eq> {
|
||||||
|
uf: &'a UnionFind<X>,
|
||||||
|
downward_edges: Vec<Vec<usize>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, X: Copy + Hash + Eq> HeightCalc<'a, X> {
|
||||||
|
fn new(uf: &'a UnionFind<X>) -> Self {
|
||||||
|
let mut downward_edges: Vec<Vec<usize>> =
|
||||||
|
uf.nodes.iter().map(|_| Vec::new()).collect();
|
||||||
|
for (i, node) in uf.nodes.iter().enumerate() {
|
||||||
|
if let Node::Child { parent_idx } = node {
|
||||||
|
downward_edges[*parent_idx].push(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
HeightCalc { uf, downward_edges }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calc_info(&self, idx: usize) -> HeightInfo {
|
||||||
|
let mut result = HeightInfo { height: 0, size: 1 };
|
||||||
|
for child in &self.downward_edges[idx] {
|
||||||
|
let child_result = self.calc_info(*child);
|
||||||
|
result.height = max(result.height, child_result.height + 1);
|
||||||
|
result.size += child_result.size;
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_roots(&self) -> u32 {
|
||||||
|
let mut total_size = 0;
|
||||||
|
let mut max_height = 0;
|
||||||
|
for (i, node) in self.uf.nodes.iter().enumerate() {
|
||||||
|
if let Node::Root(root) = node {
|
||||||
|
let info = self.calc_info(i);
|
||||||
|
assert_eq!(root.size, info.size);
|
||||||
|
|
||||||
|
total_size += info.size;
|
||||||
|
max_height = max(max_height, info.height);
|
||||||
|
|
||||||
|
let max_expected_height = ceil_log2(root.size + 1) - 1;
|
||||||
|
if info.height > max_expected_height {
|
||||||
|
eprintln!(
|
||||||
|
"height {}\t max_expected_height {}\t size {}",
|
||||||
|
info.height, max_expected_height, info.size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
assert!(info.height <= max_expected_height);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert_eq!(total_size, self.uf.nodes.len());
|
||||||
|
assert_eq!(total_size, self.uf.idx_map.len());
|
||||||
|
return max_height;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check(uf: &'a UnionFind<X>) -> u32 {
|
||||||
|
HeightCalc::new(uf).check_roots()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_basic() {
|
||||||
|
let mut f = UnionFind::new();
|
||||||
|
assert_eq!(f.find(10), 10);
|
||||||
|
assert_eq!(f.find(12), 12);
|
||||||
|
|
||||||
|
f.union(10, 12);
|
||||||
|
f.union(11, 13);
|
||||||
|
|
||||||
|
HeightCalc::check(&f);
|
||||||
|
|
||||||
|
assert_eq!(f.find(13), 11);
|
||||||
|
assert_eq!(f.find(12), 10);
|
||||||
|
assert_eq!(f.find(11), 11);
|
||||||
|
assert_eq!(f.find(10), 10);
|
||||||
|
|
||||||
|
f.union(12, 13);
|
||||||
|
|
||||||
|
HeightCalc::check(&f);
|
||||||
|
|
||||||
|
assert_eq!(f.find(13), 10);
|
||||||
|
assert_eq!(f.find(12), 10);
|
||||||
|
assert_eq!(f.find(11), 10);
|
||||||
|
assert_eq!(f.find(10), 10);
|
||||||
|
|
||||||
|
assert_eq!(f.find(14), 14);
|
||||||
|
|
||||||
|
HeightCalc::check(&f);
|
||||||
|
|
||||||
|
// Union the set with itself
|
||||||
|
f.union(11, 10);
|
||||||
|
|
||||||
|
HeightCalc::check(&f);
|
||||||
|
|
||||||
|
assert_eq!(f.find(13), 10);
|
||||||
|
assert_eq!(f.find(12), 10);
|
||||||
|
assert_eq!(f.find(11), 10);
|
||||||
|
assert_eq!(f.find(10), 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chain_a_height() {
|
||||||
|
let mut f = UnionFind::new();
|
||||||
|
for i in 0..1000 {
|
||||||
|
f.union(i, i + 1);
|
||||||
|
HeightCalc::check(&f);
|
||||||
|
}
|
||||||
|
assert_eq!(f.find(1000), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chain_b_height() {
|
||||||
|
let mut f = UnionFind::new();
|
||||||
|
for i in 0..1000 {
|
||||||
|
f.union(i + 1, i);
|
||||||
|
HeightCalc::check(&f);
|
||||||
|
}
|
||||||
|
assert_eq!(f.find(0), 1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_binary_tree_height() {
|
||||||
|
let height = 8;
|
||||||
|
let count = 1 << height;
|
||||||
|
|
||||||
|
let mut f = UnionFind::new();
|
||||||
|
for current_height in 0..height {
|
||||||
|
let stride = 1 << current_height;
|
||||||
|
for i in (0..count).step_by(2 * stride) {
|
||||||
|
f.union(i, i + stride);
|
||||||
|
}
|
||||||
|
let actual_height = HeightCalc::check(&f);
|
||||||
|
|
||||||
|
// actual_height can vary based on tiebreaker condition
|
||||||
|
assert!(
|
||||||
|
actual_height == current_height
|
||||||
|
|| actual_height == current_height + 1
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check path halving
|
||||||
|
let actual_height_before = HeightCalc::check(&f);
|
||||||
|
for i in 0..count {
|
||||||
|
assert_eq!(f.find(i), 0);
|
||||||
|
}
|
||||||
|
let actual_height_after = HeightCalc::check(&f);
|
||||||
|
|
||||||
|
assert!(actual_height_after <= actual_height_before.div_ceil(2));
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue