-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperiment_timbre_transfer.py
135 lines (109 loc) · 5.4 KB
/
experiment_timbre_transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import random
from abc import ABC
from pathlib import Path
import pytorch_lightning as pl
import torch
import torchvision.utils as vutils
from pytorch_lightning.loggers import WandbLogger
from torch import optim
from datasets import TimbreDataModule
from models import TimbreTransfer
from models.types_ import *
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
from utils import re_nest_configs, merge
from prodict import Prodict
from pytorch_lightning.loggers import TensorBoardLogger
class TimbreTransferLM(pl.LightningModule, ABC):
def __init__(self,
model: TimbreTransfer,
config: Prodict,
) -> None:
super(TimbreTransferLM, self).__init__()
self.save_hyperparameters()
print(model)
self.model = model
self.config = config
# print(self.config)
self.curr_device = None
self.hold_graph = False
try:
self.hold_graph = self.config['exp_params']['retain_first_backpass']
except:
pass
def get_data(self):
return self.data
def training_step(self, batch_item, batch_idx, optimizer_idx=0):
# print(f'\n=== Training step. batchidx: {batch_idx}, optimizeridx: {optimizer_idx} ===')
re_a, di_a, re_b, di_b = self.create_input_batch(batch_item)
self.curr_device = re_b.device
# print(type(batch.device))
recons, mu, log_var, _ = self.model.forward(re_a, di_a, re_b, di_b)
music_train_loss = self.model.loss_function(recons=recons,
re_b=re_b,
kld_weight=self.config['exp_params']['kld_weight'],
mu=mu,
log_var=log_var
)
self.log('epoch', self.trainer.current_epoch)
for key, val in music_train_loss.items():
self.log(key, val)
return music_train_loss['loss']
def validation_step(self, batch_item, batch_idx, ):
# print(f'\n=== Validation step. batchidx: {batch_idx} ===')
re_a, di_a, re_b, di_b = self.create_input_batch(batch_item)
self.curr_device = re_b.device
recons, mu, log_var, _ = self.model.forward(re_a, di_a, re_b, di_b)
music_val_loss = self.model.loss_function(recons=recons,
re_b=re_b,
kld_weight=self.config['exp_params']['kld_weight'],
mu=mu,
log_var=log_var
)
for key, val in music_val_loss.items():
self.log(f"val_{key}", val)
def configure_optimizers(self):
music_optimizer = optim.Adam(self.model.parameters(),
lr=self.config['exp_params']['LR'],
weight_decay=self.config['exp_params']['weight_decay'])
return music_optimizer
def on_validation_end(self) -> None:
# Get sample reconstruction image
batch_item = next(iter(self.trainer.datamodule.val_dataloader())) # this does not return current device correctly
key = batch_item[-2].detach().to("cpu").numpy()[0]
re_a, di_a, re_b, di_b = self.create_input_batch(batch_item)
recons, mu, log_var, _ = self.model.forward(re_a.to(self.curr_device), di_a.to(self.curr_device), re_b.to(self.curr_device), di_b.to(self.curr_device))
di_b = di_b.detach().to("cpu").numpy()
re_b = re_b.detach().to("cpu").numpy()
recons = torch.squeeze(recons, 0).to("cpu").numpy()
spectrograms = [di_b[:, 0, :, :], re_b[:, 0, :, :], recons[:, 0, :, :]]
if type(self.trainer.logger) == WandbLogger:
recons_dir = Path(self.trainer.logger.save_dir) / 'recons'
elif type(self.trainer.logger) == TensorBoardLogger:
recons_dir = Path(self.trainer.logger.log_dir) / 'recons'
self.trainer.datamodule.dataset.preprocessing_pipeline.visualizer.visualize_multiple(
spectrograms,
file_dir=recons_dir,
col_titles=["DI", "Expected reamped", "Reconstructed reamped"],
filename=f"reconstruction-e_{self.trainer.current_epoch}.png",
title=f"reconstruction for epoch e_{self.trainer.current_epoch}: DI v/s Expected v/s Reconstructed "
f"reamped clip (key={key})"
)
if type(self.trainer.logger) == WandbLogger:
try:
self.trainer.logger.log_image(key=f"epoch-{self.trainer.current_epoch}", images=[
str(recons_dir / f"reconstruction-e_{self.trainer.current_epoch}.png")])
except Exception as e:
print(f"Ignoring Exception while wandb.log_image: {e}")
def create_input_batch(self, batch):
batch, batch_di, _, _, _, _ = batch
batch = torch.squeeze(batch, 0)
batch_di = torch.squeeze(batch_di, 0)
b_size = self.config.data_params.batch_size // 2
re_a = batch[0:b_size, :]
di_a = batch_di[0:b_size, :]
re_b = batch[b_size:, :]
di_b = batch_di[b_size:, :]
return re_a, di_a, re_b, di_b
def on_fit_end(self) -> None:
print("On fit end")