-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_train.py
128 lines (106 loc) · 5.93 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from dataloader import DicomDataset, image_transform
from loss import KLDivergence, L2Loss
from model import VAE
from plotting import plot_and_save_loss, plot_model_architecture
from save_load import save_checkpoint, load_from_checkpoint, load_config, create_output_directories
from train import train_loop, val_loop
def parse_cl_args() -> argparse.Namespace:
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--config")
return arg_parser.parse_args()
if __name__ == "__main__":
cl_args = parse_cl_args()
config = load_config(cl_args.config)
create_output_directories(config)
transform_composition = image_transform(config)
# Datasets and loaders
dataset = DicomDataset(config['data']['data_dir'],
transform=transform_composition,
target_transform=transform_composition)
train_dataset, val_dataset, test_dataset = random_split(dataset,
config['data']['train_val_test_split'],
torch.Generator().manual_seed(91722))
train_dataloader = DataLoader(dataset=train_dataset,
batch_size=config['model']['batch_size'],
shuffle=True)
val_dataloader = DataLoader(dataset=dataset,
batch_size=config['model']['batch_size'],
shuffle=False)
model = VAE(latent_dim=config['model']['latent_dim'],
img_dim=config['model']['img_dim'])
model.to(device=config['model']['device'])
# beta is the weight for the KL loss term. I use definition from B-VAE paper https://openreview.net/pdf?id=Sy2fzU9gl
beta = config['model']['latent_dim'] / (config['model']['img_dim'] ** 2)
plot_model_architecture(model=model,
batch_size=config['model']['batch_size'],
channels=config['model']['channels'],
img_dim=config['model']['img_dim'],
save_dir=config['logging']['plot_save_dir'])
optimizer = torch.optim.Adam(model.parameters(),
lr=config['model']['learning_rate'],
weight_decay=config['model']['weight_decay'])
# If we are resuming training, load the state dict for the model and the optimizer from best_epoch
if config['model']['resume']:
best_model = os.path.join(config['logging']['model_save_dir'], "best_epoch.pth")
model, optimizer, loss_dict, start_epoch = load_from_checkpoint(checkpoint_path=best_model,
model=model,
optimizer=optimizer)
best_val_loss = min(loss_dict["VAL_LOSS_RECON"])
best_val_epoch = np.argmin(loss_dict["VAL_LOSS_RECON"])
else:
start_epoch = 0
loss_dict = {"TRAIN_LOSS_KL": [],
"TRAIN_LOSS_RECON": [],
"VAL_LOSS_KL": [],
"VAL_LOSS_RECON": []}
best_val_loss = 1e12
best_val_epoch = 0
for t in range(start_epoch, config['model']['max_epochs']):
print(f"Epoch {t}\n-------------------------------")
train_loss_kl, train_loss_recon = train_loop(dataloader=train_dataloader,
model=model,
loss_fn_kl=KLDivergence(),
loss_fn_recon=L2Loss(),
beta=beta,
optimizer=optimizer,
amp_on=config['model']['use_amp'])
val_loss_kl, val_loss_recon = val_loop(dataloader=val_dataloader,
model=model,
loss_fn_kl=KLDivergence(),
loss_fn_recon=L2Loss(),
beta=beta,
epoch_number=t,
plot_every_n_epochs=config['logging']['plot_every_n_epochs'],
plot_save_dir=config['logging']['plot_save_dir'])
val_loss = val_loss_kl + val_loss_recon
print(f"Validation loss for epoch {t:>2d}:")
print(f" KL loss: {val_loss_kl:>15.2f}")
print(f" Recon loss:{val_loss_recon:>15.2f}")
loss_dict["TRAIN_LOSS_KL"].append(train_loss_kl)
loss_dict["TRAIN_LOSS_RECON"].append(train_loss_recon)
loss_dict["VAL_LOSS_KL"].append(val_loss_kl)
loss_dict["VAL_LOSS_RECON"].append(val_loss_recon)
plot_and_save_loss(loss_dict=loss_dict,
save_dir=config['logging']['plot_save_dir'])
if t % config['logging']['save_model_every_n_epochs'] == 0:
save_checkpoint(save_path=os.path.join(config['logging']['model_save_dir'], f"epoch_{t}.pth"),
model=model,
optimizer=optimizer,
loss_dict=loss_dict,
epoch_number=t)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_val_epoch = t
save_checkpoint(save_path=os.path.join(config['logging']['model_save_dir'], "best_epoch.pth"),
model=model,
optimizer=optimizer,
loss_dict=loss_dict,
epoch_number=t)
print("Done training.")
print("Best epoch was: " + str(best_val_epoch))