mirror of https://gitlab.freedesktop.org/mesa/mesa
nak: Rewrite union_find and use it in repair_ssa
The new UnionFind is safe code, is generic over the element type, and uses constant stack space. Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27454>
This commit is contained in:
parent
b5f4c54d0d
commit
b681677f7d
|
@ -119,6 +119,10 @@ _libnak_rs = static_library(
|
|||
link_with: [_libbitview_rs, libnak_bindings_gen, _libnak_ir_proc_rs],
|
||||
)
|
||||
|
||||
if with_tests
|
||||
rust.test('nak', _libnak_rs, suite : ['nouveau'])
|
||||
endif
|
||||
|
||||
nak_nir_algebraic_c = custom_target(
|
||||
'nak_nir_algebraic.c',
|
||||
input : 'nak_nir_algebraic.py',
|
||||
|
|
|
@ -27,3 +27,4 @@ mod repair_ssa;
|
|||
mod sph;
|
||||
mod spill_values;
|
||||
mod to_cssa;
|
||||
mod union_find;
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
use crate::bitset::BitSet;
|
||||
use crate::ir::*;
|
||||
use crate::union_find::UnionFind;
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
|
@ -233,7 +234,7 @@ impl Function {
|
|||
|
||||
// For loop back-edges, we inserted a phi whether we need one or not.
|
||||
// We want to eliminate any redundant phis.
|
||||
let mut ssa_map = HashMap::new();
|
||||
let mut ssa_map = UnionFind::new();
|
||||
if cfg.has_loop() {
|
||||
let mut to_do = true;
|
||||
while to_do {
|
||||
|
@ -245,9 +246,7 @@ impl Function {
|
|||
for (_, p_ssa) in phi.srcs.iter_mut() {
|
||||
// Apply the remap to the phi sources so that we
|
||||
// pick up any remaps from previous loop iterations.
|
||||
while let Some(new_ssa) = ssa_map.get(p_ssa) {
|
||||
*p_ssa = *new_ssa;
|
||||
}
|
||||
*p_ssa = ssa_map.find(*p_ssa);
|
||||
|
||||
if *p_ssa == phi.dst {
|
||||
continue;
|
||||
|
@ -261,7 +260,10 @@ impl Function {
|
|||
// All sources are identical or the phi destination so
|
||||
// we can delete this phi and add it to the remap
|
||||
let ssa = ssa.expect("Circular SSA def");
|
||||
ssa_map.insert(phi.dst, ssa);
|
||||
// union(a, b) ensures that the representative is the representative
|
||||
// for a. This means union(ssa, phi.dst) ensures that phi.dst gets
|
||||
// mapped to ssa, not the other way around.
|
||||
ssa_map.union(ssa, phi.dst);
|
||||
to_do = true;
|
||||
false
|
||||
});
|
||||
|
@ -300,9 +302,7 @@ impl Function {
|
|||
if !ssa_map.is_empty() {
|
||||
for instr in &mut bb.instrs {
|
||||
instr.for_each_ssa_use_mut(|ssa| {
|
||||
while let Some(new_ssa) = ssa_map.get(ssa) {
|
||||
*ssa = *new_ssa;
|
||||
}
|
||||
*ssa = ssa_map.find(*ssa);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -312,11 +312,9 @@ impl Function {
|
|||
if !s_phis.is_empty() {
|
||||
let phi_src = get_or_insert_phi_srcs(bb);
|
||||
for phi in s_phis.iter() {
|
||||
let mut ssa = phi.srcs.get(&b_idx).unwrap();
|
||||
while let Some(new_ssa) = ssa_map.get(ssa) {
|
||||
ssa = new_ssa;
|
||||
}
|
||||
phi_src.srcs.push(phi.idx, (*ssa).into());
|
||||
let mut ssa = *phi.srcs.get(&b_idx).unwrap();
|
||||
ssa = ssa_map.find(ssa);
|
||||
phi_src.srcs.push(phi.idx, ssa.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,291 @@
|
|||
// Copyright © 2024 Mel Henning
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct Root<X: Copy> {
|
||||
size: usize,
|
||||
representative: X,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
enum Node<X: Copy> {
|
||||
Child { parent_idx: usize },
|
||||
Root(Root<X>),
|
||||
}
|
||||
|
||||
/// Union-find structure
|
||||
///
|
||||
/// This implementation follows Tarjan and van Leeuwen - specifically the
|
||||
/// "link by size" and "halving" variant.
|
||||
///
|
||||
/// Robert E. Tarjan and Jan van Leeuwen. 1984. Worst-case Analysis of Set
|
||||
/// Union Algorithms. J. ACM 31, 2 (April 1984), 245–281.
|
||||
/// https://doi.org/10.1145/62.2160
|
||||
pub struct UnionFind<X: Copy + Hash + Eq> {
|
||||
idx_map: HashMap<X, usize>,
|
||||
nodes: Vec<Node<X>>,
|
||||
}
|
||||
|
||||
impl<X: Copy + Hash + Eq> UnionFind<X> {
|
||||
/// Create a new union-find structure
|
||||
///
|
||||
/// At initialization, each possible value is in its own set
|
||||
pub fn new() -> Self {
|
||||
UnionFind {
|
||||
idx_map: HashMap::new(),
|
||||
nodes: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn find_root(&mut self, mut idx: usize) -> (usize, Root<X>) {
|
||||
loop {
|
||||
match self.nodes[idx] {
|
||||
Node::Child { parent_idx } => {
|
||||
match self.nodes[parent_idx] {
|
||||
Node::Child {
|
||||
parent_idx: grandparent_idx,
|
||||
} => {
|
||||
// "Halving" in Tarjan and van Leeuwen
|
||||
self.nodes[idx] = Node::Child {
|
||||
parent_idx: grandparent_idx,
|
||||
};
|
||||
idx = grandparent_idx;
|
||||
}
|
||||
Node::Root(parent_root) => {
|
||||
return (parent_idx, parent_root)
|
||||
}
|
||||
}
|
||||
}
|
||||
Node::Root(root) => return (idx, root),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the representative element for x
|
||||
pub fn find(&mut self, x: X) -> X {
|
||||
match self.idx_map.get(&x) {
|
||||
Some(&idx) => {
|
||||
let (_, Root { representative, .. }) = self.find_root(idx);
|
||||
representative
|
||||
}
|
||||
None => x,
|
||||
}
|
||||
}
|
||||
|
||||
fn map_or_create(&mut self, x: X) -> usize {
|
||||
*self.idx_map.entry(x).or_insert_with(|| {
|
||||
self.nodes.push(Node::Root(Root {
|
||||
size: 1,
|
||||
representative: x,
|
||||
}));
|
||||
self.nodes.len() - 1
|
||||
})
|
||||
}
|
||||
|
||||
/// Union the sets containing a and b
|
||||
///
|
||||
/// The representative for a will become the representative of
|
||||
/// the combined set
|
||||
pub fn union(&mut self, a: X, b: X) {
|
||||
if a == b {
|
||||
return;
|
||||
}
|
||||
|
||||
let a_idx = self.map_or_create(a);
|
||||
let b_idx = self.map_or_create(b);
|
||||
let (a_root_idx, a_root) = self.find_root(a_idx);
|
||||
let (b_root_idx, b_root) = self.find_root(b_idx);
|
||||
|
||||
if a_root_idx != b_root_idx {
|
||||
// Keep the tree balanced
|
||||
let (new_root_idx, new_child_idx) = if a_root.size >= b_root.size {
|
||||
(a_root_idx, b_root_idx)
|
||||
} else {
|
||||
(b_root_idx, a_root_idx)
|
||||
};
|
||||
|
||||
self.nodes[new_root_idx] = Node::Root(Root {
|
||||
size: a_root.size + b_root.size,
|
||||
representative: a_root.representative,
|
||||
});
|
||||
self.nodes[new_child_idx] = Node::Child {
|
||||
parent_idx: new_root_idx,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Return true if find() is the identity mapping
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.nodes.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::union_find::Node;
|
||||
use crate::union_find::UnionFind;
|
||||
use std::cmp::max;
|
||||
use std::hash::Hash;
|
||||
|
||||
fn ceil_log2(x: usize) -> u32 {
|
||||
assert!(x > 0);
|
||||
usize::BITS - (x - 1).leading_zeros()
|
||||
}
|
||||
|
||||
struct HeightInfo {
|
||||
height: u32,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
pub struct HeightCalc<'a, X: Copy + Hash + Eq> {
|
||||
uf: &'a UnionFind<X>,
|
||||
downward_edges: Vec<Vec<usize>>,
|
||||
}
|
||||
|
||||
impl<'a, X: Copy + Hash + Eq> HeightCalc<'a, X> {
|
||||
fn new(uf: &'a UnionFind<X>) -> Self {
|
||||
let mut downward_edges: Vec<Vec<usize>> =
|
||||
uf.nodes.iter().map(|_| Vec::new()).collect();
|
||||
for (i, node) in uf.nodes.iter().enumerate() {
|
||||
if let Node::Child { parent_idx } = node {
|
||||
downward_edges[*parent_idx].push(i);
|
||||
}
|
||||
}
|
||||
|
||||
HeightCalc { uf, downward_edges }
|
||||
}
|
||||
|
||||
fn calc_info(&self, idx: usize) -> HeightInfo {
|
||||
let mut result = HeightInfo { height: 0, size: 1 };
|
||||
for child in &self.downward_edges[idx] {
|
||||
let child_result = self.calc_info(*child);
|
||||
result.height = max(result.height, child_result.height + 1);
|
||||
result.size += child_result.size;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn check_roots(&self) -> u32 {
|
||||
let mut total_size = 0;
|
||||
let mut max_height = 0;
|
||||
for (i, node) in self.uf.nodes.iter().enumerate() {
|
||||
if let Node::Root(root) = node {
|
||||
let info = self.calc_info(i);
|
||||
assert_eq!(root.size, info.size);
|
||||
|
||||
total_size += info.size;
|
||||
max_height = max(max_height, info.height);
|
||||
|
||||
let max_expected_height = ceil_log2(root.size + 1) - 1;
|
||||
if info.height > max_expected_height {
|
||||
eprintln!(
|
||||
"height {}\t max_expected_height {}\t size {}",
|
||||
info.height, max_expected_height, info.size
|
||||
);
|
||||
}
|
||||
assert!(info.height <= max_expected_height);
|
||||
}
|
||||
}
|
||||
assert_eq!(total_size, self.uf.nodes.len());
|
||||
assert_eq!(total_size, self.uf.idx_map.len());
|
||||
return max_height;
|
||||
}
|
||||
|
||||
pub fn check(uf: &'a UnionFind<X>) -> u32 {
|
||||
HeightCalc::new(uf).check_roots()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic() {
|
||||
let mut f = UnionFind::new();
|
||||
assert_eq!(f.find(10), 10);
|
||||
assert_eq!(f.find(12), 12);
|
||||
|
||||
f.union(10, 12);
|
||||
f.union(11, 13);
|
||||
|
||||
HeightCalc::check(&f);
|
||||
|
||||
assert_eq!(f.find(13), 11);
|
||||
assert_eq!(f.find(12), 10);
|
||||
assert_eq!(f.find(11), 11);
|
||||
assert_eq!(f.find(10), 10);
|
||||
|
||||
f.union(12, 13);
|
||||
|
||||
HeightCalc::check(&f);
|
||||
|
||||
assert_eq!(f.find(13), 10);
|
||||
assert_eq!(f.find(12), 10);
|
||||
assert_eq!(f.find(11), 10);
|
||||
assert_eq!(f.find(10), 10);
|
||||
|
||||
assert_eq!(f.find(14), 14);
|
||||
|
||||
HeightCalc::check(&f);
|
||||
|
||||
// Union the set with itself
|
||||
f.union(11, 10);
|
||||
|
||||
HeightCalc::check(&f);
|
||||
|
||||
assert_eq!(f.find(13), 10);
|
||||
assert_eq!(f.find(12), 10);
|
||||
assert_eq!(f.find(11), 10);
|
||||
assert_eq!(f.find(10), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chain_a_height() {
|
||||
let mut f = UnionFind::new();
|
||||
for i in 0..1000 {
|
||||
f.union(i, i + 1);
|
||||
HeightCalc::check(&f);
|
||||
}
|
||||
assert_eq!(f.find(1000), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chain_b_height() {
|
||||
let mut f = UnionFind::new();
|
||||
for i in 0..1000 {
|
||||
f.union(i + 1, i);
|
||||
HeightCalc::check(&f);
|
||||
}
|
||||
assert_eq!(f.find(0), 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_tree_height() {
|
||||
let height = 8;
|
||||
let count = 1 << height;
|
||||
|
||||
let mut f = UnionFind::new();
|
||||
for current_height in 0..height {
|
||||
let stride = 1 << current_height;
|
||||
for i in (0..count).step_by(2 * stride) {
|
||||
f.union(i, i + stride);
|
||||
}
|
||||
let actual_height = HeightCalc::check(&f);
|
||||
|
||||
// actual_height can vary based on tiebreaker condition
|
||||
assert!(
|
||||
actual_height == current_height
|
||||
|| actual_height == current_height + 1
|
||||
);
|
||||
}
|
||||
|
||||
// Check path halving
|
||||
let actual_height_before = HeightCalc::check(&f);
|
||||
for i in 0..count {
|
||||
assert_eq!(f.find(i), 0);
|
||||
}
|
||||
let actual_height_after = HeightCalc::check(&f);
|
||||
|
||||
assert!(actual_height_after <= actual_height_before.div_ceil(2));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue