diff --git a/uberduck_ml_dev/utils/denoiser.py b/uberduck_ml_dev/utils/denoiser.py index 09939664..48a14b18 100644 --- a/uberduck_ml_dev/utils/denoiser.py +++ b/uberduck_ml_dev/utils/denoiser.py @@ -22,6 +22,7 @@ import sys import torch from ..models.common import STFT +from ..vocoders.istftnet import iSTFTNetGenerator, TorchSTFT class Denoiser(torch.nn.Module): @@ -46,11 +47,21 @@ def __init__( raise Exception("Mode {} if not supported".format(mode)) with torch.no_grad(): - bias_audio = ( - hifigan.vocoder.forward(mel_input.to(hifigan.device)) - .view(1, -1) - .float() - ) + if isinstance(hifigan, iSTFTNetGenerator): + self.stft = TorchSTFT(filter_length=16, hop_length=4, win_length=16, device="cpu").to("cpu") + spec, phase = hifigan.vocoder(mel_input.to(hifigan.device)) + y_g_hat = self.stft.inverse(spec.cpu(), phase.cpu()) + bias_audio = ( + y_g_hat + .view(1, -1) + .float() + ) + else: + bias_audio = ( + hifigan.vocoder.forward(mel_input.to(hifigan.device)) + .view(1, -1) + .float() + ) bias_spec, _ = self.stft.transform(bias_audio.cpu()) self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) diff --git a/uberduck_ml_dev/vocoders/istftnet.py b/uberduck_ml_dev/vocoders/istftnet.py new file mode 100644 index 00000000..af1041af --- /dev/null +++ b/uberduck_ml_dev/vocoders/istftnet.py @@ -0,0 +1,633 @@ +__all__ = [ + "iSTFTNetGenerator", + "ResBlock1", + "ResBlock2", + "Generator", + "DiscriminatorP", + "MultiPeriodDiscriminator", + "DiscriminatorS", + "MultiScaleDiscriminator", + "feature_loss", + "discriminator_loss", + "generator_loss", + "LRELU_SLOPE", + "AttrDict", + "build_env", + "init_weights", + "apply_weight_norm", + "get_padding", +] + +""" from https://github.com/rishikksh20/iSTFTNet-pytorch """ + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +import numpy as np +from torch.autograd import Variable +from scipy.signal import get_window +from librosa.util import pad_center, tiny +import librosa.util as librosa_util +import glob +import os +import shutil +import matplotlib +import torch +from torch.nn.utils import weight_norm +matplotlib.use("Agg") +import matplotlib.pylab as plt +import json + + +LRELU_SLOPE = 0.1 + + +""" +BSD 3-Clause License +Copyright (c) 2017, Prem Seetharaman +All rights reserved. +* Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: +* Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + + +def window_sumsquare(window, n_frames, hop_length=200, win_length=800, + n_fft=800, dtype=np.float32, norm=None): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + n_frames : int > 0 + The number of analysis frames + hop_length : int > 0 + The number of samples to advance between frames + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + n_fft : int > 0 + The length of each analysis frame. + dtype : np.dtype + The data type of the output + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm)**2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] + return x + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + def __init__(self, filter_length=800, hop_length=200, win_length=800, + window='hann'): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), + np.imag(fourier_basis[:cutoff, :])]) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :]) + + if window is not None: + assert(filter_length >= win_length) + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer('forward_basis', forward_basis.float()) + self.register_buffer('inverse_basis', inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode='reflect') + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable( + torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, magnitude.size(-1), hop_length=self.hop_length, + win_length=self.win_length, n_fft=self.filter_length, + dtype=np.float32) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0]) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False) + window_sum = window_sum.to(inverse_transform.device()) if magnitude.is_cuda else window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] + inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class TorchSTFT(torch.nn.Module): + def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann', device="cpu"): + super().__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32)) + self.device = device + + def transform(self, input_data): + forward_transform = torch.stft( + input_data, + self.filter_length, self.hop_length, self.win_length, window=self.window, + return_complex=True) + + return torch.abs(forward_transform), torch.angle(forward_transform) + + def inverse(self, magnitude, phase): + inverse_transform = torch.istft( + magnitude * torch.exp(phase * 1j), + self.filter_length, self.hop_length, self.win_length, window=self.window.to(self.device)) + + return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class iSTFTNetGenerator(nn.Module): + def __init__(self, config, checkpoint, cudnn_enabled=False): + super().__init__() + self.config = config + self.checkpoint = checkpoint + self.device = "cuda" if torch.cuda.is_available() and cudnn_enabled else "cpu" + self.vocoder, self.stft = self.load_checkpoint() + self.vocoder.remove_weight_norm() + + @torch.no_grad() + def load_checkpoint(self): + h = self.load_config() + vocoder = Generator(h) + stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft, device=self.device).to(self.device) + vocoder.load_state_dict( + torch.load( + self.checkpoint, + map_location="cuda" if self.device == "cuda" else "cpu", + )["generator"] + ) + if self.device == "cuda": + vocoder = vocoder.cuda() + return vocoder.eval(), stft + + @torch.no_grad() + def load_config(self): + with open(self.config) as f: + h = AttrDict(json.load(f)) + return h + + def forward(self, mel, max_wav_value=32768): + return self.infer(mel, max_wav_value=max_wav_value) + + @torch.no_grad() + def infer(self, mel, max_wav_value=32768): + spec, phase = self.vocoder(mel) + y_g_hat = self.stft.inverse(spec, phase) + audio = ( + y_g_hat.cpu().squeeze().clamp(-1, 1).numpy() + * max_wav_value + ).astype(np.int16) + return audio + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), + k, u, padding=(k-u)//2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel//(2**(i+1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.post_n_fft = h.gen_istft_n_fft + self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x) + else: + xs += self.resblocks[i*self.num_kernels+j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.reflection_pad(x) + x = self.conv_post(x) + spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]) + phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :]) + + return spec, phase + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i-1](y) + y_hat = self.meanpools[i-1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss*2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1-dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1-dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses