diff --git a/dnn/torch/osce/adv_train_vocoder.py b/dnn/torch/osce/adv_train_vocoder.py index 754a15297..3728bc2f6 100644 --- a/dnn/torch/osce/adv_train_vocoder.py +++ b/dnn/torch/osce/adv_train_vocoder.py @@ -152,7 +152,9 @@ lambda_feat = setup['training']['lambda_feat'] lambda_reg = setup['training']['lambda_reg'] adv_target = setup['training'].get('adv_target', 'target') - +cont_ratio = setup['training'].get('cont_ratio', 0) +pre_frames_min = setup['training'].get('pre_frames_min', 8) +pre_frames_max = setup['training'].get('pre_frames_max', 50) # load training dataset data_config = setup['data'] @@ -355,7 +357,14 @@ def optimizer_to(optim, device): disc_target = batch[adv_target].to(device) # calculate model output - output = model(batch['features'], batch['periods']) + if random.random() < cont_ratio: + pre_frames = random.randint(pre_frames_min // 2, pre_frames_max // 2) * 2 + pre_sig = target[:, :pre_frames * model.FEATURE_FRAME_SIZE] + else: + pre_frames = 0 + pre_sig = None + + output = model(batch['features'], batch['periods'], signal=pre_sig) # discriminator update scores_gen = disc(output.detach()) diff --git a/dnn/torch/osce/engine/vocoder_engine.py b/dnn/torch/osce/engine/vocoder_engine.py index cfbad6cf4..96aa24b2b 100644 --- a/dnn/torch/osce/engine/vocoder_engine.py +++ b/dnn/torch/osce/engine/vocoder_engine.py @@ -1,8 +1,9 @@ import torch from tqdm import tqdm import sys +import random -def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, pre_frames=4, log_interval=10): +def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, cont_ratio=0.5, pre_frames_min=8, pre_frames_max=64, log_interval=10): model.to(device) model.train() @@ -26,10 +27,13 @@ def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, target = batch['target'] # calculate model output - if pre_frames > 0: + if random.random() < cont_ratio: + pre_frames = random.randint(pre_frames_min // 2, pre_frames_max // 2) * 2 pre_sig = target[:, :pre_frames * model.FEATURE_FRAME_SIZE] else: + pre_frames = 0 pre_sig = None + output = model(batch['features'], batch['periods'], signal=pre_sig) # calculate loss diff --git a/dnn/torch/osce/make_default_setup.py b/dnn/torch/osce/make_default_setup.py index d7365fff9..3c5c05cb0 100644 --- a/dnn/torch/osce/make_default_setup.py +++ b/dnn/torch/osce/make_default_setup.py @@ -66,7 +66,7 @@ parser = argparse.ArgumentParser() parser.add_argument('name', type=str, help='name of default setup file') -parser.add_argument('--model', choices=['lace', 'nolace', 'lavoce'], help='model name', default='lace') +parser.add_argument('--model', choices=['lace', 'nolace', 'lavoce', 'lavocecont'], help='model name', default='lace') parser.add_argument('--adversarial', action='store_true', help='setup for adversarial training') parser.add_argument('--path2dataset', type=str, help='dataset path', default=None) diff --git a/dnn/torch/osce/models/lavoce_cont.py b/dnn/torch/osce/models/lavoce_cont.py index 01eefa6ec..f8b776b90 100644 --- a/dnn/torch/osce/models/lavoce_cont.py +++ b/dnn/torch/osce/models/lavoce_cont.py @@ -205,9 +205,9 @@ def forward(self, features, periods, signal=None, debug=False): batch_size = features.size(0) num_frames = features.size(1) - if len(signal.shape) < 3: signal = signal.unsqueeze(1) if signal is not None: + if len(signal.shape) < 3: signal = signal.unsqueeze(1) pre_frames = signal.size(-1) // self.FEATURE_FRAME_SIZE phase_features = calculate_phase_features(signal, periods[:, :pre_frames].squeeze(-1)) full_phase_features = torch.cat( diff --git a/dnn/torch/osce/train_vocoder.py b/dnn/torch/osce/train_vocoder.py index f4d8157d7..3572e962a 100644 --- a/dnn/torch/osce/train_vocoder.py +++ b/dnn/torch/osce/train_vocoder.py @@ -143,6 +143,9 @@ epochs = setup['training']['epochs'] lr = setup['training']['lr'] lr_decay_factor = setup['training']['lr_decay_factor'] +cont_ratio = setup['training'].get('cont_ratio', 0) +pre_frames_min = setup['training'].get('pre_frames_min', 8) +pre_frames_max = setup['training'].get('pre_frames_max', 50) # load training dataset data_config = setup['data'] @@ -253,7 +256,15 @@ def criterion(x, y): for ep in range(1, epochs + 1): print(f"training epoch {ep}...") - new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler) + new_loss = train_one_epoch(model, + criterion, + optimizer, + dataloader, + device, + scheduler, + cont_ratio=cont_ratio, + pre_frames_min=pre_frames_min, + pre_frames_max=pre_frames_max) # save checkpoint diff --git a/dnn/torch/osce/utils/templates.py b/dnn/torch/osce/utils/templates.py index 42137b26d..306bc1f7f 100644 --- a/dnn/torch/osce/utils/templates.py +++ b/dnn/torch/osce/utils/templates.py @@ -326,10 +326,124 @@ } +lavocecont_setup = { + 'data': { + 'frames_per_sample': 100, + 'target': 'signal' + }, + 'dataset': '/local/datasets/lpcnet_large/training', + 'model': { + 'args': [], + 'kwargs': { + 'comb_gain_limit_db': 10, + 'cond_dim': 256, + 'conv_gain_limits_db': [-12, 12], + 'global_gain_limits_db': [-6, 6], + 'kernel_size': 15, + 'num_features': 19, + 'pitch_embedding_dim': 64, + 'pitch_max': 300, + 'preemph': 0.85, + 'pulses': True + }, + 'name': 'lavocecont' + }, + 'training': { + 'batch_size': 256, + 'epochs': 50, + 'loss': { + 'w_l1': 0, + 'w_l2': 0, + 'w_lm': 0, + 'w_logmel': 0, + 'w_sc': 0, + 'w_slm': 2, + 'w_sxcorr': 1, + 'w_wsc': 0, + 'w_xcorr': 0 + }, + 'lr': 0.0005, + 'lr_decay_factor': 2.5e-05, + 'cont_ratio': 0.8, + 'pre_frames_min': 8, + 'pre_frames_max': 64 + }, + 'validation_dataset': '/local/datasets/lpcnet_large/validation' +} + +lavocecont_setup_adv = { + 'data': { + 'frames_per_sample': 100, + 'target': 'signal' + }, + 'dataset': '/local/datasets/lpcnet_large/training', + 'discriminator': { + 'args': [], + 'kwargs': { + 'architecture': 'free', + 'design': 'f_down', + 'fft_sizes_16k': [ + 64, + 128, + 256, + 512, + 1024, + 2048, + ], + 'freq_roi': [0, 7400], + 'fs': 16000, + 'max_channels': 256, + 'noise_gain': 0.0, + }, + 'name': 'fdmresdisc', + }, + 'model': { + 'args': [], + 'kwargs': { + 'comb_gain_limit_db': 10, + 'cond_dim': 256, + 'conv_gain_limits_db': [-12, 12], + 'global_gain_limits_db': [-6, 6], + 'kernel_size': 15, + 'num_features': 19, + 'pitch_embedding_dim': 64, + 'pitch_max': 300, + 'preemph': 0.85, + 'pulses': True + }, + 'name': 'lavocecont' + }, + 'training': { + 'batch_size': 64, + 'epochs': 50, + 'gen_lr_reduction': 1, + 'lambda_feat': 1.0, + 'lambda_reg': 0.6, + 'loss': { + 'w_l1': 0, + 'w_l2': 0, + 'w_lm': 0, + 'w_logmel': 0, + 'w_sc': 0, + 'w_slm': 2, + 'w_sxcorr': 1, + 'w_wsc': 0, + 'w_xcorr': 0 + }, + 'lr': 0.0001, + 'lr_decay_factor': 2.5e-09, + 'cont_ratio': 0.8, + 'pre_frames_min': 8, + 'pre_frames_max': 50 + }, +} + setup_dict = { 'lace': lace_setup, 'nolace': nolace_setup, 'nolace_adv': nolace_setup_adv, 'lavoce': lavoce_setup, - 'lavoce_adv': lavoce_setup_adv + 'lavoce_adv': lavoce_setup_adv, + 'lavocecont': lavocecont_setup, + 'lavocecont_adv': lavocecont_setup_adv }