Skip to content

Commit

Permalink
Update models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Bebra777228 authored Aug 4, 2024
1 parent 2da67eb commit e9c4bdb
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions src/infer_pack/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
class TextEncoder(nn.Module):
def __init__(self, input_dim, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True):
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.emb_phone = nn.Linear(input_dim, hidden_channels)
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
self.emb_pitch = nn.Embedding(256, hidden_channels) if f0 else None
Expand All @@ -25,13 +27,13 @@ def __init__(self, input_dim, out_channels, hidden_channels, filter_channels, n_

def forward(self, phone, pitch, lengths):
x = self.emb_phone(phone) + self.emb_pitch(pitch) if pitch is not None else self.emb_phone(phone)
x = x * math.sqrt(hidden_channels)
x = x * math.sqrt(self.hidden_channels)
x = self.lrelu(x)
x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, out_channels, dim=1)
m, logs = torch.split(stats, self.out_channels, dim=1)
return m, logs, x_mask


Expand All @@ -57,6 +59,7 @@ def remove_weight_norm(self):
class PosteriorEncoder(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
super().__init__()
self.out_channels = out_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
Expand All @@ -66,7 +69,7 @@ def forward(self, x, x_lengths, g=None):
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, out_channels, dim=1)
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask

Expand Down Expand Up @@ -127,11 +130,16 @@ def remove_weight_norm(self):
class SineGen(torch.nn.Module):
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
super(SineGen, self).__init__()
self.dim = harmonic_num + 1
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.dim = self.harmonic_num + 1
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold

def _f02uv(self, f0):
uv = torch.ones_like(f0)
uv = uv * (f0 > voiced_threshold)
uv = uv * (f0 > self.voiced_threshold)
if uv.device.type == "privateuseone":
uv = uv.float()
return uv
Expand All @@ -141,9 +149,9 @@ def forward(self, f0, upp):
f0 = f0[:, None].transpose(1, 2)
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
f0_buf[:, :, 0] = f0[:, :, 0]
for idx in np.arange(harmonic_num):
for idx in np.arange(self.harmonic_num):
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
rad_values = (f0_buf / samp_rate) % 1
rad_values = (f0_buf / self.sampling_rate) % 1
rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
Expand All @@ -155,10 +163,10 @@ def forward(self, f0, upp):
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
cumsum_shift = torch.zeros_like(rad_values)
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi) * sine_amp
sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi) * self.sine_amp
uv = self._f02uv(f0)
uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode="nearest").transpose(2, 1)
noise_amp = uv * noise_std + (1 - uv) * sine_amp / 3
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
Expand Down Expand Up @@ -257,12 +265,12 @@ class Synthesizer(nn.Module):
def __init__(self, input_dim, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, **kwargs):
super().__init__()
sr = sr2sr[sr] if isinstance(sr, str) else sr
self.segment_size = segment_size
self.enc_p = TextEncoder(input_dim, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs["is_half"])
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
self.emb_g = nn.Embedding(spk_embed_dim, gin_channels)


def remove_weight_norm(self):
self.dec.remove_weight_norm()
Expand All @@ -274,8 +282,8 @@ def forward(self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds):
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
z_slice, ids_slice = rand_slice_segments(z, y_lengths, segment_size)
pitchf = slice_segments2(pitchf, ids_slice, segment_size)
z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
pitchf = slice_segments2(pitchf, ids_slice, self.segment_size)
o = self.dec(z_slice, pitchf, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)

Expand All @@ -296,12 +304,12 @@ def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
class Synthesizer_nono(nn.Module):
def __init__(self, input_dim, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr=None, **kwargs):
super().__init__()
self.segment_size = segment_size
self.enc_p = TextEncoder(input_dim, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=False)
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
self.emb_g = nn.Embedding(spk_embed_dim, gin_channels)


def remove_weight_norm(self):
self.dec.remove_weight_norm()
Expand All @@ -313,7 +321,7 @@ def forward(self, phone, phone_lengths, y, y_lengths, ds):
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
z_slice, ids_slice = rand_slice_segments(z, y_lengths, segment_size)
z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=g)
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)

Expand All @@ -340,10 +348,7 @@ def __init__(self, use_spectral_norm=False):
self.discriminators = nn.ModuleList(discs)

def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
Expand All @@ -364,10 +369,7 @@ def __init__(self, use_spectral_norm=False):
self.discriminators = nn.ModuleList(discs)

def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
Expand Down

0 comments on commit e9c4bdb

Please sign in to comment.