-
Notifications
You must be signed in to change notification settings - Fork 0
/
3_streaming_countSketch.py
134 lines (105 loc) · 4.75 KB
/
3_streaming_countSketch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from pyspark import SparkContext, SparkConf
from pyspark.streaming import StreamingContext
from pyspark import StorageLevel
import threading
import sys
import numpy as np
# Constants
THRESHOLD = 10000000 # After how many items should we stop?
P = 8191 # Prime number
class CountSketch:
def __init__(self, d, w):
self.d = d
self.w = w
self.count_sketch = np.zeros((self.d, self.w), dtype=np.int64)
self.ha = np.random.randint(1, P, size=self.d)
self.hb = np.random.randint(0, P, size=self.d)
def _get_count_sketch(self, index):
return self.count_sketch[index]
def update(self, elem, count):
for j in range(self.d):
hash_val = self._hash_function(self.ha[j], self.hb[j], elem)
g = self._generate_sign(elem)
self.count_sketch[j][hash_val] += g * count
def estimate_item_frequency(self, elem):
f_tilde = []
for j in range(self.d):
hash_val = self._hash_function(self.ha[j], self.hb[j], elem)
g = self._generate_sign(elem)
f_tilde.append(self._get_count_sketch(j)[hash_val] * g)
return np.median(f_tilde)
def estimate_F2(self):
return np.median([sum(self._get_count_sketch(i) ** 2) for i in range(D)])
def _hash_function(self, a, b, x):
return ((a * x + b) % P) % self.w
@staticmethod
def _generate_sign(x):
random_generator = np.random.default_rng(x)
return random_generator.choice([-1, 1])
# Operations to perform after receiving an RDD 'batch'
def process_batch(batch):
global streamLength, true_freq_histogram, count_sketch, D, W, left, right
if streamLength[0] >= THRESHOLD:
return
batch_size = batch.count()
if batch_size > 0:
streamLength[0] += batch_size
batch_item_freq = batch.map(lambda s: (int(s), 1)).reduceByKey(lambda i1, i2: i1 + i2).collectAsMap()
for key, value in batch_item_freq.items():
if key in range(left, right+1):
if key not in true_freq_histogram:
true_freq_histogram[key] = value
else:
true_freq_histogram[key] += value
count_sketch.update(key, value)
if streamLength[0] >= THRESHOLD:
stopping_condition.set()
def print_statistics(portNumber, length_StreamInterval, distinct_StreamInterval):
print(f'D = {D}, W = {W}, [left,right] = [{left},{right}], K = {K}, Port = {portNumber}')
print(f'Total number of items = {streamLength[0]}')
print(f'Total number of items in [{left}, {right}] = {length_StreamInterval}')
print(f'Number of distinct items in [{left}, {right}] = {distinct_StreamInterval}')
top_K_freq = sorted(true_freq_histogram.items(), key=lambda x: x[1], reverse=True)[:K]
if K <= 20:
for i, freq in top_K_freq:
print(f'Item {i} Freq = {freq} Est. Freq = {approximate_freq[i]}')
avg_error = 0.0
for i, freq in top_K_freq:
avg_error += abs(freq - approximate_freq[i]) / freq
avg_error /= K
print(f'Avg err for top {K} = {avg_error:.4f}')
print(f'F2 {F2:.4f} F2 Estimate {F2_aprox:.4f}')
if __name__ == '__main__':
assert len(sys.argv) == 7, "USAGE: D W left right K portExp"
conf = SparkConf().setAppName("Streaming_CountSketch_HW3")
sc = SparkContext(conf=conf)
ssc = StreamingContext(sc, 1) # Batch duration of 1 second
ssc.sparkContext.setLogLevel("ERROR")
stopping_condition = threading.Event()
# INPUT READING
D = int(sys.argv[1])
W = int(sys.argv[2])
left = int(sys.argv[3])
right = int(sys.argv[4])
K = int(sys.argv[5])
portExp = int(sys.argv[6])
streamLength = [0]
true_freq_histogram = {}
count_sketch = CountSketch(D, W)
# CODE TO PROCESS AN UNBOUNDED STREAM OF DATA IN BATCHES
stream = ssc.socketTextStream("algo.dei.unipd.it", portExp, StorageLevel.MEMORY_AND_DISK)
stream.foreachRDD(lambda batch: process_batch(batch))
# MANAGING STREAMING SPARK CONTEXT
ssc.start()
stopping_condition.wait()
ssc.stop(False, True)
# COMPUTE AND PRINT FINAL STATISTICS
approximate_freq = {}
for item in true_freq_histogram.keys():
approximate_freq[item] = count_sketch.estimate_item_frequency(item)
sigmaR_length = sum(val for val in true_freq_histogram.values())
sigmaR_distinct = len(true_freq_histogram)
F2 = sum(val ** 2 for val in true_freq_histogram.values()) / sigmaR_length ** 2
F2_aprox = count_sketch.estimate_F2() / sigmaR_length ** 2
print_statistics(portExp, sigmaR_length, sigmaR_distinct)
sc.stop()