-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
82 lines (60 loc) · 2.72 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
from pybaselines import Baseline
from sklearn.neighbors import LocalOutlierFactor
from data.data_loader import SignalDataset
from src.handlers.gnn_handler import GNNhandler
from src.handlers.generic_handler import GenericHandler
from src.evaluator import Evaluator
def evaluate_asls_lof(json_file, max_samples=None, visualise=False):
baseline_fitter = Baseline(x_data=np.arange(1000)/1000)
handler = GenericHandler(
model=lambda x: baseline_fitter.asls(x, lam=200, p=0.3, max_iter=100)[0],
clf=LocalOutlierFactor(n_neighbors=900)
)
if max_samples is not None and max_samples > 0:
dataset = SignalDataset(json_file, max_samples=max_samples)
else:
dataset = SignalDataset(json_file)
evaluator = Evaluator(epsilon=11, model=handler, dataset=dataset)
evaluator.evaluate_all_samples()
if visualise:
evaluator.visualize_per_sample()
def evaluate_snip_lof(json_file, max_samples=None, visualise=False):
baseline_fitter = Baseline(x_data=np.arange(1000)/1000)
handler = GenericHandler(
model=lambda x: baseline_fitter.asls(x, lam=200, p=0.3, max_iter=100)[0],
clf=LocalOutlierFactor(n_neighbors=900)
)
if max_samples is not None and max_samples > 0:
dataset = SignalDataset(json_file, max_samples=max_samples)
else:
dataset = SignalDataset(json_file)
evaluator = Evaluator(epsilon=11, model=handler, dataset=dataset)
evaluator.evaluate_all_samples()
if visualise:
evaluator.visualize_per_sample()
def evaluate_gnn_lof(json_file, model_file, train=False, max_samples=None, visualise=False, test_json_file=None):
handler = GNNhandler(
load_weights=True, batch_size=20, num_epochs=10, n_points=1000, scale=1,
model_weights_pth=model_file, clf=LocalOutlierFactor(n_neighbors=900)
)
if train and test_json_file is not None:
handler.train_model(json_file, test_json_file)
if max_samples is not None and max_samples > 0:
dataset = SignalDataset(json_file, max_samples=max_samples)
else:
dataset = SignalDataset(json_file)
evaluator = Evaluator(epsilon=11, model=handler, dataset=dataset)
evaluator.evaluate_all_samples()
if visualise:
evaluator.visualize_per_sample()
if __name__ == "__main__":
synthetic_json_file = "data/2000_generated_signals.json"
print("ASLS and LOF:")
evaluate_asls_lof(synthetic_json_file)
print()
print("SNIP and LOF:")
evaluate_snip_lof(synthetic_json_file)
print()
print("GNN and LOF:")
evaluate_gnn_lof(synthetic_json_file, "gnn_frontal.pth", train=False)