diff --git a/pybloom/pybloom.py b/pybloom/pybloom.py index beeefe4..82370c2 100644 --- a/pybloom/pybloom.py +++ b/pybloom/pybloom.py @@ -376,6 +376,43 @@ def add(self, key): filter.add(key, skip_check=True) return False + def copy(self): + """Return a copy of this scalable bloom filter. + """ + new_filter = ScalableBloomFilter(initial_capacity=self.initial_capacity, + error_rate=self.error_rate, + mode=self.SMALL_SET_GROWTH) + new_filter.filters = self.filters[:] + return new_filter + + def union(self, other): + """ Calculates the union of the underlying classic bloom filters and returns + a new scalable bloom filter object.""" + + if self.scale != other.scale or \ + self.initial_capacity != other.initial_capacity or \ + self.error_rate != other.error_rate: + raise ValueError("Unioning two scalable bloom filters requires \ + both filters to have both the same mode, initial capacity and error rate") + if len(self.filters) > len(other.filters): + larger_sbf = self.copy() + smaller_sbf = other.copy() + else: + larger_sbf = other.copy() + smaller_sbf = self.copy() + # Union the underlying classic bloom filters + new_filters = [] + for i in range(len(smaller_sbf.filters)): + new_filter = larger_sbf.filters[i] | smaller_sbf.filters[i] + new_filters.append(new_filter) + for i in range(len(smaller_sbf.filters), len(larger_sbf.filters)): + new_filters.append(larger_sbf.filters[i]) + larger_sbf.filters = new_filters + return larger_sbf + + def __or__(self, other): + return self.union(other) + @property def capacity(self): """Returns the total capacity for all filters in this SBF""" diff --git a/pybloom/tests.py b/pybloom/tests.py index 13d9b7d..babbadd 100644 --- a/pybloom/tests.py +++ b/pybloom/tests.py @@ -77,6 +77,18 @@ def _run(): new_bloom = bloom_one.union(bloom_two) self.assertRaises(ValueError, _run) + def test_union_scalable_bloom_filter(self): + bloom_one = ScalableBloomFilter(mode=ScalableBloomFilter.SMALL_SET_GROWTH) + bloom_two = ScalableBloomFilter(mode=ScalableBloomFilter.SMALL_SET_GROWTH) + chars = [chr(i) for i in range_fn(97, 123)] + for char in chars[int(len(chars) / 2):]: + bloom_one.add(char) + for char in chars[:int(len(chars) / 2)]: + bloom_two.add(char) + new_bloom = bloom_one.union(bloom_two) + for char in chars: + self.assertTrue(char in new_bloom) + class Serialization(unittest.TestCase): SIZE = 12345 EXPECTED = set([random.randint(0, 10000100) for _ in range_fn(SIZE)])