From 379bee117d937857d3529b39600df1d3e3c9130e Mon Sep 17 00:00:00 2001 From: Mitchell DeHaven Date: Tue, 23 Apr 2024 13:57:17 -0600 Subject: [PATCH 1/6] Change callback population to preserve progress bar --- src/python/piper_train/__main__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/python/piper_train/__main__.py b/src/python/piper_train/__main__.py index 3a4ff51e..51bc5abe 100644 --- a/src/python/piper_train/__main__.py +++ b/src/python/piper_train/__main__.py @@ -57,12 +57,13 @@ def main(): num_speakers = int(config["num_speakers"]) sample_rate = int(config["audio"]["sample_rate"]) - trainer = Trainer.from_argparse_args(args) + callbacks = [] if args.checkpoint_epochs is not None: - trainer.callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)] + callbacks.append(ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)) _LOGGER.debug( "Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs ) + trainer = Trainer.from_argparse_args(args, callbacks=callbacks) dict_args = vars(args) if args.quality == "x-low": From 555e8aa7aa6db1e24a685aa674cc9109153bd74f Mon Sep 17 00:00:00 2001 From: Mitchell DeHaven Date: Tue, 23 Apr 2024 14:01:44 -0600 Subject: [PATCH 2/6] Add early stopping to prevent training for too long and overfitting --- src/python/piper_train/__main__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/python/piper_train/__main__.py b/src/python/piper_train/__main__.py index 51bc5abe..d3630bb0 100644 --- a/src/python/piper_train/__main__.py +++ b/src/python/piper_train/__main__.py @@ -5,7 +5,7 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from .vits.lightning import VitsModel @@ -24,6 +24,11 @@ def main(): type=int, help="Save checkpoint every N epochs (default: 1)", ) + parser.add_argument( + "--patience", + type=int, + help="Number of validation cycles to allow to pass without improvement before stopping training" + ) parser.add_argument( "--quality", default="medium", @@ -63,6 +68,8 @@ def main(): _LOGGER.debug( "Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs ) + if args.patience is not None: + callbacks.append(EarlyStopping(monitor="val_loss", min_delta=0.00, patience=args.patience, verbose=True, mode="min")) trainer = Trainer.from_argparse_args(args, callbacks=callbacks) dict_args = vars(args) From d8655a19ca18a5a08106e225db01d7bbcc8f3f6b Mon Sep 17 00:00:00 2001 From: Mitchell DeHaven Date: Tue, 23 Apr 2024 14:04:08 -0600 Subject: [PATCH 3/6] Only save best model based on validation loss --- src/python/piper_train/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/piper_train/__main__.py b/src/python/piper_train/__main__.py index d3630bb0..4bc440be 100644 --- a/src/python/piper_train/__main__.py +++ b/src/python/piper_train/__main__.py @@ -64,7 +64,7 @@ def main(): callbacks = [] if args.checkpoint_epochs is not None: - callbacks.append(ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)) + callbacks.append(ModelCheckpoint(every_n_epochs=args.checkpoint_epochs, monitor="val_loss", save_top_k=1, mode="min")) _LOGGER.debug( "Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs ) From e6af4c6dd223064e443b42de3e6345da4e903535 Mon Sep 17 00:00:00 2001 From: Mitchell DeHaven Date: Tue, 23 Apr 2024 15:18:17 -0600 Subject: [PATCH 4/6] Put test set generation in 'on_validation_end', otherwise it gets ran for each validation batch --- src/python/piper_train/vits/lightning.py | 45 +++++++++++++----------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/src/python/piper_train/vits/lightning.py b/src/python/piper_train/vits/lightning.py index c6b7250a..58f85794 100644 --- a/src/python/piper_train/vits/lightning.py +++ b/src/python/piper_train/vits/lightning.py @@ -282,29 +282,32 @@ def training_step_d(self, batch: Batch): def validation_step(self, batch: Batch, batch_idx: int): val_loss = self.training_step_g(batch) + self.training_step_d(batch) self.log("val_loss", val_loss) - - # Generate audio examples - for utt_idx, test_utt in enumerate(self._test_dataset): - text = test_utt.phoneme_ids.unsqueeze(0).to(self.device) - text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device) - scales = [0.667, 1.0, 0.8] - sid = ( - test_utt.speaker_id.to(self.device) - if test_utt.speaker_id is not None - else None - ) - test_audio = self(text, text_lengths, scales, sid=sid).detach() - - # Scale to make louder in [-1, 1] - test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max()))) - - tag = test_utt.text or str(utt_idx) - self.logger.experiment.add_audio( - tag, test_audio, sample_rate=self.hparams.sample_rate - ) - return val_loss + def on_validation_end(self) -> None: + # Generate audio examples after validation, but not during sanity check + if not self.trainer.sanity_checking: + for utt_idx, test_utt in enumerate(self._test_dataset): + text = test_utt.phoneme_ids.unsqueeze(0).to(self.device) + text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device) + scales = [0.667, 1.0, 0.8] + sid = ( + test_utt.speaker_id.to(self.device) + if test_utt.speaker_id is not None + else None + ) + test_audio = self(text, text_lengths, scales, sid=sid).detach() + + # Scale to make louder in [-1, 1] + test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max()))) + + tag = test_utt.text or str(utt_idx) + self.logger.experiment.add_audio( + tag, test_audio, sample_rate=self.hparams.sample_rate + ) + + return super().on_validation_end() + def configure_optimizers(self): optimizers = [ torch.optim.AdamW( From de98f64cc726cc3e6119679c01fd70f944f4668e Mon Sep 17 00:00:00 2001 From: Mitchell DeHaven Date: Tue, 23 Apr 2024 17:40:25 -0600 Subject: [PATCH 5/6] Absolute value needs to be taken before max, this results in warning messages due to amplitude issues --- src/python/piper_train/vits/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/piper_train/vits/lightning.py b/src/python/piper_train/vits/lightning.py index 58f85794..0f9183e0 100644 --- a/src/python/piper_train/vits/lightning.py +++ b/src/python/piper_train/vits/lightning.py @@ -299,7 +299,7 @@ def on_validation_end(self) -> None: test_audio = self(text, text_lengths, scales, sid=sid).detach() # Scale to make louder in [-1, 1] - test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max()))) + test_audio = test_audio * (1.0 / max(0.01, abs(test_audio).max())) tag = test_utt.text or str(utt_idx) self.logger.experiment.add_audio( From 8eda7596f3857b6e160284603bbc06cb3ce77a31 Mon Sep 17 00:00:00 2001 From: Mitchell DeHaven Date: Tue, 23 Apr 2024 17:41:54 -0600 Subject: [PATCH 6/6] Force ops without half precision support in float32 context, allowing fp16 training --- src/python/piper_train/vits/mel_processing.py | 67 ++++++++++--------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/src/python/piper_train/vits/mel_processing.py b/src/python/piper_train/vits/mel_processing.py index 72e81e24..40efc2f1 100644 --- a/src/python/piper_train/vits/mel_processing.py +++ b/src/python/piper_train/vits/mel_processing.py @@ -55,23 +55,24 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) mode="reflect", ) y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, + with torch.autocast(device_type=y.device.type, dtype=torch.float32): + y = y.to(y.device.type, torch.float32) + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) - ) - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return spec @@ -116,24 +117,26 @@ def mel_spectrogram_torch( mode="reflect", ) y = y.squeeze(1) - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, + with torch.autocast(device_type=y.device.type, dtype=torch.float32): + y = y.to(y.device.type, torch.float32) + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) - ) + # print(y.dtype, spec.dtype) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - - spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = spectral_normalize_torch(spec) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) return spec