-
-
Notifications
You must be signed in to change notification settings - Fork 174
/
training_script.py
192 lines (144 loc) · 9.03 KB
/
training_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
Notes:
* I won't add model checkpoint averaging as mentioned in the paper - it just feels like an arbitrary heuristic
and it won't add anything to the learning experience this repo aims to provide.
"""
import argparse
import time
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from utils.optimizers_and_distributions import CustomLRAdamOptimizer, LabelSmoothingDistribution
from models.definitions.transformer_model import Transformer
from utils.data_utils import get_data_loaders, get_masks_and_count_tokens, get_src_and_trg_batches, DatasetType, LanguageDirection
import utils.utils as utils
from utils.constants import *
# Global vars for logging purposes
num_of_trg_tokens_processed = 0
bleu_scores = []
global_train_step, global_val_step = [0, 0]
writer = SummaryWriter() # (tensorboard) writer will output to ./runs/ directory by default
# Simple decorator function so that I don't have to pass these arguments every time I call get_train_val_loop
def get_train_val_loop(baseline_transformer, custom_lr_optimizer, kl_div_loss, label_smoothing, pad_token_id, time_start):
def train_val_loop(is_train, token_ids_loader, epoch):
global num_of_trg_tokens_processed, global_train_step, global_val_step, writer
if is_train:
baseline_transformer.train()
else:
baseline_transformer.eval()
device = next(baseline_transformer.parameters()).device
#
# Main loop - start of the CORE PART
#
for batch_idx, token_ids_batch in enumerate(token_ids_loader):
src_token_ids_batch, trg_token_ids_batch_input, trg_token_ids_batch_gt = get_src_and_trg_batches(token_ids_batch)
src_mask, trg_mask, num_src_tokens, num_trg_tokens = get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch_input, pad_token_id, device)
# log because the KL loss expects log probabilities (just an implementation detail)
predicted_log_distributions = baseline_transformer(src_token_ids_batch, trg_token_ids_batch_input, src_mask, trg_mask)
smooth_target_distributions = label_smoothing(trg_token_ids_batch_gt) # these are regular probabilities
if is_train:
custom_lr_optimizer.zero_grad() # clean the trainable weights gradients in the computational graph
loss = kl_div_loss(predicted_log_distributions, smooth_target_distributions)
if is_train:
loss.backward() # compute the gradients for every trainable weight in the computational graph
custom_lr_optimizer.step() # apply the gradients to weights
# End of CORE PART
#
# Logging and metrics
#
if is_train:
global_train_step += 1
num_of_trg_tokens_processed += num_trg_tokens
if training_config['enable_tensorboard']:
writer.add_scalar('training_loss', loss.item(), global_train_step)
if training_config['console_log_freq'] is not None and batch_idx % training_config['console_log_freq'] == 0:
print(f'Transformer training: time elapsed= {(time.time() - time_start):.2f} [s] '
f'| epoch={epoch + 1} | batch= {batch_idx + 1} '
f'| target tokens/batch= {num_of_trg_tokens_processed / training_config["console_log_freq"]}')
num_of_trg_tokens_processed = 0
# Save model checkpoint
if training_config['checkpoint_freq'] is not None and (epoch + 1) % training_config['checkpoint_freq'] == 0 and batch_idx == 0:
ckpt_model_name = f"transformer_ckpt_epoch_{epoch + 1}.pth"
torch.save(utils.get_training_state(training_config, baseline_transformer), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))
else:
global_val_step += 1
if training_config['enable_tensorboard']:
writer.add_scalar('val_loss', loss.item(), global_val_step)
return train_val_loop
def train_transformer(training_config):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checking whether you have a GPU, I hope so!
# Step 1: Prepare data loaders
train_token_ids_loader, val_token_ids_loader, src_field_processor, trg_field_processor = get_data_loaders(
training_config['dataset_path'],
training_config['language_direction'],
training_config['dataset_name'],
training_config['batch_size'],
device)
pad_token_id = src_field_processor.vocab.stoi[PAD_TOKEN] # pad token id is the same for target as well
src_vocab_size = len(src_field_processor.vocab)
trg_vocab_size = len(trg_field_processor.vocab)
# Step 2: Prepare the model (original transformer) and push to GPU
baseline_transformer = Transformer(
model_dimension=BASELINE_MODEL_DIMENSION,
src_vocab_size=src_vocab_size,
trg_vocab_size=trg_vocab_size,
number_of_heads=BASELINE_MODEL_NUMBER_OF_HEADS,
number_of_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
dropout_probability=BASELINE_MODEL_DROPOUT_PROB
).to(device)
# Step 3: Prepare other training related utilities
kl_div_loss = nn.KLDivLoss(reduction='batchmean') # gives better BLEU score than "mean"
# Makes smooth target distributions as opposed to conventional one-hot distributions
# My feeling is that this is a really dummy and arbitrary heuristic but time will tell.
label_smoothing = LabelSmoothingDistribution(BASELINE_MODEL_LABEL_SMOOTHING_VALUE, pad_token_id, trg_vocab_size, device)
# Check out playground.py for an intuitive visualization of how the LR changes with time/training steps, easy stuff.
custom_lr_optimizer = CustomLRAdamOptimizer(
Adam(baseline_transformer.parameters(), betas=(0.9, 0.98), eps=1e-9),
BASELINE_MODEL_DIMENSION,
training_config['num_warmup_steps']
)
# The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops
train_val_loop = get_train_val_loop(baseline_transformer, custom_lr_optimizer, kl_div_loss, label_smoothing, pad_token_id, time.time())
# Step 4: Start the training
for epoch in range(training_config['num_of_epochs']):
# Training loop
train_val_loop(is_train=True, token_ids_loader=train_token_ids_loader, epoch=epoch)
# Validation loop
with torch.no_grad():
train_val_loop(is_train=False, token_ids_loader=val_token_ids_loader, epoch=epoch)
bleu_score = utils.calculate_bleu_score(baseline_transformer, val_token_ids_loader, trg_field_processor)
if training_config['enable_tensorboard']:
writer.add_scalar('bleu_score', bleu_score, epoch)
# Save the latest transformer in the binaries directory
torch.save(utils.get_training_state(training_config, baseline_transformer), os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
if __name__ == "__main__":
#
# Fixed args - don't change these unless you have a good reason
#
num_warmup_steps = 4000
#
# Modifiable args - feel free to play with these (only small subset is exposed by design to avoid cluttering)
#
parser = argparse.ArgumentParser()
# According to the paper I infered that the baseline was trained for ~19 epochs on the WMT-14 dataset and I got
# nice returns up to epoch ~20 on IWSLT as well (nice round number)
parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=20)
# You should adjust this for your particular machine (I have RTX 2080 with 8 GBs of VRAM so 1500 fits nicely!)
parser.add_argument("--batch_size", type=int, help="target number of tokens in a src/trg batch", default=1500)
# Data related args
parser.add_argument("--dataset_name", choices=[el.name for el in DatasetType], help='which dataset to use for training', default=DatasetType.IWSLT.name)
parser.add_argument("--language_direction", choices=[el.name for el in LanguageDirection], help='which direction to translate', default=LanguageDirection.E2G.name)
parser.add_argument("--dataset_path", type=str, help='download dataset to this path', default=DATA_DIR_PATH)
# Logging/debugging/checkpoint related (helps a lot with experimentation)
parser.add_argument("--enable_tensorboard", type=bool, help="enable tensorboard logging", default=True)
parser.add_argument("--console_log_freq", type=int, help="log to output console (batch) freq", default=10)
parser.add_argument("--checkpoint_freq", type=int, help="checkpoint model saving (epoch) freq", default=1)
args = parser.parse_args()
# Wrapping training configuration into a dictionary
training_config = dict()
for arg in vars(args):
training_config[arg] = getattr(args, arg)
training_config['num_warmup_steps'] = num_warmup_steps
# Train the original transformer model
train_transformer(training_config)