Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
add script for fast choise the best validation checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Jun 23, 2020
1 parent 8d77e1b commit 4cffac1
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 7 deletions.
6 changes: 3 additions & 3 deletions test_all_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from utils.tensorboard import TensorboardWriter

from utils.dataset import test_dataloader
from utils.dataset import test_dataloader, eval_dataloader

from utils.generic_utils import validation, PowerLaw_Compressed_Loss, SiSNR_With_Pit

Expand Down Expand Up @@ -105,8 +105,8 @@ def test(args, log_dir, checkpoint_path, testloader, tensorboard, c, model_name,
c.dataset['test_dir'] = args.dataset_dir
# set batchsize = 1
c.train_config['batch_size'] = 1
test_dataloader = test_dataloader(c, ap)

test_dataloader = eval_dataloader(c, ap)
print(c.dataset['format'])
best_sdr = 0
best_loss = 999999999
best_sdr_checkpoint = ''
Expand Down
121 changes: 121 additions & 0 deletions test_fast_all_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
import math
import torch
import torch.nn as nn
import traceback
from glob import glob

import time
import numpy as np

import tqdm

import argparse

from utils.generic_utils import load_config, load_config_from_str
from utils.generic_utils import set_init_dict

from utils.tensorboard import TensorboardWriter

from utils.dataset import test_dataloader

from utils.generic_utils import validation, PowerLaw_Compressed_Loss, SiSNR_With_Pit, test_fast_with_si_srn

from models.voicefilter.model import VoiceFilter
from models.voicesplit.model import VoiceSplit
from utils.audio_processor import WrapperAudioProcessor as AudioProcessor

from shutil import copyfile
import yaml

def test(args, log_dir, checkpoint_path, testloader, tensorboard, c, model_name, ap, cuda=True):
if(model_name == 'voicefilter'):
model = VoiceFilter(c)
elif(model_name == 'voicesplit'):
model = VoiceSplit(c)
else:
raise Exception(" The model '"+model_name+"' is not suported")

if c.train_config['optimizer'] == 'adam':
optimizer = torch.optim.Adam(model.parameters(),
lr=c.train_config['learning_rate'])
else:
raise Exception("The %s not is a optimizer supported" % c.train['optimizer'])

step = 0
if checkpoint_path is not None:
try:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
if cuda:
model = model.cuda()
except:
raise Exception("Fail in load checkpoint, you need use this configs: %s" %checkpoint['config_str'])

try:
optimizer.load_state_dict(checkpoint['optimizer'])
except:
print(" > Optimizer state is not loaded from checkpoint path, you see this mybe you change the optimizer")

step = checkpoint['step']
else:
raise Exception("You need specific a checkpoint for test")
# convert model from cuda
if cuda:
model = model.cuda()

# definitions for power-law compressed loss
power = c.loss['power']
complex_ratio = c.loss['complex_loss_ratio']

if c.loss['loss_name'] == 'power_law_compression':
criterion = PowerLaw_Compressed_Loss(power, complex_ratio)
elif c.loss['loss_name'] == 'si_snr':
criterion = SiSNR_With_Pit()
else:
raise Exception(" The loss '"+c.loss['loss_name']+"' is not suported")
return test_fast_with_si_srn(criterion, ap, model, testloader, tensorboard, step, cuda=cuda, loss_name=c.loss['loss_name'], test=True)


if __name__ == '__main__':
parser = argparse.ArgumentParser()

parser.add_argument('-d', '--dataset_dir', type=str, default='./',
help="Root directory of run.")
parser.add_argument('-c', '--config_path', type=str, required=False, default=None,
help="json file with configurations")
parser.add_argument('--checkpoints_path', type=str, required=True,
help="path of checkpoint pt file, for continue training")
args = parser.parse_args()

all_checkpoints = sorted(glob(os.path.join(args.checkpoints_path, '*.pt')))
#print(all_checkpoints, os.listdir(args.checkpoints_path))
if args.config_path:
c = load_config(args.config_path)
else: #load config in checkpoint
checkpoint = torch.load(all_checkpoints[0], map_location='cpu')
c = load_config_from_str(checkpoint['config_str'])

ap = AudioProcessor(c.audio)

log_path = os.path.join(c.train_config['logs_path'], c.model_name)
audio_config = c.audio[c.audio['backend']]
tensorboard = TensorboardWriter(log_path, audio_config)
# set test dataset dir
c.dataset['test_dir'] = args.dataset_dir
# set batchsize = 32
c.test_config['batch_size'] = 5
test_dataloader = test_dataloader(c, ap)
best_loss = 999999999
best_loss_checkpoint = ''
sdrs_checkpoint = []
for i in tqdm.tqdm(range(len(all_checkpoints))):
checkpoint = all_checkpoints[i]
mean_loss= test(args, log_path, checkpoint, test_dataloader, tensorboard, c, c.model_name, ap, cuda=True)
sdrs_checkpoint.append([mean_loss, checkpoint])
if mean_loss < best_loss:
best_loss = mean_loss
best_loss_checkpoint = checkpoint
print("Best Loss checkpoint is: ", best_loss_checkpoint, "Best Loss:", best_loss)
copyfile(best_sdr_checkpoint, os.path.join(args.checkpoints_path,'fast_best_checkpoint.pt'))
np.save(os.path.join(args.checkpoints_path,"Loss_validation_with_VCTK_best_SI-SNR_is_"+str(best_sdr)+".np"), np.array(sdrs_checkpoint))
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from utils.tensorboard import TensorboardWriter

from utils.dataset import train_dataloader, test_dataloader
from utils.dataset import train_dataloader, eval_dataloader

from utils.generic_utils import validation, PowerLaw_Compressed_Loss, SiSNR_With_Pit

Expand Down Expand Up @@ -159,5 +159,5 @@ def train(args, log_dir, checkpoint_path, trainloader, testloader, tensorboard,
raise Exception("Please verify directories of dataset in "+args.config_path)

train_dataloader = train_dataloader(c, ap)
test_dataloader = test_dataloader(c, ap)
test_dataloader = eval_dataloader(c, ap)
train(args, log_path, args.checkpoint_path, train_dataloader, test_dataloader, tensorboard, c, c.model_name, ap, cuda=True)
48 changes: 47 additions & 1 deletion utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __getitem__(self, idx):
target_spec, _ = self.ap.get_spec_from_audio(target_wav, return_phase=True)
target_spec = torch.from_numpy(target_spec)
mixed_spec = torch.from_numpy(mixed_spec)
mixed_phase = torch.from_numpy(mixed_phase)
target_wav = torch.from_numpy(target_wav)
mixed_wav = torch.from_numpy(mixed_wav)
seq_len = torch.from_numpy(np.array([mixed_wav.shape[0]]))
return emb, target_spec, mixed_spec, target_wav, mixed_wav, mixed_phase, seq_len

Expand All @@ -69,6 +72,15 @@ def test_dataloader(c, ap):
collate_fn=test_collate_fn, batch_size=c.test_config['batch_size'],
shuffle=False, num_workers=c.test_config['num_workers'])

def eval_dataloader(c, ap):
return DataLoader(dataset=Dataset(c, ap, train=False),
collate_fn=eval_collate_fn, batch_size=c.test_config['batch_size'],
shuffle=False, num_workers=c.test_config['num_workers'])


def eval_collate_fn(batch):
return batch

def train_collate_fn(item):
embs_list = []
target_list = []
Expand Down Expand Up @@ -102,4 +114,38 @@ def train_collate_fn(item):
return embs_list, target_list, mixed_list, seq_len_list, target_wav_list, mixed_phase_list

def test_collate_fn(batch):
return batch
embs_list = []
target_list = []
mixed_list = []
seq_len_list = []
mixed_phase_list = []
target_wav_list = []
mixed_wav_list = []

for emb, target, mixed, target_wav, mixed_wav, mixed_phase, seq_len in batch:
#print(emb)
if emb.tolist() == [0]:
#print("ignorado ", emb)
continue
embs_list.append(emb)
target_list.append(target)
mixed_list.append(mixed)
seq_len_list.append(seq_len)
mixed_phase_list.append(mixed_phase)
target_wav_list.append(target_wav)
mixed_wav_list.append(mixed_wav)

# concate tensors in dim 0
target_list = stack(target_list, dim=0)
mixed_list = stack(mixed_list, dim=0)
seq_len_list = stack(seq_len_list, dim=0)
target_wav_list = stack(target_wav_list, dim=0)
mixed_phase_list = stack(mixed_phase_list, dim=0) # np.array(mixed_phase_list)
mixed_wav_list = stack(mixed_wav_list, dim=0)
try:
embs_list = stack(embs_list, dim=0)
except:
#print('erro, stack')
embs_list = embs_list
return embs_list, target_list, mixed_list, target_wav_list, mixed_wav_list, mixed_phase_list, seq_len_list

33 changes: 32 additions & 1 deletion utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,11 @@ def validation(criterion, ap, model, testloader, tensorboard, step, cuda=True, l
mixed_spec = mixed_spec[0].cpu().detach().numpy()
clean_spec = clean_spec[0].cpu().detach().numpy()
est_mag = est_mag[0].cpu().detach().numpy()
mixed_phase = mixed_phase[0].cpu().detach().numpy()

est_wav = ap.inv_spectrogram(est_mag, phase=mixed_phase)
est_mask = est_mask[0].cpu().detach().numpy()

if loss_name == 'si_snr':
test_loss = criterion(torch.from_numpy(np.array([[clean_wav]])), torch.from_numpy(np.array([[est_wav]])), seq_len).item()
sdr = bss_eval_sources(clean_wav, est_wav, False)[0][0]
Expand All @@ -244,12 +246,41 @@ def validation(criterion, ap, model, testloader, tensorboard, step, cuda=True, l
print("Mean Test Loss:", mean_test_loss)
print("Mean Test SDR:", mean_sdr)
return mean_test_loss, mean_sdr

def test_fast_with_si_srn(criterion, ap, model, testloader, tensorboard, step, cuda=True, loss_name='si_snr', test=False):
losses = []
model.eval()
# set fast and best criterion
criterion = SiSNR_With_Pit()
count = 0
with torch.no_grad():
for emb, clean_spec, mixed_spec, clean_wav, mixed_wav, mixed_phase, seq_len in testloader:
if cuda:
emb = emb.cuda()
clean_spec = clean_spec.cuda()
mixed_spec = mixed_spec.cuda()
mixed_phase = mixed_phase.cuda()
seq_len = seq_len.cuda()
est_mask = model(mixed_spec, emb)
est_mag = est_mask * mixed_spec
# convert spec to wav using phase
output = ap.torch_inv_spectrogram(est_mag, mixed_phase)
target = ap.torch_inv_spectrogram(clean_spec, mixed_phase)
shape = list(target.shape)
target = torch.reshape(target, [shape[0],1]+shape[1:]) # append channel dim
output = torch.reshape(output, [shape[0],1]+shape[1:]) # append channel dim
test_loss = criterion(output, target, seq_len).item()
losses.append(test_loss)

mean_test_loss = np.array(losses).mean()
print("Mean Si-SRN with Pit Loss:", mean_test_loss)
return mean_test_loss

class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self


def load_config(config_path):
config = AttrDict()
with open(config_path, "r") as f:
Expand Down

0 comments on commit 4cffac1

Please sign in to comment.