Skip to content

Commit

Permalink
full integration of LaVoceCont
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Buethe committed Sep 26, 2023
1 parent d090c7c commit ab3ba9a
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 8 deletions.
13 changes: 11 additions & 2 deletions dnn/torch/osce/adv_train_vocoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@
lambda_feat = setup['training']['lambda_feat']
lambda_reg = setup['training']['lambda_reg']
adv_target = setup['training'].get('adv_target', 'target')

cont_ratio = setup['training'].get('cont_ratio', 0)
pre_frames_min = setup['training'].get('pre_frames_min', 8)
pre_frames_max = setup['training'].get('pre_frames_max', 50)

# load training dataset
data_config = setup['data']
Expand Down Expand Up @@ -355,7 +357,14 @@ def optimizer_to(optim, device):
disc_target = batch[adv_target].to(device)

# calculate model output
output = model(batch['features'], batch['periods'])
if random.random() < cont_ratio:
pre_frames = random.randint(pre_frames_min // 2, pre_frames_max // 2) * 2
pre_sig = target[:, :pre_frames * model.FEATURE_FRAME_SIZE]
else:
pre_frames = 0
pre_sig = None

output = model(batch['features'], batch['periods'], signal=pre_sig)

# discriminator update
scores_gen = disc(output.detach())
Expand Down
8 changes: 6 additions & 2 deletions dnn/torch/osce/engine/vocoder_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
from tqdm import tqdm
import sys
import random

def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, pre_frames=4, log_interval=10):
def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler, cont_ratio=0.5, pre_frames_min=8, pre_frames_max=64, log_interval=10):

model.to(device)
model.train()
Expand All @@ -26,10 +27,13 @@ def train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler,
target = batch['target']

# calculate model output
if pre_frames > 0:
if random.random() < cont_ratio:
pre_frames = random.randint(pre_frames_min // 2, pre_frames_max // 2) * 2
pre_sig = target[:, :pre_frames * model.FEATURE_FRAME_SIZE]
else:
pre_frames = 0
pre_sig = None

output = model(batch['features'], batch['periods'], signal=pre_sig)

# calculate loss
Expand Down
2 changes: 1 addition & 1 deletion dnn/torch/osce/make_default_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
parser = argparse.ArgumentParser()

parser.add_argument('name', type=str, help='name of default setup file')
parser.add_argument('--model', choices=['lace', 'nolace', 'lavoce'], help='model name', default='lace')
parser.add_argument('--model', choices=['lace', 'nolace', 'lavoce', 'lavocecont'], help='model name', default='lace')
parser.add_argument('--adversarial', action='store_true', help='setup for adversarial training')
parser.add_argument('--path2dataset', type=str, help='dataset path', default=None)

Expand Down
2 changes: 1 addition & 1 deletion dnn/torch/osce/models/lavoce_cont.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ def forward(self, features, periods, signal=None, debug=False):

batch_size = features.size(0)
num_frames = features.size(1)
if len(signal.shape) < 3: signal = signal.unsqueeze(1)

if signal is not None:
if len(signal.shape) < 3: signal = signal.unsqueeze(1)
pre_frames = signal.size(-1) // self.FEATURE_FRAME_SIZE
phase_features = calculate_phase_features(signal, periods[:, :pre_frames].squeeze(-1))
full_phase_features = torch.cat(
Expand Down
13 changes: 12 additions & 1 deletion dnn/torch/osce/train_vocoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@
epochs = setup['training']['epochs']
lr = setup['training']['lr']
lr_decay_factor = setup['training']['lr_decay_factor']
cont_ratio = setup['training'].get('cont_ratio', 0)
pre_frames_min = setup['training'].get('pre_frames_min', 8)
pre_frames_max = setup['training'].get('pre_frames_max', 50)

# load training dataset
data_config = setup['data']
Expand Down Expand Up @@ -253,7 +256,15 @@ def criterion(x, y):

for ep in range(1, epochs + 1):
print(f"training epoch {ep}...")
new_loss = train_one_epoch(model, criterion, optimizer, dataloader, device, scheduler)
new_loss = train_one_epoch(model,
criterion,
optimizer,
dataloader,
device,
scheduler,
cont_ratio=cont_ratio,
pre_frames_min=pre_frames_min,
pre_frames_max=pre_frames_max)


# save checkpoint
Expand Down
116 changes: 115 additions & 1 deletion dnn/torch/osce/utils/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,124 @@
}


lavocecont_setup = {
'data': {
'frames_per_sample': 100,
'target': 'signal'
},
'dataset': '/local/datasets/lpcnet_large/training',
'model': {
'args': [],
'kwargs': {
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'kernel_size': 15,
'num_features': 19,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'pulses': True
},
'name': 'lavocecont'
},
'training': {
'batch_size': 256,
'epochs': 50,
'loss': {
'w_l1': 0,
'w_l2': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_slm': 2,
'w_sxcorr': 1,
'w_wsc': 0,
'w_xcorr': 0
},
'lr': 0.0005,
'lr_decay_factor': 2.5e-05,
'cont_ratio': 0.8,
'pre_frames_min': 8,
'pre_frames_max': 64
},
'validation_dataset': '/local/datasets/lpcnet_large/validation'
}

lavocecont_setup_adv = {
'data': {
'frames_per_sample': 100,
'target': 'signal'
},
'dataset': '/local/datasets/lpcnet_large/training',
'discriminator': {
'args': [],
'kwargs': {
'architecture': 'free',
'design': 'f_down',
'fft_sizes_16k': [
64,
128,
256,
512,
1024,
2048,
],
'freq_roi': [0, 7400],
'fs': 16000,
'max_channels': 256,
'noise_gain': 0.0,
},
'name': 'fdmresdisc',
},
'model': {
'args': [],
'kwargs': {
'comb_gain_limit_db': 10,
'cond_dim': 256,
'conv_gain_limits_db': [-12, 12],
'global_gain_limits_db': [-6, 6],
'kernel_size': 15,
'num_features': 19,
'pitch_embedding_dim': 64,
'pitch_max': 300,
'preemph': 0.85,
'pulses': True
},
'name': 'lavocecont'
},
'training': {
'batch_size': 64,
'epochs': 50,
'gen_lr_reduction': 1,
'lambda_feat': 1.0,
'lambda_reg': 0.6,
'loss': {
'w_l1': 0,
'w_l2': 0,
'w_lm': 0,
'w_logmel': 0,
'w_sc': 0,
'w_slm': 2,
'w_sxcorr': 1,
'w_wsc': 0,
'w_xcorr': 0
},
'lr': 0.0001,
'lr_decay_factor': 2.5e-09,
'cont_ratio': 0.8,
'pre_frames_min': 8,
'pre_frames_max': 50
},
}

setup_dict = {
'lace': lace_setup,
'nolace': nolace_setup,
'nolace_adv': nolace_setup_adv,
'lavoce': lavoce_setup,
'lavoce_adv': lavoce_setup_adv
'lavoce_adv': lavoce_setup_adv,
'lavocecont': lavocecont_setup,
'lavocecont_adv': lavocecont_setup_adv
}

0 comments on commit ab3ba9a

Please sign in to comment.