-
Notifications
You must be signed in to change notification settings - Fork 0
/
example.py
110 lines (76 loc) · 2.14 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from time import sleep
from rich.console import Console
from rich.progress import track, Progress, BarColumn, TimeRemainingColumn
import numpy as np
from tintml import Tint
tint = Tint()
with tint.status("Initialization"):
# Hyperparameters
n_datapoints = 100
n_valpoints = 10
n_epochs = 4
sleep(0.17)
tint.log("Read command line arguments")
sleep(0.12)
tint.log("Set paths")
sleep(0.28)
tint.log("Finished initialization")
## LOAD DATA ##
tint.printh('Data Processing')
with tint.status("Processing"):
sleep(1.0)
tint.log("Read data from files")
sleep(0.3)
tint.log("Applied train/validation split")
sleep(1.4)
tint.log("Applied data augmentations")
## SET UP MODEL ##
tint.printh('Model Setup')
with tint.status("Setup Model"):
sleep(1.7)
tint.log("Defined model graph")
sleep(1.2)
tint.log("Loaded model weights")
## TRAINING ##
tint.printh("Training")
sleep(0.3)
prev_val_error = 1.
prev_train_error = 1.
for epoch_idx in range(1,n_epochs+1):
tint.print(f"Epoch {epoch_idx}/{n_epochs}")
# Train
for i in tint.iter(range(n_datapoints), "Training..."):
sleep(0.06)
train_error = prev_train_error * 0.9 + np.random.normal(0,0.1)
prev_train_error = train_error
# Val
for i in tint.iter(range(n_valpoints), "Validating..."):
sleep(0.05)
val_error = prev_val_error * 0.9 + np.random.normal(0,0.1)
prev_val_error = val_error
# Metrics
tint.print_metrics({
'Train loss': train_error,
'Val loss': val_error,
}, [True, True], multi_line=True)
tint.log("Finished training")
## TESTING ##
tint.printh("Testing")
with tint.status("Setting up testing"):
sleep(1)
tint.log("Read test data")
sleep(1.7)
tint.log("Started testing procedure")
for i in tint.iter(range(n_datapoints), "Testing"):
sleep(0.012)
test_error = prev_val_error + np.random.normal(0,0.1)
tint.print_metrics({
'Test loss': test_error
}, [True])
## SAVING ##
tint.printh('Saving')
with tint.status("Saving run"):
sleep(2)
tint.log("Saved model weights.")
sleep(0.8)
tint.log("Saved metrics.")