diff --git a/src/v2/sketch.rs b/src/v2/sketch.rs index 9d44e5d..9e98b81 100644 --- a/src/v2/sketch.rs +++ b/src/v2/sketch.rs @@ -1,6 +1,8 @@ +const RESET_MASK: u64 = 0x7777777777777777; +const ONE_MASK: u64 = 0x1111111111111111; + pub struct CountMinSketch { - row_64_size: usize, - row_mask: usize, + block_mask: usize, table: Vec, additions: usize, sample_size: usize, @@ -8,29 +10,25 @@ pub struct CountMinSketch { impl CountMinSketch { pub fn new(size: usize) -> CountMinSketch { - let row_counter_size = ((size * 3) as u64).next_power_of_two() as usize; - let row_64_size = row_counter_size / 16; - let row_mask = row_counter_size - 1; - // each u64 contains 16 counters, so vec size is (row_counter_size * 4 / 16) - let table = vec![0; row_counter_size >> 2]; + let counter_size = size.next_power_of_two(); + let block_mask = (counter_size >> 3) - 1; + let table = vec![0; counter_size]; CountMinSketch { additions: 0, - sample_size: 10 * row_counter_size, + sample_size: 10 * counter_size, table, - row_mask, - row_64_size, + block_mask, } } - fn index_of(&self, h: u64, offset: u8) -> (usize, usize) { - let hn = h + (offset as u64) * (h >> 32); - let i = (hn & (self.row_mask as u64)) as usize; - let index = offset as usize * self.row_64_size + (i >> 4); - let offset = (i & 0xF) << 2; - (index, offset) + fn index_of(&self, counter_hash: u64, block: u64, offset: u8) -> (usize, usize) { + let h = counter_hash >> (offset << 3); + let index = block + (h & 1) + (offset << 1) as u64; + (index as usize, (h >> 1 & 0xf) as usize) } fn inc(&mut self, index: usize, offset: usize) -> bool { + let offset = offset << 2; let mask = 0xF << offset; if self.table[index] & mask != mask { self.table[index] += 1 << offset; @@ -40,10 +38,13 @@ impl CountMinSketch { } pub fn add(&mut self, h: u64) { - let (index0, offset0) = self.index_of(h, 0); - let (index1, offset1) = self.index_of(h, 1); - let (index2, offset2) = self.index_of(h, 2); - let (index3, offset3) = self.index_of(h, 3); + let counter_hash = rehash(h); + let block_hash = h; + let block = (block_hash & (self.block_mask as u64)) << 3; + let (index0, offset0) = self.index_of(counter_hash, block, 0); + let (index1, offset1) = self.index_of(counter_hash, block, 1); + let (index2, offset2) = self.index_of(counter_hash, block, 2); + let (index3, offset3) = self.index_of(counter_hash, block, 3); let mut added: bool; added = self.inc(index0, offset0); @@ -60,46 +61,62 @@ impl CountMinSketch { } fn reset(&mut self) { - let _ = self.table.iter().map(|x| x >> 1); - self.additions >>= 1; + let mut count = 0; + + for i in self.table.iter_mut() { + count += (*i & ONE_MASK).count_ones(); + *i = (*i >> 1) & RESET_MASK; + } + + self.additions = (self.additions - ((count >> 2) as usize)) >> 1; } - fn count(&self, h: u64, offset: u8) -> usize { - let (index, offset) = self.index_of(h, offset); + fn count(&self, h: u64, block: u64, offset: u8) -> usize { + let (index, offset) = self.index_of(h, block, offset); + let offset = offset << 2; let count = (self.table[index] >> offset) & 0xF; count as usize } pub fn estimate(&self, h: u64) -> usize { - let count0 = self.count(h, 0); - let count1 = self.count(h, 1); - let count2 = self.count(h, 2); - let count3 = self.count(h, 3); + let counter_hash = rehash(h); + let block_hash = h; + let block = (block_hash & (self.block_mask as u64)) << 3; + let count0 = self.count(counter_hash, block, 0); + let count1 = self.count(counter_hash, block, 1); + let count2 = self.count(counter_hash, block, 2); + let count3 = self.count(counter_hash, block, 3); let s = [count0, count1, count2, count3]; let min = s.iter().min().unwrap(); *min } } +fn rehash(h: u64) -> u64 { + let mut h = h.wrapping_mul(0x94d049bb133111eb); + h ^= h >> 31; + h +} + #[cfg(test)] mod tests { - use std::hash::{BuildHasher, RandomState}; + use std::{ + collections::HashMap, + hash::{BuildHasher, RandomState}, + }; use super::CountMinSketch; #[test] fn test_sketch() { - let mut sketch = CountMinSketch::new(100); - // 512 counters per row, 2048 bits per row, 32 uint64 per row - assert_eq!(sketch.row_64_size, 32); - assert_eq!(sketch.row_mask, 511); - // 32 uint64 * 4 rows - assert_eq!(sketch.table.len(), 128); - assert_eq!(sketch.sample_size, 5120); + let mut sketch = CountMinSketch::new(10000); + assert_eq!(sketch.table.len(), 16384); + assert_eq!(sketch.block_mask, 2047); + assert_eq!(sketch.sample_size, 163840); let hasher = RandomState::new(); let mut failed = 0; - for i in 0..500 { + for i in 0..8000 { let key = format!("foo:bar:{}", i); let h = hasher.hash_one(key); sketch.add(h); @@ -115,17 +132,76 @@ mod tests { let es1 = sketch.estimate(h); let es2 = sketch.estimate(h2); - if es2 > es1 { + if es1 != 5 { + failed += 1 + } + if es2 != 3 { failed += 1 } assert!(es1 >= 5); assert!(es2 >= 3); } - assert!(failed as f64 / 4000.0 < 0.1); - assert!(sketch.additions > 3900); - let a = sketch.additions; + assert!(failed < 40); + } + #[test] + fn test_sketch_reset_counter() { + let mut sketch = CountMinSketch::new(1000); + for i in sketch.table.iter_mut() { + *i = !0; + } + sketch.additions = 100000; + let hasher = RandomState::new(); + let h = hasher.hash_one("foo"); + assert_eq!(sketch.estimate(h), 15); sketch.reset(); - assert_eq!(sketch.additions, a >> 1); + assert_eq!(sketch.estimate(h), 7); + } + + #[test] + fn test_sketch_reset_addition() { + let mut sketch = CountMinSketch::new(500); + let hasher = RandomState::new(); + let mut counts = HashMap::new(); + for i in 0..5 { + let key = format!("foo:bar:{}", i); + let h = hasher.hash_one(key); + sketch.add(h); + sketch.add(h); + sketch.add(h); + sketch.add(h); + sketch.add(h); + let keyb = format!("foo:bar:{}:b", i); + let h2 = hasher.hash_one(keyb); + sketch.add(h2); + sketch.add(h2); + sketch.add(h2); + + let es1 = sketch.estimate(h); + let es2 = sketch.estimate(h2); + counts.insert(h, es1); + counts.insert(h2, es2); + } + let total_before = sketch.additions; + let mut diff = 0; + sketch.reset(); + for i in 0..5 { + let key = format!("foo:bar:{}", i); + let h = hasher.hash_one(key); + let keyb = format!("foo:bar:{}:b", i); + let h2 = hasher.hash_one(keyb); + + let es1 = sketch.estimate(h); + let es2 = sketch.estimate(h2); + let es1_prev = *counts.get(&h).unwrap(); + let es2_prev = *counts.get(&h2).unwrap(); + diff += es1_prev - es1; + diff += es2_prev - es2; + + assert_eq!(es1, es1_prev / 2 as usize); + assert_eq!(es2, es2_prev / 2 as usize); + } + + assert_eq!(total_before - sketch.additions, diff); } }