forked from jlrussin/transformer_scan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
113 lines (95 loc) · 4.01 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Training script for experiments with SCAN dataset
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from data import build_scan
from models.transformer import *
from test import test
def train(run,args):
# CUDA
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
# Data
SRC, TRG, train_data, dev_data, test_data = build_scan(args.split,
args.batch_size,
device)
# vocab
src_vocab_size = len(SRC.vocab.stoi)
trg_vocab_size = len(TRG.vocab.stoi)
pad_idx = SRC.vocab[SRC.pad_token]
assert TRG.vocab[TRG.pad_token] == pad_idx
# Model
if args.model_type == 'transformer':
model = Transformer(src_vocab_size,trg_vocab_size,args.d_model,
args.nhead,args.num_encoder_layers,
args.num_decoder_layers,args.dim_feedforward,
args.dropout,pad_idx,device)
if args.load_weights_from is not None:
model.load_state_dict(torch.load(args.load_weights_from))
model = model.to(device)
model.train()
# Loss function
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)
loss_fn = loss_fn.to(device)
# Optimizer
params = model.parameters()
optimizer = optim.Adam(params, lr=args.learning_rate)
# Setup things to record
loss_data = [] # records losses
train_accs = [] # records train accuracy
dev_accs = [] # records development accuracy
test_accs = [] # records test accuracy
best_dev_acc = float('-inf') # best dev accuracy (for doing early stopping)
# Training loop:
for epoch in range(args.num_epochs):
for iter,batch in enumerate(train_data):
optimizer.zero_grad()
out, attn_wts = model(batch.src,batch.trg)
loss = loss_fn(out.view(-1,trg_vocab_size),batch.trg.view(-1))
loss.backward()
optimizer.step()
# Record loss
if iter % args.record_loss_every == 0:
loss_datapoint = loss.data.item()
print('Run:', run,
'Epoch:', epoch,
'Iter:', iter,
'Loss:', loss_datapoint)
loss_data.append(loss_datapoint)
# Checkpoint
if epoch % args.checkpoint_every == 0:
# Checkpoint on train data
print("Checking training accuracy...")
train_acc = test(train_data, model, pad_idx, device, args)
print("Training accuracy is ", train_acc)
train_accs.append(train_acc)
# Checkpoint on development data
print("Checking development accuracy...")
dev_acc = test(dev_data, model, pad_idx, device, args)
print("Development accuracy is ", dev_acc)
dev_accs.append(dev_acc)
# Checkpoint on test data
print("Checking test accuracy...")
test_acc = test(test_data, model, pad_idx, device, args)
print("Test accuracy is ", test_acc)
test_accs.append(test_acc)
# Write stats file
results_path = 'results/%s' % (args.results_dir)
if not os.path.isdir(results_path):
os.mkdir(results_path)
stats = {'loss_data':loss_data,
'train_accs':train_accs,
'dev_accs':dev_accs,
'test_accs':test_accs}
results_fn = '%s/%s%d.json' % (results_path,args.out_data_file,run)
with open(results_fn, 'w') as f:
json.dump(stats, f)
# Save model weights
if run == 0: #first run only
if dev_acc > best_dev_acc: # use dev to decide to save
best_dev_acc = dev_acc
if args.checkpoint_path is not None:
torch.save(model.state_dict(),
args.checkpoint_path)