Skip to content

Commit

Permalink
Major fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmcpantoja committed Apr 25, 2024
1 parent 21d9597 commit 9e3a65d
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 24 deletions.
9 changes: 6 additions & 3 deletions src/python/piper_train/norm_audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def cache_norm_audio(
audio_cache_id = sha256(str(audio_path).encode()).hexdigest()

audio_norm_path = cache_dir / f"{audio_cache_id}.pt"
audio_spec_path = cache_dir / f"{audio_cache_id}.spec.pt"
if use_mel_spec_posterior:
audio_spec_path = audio_spec_path.replace(".spec.pt", ".mel.pt")
audio_spec_path = cache_dir / f"{audio_cache_id}.mel.pt"
else:
audio_spec_path = cache_dir / f"{audio_cache_id}.spec.pt"
# Normalize audio
audio_norm_tensor: Optional[torch.FloatTensor] = None
if ignore_cache or (not audio_norm_path.exists()):
Expand Down Expand Up @@ -81,14 +82,16 @@ def cache_norm_audio(
if audio_norm_tensor is None:
# Load pre-cached normalized audio
audio_norm_tensor = torch.load(audio_norm_path)
if self.use_mel_spec_posterior:
if use_mel_spec_posterior:
audio_spec_tensor = mel_spectrogram_torch(
y=audio_norm_tensor,
n_fft=filter_length,
num_mels = n_mels,
sampling_rate=sample_rate,
hop_size=hop_length,
win_size=window_length,
fmin = 0.0,
fmax = None,
center=False,
).squeeze(0)
else:
Expand Down
35 changes: 19 additions & 16 deletions src/python/piper_train/vits/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from .models import (
SynthesizerTrn,
MultiPeriodDiscriminator,
DurationDiscriminator,
DurationDiscriminator2,
DurationDiscriminatorV1,
DurationDiscriminatorV2,
AVAILABLE_FLOW_TYPES,
AVAILABLE_DURATION_DISCRIMINATOR_TYPES
)
from pqmf import PQMF
from .pqmf import PQMF

_LOGGER = logging.getLogger("vits.lightning")

Expand All @@ -44,11 +44,13 @@ def __init__(
# Vits2
use_mel_posterior_encoder: bool = True,
use_duration_discriminator: bool = True,
duration_discriminator_type: str = dur_disc_2,
duration_discriminator_type: str = "dur_disc_2",
use_transformer_flows: bool = True,
transformer_flow_type: str = "pre_conv2",
use_spk_conditioned_encoder: bool = False,
use_noise_scaled_mas: bool = True,
mas_noise_scale_initial: float = 0.01,
noise_scale_delta = 2e-6,
# mel
filter_length: int = 1024,
hop_length: int = 256,
Expand Down Expand Up @@ -99,6 +101,7 @@ def __init__(
num_test_examples: int = 5,
validation_split: float = 0.1,
max_phoneme_ids: Optional[int] = None,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
Expand All @@ -115,21 +118,21 @@ def __init__(
self.posterior_channels = self.hparams.filter_length // 2 + 1

# More VITS2 features:
if hparams.use_transformer_flows == True:
if self.hparams.use_transformer_flows == True:
self.transformer_flow_type = self.hparams.transformer_flow_type
print(f"Using transformer flows {self.transformer_flow_type} for VITS2")
assert self.transformer_flow_type in AVAILABLE_FLOW_TYPES, f"transformer_flow_type must be one of {AVAILABLE_FLOW_TYPES}"
else:
print("Using normal flows for VITS1")

if self.hparams.use_spk_conditioned_encoder == True:
if self.hparams.n_speakers == 0:
print("Warning: use_spk_conditioned_encoder is True but n_speakers is 0")
if self.hparams.num_speakers == 0:
print("Warning: use_spk_conditioned_encoder is True but num_speakers is 0")
print("Setting use_spk_conditioned_encoder to False as model is a single speaker model")
else:
print("Using normal encoder for VITS1 (cuz it's single speaker after all)")

if self.hparams.use_noise_scaled_mas == True:
if self.hparams.use_noise_scaled_mas:
print("Using noise scaled MAS for VITS2")
self.mas_noise_scale_initial = 0.01
self.noise_scale_delta = 2e-6
Expand Down Expand Up @@ -181,26 +184,26 @@ def __init__(
# print("Using duration discriminator for VITS2")
#- for duration_discriminator2
# duration_discriminator_type = getattr(hps.model, "duration_discriminator_type", "dur_disc_1")
duration_discriminator_type = hparams.duration_discriminator_type
duration_discriminator_type = self.hparams.duration_discriminator_type
print(f"Using duration discriminator {duration_discriminator_type} for VITS2")
assert duration_discriminator_type in AVAILABLE_DURATION_DISCRIMINATOR_TYPES.keys(), f"duration_discriminator_type must be one of {list(AVAILABLE_DURATION_DISCRIMINATOR_TYPES.keys())}"
#DurationDiscriminator = AVAILABLE_DURATION_DISCRIMINATOR_TYPES[duration_discriminator_type]

if duration_discriminator_type == "dur_disc_1":
self.net_dur_disc = DurationDiscriminator(
self.net_dur_disc = DurationDiscriminatorV1(
self.hparams.hidden_channels,
self.hparams.hidden_channels,
3,
0.1,
gin_channels=self.hparams.gin_channels if self.hparams.n_speakers != 0 else 0,
gin_channels=self.hparams.gin_channels if self.hparams.num_speakers != 0 else 0,
)
elif duration_discriminator_type == "dur_disc_2":
self.net_dur_disc = DurationDiscriminator2(
self.net_dur_disc = DurationDiscriminatorV2(
self.hparams.hidden_channels,
self.hparams.hidden_channels,
3,
0.1,
gin_channels=self.hparams.gin_channels if self.hparams.n_speakers != 0 else 0,
gin_channels=self.hparams.gin_channels if self.hparams.num_speakers != 0 else 0,
)
else:
print("NOT using any duration discriminator like VITS1")
Expand Down Expand Up @@ -309,9 +312,9 @@ def training_step_g(self, batch: Batch):
batch.speaker_ids if batch.speaker_ids is not None else None,
)
# VITS2:
if self.model_g.module.use_noise_scaled_mas:
current_mas_noise_scale = self.model_g.module.mas_noise_scale_initial - self.model_g.module.noise_scale_delta * self.global_step
self.model_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0)
if self.model_g.use_noise_scaled_mas:
current_mas_noise_scale = self.model_g.mas_noise_scale_initial - self.model_g.noise_scale_delta * self.global_step
self.model_g.current_mas_noise_scale = max(current_mas_noise_scale, 0.0)
(
y_hat,
y_hat_mb,
Expand Down
2 changes: 1 addition & 1 deletion src/python/piper_train/vits/losses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from stft_loss import MultiResolutionSTFTLoss
from .stft_loss import MultiResolutionSTFTLoss


def feature_loss(fmap_r, fmap_g):
Expand Down
8 changes: 4 additions & 4 deletions src/python/piper_train/vits/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from . import attentions, commons, modules, monotonic_align
from .commons import get_padding, init_weights
from pqmf import PQMF
from stft import TorchSTFT, OnnxSTFT
from .pqmf import PQMF
from .stft import TorchSTFT, OnnxSTFT

AVAILABLE_FLOW_TYPES = ["pre_conv", "pre_conv2", "fft", "mono_layer_inter_residual", "mono_layer_post_residual"]
AVAILABLE_DURATION_DISCRIMINATOR_TYPES = {"dur_disc_1": "DurationDiscriminator", "dur_disc_2": "DurationDiscriminator2"}
Expand Down Expand Up @@ -169,7 +169,7 @@ def forward(self, x, x_mask, g=None):
return x * x_mask


class DurationDiscriminator(nn.Module): # vits2
class DurationDiscriminatorV1(nn.Module): # vits2
# TODO : not using "spk conditioning" for now according to the paper.
# Can be a better discriminator if we use it.
def __init__(
Expand Down Expand Up @@ -254,7 +254,7 @@ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
return output_probs


class DurationDiscriminator2(nn.Module): # vits2 - DurationDiscriminator2
class DurationDiscriminatorV2(nn.Module): # vits2 - DurationDiscriminator2
# TODO : not using "spk conditioning" for now according to the paper.
# Can be a better discriminator if we use it.
def __init__(
Expand Down

0 comments on commit 9e3a65d

Please sign in to comment.