From a410d66c07aa88ba6ad0332e49cf84eda2cb12eb Mon Sep 17 00:00:00 2001 From: Shengqiang Li <49022799+Shengqiang-Li@users.noreply.github.com> Date: Mon, 1 Apr 2024 13:34:13 +0800 Subject: [PATCH] [vits] Support WavLM Discriminator (#215) Co-authored-by: ShengqiangLi --- examples/baker/configs/v1.json | 8 +- examples/baker/configs/vits2_v1.json | 8 +- requirements.txt | 1 + wetts/vits/losses.py | 95 ++++++++++++++++++++++ wetts/vits/model/discriminators.py | 49 +++++++++++ wetts/vits/train.py | 117 ++++++++++++++++++++++++--- wetts/vits/utils/task.py | 11 +++ 7 files changed, 274 insertions(+), 15 deletions(-) diff --git a/examples/baker/configs/v1.json b/examples/baker/configs/v1.json index dbdee1a..9c315ea 100644 --- a/examples/baker/configs/v1.json +++ b/examples/baker/configs/v1.json @@ -43,6 +43,12 @@ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], "n_layers_q": 3, "use_spectral_norm": false, - "gin_channels": 256 + "gin_channels": 256, + "use_wd": true, + "slm_model": "exp/slm/wavlm-base-plus", + "slm_sr": 16000, + "slm_hidden": 768, + "slm_nlayers": 13, + "slm_initial_channel": 64 } } diff --git a/examples/baker/configs/vits2_v1.json b/examples/baker/configs/vits2_v1.json index 1053d93..4ca9c7d 100644 --- a/examples/baker/configs/vits2_v1.json +++ b/examples/baker/configs/vits2_v1.json @@ -50,6 +50,12 @@ "n_layers_q": 3, "use_sdp": true, "use_spectral_norm": false, - "gin_channels": 256 + "gin_channels": 256, + "use_wd": true, + "slm_model": "exp/slm/wavlm-base-plus", + "slm_sr": 16000, + "slm_hidden": 768, + "slm_nlayers": 13, + "slm_initial_channel": 64 } } diff --git a/requirements.txt b/requirements.txt index 59278e6..c41c9c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ torch torchvision tqdm transformers +huggingface_hub diff --git a/wetts/vits/losses.py b/wetts/vits/losses.py index 470abec..66da771 100644 --- a/wetts/vits/losses.py +++ b/wetts/vits/losses.py @@ -1,4 +1,6 @@ import torch +import torchaudio +from transformers import AutoModel def feature_loss(fmap_r, fmap_g): @@ -56,3 +58,96 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): kl = torch.sum(kl * z_mask) l = kl / torch.sum(z_mask) return l + + +class WavLMLoss(torch.nn.Module): + def __init__(self, model, wd, model_sr, slm_sr=16000): + super(WavLMLoss, self).__init__() + self.wavlm = AutoModel.from_pretrained(model) + self.wd = wd + self.resample = torchaudio.transforms.Resample(model_sr, slm_sr) + self.wavlm.eval() + for param in self.wavlm.parameters(): + param.requires_grad = False + + def forward(self, wav, y_rec): + with torch.no_grad(): + wav_16 = self.resample(wav) + wav_embeddings = self.wavlm( + input_values=wav_16, output_hidden_states=True + ).hidden_states + y_rec_16 = self.resample(y_rec) + y_rec_embeddings = self.wavlm( + input_values=y_rec_16.squeeze(), output_hidden_states=True + ).hidden_states + + floss = 0 + for er, eg in zip(wav_embeddings, y_rec_embeddings): + floss += torch.mean(torch.abs(er - eg)) + + return floss.mean() + + def generator(self, y_rec): + y_rec_16 = self.resample(y_rec) + y_rec_embeddings = self.wavlm( + input_values=y_rec_16, output_hidden_states=True + ).hidden_states + y_rec_embeddings = ( + torch.stack(y_rec_embeddings, dim=1) + .transpose(-1, -2) + .flatten(start_dim=1, end_dim=2) + ) + y_df_hat_g = self.wd(y_rec_embeddings) + loss_gen = torch.mean((1 - y_df_hat_g) ** 2) + + return loss_gen + + def discriminator(self, wav, y_rec): + with torch.no_grad(): + wav_16 = self.resample(wav) + wav_embeddings = self.wavlm( + input_values=wav_16, output_hidden_states=True + ).hidden_states + y_rec_16 = self.resample(y_rec) + y_rec_embeddings = self.wavlm( + input_values=y_rec_16, output_hidden_states=True + ).hidden_states + + y_embeddings = ( + torch.stack(wav_embeddings, dim=1) + .transpose(-1, -2) + .flatten(start_dim=1, end_dim=2) + ) + y_rec_embeddings = ( + torch.stack(y_rec_embeddings, dim=1) + .transpose(-1, -2) + .flatten(start_dim=1, end_dim=2) + ) + + y_d_rs = self.wd(y_embeddings) + y_d_gs = self.wd(y_rec_embeddings) + + y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs + + r_loss = torch.mean((1 - y_df_hat_r) ** 2) + g_loss = torch.mean((y_df_hat_g) ** 2) + + loss_disc_f = r_loss + g_loss + + return loss_disc_f.mean() + + def discriminator_forward(self, wav): + with torch.no_grad(): + wav_16 = self.resample(wav) + wav_embeddings = self.wavlm( + input_values=wav_16, output_hidden_states=True + ).hidden_states + y_embeddings = ( + torch.stack(wav_embeddings, dim=1) + .transpose(-1, -2) + .flatten(start_dim=1, end_dim=2) + ) + + y_d_rs = self.wd(y_embeddings) + + return y_d_rs diff --git a/wetts/vits/model/discriminators.py b/wetts/vits/model/discriminators.py index a426e70..9d2534a 100644 --- a/wetts/vits/model/discriminators.py +++ b/wetts/vits/model/discriminators.py @@ -447,3 +447,52 @@ def forward(self, x, x_mask, dur_r, dur_hat, g=None): output_probs.append([output_prob]) return output_probs + + +class WavLMDiscriminator(nn.Module): + """docstring for Discriminator.""" + + def __init__( + self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False + ): + super(WavLMDiscriminator, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.pre = norm_f( + Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0) + ) + + self.convs = nn.ModuleList( + [ + norm_f( + nn.Conv1d( + initial_channel, initial_channel * 2, kernel_size=5, padding=2 + ) + ), + norm_f( + nn.Conv1d( + initial_channel * 2, + initial_channel * 4, + kernel_size=5, + padding=2, + ) + ), + norm_f( + nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2) + ), + ] + ) + + self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1)) + + def forward(self, x): + x = self.pre(x) + + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + x = torch.flatten(x, 1, -1) + + return x diff --git a/wetts/vits/train.py b/wetts/vits/train.py index 3e5e6b8..b0377f0 100644 --- a/wetts/vits/train.py +++ b/wetts/vits/train.py @@ -18,10 +18,15 @@ MultiPeriodDiscriminator, MultiPeriodMultiResolutionDiscriminator, DurationDiscriminatorV1, - DurationDiscriminatorV2) + DurationDiscriminatorV2, + WavLMDiscriminator) from model.flows import AVAILABLE_FLOW_TYPES from model.models import SynthesizerTrn -from losses import generator_loss, discriminator_loss, feature_loss, kl_loss +from losses import (generator_loss, + discriminator_loss, + feature_loss, + kl_loss, + WavLMLoss) from utils import commons, task from utils.mel_processing import mel_spectrogram_torch, spec_to_mel_torch @@ -157,6 +162,30 @@ def main(): print("NOT using any duration discriminator like VITS1") net_dur_disc = None + if ("use_wd" in hps.model.keys() and hps.model.use_wd): + os.makedirs(hps.model.slm_model, exist_ok=True) + if not os.path.isfile( + os.path.join(hps.model.slm_model, "pytorch_model.bin") + ): + task.download_wavlm_model(hps.model.slm_model) + + net_wd = WavLMDiscriminator( + hps.model.slm_hidden, + hps.model.slm_nlayers, + hps.model.slm_initial_channel + ).cuda(local_rank) + + wl = WavLMLoss( + hps.model.slm_model, + net_wd, + hps.data.sampling_rate, + hps.model.slm_sr, + ).to(local_rank) + print("Using WavLMDiscriminator") + else: + net_wd = None + wl = None + net_g = SynthesizerTrn(hps.data.num_phones, posterior_channels, hps.train.segment_size // hps.data.hop_length, @@ -188,6 +217,12 @@ def main(): device_ids=[local_rank], find_unused_parameters=True ) + if net_wd: + net_wd = DDP( + net_wd, + device_ids=[local_rank], + find_unused_parameters=True + ) # Get the optimizer optim_g = torch.optim.AdamW( @@ -211,6 +246,15 @@ def main(): ) else: optim_dur_disc = None + if net_wd: + optim_wd = torch.optim.AdamW( + net_wd.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + else: + optim_wd = None # Load the checkpoint try: @@ -226,6 +270,12 @@ def main(): net_dur_disc, optim_dur_disc, ) + if net_wd: + _, _, _, epoch_str = task.load_checkpoint( + task.latest_checkpoint_path(hps.model_dir, "WD_*.pth"), + net_wd, + optim_wd, + ) global_step = (epoch_str - 1) * len(train_loader) except Exception as e: epoch_str = 1 @@ -241,6 +291,12 @@ def main(): optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) else: scheduler_dur_disc = None + if net_wd: + scheduler_wd = torch.optim.lr_scheduler.ExponentialLR( + optim_wd, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 + ) + else: + scheduler_wd = None # Get the tensorboard summary writer = None @@ -260,9 +316,9 @@ def main(): local_rank, epoch, hps, - [net_g, net_d, net_dur_disc], - [optim_g, optim_d, optim_dur_disc], - [scheduler_g, scheduler_d, scheduler_dur_disc], + [net_g, net_d, net_dur_disc, net_wd, wl], + [optim_g, optim_d, optim_dur_disc, optim_wd], + [scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd], scaler, [train_loader, eval_loader], logger, @@ -274,9 +330,9 @@ def main(): local_rank, epoch, hps, - [net_g, net_d, net_dur_disc], - [optim_g, optim_d, optim_dur_disc], - [scheduler_g, scheduler_d, scheduler_dur_disc], + [net_g, net_d, net_dur_disc, net_wd, wl], + [optim_g, optim_d, optim_dur_disc, optim_wd], + [scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd], scaler, [train_loader, None], None, @@ -286,13 +342,14 @@ def main(): scheduler_d.step() if net_dur_disc: scheduler_dur_disc.step() - + if net_wd: + scheduler_wd.step() def train_and_evaluate(rank, local_rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): - net_g, net_d, net_dur_disc = nets - optim_g, optim_d, optim_dur_disc = optims - scheduler_g, scheduler_d, scheduler_dur_disc = schedulers + net_g, net_d, net_dur_disc, net_wd, wl = nets + optim_g, optim_d, optim_dur_disc, optim_wd = optims + scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd = schedulers train_loader, eval_loader = loaders if writers: writer, writer_eval = writers @@ -304,7 +361,8 @@ def train_and_evaluate(rank, local_rank, epoch, hps, nets, optims, schedulers, s net_d.train() if net_dur_disc: net_dur_disc.train() - + if net_wd: + net_wd.train() for batch_idx, ( x, x_lengths, @@ -406,6 +464,17 @@ def train_and_evaluate(rank, local_rank, epoch, hps, nets, optims, schedulers, s grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) scaler.step(optim_d) + if wl is not None and net_wd is not None: + with autocast(enabled=hps.train.fp16_run, dtype=torch.float16): + loss_slm = wl.discriminator( + y.detach().squeeze(), y_hat.detach().squeeze() + ).mean() + optim_wd.zero_grad() + scaler.scale(loss_slm).backward() + scaler.unscale_(optim_wd) + grad_norm_wd = commons.clip_grad_value_(net_wd.parameters(), None) + scaler.step(optim_wd) + with autocast(enabled=hps.train.fp16_run): # Generator y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) @@ -424,6 +493,11 @@ def train_and_evaluate(rank, local_rank, epoch, hps, nets, optims, schedulers, s if net_dur_disc: loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g) loss_gen_all += loss_dur_gen + if net_wd is not None: + loss_lm = wl(y.detach().squeeze(), y_hat.squeeze()).mean() + loss_lm_gen = wl.generator(y_hat.squeeze()) + loss_gen_all += loss_lm + loss_gen_all += loss_lm_gen optim_g.zero_grad() scaler.scale(loss_gen_all).backward() @@ -457,6 +531,15 @@ def train_and_evaluate(rank, local_rank, epoch, hps, nets, optims, schedulers, s "grad_norm_dur_disc": grad_norm_dur_disc, }) + if net_wd: + scalar_dict.update( + { + "loss/wd/total": loss_slm.item(), + "loss/g/lm": loss_lm.item(), + "loss/g/lm_gen": loss_lm_gen.item(), + "grad_norm_wd": grad_norm_wd, + } + ) scalar_dict.update({ "loss/g/fm": loss_fm, "loss/g/mel": loss_mel, @@ -524,6 +607,14 @@ def train_and_evaluate(rank, local_rank, epoch, hps, nets, optims, schedulers, s os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)), ) + if net_wd: + task.save_checkpoint( + net_wd, + optim_wd, + hps.train.learning_rate, + epoch, + os.path.join(hps.model_dir, "WD_{}.pth".format(global_step)), + ) global_step += 1 if rank == 0: diff --git a/wetts/vits/utils/task.py b/wetts/vits/utils/task.py index ece5423..3a7c88a 100644 --- a/wetts/vits/utils/task.py +++ b/wetts/vits/utils/task.py @@ -3,6 +3,7 @@ import json import logging import os +import huggingface_hub from pathlib import Path import torch @@ -17,6 +18,16 @@ logger = logging +def download_wavlm_model(wavlm_dir): + huggingface_hub.snapshot_download( + repo_id="microsoft/wavlm-base-plus", + local_dir=wavlm_dir, + local_dir_use_symlinks=False + ) + logger.info("Download wavlm-base-plus model to {}".format(wavlm_dir)) + return + + def load_checkpoint(checkpoint_path, model, optimizer=None): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")