-
Notifications
You must be signed in to change notification settings - Fork 0
/
estimate_nu.py
53 lines (43 loc) · 1.67 KB
/
estimate_nu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
import csv
import os
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
def main(path):
# parse data generated by the pivot algorithm
num_steps = [1000 * 2 ** i for i in range(11)]
data = {}
for n in num_steps:
with open(os.path.join(path, f"{n}.csv")) as f:
reader = csv.reader(f)
entries = [[int(x), int(y)] for x, y in list(reader)]
data[n] = np.array(entries[(len(entries) // 2):]) # omit warm-up entries
# compute mean squared distance
squared_dists = {key: np.linalg.norm(val, ord=2, axis=1) ** 2 for key, val in data.items()}
mean_sq_dists = {key: np.mean(val) for key, val in squared_dists.items()}
# curve to fit
def f(n, nu, C):
return C * n ** (2 * nu)
# (optional) plot data and curve
x_scatter = list(mean_sq_dists.keys())
y_scatter = list(mean_sq_dists.values())
x_curve = np.linspace(min(x_scatter), max(x_scatter), 100)
y_curve = f(x_curve, 0.75, 1) # C = 1 is not very accurate but good enough for the plot
plt.scatter(x_scatter, y_scatter)
plt.loglog(x_curve, y_curve, color="red")
plt.xlabel("Number of steps")
plt.ylabel("Average mean squared distance")
plt.title("Estimating nu with the pivot algorithm")
# fit curve with scipy
params, _ = curve_fit(f, list(mean_sq_dists.keys()), list(mean_sq_dists.values()))
nu, _ = params
print(f"nu estimate: {nu}")
# show plot
plt.savefig("nu_estimate.png")
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("path", type=str)
args = parser.parse_args()
main(args.path)