소스 검색

Gave it a little more user-friendly constructor: So we pass ideal_num_elements and error_rate instead of arcana

dstromberg 14 년 전
부모
커밋
cf18ebb68a
2개의 변경된 파일35개의 추가작업 그리고 13개의 파일을 삭제
  1. 31 10
      bloom_filter_mod.py
  2. 4 3
      test-bloom-filter

+ 31 - 10
bloom_filter_mod.py

@@ -6,6 +6,7 @@
 
 # Tweaked a bit by Daniel Richard Stromberg, mostly to make it pass pylint
 
+import math
 import array
 import random
 
@@ -13,24 +14,44 @@ def get_probes(bfilter, key):
 	'''Generate a bunch of fast hash functions'''
 	hasher = random.Random(key).randrange
 	for _ in range(bfilter.num_probes):
-		array_index = hasher(len(bfilter.arr))
+		array_index = hasher(len(bfilter.array_))
 		bit_index = hasher(32)
 		yield array_index, 1 << bit_index
 
 class Bloom_filter:
 	'''Probabilistic set membership testing for large sets'''
 
-	def __init__(self, num_bits, num_probes, probe_func=get_probes):
-		self.num_bits = num_bits
-		num_words = (num_bits + 31) // 32
-		self.arr = array.array('L', [0]) * num_words
-		self.num_probes = num_probes
+	def __init__(self, ideal_num_elements, error_rate, probe_func=get_probes):
+		if ideal_num_elements <= 0:
+			raise ValueError('ideal_num_elements must be > 0')
+		if not (0 < error_rate < 1):
+			raise ValueError('error_rate must be between 0 and 1 inclusive')
+
+		self.error_rate = error_rate
+		# With fewer elements, we should do very well.  With more elements, our error rate "guarantee"
+		# drops rapidly.
+		self.ideal_num_elements = ideal_num_elements
+
+		self.num_bits = - int((self.ideal_num_elements * math.log(self.error_rate)) / (math.log(2) ** 2))
+
+		self.num_words = int((self.num_bits + 31) / 32)
+		self.array_ = array.array('L', [0]) * self.num_words
+
+		self.num_probes = int((self.num_bits / self.ideal_num_elements) * math.log(2))
+
 		self.probe_func = probe_func
 
+	def __repr__(self):
+		return 'Bloom_filter(ideal_num_elements=%d, error_rate=%f, num_bits=%d)' % (
+			self.ideal_num_elements,
+			self.error_rate,
+			self.num_bits,
+			)
+
 	def add(self, key):
 		'''Add an element to the filter'''
 		for i, mask in self.probe_func(self, key):
-			self.arr[i] |= mask
+			self.array_[i] |= mask
 
 	def _match_template(self, bfilter):
 		'''Compare a sort of signature for two bloom filters.  Used in preparation for binary operations'''
@@ -41,7 +62,7 @@ class Bloom_filter:
 	def union(self, bfilter):
 		'''Compute the set union of two bloom filters'''
 		if self._match_template(bfilter):
-			self.arr = [a | b for a, b in zip(self.arr, bfilter.arr)]
+			self.array_ = [a | b for a, b in zip(self.array_, bfilter.array_)]
 		else:
 			# Union b/w two unrelated bloom filter raises this
 			raise ValueError("Mismatched bloom filters")
@@ -52,7 +73,7 @@ class Bloom_filter:
 	def intersection(self, bfilter):
 		'''Compute the set intersection of two bloom filters'''
 		if self._match_template(bfilter):
-			self.arr = [a & b for a, b in zip(self.arr, bfilter.arr)]
+			self.array_ = [a & b for a, b in zip(self.array_, bfilter.array_)]
 		else:
 			# Intersection b/w two unrelated bloom filter raises this
 			raise ValueError("Mismatched bloom filters")
@@ -61,5 +82,5 @@ class Bloom_filter:
 		return self.intersection(bfilter)
 
 	def __contains__(self, key):
-		return all(self.arr[i] & mask for i, mask in self.probe_func(self, key))
+		return all(self.array_[i] & mask for i, mask in self.probe_func(self, key))
 

+ 4 - 3
test-bloom-filter

@@ -21,14 +21,16 @@ def tests():
 		Pennsylvania RhodeIsland SouthCarolina SouthDakota Tennessee Texas Utah
 		Vermont Virginia Washington WestVirginia Wisconsin Wyoming'''.split()
 
-	bloom_filter = bloom_filter_mod.Bloom_filter(num_bits=1000, num_probes=14)
+	trials = 100000
+
+	bloom_filter = bloom_filter_mod.Bloom_filter(ideal_num_elements=trials * 2, error_rate=0.01)
+	print(repr(bloom_filter))
 	for state in states:
 		bloom_filter.add(state)
 
 	states_in_count = sum(state in bloom_filter for state in states)
 	print('%d true positives out of %d trials' % (states_in_count, len(states)))
 
-	trials = 100000
 	false_positives = 0
 	for trialno in range(trials):
 		dummy = trialno
@@ -44,4 +46,3 @@ def tests():
 
 tests()
 
-