diff --git a/src/python/piper_train/__main__.py b/src/python/piper_train/__main__.py index da7221a3..0f9a55eb 100644 --- a/src/python/piper_train/__main__.py +++ b/src/python/piper_train/__main__.py @@ -5,7 +5,7 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from .vits.lightning import VitsModel @@ -27,6 +27,11 @@ def main(): type=int, help="Save checkpoint every N epochs (default: 1)", ) + parser.add_argument( + "--patience", + type=int, + help="Number of validation cycles to allow to pass without improvement before stopping training" + ) parser.add_argument( "--quality", default="medium", @@ -76,19 +81,22 @@ def main(): num_speakers = int(config["num_speakers"]) sample_rate = int(config["audio"]["sample_rate"]) - trainer = Trainer.from_argparse_args(args) + callbacks = [] if args.checkpoint_epochs is not None: - trainer.callbacks = [ModelCheckpoint( - every_n_epochs=args.checkpoint_epochs, - save_top_k=args.num_ckpt, - save_last=args.save_last - )] + callbacks.append( + ModelCheckpoint(every_n_epochs=args.checkpoint_epochs, monitor="val_loss", save_top_k=args.num_ckpt, save_last=args.save_last, mode="min") + ) _LOGGER.debug( "Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs ) _LOGGER.debug( "%s Checkpoints will be saved", args.num_ckpt ) + if args.patience is not None: + callbacks.append( + EarlyStopping(monitor="val_loss", min_delta=0.00, patience=args.patience, verbose=False, mode="min") + ) + trainer = Trainer.from_argparse_args(args, callbacks=callbacks) dict_args = vars(args) if args.quality == "x-low": diff --git a/src/python/piper_train/vits/lightning.py b/src/python/piper_train/vits/lightning.py index 5dc1c96c..a3698742 100644 --- a/src/python/piper_train/vits/lightning.py +++ b/src/python/piper_train/vits/lightning.py @@ -428,30 +428,34 @@ def validation_step(self, batch: Batch, batch_idx: int): val_loss = self.training_step_g(batch) + self.training_step_d(batch) + self.training_step_dur(batch) self.log("val_loss", val_loss) print(f"Epoch: {self.current_epoch}. Steps: {self.global_step}. Validation loss: {val_loss}") - # Generate audio examples - for utt_idx, test_utt in enumerate(self._test_dataset): - text = test_utt.phoneme_ids.unsqueeze(0).to(self.device) - text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device) - scales = [1.0, 1.0, 1.0] - sid = ( - test_utt.speaker_id.to(self.device) - if test_utt.speaker_id is not None - else None - ) - test_audio = self(text, text_lengths, scales, sid=sid).detach() + return val_loss - # Scale to make louder in [-1, 1] - test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max()))) + def on_validation_end(self) -> None: + # Generate audio examples after validation, but not during sanity check + if not self.trainer.sanity_checking: + for utt_idx, test_utt in enumerate(self._test_dataset): + text = test_utt.phoneme_ids.unsqueeze(0).to(self.device) + text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device) + scales = [1.0, 1.0, 1.0] + sid = ( + test_utt.speaker_id.to(self.device) + if test_utt.speaker_id is not None + else None + ) + test_audio = self(text, text_lengths, scales, sid=sid).detach() - tag = test_utt.text or str(utt_idx) - self.logger.experiment.add_audio( - tag, - test_audio, - self.global_step, - sample_rate=self.hparams.sample_rate - ) + # Scale to make louder in [-1, 1] + test_audio = test_audio * (1.0 / max(0.01, abs(test_audio).max())) - return val_loss + tag = test_utt.text or str(utt_idx) + self.logger.experiment.add_audio( + tag, + test_audio, + self.global_step, + sample_rate=self.hparams.sample_rate + ) + + return super().on_validation_end() def configure_optimizers(self): optimizers = [