Skip to content

Commit dbefda7

Browse files
committed
add train_UNet_Onset_VAT.py
1 parent 6d4a95b commit dbefda7

File tree

1 file changed

+174
-0
lines changed

1 file changed

+174
-0
lines changed

train_UNet_Onset_VAT.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import os
2+
3+
from datetime import datetime
4+
import pickle
5+
6+
import numpy as np
7+
from sacred import Experiment
8+
from sacred.commands import print_config, save_config
9+
from sacred.observers import FileStorageObserver
10+
from torch.optim.lr_scheduler import StepLR, CyclicLR
11+
from torch.utils.data import DataLoader
12+
from tqdm import tqdm
13+
14+
from model import *
15+
ex = Experiment('train_original')
16+
17+
# parameters for the network
18+
ds_ksize, ds_stride = (2,2),(2,2)
19+
mode = 'imagewise'
20+
sparsity = 1
21+
output_channel = 2
22+
logging_freq = 100
23+
saving_freq = 200
24+
25+
26+
@ex.config
27+
def config():
28+
root = 'runs'
29+
# logdir = f'runs_AE/test' + '-' + datetime.now().strftime('%y%m%d-%H%M%S')
30+
# Choosing GPU to use
31+
# GPU = '0'
32+
# os.environ['CUDA_VISIBLE_DEVICES']=str(GPU)
33+
onset_stack=True
34+
device = 'cuda:0'
35+
log = True
36+
w_size = 31
37+
spec = 'Mel'
38+
resume_iteration = None
39+
train_on = 'MAPS'
40+
n_heads=4
41+
position=True
42+
iteration = 10
43+
VAT_start = 0
44+
alpha = 1
45+
VAT=True
46+
XI= 1e-6
47+
eps=2
48+
small = False
49+
supersmall = False
50+
KL_Div = False
51+
reconstruction = False
52+
53+
54+
batch_size = 8
55+
train_batch_size = 8
56+
sequence_length = 327680
57+
if torch.cuda.is_available() and torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory < 10e9:
58+
batch_size //= 2
59+
sequence_length //= 2
60+
print(f'Reducing batch size to {batch_size} and sequence_length to {sequence_length} to save memory')
61+
62+
epoches = 20000
63+
step_size_up = 100
64+
max_lr = 1e-4
65+
learning_rate = 1e-3
66+
# base_lr = learning_rate
67+
68+
learning_rate_decay_steps = 1000
69+
learning_rate_decay_rate = 0.98
70+
71+
leave_one_out = None
72+
73+
clip_gradient_norm = 3
74+
75+
validation_length = sequence_length
76+
refresh = False
77+
78+
logdir = f'{root}/Unet_Onset-recons={reconstruction}-XI={XI}-eps={eps}-alpha={alpha}-train_on=small_{small}_{train_on}-w_size={w_size}-n_heads={n_heads}-lr={learning_rate}-'+ datetime.now().strftime('%y%m%d-%H%M%S')
79+
80+
ex.observers.append(FileStorageObserver.create(logdir)) # saving source code
81+
82+
@ex.automain
83+
def train(spec, resume_iteration, train_on, batch_size, sequence_length,w_size, n_heads, small, train_batch_size,
84+
learning_rate, learning_rate_decay_steps, learning_rate_decay_rate, leave_one_out, position, alpha, KL_Div,
85+
clip_gradient_norm, validation_length, refresh, device, epoches, logdir, log, iteration, VAT_start, VAT, XI, eps,
86+
reconstruction, supersmall):
87+
print_config(ex.current_run)
88+
89+
90+
supervised_set, unsupervised_set, validation_dataset, full_validation = prepare_VAT_dataset(
91+
sequence_length=sequence_length,
92+
validation_length=sequence_length,
93+
refresh=refresh,
94+
device=device,
95+
small=small,
96+
supersmall=supersmall,
97+
dataset=train_on)
98+
if VAT:
99+
unsupervised_loader = DataLoader(unsupervised_set, batch_size, shuffle=True, drop_last=True)
100+
101+
# supervised_set, unsupervised_set = torch.utils.data.random_split(dataset, [100, 39],
102+
# generator=torch.Generator().manual_seed(42))
103+
104+
supervised_loader = DataLoader(supervised_set, train_batch_size, shuffle=True, drop_last=True)
105+
valloader = DataLoader(validation_dataset, 4, shuffle=False, drop_last=True)
106+
batch_visualize = next(iter(valloader)) # Getting one fixed batch for visualization
107+
108+
ds_ksize, ds_stride = (2,2),(2,2)
109+
if resume_iteration is None:
110+
model = UNet_Onset(ds_ksize,ds_stride, log=log, reconstruction=reconstruction,
111+
mode=mode, spec=spec, device=device, XI=XI, eps=eps)
112+
model.to(device)
113+
optimizer = torch.optim.Adam(model.parameters(), learning_rate)
114+
resume_iteration = 0
115+
else: # Loading checkpoints and continue training
116+
trained_dir='trained_MAPS' # Assume that the checkpoint is in this folder
117+
model_path = os.path.join(trained_dir, f'{resume_iteration}.pt')
118+
model = torch.load(model_path)
119+
optimizer = torch.optim.Adam(model.parameters(), learning_rate)
120+
optimizer.load_state_dict(torch.load(os.path.join(trained_dir, 'last-optimizer-state.pt')))
121+
122+
summary(model)
123+
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=step_size_up,cycle_momentum=False)
124+
scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate)
125+
126+
# loop = tqdm(range(resume_iteration + 1, iterations + 1))
127+
128+
for ep in range(1, epoches+1):
129+
if VAT==True:
130+
predictions, losses, optimizer = train_VAT_model(model, iteration, ep, supervised_loader, unsupervised_loader,
131+
optimizer, scheduler, clip_gradient_norm, alpha, VAT, VAT_start)
132+
else:
133+
predictions, losses, optimizer = train_VAT_model(model, iteration, ep, supervised_loader, None,
134+
optimizer, scheduler, clip_gradient_norm, alpha, VAT, VAT_start)
135+
loss = sum(losses.values())
136+
137+
# Logging results to tensorboard
138+
if ep == 1:
139+
writer = SummaryWriter(logdir) # create tensorboard logger
140+
if ep < VAT_start:
141+
tensorboard_log(batch_visualize, model, validation_dataset, supervised_loader,
142+
ep, logging_freq, saving_freq, n_heads, logdir, w_size, writer,
143+
False, VAT_start, reconstruction)
144+
else:
145+
tensorboard_log(batch_visualize, model, validation_dataset, supervised_loader,
146+
ep, logging_freq, saving_freq, n_heads, logdir, w_size, writer,
147+
True, VAT_start, reconstruction)
148+
149+
# Saving model
150+
if (ep)%saving_freq == 0:
151+
torch.save(model.state_dict(), os.path.join(logdir, f'model-{ep}.pt'))
152+
torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt'))
153+
for key, value in {**losses}.items():
154+
writer.add_scalar(key, value.item(), global_step=ep)
155+
156+
157+
# Evaluating model performance on the full MAPS songs in the test split
158+
print('Training finished, now evaluating on the MAPS test split (full songs)')
159+
with torch.no_grad():
160+
model = model.eval()
161+
metrics = evaluate_wo_velocity(tqdm(full_validation), model, reconstruction=False,
162+
save_path=os.path.join(logdir,'./MIDI_results'))
163+
164+
for key, values in metrics.items():
165+
if key.startswith('metric/'):
166+
_, category, name = key.split('/')
167+
print(f'{category:>32} {name:25}: {np.mean(values):.3f} ± {np.std(values):.3f}')
168+
169+
export_path = os.path.join(logdir, 'result_dict')
170+
pickle.dump(metrics, open(export_path, 'wb'))
171+
172+
173+
174+

0 commit comments

Comments
 (0)