Skip to content

Commit

Permalink
Merge pull request #2 from TrueDoctor/master
Browse files Browse the repository at this point in the history
Fix bug in unify function
  • Loading branch information
Khojasteh authored Feb 12, 2024
2 parents b3e977c + 528665a commit 5d8a97d
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 11 deletions.
120 changes: 120 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,11 @@ edition = "2021"
license = "MIT"
readme = "README.md"
repository = "https://github.com/khojasteh/uf_rush"
documentation = "https://docs.rs/uf_rush/0.1.1"
documentation = "https://docs.rs/uf_rush/0.1.1"

[dev-dependencies]
rand = "0.8.5"
rayon = "1.8.1"

[profile.test]
opt-level = 3
76 changes: 66 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,40 +158,52 @@ impl UFRush {
/// deep, which helps keep the operation's time complexity nearly constant.
pub fn unite(&self, x: usize, y: usize) -> bool {
loop {
// Load representative for x and y
let mut x_rep = self.find(x);
let mut y_rep = self.find(y);

// If they are already part of the same set, return false
if x_rep == y_rep {
return false;
}

// Load the encoded representation of the representatives
let x_node = self.nodes[x_rep].load(Ordering::Relaxed);
let y_node = self.nodes[y_rep].load(Ordering::Relaxed);

let mut x_rank = rank(x_node);
let mut y_rank = rank(y_node);

if x_rank > y_rank || (x_rank == y_rank && x_rank < y_rank) {
// Swap the elements around to always make x the smaller one
if x_rank > y_rank || (x_rank == y_rank && x_rep > y_rep) {
std::mem::swap(&mut x_rep, &mut y_rep);
std::mem::swap(&mut x_rank, &mut y_rank);
}

// x_rep is a root
let cur_value = encode(x_rep, x_rank);
// assign the new root to be y
let new_value = encode(y_rep, x_rank);
// change the value of the smaller subtree root to point to the other one
if self.nodes[x_rep]
.compare_exchange(cur_value, new_value, Ordering::Release, Ordering::Acquire)
.is_ok()
{
let cur_value = encode(y_rep, y_rank);
let new_value = encode(y_rep, y_rank + 1);
let _ = self.nodes[y_rep].compare_exchange_weak(
cur_value,
new_value,
Ordering::Release,
Ordering::Relaxed,
);
// x_repr now points to y_repr
// If the subtrees has the same height, increase the rank of the new root
if x_rank == y_rank {
let cur_value = encode(y_rep, y_rank);
let new_value = encode(y_rep, y_rank + 1);
let _ = self.nodes[y_rep].compare_exchange_weak(
cur_value,
new_value,
Ordering::Release,
Ordering::Relaxed,
);
}
return true;
}
// A different thread has already merged modified the value of x_repr -> repeat
}
}

Expand Down Expand Up @@ -313,6 +325,50 @@ mod tests {
assert!(!is_cyclic(4, [(0, 1), (1, 2), (2, 3)]));
}

#[test]
fn stress_test() {
use rand::prelude::*;
use std::sync::{Arc, Barrier};
use std::thread;

let num_elements = 1_00_000; // Adjust based on the system's capability

let elements = 1 << 9;

// Preparing a pool of element pairs for unification
let mut pairs = Vec::new();
// Make sure everythin is connected
for i in 0..=elements {
let i = i % elements;
pairs.push((i, i + 1));
}
// Add random edges to the graph
for i in 0..num_elements - 1 {
let source = rand::random::<usize>() % elements;
let target = rand::random::<usize>() % elements;
pairs.push((source, target));
}

for i in 0..1000 {
// Shuffle pairs to randomize access patterns
let uf = Arc::new(UFRush::new(elements + 1));
use rand::{thread_rng, Rng};
use rayon::prelude::*;
let mut rng = thread_rng();
let total_unites = AtomicUsize::new(0);
let total_unites = &total_unites;
pairs.shuffle(&mut rng);

pairs.par_iter().for_each(|(x, y)| {
if uf.unite(*x, *y) {
total_unites.fetch_add(1, Ordering::Relaxed);
}
});

assert_eq!(total_unites.load(Ordering::SeqCst), elements);
}
}

fn is_cyclic<I>(vertices: usize, edges: I) -> bool
where
I: IntoIterator<Item = (usize, usize)>,
Expand Down Expand Up @@ -340,4 +396,4 @@ mod tests {
// Wait for all threads to finish and check if any of them found a cycle
handles.into_iter().any(|handle| handle.join().unwrap())
}
}
}

0 comments on commit 5d8a97d

Please sign in to comment.