diff --git a/src/python/piper_train/__main__.py b/src/python/piper_train/__main__.py index 3a4ff51e..4bc440be 100644 --- a/src/python/piper_train/__main__.py +++ b/src/python/piper_train/__main__.py @@ -5,7 +5,7 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from .vits.lightning import VitsModel @@ -24,6 +24,11 @@ def main(): type=int, help="Save checkpoint every N epochs (default: 1)", ) + parser.add_argument( + "--patience", + type=int, + help="Number of validation cycles to allow to pass without improvement before stopping training" + ) parser.add_argument( "--quality", default="medium", @@ -57,12 +62,15 @@ def main(): num_speakers = int(config["num_speakers"]) sample_rate = int(config["audio"]["sample_rate"]) - trainer = Trainer.from_argparse_args(args) + callbacks = [] if args.checkpoint_epochs is not None: - trainer.callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)] + callbacks.append(ModelCheckpoint(every_n_epochs=args.checkpoint_epochs, monitor="val_loss", save_top_k=1, mode="min")) _LOGGER.debug( "Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs ) + if args.patience is not None: + callbacks.append(EarlyStopping(monitor="val_loss", min_delta=0.00, patience=args.patience, verbose=True, mode="min")) + trainer = Trainer.from_argparse_args(args, callbacks=callbacks) dict_args = vars(args) if args.quality == "x-low": diff --git a/src/python/piper_train/vits/lightning.py b/src/python/piper_train/vits/lightning.py index c6b7250a..0f9183e0 100644 --- a/src/python/piper_train/vits/lightning.py +++ b/src/python/piper_train/vits/lightning.py @@ -282,29 +282,32 @@ def training_step_d(self, batch: Batch): def validation_step(self, batch: Batch, batch_idx: int): val_loss = self.training_step_g(batch) + self.training_step_d(batch) self.log("val_loss", val_loss) - - # Generate audio examples - for utt_idx, test_utt in enumerate(self._test_dataset): - text = test_utt.phoneme_ids.unsqueeze(0).to(self.device) - text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device) - scales = [0.667, 1.0, 0.8] - sid = ( - test_utt.speaker_id.to(self.device) - if test_utt.speaker_id is not None - else None - ) - test_audio = self(text, text_lengths, scales, sid=sid).detach() - - # Scale to make louder in [-1, 1] - test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max()))) - - tag = test_utt.text or str(utt_idx) - self.logger.experiment.add_audio( - tag, test_audio, sample_rate=self.hparams.sample_rate - ) - return val_loss + def on_validation_end(self) -> None: + # Generate audio examples after validation, but not during sanity check + if not self.trainer.sanity_checking: + for utt_idx, test_utt in enumerate(self._test_dataset): + text = test_utt.phoneme_ids.unsqueeze(0).to(self.device) + text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device) + scales = [0.667, 1.0, 0.8] + sid = ( + test_utt.speaker_id.to(self.device) + if test_utt.speaker_id is not None + else None + ) + test_audio = self(text, text_lengths, scales, sid=sid).detach() + + # Scale to make louder in [-1, 1] + test_audio = test_audio * (1.0 / max(0.01, abs(test_audio).max())) + + tag = test_utt.text or str(utt_idx) + self.logger.experiment.add_audio( + tag, test_audio, sample_rate=self.hparams.sample_rate + ) + + return super().on_validation_end() + def configure_optimizers(self): optimizers = [ torch.optim.AdamW( diff --git a/src/python/piper_train/vits/mel_processing.py b/src/python/piper_train/vits/mel_processing.py index 72e81e24..40efc2f1 100644 --- a/src/python/piper_train/vits/mel_processing.py +++ b/src/python/piper_train/vits/mel_processing.py @@ -55,23 +55,24 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) mode="reflect", ) y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, + with torch.autocast(device_type=y.device.type, dtype=torch.float32): + y = y.to(y.device.type, torch.float32) + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) - ) - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) return spec @@ -116,24 +117,26 @@ def mel_spectrogram_torch( mode="reflect", ) y = y.squeeze(1) - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, + with torch.autocast(device_type=y.device.type, dtype=torch.float32): + y = y.to(y.device.type, torch.float32) + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) - ) + # print(y.dtype, spec.dtype) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - - spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = spectral_normalize_torch(spec) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) return spec