Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some improvements / bug fixes #476

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
14 changes: 11 additions & 3 deletions src/python/piper_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

from .vits.lightning import VitsModel

Expand All @@ -24,6 +24,11 @@ def main():
type=int,
help="Save checkpoint every N epochs (default: 1)",
)
parser.add_argument(
"--patience",
type=int,
help="Number of validation cycles to allow to pass without improvement before stopping training"
)
parser.add_argument(
"--quality",
default="medium",
Expand Down Expand Up @@ -57,12 +62,15 @@ def main():
num_speakers = int(config["num_speakers"])
sample_rate = int(config["audio"]["sample_rate"])

trainer = Trainer.from_argparse_args(args)
callbacks = []
if args.checkpoint_epochs is not None:
trainer.callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)]
callbacks.append(ModelCheckpoint(every_n_epochs=args.checkpoint_epochs, monitor="val_loss", save_top_k=1, mode="min"))
_LOGGER.debug(
"Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs
)
if args.patience is not None:
callbacks.append(EarlyStopping(monitor="val_loss", min_delta=0.00, patience=args.patience, verbose=True, mode="min"))
trainer = Trainer.from_argparse_args(args, callbacks=callbacks)

dict_args = vars(args)
if args.quality == "x-low":
Expand Down
45 changes: 24 additions & 21 deletions src/python/piper_train/vits/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,29 +282,32 @@ def training_step_d(self, batch: Batch):
def validation_step(self, batch: Batch, batch_idx: int):
val_loss = self.training_step_g(batch) + self.training_step_d(batch)
self.log("val_loss", val_loss)

# Generate audio examples
for utt_idx, test_utt in enumerate(self._test_dataset):
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
scales = [0.667, 1.0, 0.8]
sid = (
test_utt.speaker_id.to(self.device)
if test_utt.speaker_id is not None
else None
)
test_audio = self(text, text_lengths, scales, sid=sid).detach()

# Scale to make louder in [-1, 1]
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))

tag = test_utt.text or str(utt_idx)
self.logger.experiment.add_audio(
tag, test_audio, sample_rate=self.hparams.sample_rate
)

return val_loss

def on_validation_end(self) -> None:
# Generate audio examples after validation, but not during sanity check
if not self.trainer.sanity_checking:
for utt_idx, test_utt in enumerate(self._test_dataset):
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
scales = [0.667, 1.0, 0.8]
sid = (
test_utt.speaker_id.to(self.device)
if test_utt.speaker_id is not None
else None
)
test_audio = self(text, text_lengths, scales, sid=sid).detach()

# Scale to make louder in [-1, 1]
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio).max()))

tag = test_utt.text or str(utt_idx)
self.logger.experiment.add_audio(
tag, test_audio, sample_rate=self.hparams.sample_rate
)

return super().on_validation_end()

def configure_optimizers(self):
optimizers = [
torch.optim.AdamW(
Expand Down
67 changes: 35 additions & 32 deletions src/python/piper_train/vits/mel_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,24 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
mode="reflect",
)
y = y.squeeze(1)

spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
with torch.autocast(device_type=y.device.type, dtype=torch.float32):
y = y.to(y.device.type, torch.float32)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

return spec

Expand Down Expand Up @@ -116,24 +117,26 @@ def mel_spectrogram_torch(
mode="reflect",
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
with torch.autocast(device_type=y.device.type, dtype=torch.float32):
y = y.to(y.device.type, torch.float32)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
)
# print(y.dtype, spec.dtype)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)

return spec