-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
executable file
·53 lines (46 loc) · 1.91 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from infcomp.client import RequesterClient, RequesterFile
from infcomp.nn.nn import NN
from infcomp.logger import Logger
from infcomp.util import load_nn
from infcomp.settings import settings
from infcomp.util import save_if_its_time
def train(directory, save_file_name, load_file_name, address, obs_embedding,
minibatch_size, save_after_n_traces, traces_dir=None):
if load_file_name is None:
nn = NN(directory=directory, file_name=save_file_name, obs_embedding_type=obs_embedding)
Logger.logger.log_info("New nn will be saved to: {}/{}".format(directory, save_file_name))
else:
nn = load_nn(load_file_name)
Logger.set(nn.logger)
Logger.logger.log_info("Resuming previous artifact: {}/{}".format(directory, load_file_name))
nn.train()
save_after_n_traces.sort(reverse=True)
if not traces_dir:
requester_class = RequesterClient
params = [address]
else:
requester_class = RequesterFile
params = [traces_dir]
with requester_class(*params) as requester:
errors = True
try:
i = 0
n_processed_traces = 0
best_loss_freq = 10
Logger.logger.log_training_begin(*params)
while not save_if_its_time(nn, save_after_n_traces, n_processed_traces):
# TODO minibatch to CUDA
minibatch = requester.minibatch(minibatch_size)
train_loss = nn.optimize(minibatch)
Logger.logger.log_training(len(minibatch), train_loss, nn)
if (i + 1) % best_loss_freq == 0:
Logger.logger.log_training_best()
i += 1
n_processed_traces += len(minibatch)
errors = False
except KeyboardInterrupt:
pass
finally:
# Try to save if there was an exception
if errors:
nn.__exit__(None, None, None)