From accf2f99636ff43af09bf0abca2ce21cbd139697 Mon Sep 17 00:00:00 2001 From: KinWaiCheuk Date: Fri, 9 Jul 2021 23:20:58 +0800 Subject: [PATCH] add train_baseline_Prestack.py --- train_baseline_Prestack.py | 164 +++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100755 train_baseline_Prestack.py diff --git a/train_baseline_Prestack.py b/train_baseline_Prestack.py new file mode 100755 index 0000000..7e01dbf --- /dev/null +++ b/train_baseline_Prestack.py @@ -0,0 +1,164 @@ +import os + +from datetime import datetime +import pickle + +import numpy as np +from sacred import Experiment +from sacred.commands import print_config, save_config +from sacred.observers import FileStorageObserver +from torch.optim.lr_scheduler import StepLR, CyclicLR +from torch.utils.data import DataLoader, ConcatDataset +from tqdm import tqdm + +from model import * +ex = Experiment('train_original') + +# parameters for the network +ds_ksize, ds_stride = (2,2),(2,2) +mode = 'imagewise' +sparsity = 1 +output_channel = 2 +logging_freq = 100 +saving_freq = 100 + +# file_path = 'Retrain_Prestack-lr=0.0001210325-141510' + +@ex.config +def config(): + root = 'runs' + # logdir = f'runs_AE/test' + '-' + datetime.now().strftime('%y%m%d-%H%M%S') + # Choosing GPU to use +# GPU = '0' +# os.environ['CUDA_VISIBLE_DEVICES']=str(GPU) + onset_stack=True + device = 'cuda:0' + log = True + w_size = 31 + spec = 'Mel' + resume_iteration = None + train_on = 'String' + n_heads=4 + position=True + iteration = 10 + VAT_start = 0 + alpha = 1 + VAT=True + XI= 1e-6 + eps=1.3 + small = True + KL_Div = False + reconstruction = False + + + batch_size = 1 + train_batch_size = 1 + sequence_length = 327680//8 + if torch.cuda.is_available() and torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory < 10e9: + batch_size //= 2 + sequence_length //= 2 + print(f'Reducing batch size to {batch_size} and sequence_length to {sequence_length} to save memory') + + epoches = 20000 + step_size_up = 100 + max_lr = 1e-4 + learning_rate = 1e-5 +# base_lr = learning_rate + + learning_rate_decay_steps = 1000 + learning_rate_decay_rate = 0.98 + + leave_one_out = None + + clip_gradient_norm = 3 + + validation_length = sequence_length + refresh = False + + logdir = f'{root}/baseline_Prestack-'+ datetime.now().strftime('%y%m%d-%H%M%S') + + ex.observers.append(FileStorageObserver.create(logdir)) # saving source code + +@ex.automain +def train(spec, resume_iteration, train_on, batch_size, sequence_length,w_size, n_heads, small, train_batch_size, + learning_rate, learning_rate_decay_steps, learning_rate_decay_rate, leave_one_out, position, alpha, KL_Div, + clip_gradient_norm, validation_length, refresh, device, epoches, logdir, log, iteration, VAT_start, VAT, XI, eps, + reconstruction,root): + print_config(ex.current_run) + + + supervised_set, unsupervised_set, validation_dataset, full_validation = prepare_VAT_dataset( + sequence_length=sequence_length, + validation_length=sequence_length, + refresh=refresh, + device=device, + small=small, + supersmall=True, + dataset=train_on) + + + + if len(validation_dataset)>4: + val_batch_size=4 + else: + val_batch_size = len(validation_dataset) + supervised_loader = DataLoader(supervised_set, train_batch_size, shuffle=True, drop_last=True) + valloader = DataLoader(validation_dataset, val_batch_size, shuffle=False, drop_last=True) + batch_visualize = next(iter(valloader)) # Getting one fixed batch for visualization + + ds_ksize, ds_stride = (2,2),(2,2) + model = Prestack_Model() + model.to(device) + optimizer = torch.optim.Adam(model.parameters(), learning_rate) + + # This model always crashes, need to keep saving weights and load it back when crashed. +# weight_path = os.path.join(root, file_path, 'model-400.pt') +# weight_dict = torch.load(weight_path, map_location=device) +# model.load_state_dict(weight_dict) + + + summary(model) +# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=step_size_up,cycle_momentum=False) + scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate) + + # loop = tqdm(range(resume_iteration + 1, iterations + 1)) + + print(f'supervised_loader') + + for ep in range(1, epoches+1): + predictions, losses, optimizer = train_model(model, ep, supervised_loader, + optimizer, scheduler, clip_gradient_norm) + loss = sum(losses.values()) + + # Logging results to tensorboard + if ep == 1: + writer = SummaryWriter(logdir) # create tensorboard logger + tensorboard_log_without_VAT(batch_visualize, model, validation_dataset, supervised_loader, + ep, logging_freq, saving_freq, n_heads, logdir, w_size, writer, + False, VAT_start, reconstruction) + + + # Saving model + if (ep)%saving_freq == 0: + torch.save(model.state_dict(), os.path.join(logdir, f'model-{ep}.pt')) + torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt')) + for key, value in {**losses}.items(): + writer.add_scalar(key, value.item(), global_step=ep) + + + # Evaluating model performance on the full MAPS songs in the test split + print('Training finished, now evaluating on the MAPS test split (full songs)') + with torch.no_grad(): + model = model.eval() + metrics = evaluate_wo_velocity(tqdm(full_validation), model, reconstruction=False, + save_path=os.path.join(logdir,'./MIDI_results')) + + for key, values in metrics.items(): + if key.startswith('metric/'): + _, category, name = key.split('/') + print(f'{category:>32} {name:25}: {np.mean(values):.3f} ± {np.std(values):.3f}') + + export_path = os.path.join(logdir, 'result_dict') + pickle.dump(metrics, open(export_path, 'wb')) + +