-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
107 lines (77 loc) · 2.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import random
import torch
from net import SineModel
from DataLoader import SineWaveTask
from tools import sine_fit1, plot_sine_test, plot_sine_learning, maml_sine, reptile_sine
import matplotlib.pyplot as plt
TRAIN_SIZE = 10000
TEST_SIZE = 1000
SINE_TRAIN = [SineWaveTask() for _ in range(TRAIN_SIZE)]
SINE_TEST = [SineWaveTask() for _ in range(TEST_SIZE)]
SINE_TRANSFER = SineModel()
def fit_transfer(epochs=1):
optim = torch.optim.Adam(SINE_TRANSFER.params())
for _ in range(epochs):
for t in random.sample(SINE_TRAIN, len(SINE_TRAIN)):
sine_fit1(SINE_TRANSFER, t, optim)
def main():
# Mean And Random Version #
ONE_SIDED_EXAMPLE = None
while ONE_SIDED_EXAMPLE is None:
cur = SineWaveTask()
x, _ = cur.training_set()
x = x.numpy()
if np.max(x) < 0 or np.min(x) > 0:
ONE_SIDED_EXAMPLE = cur
fit_transfer()
plot_sine_test(SINE_TRANSFER, SINE_TEST[0], fits=[0, 1, 10], lr=0.02)
plot_sine_learning(
[('Transfer', SINE_TRANSFER), ('Random', SineModel())],
list(range(100)),
marker='',
linestyle='-', SINE_TEST=SINE_TEST)
# MaML #
SINE_MAML = [SineModel() for _ in range(5)]
for m in SINE_MAML:
maml_sine(m, 4, SINE_TRAIN=SINE_TRAIN)
plot_sine_test(SINE_MAML[0], SINE_TEST[0], fits=[0, 1, 10], lr=0.01)
plt.show()
plot_sine_learning(
[('Transfer', SINE_TRANSFER), ('MAML', SINE_MAML[0]), ('Random', SineModel())],
list(range(10)),
SINE_TEST=SINE_TEST
)
plt.show()
plot_sine_test(SINE_MAML[0], ONE_SIDED_EXAMPLE, fits=[0, 1, 10], lr=0.01)
plt.show()
# First Order #
SINE_MAML_FIRST_ORDER = [SineModel() for _ in range(5)]
for m in SINE_MAML_FIRST_ORDER:
maml_sine(m, 4, first_order=True, SINE_TRAIN=SINE_TRAIN)
plot_sine_test(SINE_MAML_FIRST_ORDER[0], SINE_TEST[0], fits=[0, 1, 10], lr=0.01)
plt.show()
plot_sine_learning(
[('MAML', SINE_MAML), ('MAML First Order', SINE_MAML_FIRST_ORDER)],
list(range(10)),
SINE_TEST=SINE_TEST
)
plt.show()
plot_sine_test(SINE_MAML_FIRST_ORDER[0], ONE_SIDED_EXAMPLE, fits=[0, 1, 10], lr=0.01)
plt.show()
# Reptile #
SINE_REPTILE = [SineModel() for _ in range(5)]
for m in SINE_REPTILE:
reptile_sine(m, 4, k=3, batch_size=1, SINE_TRAIN=SINE_TRAIN)
plot_sine_test(SINE_REPTILE[0], SINE_TEST[0], fits=[0, 1, 10], lr=0.01)
plt.show()
plot_sine_learning(
[('MAML', SINE_MAML), ('MAML First Order', SINE_MAML_FIRST_ORDER), ('Reptile', SINE_REPTILE)],
list(range(32)),
SINE_TEST=SINE_TEST
)
plt.show()
plot_sine_test(SINE_REPTILE[0], ONE_SIDED_EXAMPLE, fits=[0, 1, 10], lr=0.01)
plt.show()
if __name__ == "__main__":
main()