-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperiment_timbre_transfer_flatten.py
127 lines (103 loc) · 5.58 KB
/
experiment_timbre_transfer_flatten.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
from abc import ABC
from pathlib import Path
import pytorch_lightning as pl
import torch
import torchvision.utils as vutils
from torch import optim
from models import TimbreTransfer, TimbreTransferFlatten
from models.types_ import *
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
import wandb
from utils import re_nest_configs
from prodict import Prodict
class TimbreTransferFlattenLM(pl.LightningModule, ABC):
def __init__(self,
model: TimbreTransferFlatten, # contains Music vae
config: dict,
# config_dump: dict, # This is for logging
) -> None:
super(TimbreTransferFlattenLM, self).__init__()
self.save_hyperparameters()
wandb.config.update(re_nest_configs({**self.hparams.config, **wandb.config}))
wandb.watch(model)
print(model)
self.model = model
self.config = Prodict.from_dict(wandb.config)
# print(self.config)
self.curr_device = None
self.hold_graph = False
try:
self.hold_graph = self.config['exp_params']['retain_first_backpass']
except:
pass
def training_step(self, batch_item, batch_idx, optimizer_idx=0):
# print(f'\n=== Training step. batchidx: {batch_idx}, optimizeridx: {optimizer_idx} ===')
re_a, di_a, re_b, di_b = self.create_input_batch(batch_item)
self.curr_device = re_b.device
# print(type(batch.device))
recons, mu_t, log_var_t, _, mu_m, log_var_m, _ = self.model.forward(re_a, di_a, re_b, di_b)
music_train_loss = self.model.loss_function(recons=recons,
re_b=re_b,
kld_weight_timbre=self.config['exp_params']['kld_weight_timbre'],
kld_weight_music=self.config['exp_params']['kld_weight_music'],
mu_t=mu_t,
log_var_t=log_var_t,
mu_m=mu_m,
log_var_m=log_var_m
)
log_dict = {key: val.item() for key, val in music_train_loss.items()}
log_dict['epoch'] = self.trainer.current_epoch
wandb.log(log_dict)
self.log_dict(log_dict, sync_dist=True)
return music_train_loss['loss']
def validation_step(self, batch_item, batch_idx, ):
# print(f'\n=== Validation step. batchidx: {batch_idx} ===')
re_a, di_a, re_b, di_b = self.create_input_batch(batch_item)
self.curr_device = re_b.device
recons, mu_t, log_var_t, _, mu_m, log_var_m, _ = self.model.forward(re_a, di_a, re_b, di_b)
music_val_loss = self.model.loss_function(recons=recons,
re_b=re_b,
kld_weight_timbre=self.config['exp_params']['kld_weight_timbre'],
kld_weight_music=self.config['exp_params']['kld_weight_music'],
mu_t=mu_t,
log_var_t=log_var_t,
mu_m=mu_m,
log_var_m=log_var_m
)
# print(music_val_loss)
log_dict = {f"val_{key}": val.item() for key, val in music_val_loss.items()}
log_dict['epoch'] = self.trainer.current_epoch
wandb.log(log_dict)
self.log_dict(log_dict, sync_dist=True)
def configure_optimizers(self):
music_optimizer = optim.Adam(self.model.parameters(),
lr=self.config['exp_params']['LR'],
weight_decay=self.config['exp_params']['weight_decay'])
return music_optimizer
def on_validation_end(self) -> None:
# Get sample reconstruction image
batch_item = next(iter(self.trainer.datamodule.val_dataloader())) # this does not return current device correctly
re_a, di_a, re_b, di_b = self.create_input_batch(batch_item)
recons, mu_t, log_var_t, _, mu_m, log_var_m, _ = self.model.forward(re_a.to(self.curr_device), di_a.to(self.curr_device), re_b.to(self.curr_device), di_b.to(self.curr_device))
di_b = di_b.detach().to("cpu").numpy()
re_b = re_b.detach().to("cpu").numpy()
recons = torch.squeeze(recons, 0).to("cpu").numpy()
spectrograms = [di_b[:, 0, :, :], re_b[:, 0, :, :], recons[:, 0, :, :]]
self.trainer.datamodule.dataset.preprocessing_pipeline.visualizer.visualize_multiple(
spectrograms,
file_dir=Path(self.trainer.logger.log_dir) / 'recons',
col_titles=["DI", "Expected reamped", "Reconstructed reamped"],
filename=f"reconstruction-e_{self.trainer.current_epoch}.png",
title=f"reconstruction for epoch e_{self.trainer.current_epoch}: DI v/s Expected v/s Reconstructed reamped clip"
)
def create_input_batch(self, batch):
batch, batch_di, _, _, _, _ = batch
batch = torch.squeeze(batch, 0)
batch_di = torch.squeeze(batch_di, 0)
b_size = self.config.data_params.batch_size // 2
re_a = batch[0:b_size, :]
di_a = batch_di[0:b_size, :]
re_b = batch[b_size:, :]
di_b = batch_di[b_size:, :]
return re_a, di_a, re_b, di_b