From 0dd0d42e12cab0a8a2f31e64b08119ad98efb5a2 Mon Sep 17 00:00:00 2001 From: Nugine Date: Tue, 7 May 2024 21:04:40 +0800 Subject: [PATCH 1/8] Add FreeVC implementation --- config/freevc.json | 97 +++++ egs/vc/FreeVC/README.md | 141 +++++++ egs/vc/FreeVC/exp_config.json | 18 + egs/vc/FreeVC/freevc.json | 54 +++ egs/vc/FreeVC/run.sh | 113 ++++++ models/vc/FreeVC/__init__.py | 0 models/vc/FreeVC/commons.py | 17 + models/vc/FreeVC/data.py | 287 ++++++++++++++ models/vc/FreeVC/hifigan.py | 246 ++++++++++++ models/vc/FreeVC/inference.py | 91 +++++ models/vc/FreeVC/losses.py | 62 +++ models/vc/FreeVC/mel_processing.py | 147 +++++++ models/vc/FreeVC/model.py | 326 +++++++++++++++ models/vc/FreeVC/preprocess.py | 319 +++++++++++++++ models/vc/FreeVC/speaker_encoder/__init__.py | 0 models/vc/FreeVC/speaker_encoder/audio.py | 115 ++++++ .../ckpt/pretrained_bak_5805000.pt.txt | 1 + .../FreeVC/speaker_encoder/compute_embed.py | 43 ++ models/vc/FreeVC/speaker_encoder/config.py | 27 ++ .../speaker_encoder/data_objects/__init__.py | 6 + .../data_objects/random_cycler.py | 37 ++ .../speaker_encoder/data_objects/speaker.py | 43 ++ .../data_objects/speaker_batch.py | 19 + .../speaker_verification_dataset.py | 67 ++++ .../speaker_encoder/data_objects/utterance.py | 26 ++ models/vc/FreeVC/speaker_encoder/hparams.py | 31 ++ models/vc/FreeVC/speaker_encoder/inference.py | 190 +++++++++ models/vc/FreeVC/speaker_encoder/model.py | 144 +++++++ .../vc/FreeVC/speaker_encoder/params_data.py | 27 ++ .../vc/FreeVC/speaker_encoder/params_model.py | 10 + .../vc/FreeVC/speaker_encoder/preprocess.py | 329 ++++++++++++++++ models/vc/FreeVC/speaker_encoder/train.py | 150 +++++++ .../FreeVC/speaker_encoder/visualizations.py | 195 +++++++++ .../FreeVC/speaker_encoder/voice_encoder.py | 193 +++++++++ models/vc/FreeVC/train.py | 372 ++++++++++++++++++ models/vc/FreeVC/train_utils.py | 167 ++++++++ models/vc/FreeVC/wavlm.py | 48 +++ 37 files changed, 4158 insertions(+) create mode 100644 config/freevc.json create mode 100644 egs/vc/FreeVC/README.md create mode 100644 egs/vc/FreeVC/exp_config.json create mode 100644 egs/vc/FreeVC/freevc.json create mode 100644 egs/vc/FreeVC/run.sh create mode 100644 models/vc/FreeVC/__init__.py create mode 100644 models/vc/FreeVC/commons.py create mode 100644 models/vc/FreeVC/data.py create mode 100644 models/vc/FreeVC/hifigan.py create mode 100644 models/vc/FreeVC/inference.py create mode 100644 models/vc/FreeVC/losses.py create mode 100644 models/vc/FreeVC/mel_processing.py create mode 100644 models/vc/FreeVC/model.py create mode 100644 models/vc/FreeVC/preprocess.py create mode 100644 models/vc/FreeVC/speaker_encoder/__init__.py create mode 100644 models/vc/FreeVC/speaker_encoder/audio.py create mode 100644 models/vc/FreeVC/speaker_encoder/ckpt/pretrained_bak_5805000.pt.txt create mode 100644 models/vc/FreeVC/speaker_encoder/compute_embed.py create mode 100644 models/vc/FreeVC/speaker_encoder/config.py create mode 100644 models/vc/FreeVC/speaker_encoder/data_objects/__init__.py create mode 100644 models/vc/FreeVC/speaker_encoder/data_objects/random_cycler.py create mode 100644 models/vc/FreeVC/speaker_encoder/data_objects/speaker.py create mode 100644 models/vc/FreeVC/speaker_encoder/data_objects/speaker_batch.py create mode 100644 models/vc/FreeVC/speaker_encoder/data_objects/speaker_verification_dataset.py create mode 100644 models/vc/FreeVC/speaker_encoder/data_objects/utterance.py create mode 100644 models/vc/FreeVC/speaker_encoder/hparams.py create mode 100644 models/vc/FreeVC/speaker_encoder/inference.py create mode 100644 models/vc/FreeVC/speaker_encoder/model.py create mode 100644 models/vc/FreeVC/speaker_encoder/params_data.py create mode 100644 models/vc/FreeVC/speaker_encoder/params_model.py create mode 100644 models/vc/FreeVC/speaker_encoder/preprocess.py create mode 100644 models/vc/FreeVC/speaker_encoder/train.py create mode 100644 models/vc/FreeVC/speaker_encoder/visualizations.py create mode 100644 models/vc/FreeVC/speaker_encoder/voice_encoder.py create mode 100644 models/vc/FreeVC/train.py create mode 100644 models/vc/FreeVC/train_utils.py create mode 100644 models/vc/FreeVC/wavlm.py diff --git a/config/freevc.json b/config/freevc.json new file mode 100644 index 00000000..085c95b3 --- /dev/null +++ b/config/freevc.json @@ -0,0 +1,97 @@ +{ + "preprocess": { + "vctk_dir": "./data/VCTK", + "vctk_16k_dir": "./data/vctk-16k", + "vctk_22k_dir": "./data/vctk-22k", + "split_dir": "./data/split", + "spk_dir": "./data/spk", + "ssl_dir": "./data/ssl", + "sr_dir": "./data/sr", + "hifigan_ckpt_path": "./ckpts/hifigan-vctk-v1", + "minh": 68, + "maxh": 92 + }, + "data": { + "max_wav_value": 32768.0, + "sampling_rate": 16000, + "filter_length": 1280, + "hop_length": 320, + "win_length": 1280, + "n_mel_channels": 80, + "mel_fmin": 0.0, + "mel_fmax": null + }, + "train": { + "log_interval": 200, + "eval_interval": 10000, + "seed": 1234, + "epochs": 10000, + "learning_rate": 2e-4, + "betas": [ + 0.8, + 0.99 + ], + "eps": 1e-9, + "batch_size": 64, + "num_workers": 16, + "fp16_run": false, + "lr_decay": 0.999875, + "segment_size": 8960, + "init_lr_ratio": 1, + "warmup_epochs": 0, + "c_mel": 45, + "c_kl": 1.0, + "use_sr": true, + "max_speclen": 128 + }, + "model": { + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 10, + 8, + 2, + 2 + ], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [ + 16, + 16, + 4, + 4 + ], + "n_layers_q": 3, + "use_spectral_norm": false, + "gin_channels": 256, + "ssl_dim": 1024, + "use_spk": true + } +} \ No newline at end of file diff --git a/egs/vc/FreeVC/README.md b/egs/vc/FreeVC/README.md new file mode 100644 index 00000000..af4eb02e --- /dev/null +++ b/egs/vc/FreeVC/README.md @@ -0,0 +1,141 @@ +# FreeVC + +This is an implementation of [FreeVC: Towards High-Quality Text-Free One-Shot Voice Conversion](https://arxiv.org/abs/2210.15418). Adapted from end-to-end framework of [VITS](https://arxiv.org/abs/2106.06103) for high-quality waveform reconstruction, and propose strategies for clean content information extraction without text annotation. It disentangle content information by imposing an information bottleneck to [WavLM](https://arxiv.org/abs/2110.13900) features, utilize the **spectrogram-resize** based data augmentation to improve the purity of extracted content information. + +There are four stages in total: + +1. Data preparation +2. Features extraction +3. Training +4. Inference/conversion + +> **NOTE:** You need to run every command of this recipe in the `Amphion` root path: +> +> ```bash +> cd Amphion +> ``` + +## 1. Data Preparation + +### Dataset Download + +For other experiments, we utilize the five datasets for training: M4Singer, Opencpop, OpenSinger, SVCC, and VCTK. How to download them is detailed [here](../../datasets/README.md). + +In this experiment, we only utilize two datasets: VTCK and LibriTTS + +### Configuration + +Specify the dataset path in `exp_config.json`. + +```json + "preprocess": { + "vctk_dir": "[VCTK dataset path]", + // ... + } +``` + +## 2. Features Extraction + +### Pretrained Models Download + +You should download pretrained HiFi-GAN (VCTK_V1) from [its repo](https://github.com/jik876/hifi-gan) according to the original paper. + +The code will automatically download pretrained [WavLM-Large](https://huggingface.co/microsoft/wavlm-large) model from Huggingface. You can also download it in advance: + +```bash +huggingface-cli download microsoft/wavlm-large +``` + +### Configuration + +Specify the data path and the checkpoint path for saving the processed data in `exp_config.json`: + +```json + "preprocess": { + // ... + "vctk_16k_dir": "[preprocessed VCTK 16k directory]", + "vctk_22k_dir": "[preprocessed VCTK 22k directory]", + "spk_dir": "[preprocess_spk directory]", + "ssl_dir": "[preprocess_ssl directory]", + "sr_dir": "[preprocess_sr directory]", + "hifigan_ckpt_path": "[hifigan checkpoint file path]" + // ... + }, +``` + +Note that the preprocessed data will take about 600GB disk space. + +### Run + +Run the `run.sh` as the preproces stage (set `--stage 1`). + +```bash +sh egs/vc/FreeVC/run.sh --stage 1 -c egs/vc/FreeVC/exp_config.json +``` + +> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "1"`. + +## 3. Training + +### Configuration + +We provide the default hyparameters in the `config/freevc.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on you GPU machines. + +```json +"model": { + "use_spk": true + // ... +}, +"train": { + "use_sr": true, + // ... + "batch_size": 32, + // ... + "learning_rate": 2.0e-4 + // ... +} +``` + +### Run + +Run the `run.sh` as the training stage (set `--stage 2`). Specify a experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/vc/FreeVC/[YourExptName]`. + +```bash +sh egs/vc/FreeVC/run.sh --stage 2 -c egs/vc/FreeVC/exp_config.json --name [YourExptName] +``` + +> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`. + +## 4. Inference/Conversion + +### Run + +For inference/conversion, you need to first create a file `convert.txt` indicating the source audio, target audio and the name of the output audio in following format: + +``` +# format +[name of the output]|[path/to/the/source/audio]|[path/to/the/target/audio] + +# an example(each reconstruction written in a line) +title1|data/vctk-16k/p225/p225_001.wav|data/vctk-16k/p226/p226_002.wav +``` + + +Then you should run `run.sh`, you need to specify the following configurations: + +| Parameters | Description | Example | +| ----------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | +| `--config` | The base configuration | `[Your path to the base configuration]` | +| `--ckpt` | The experimental directory which contains `checkpoint` | `[Your path to save logs and checkpoints]/[YourExptName]` | +| `--convert` | The convert.txt path which contains the audios to be reconstructed | `[Your path to save convert.txt]` | +| `--outdir` | The output directory to save inferred audios. | `[Your path to save logs and checkpoints]/[YourExptName]/result` | + +For example: + +```bash +sh egs/svc/VitsSVC/run.sh --stage 3 \ + --config egs/vc/FreeVC/exp_config.json \ + --ckpt ckpts/vc/FreeVC/[YourExptName]/G_5.ckpt \ + --convert ckpts/vc/FreeVC/[YourExptName] \ + --outdir ckpts/vc/FreeVC/[YourExptName]/result \ +``` \ No newline at end of file diff --git a/egs/vc/FreeVC/exp_config.json b/egs/vc/FreeVC/exp_config.json new file mode 100644 index 00000000..1fa17e6c --- /dev/null +++ b/egs/vc/FreeVC/exp_config.json @@ -0,0 +1,18 @@ +{ + "base_config": "config/freevc.json", + "preprocess": { + "vctk_dir": "[VCTK dataset path]", + "vctk_16k_dir": "[preprocessed VCTK 16k directory]", + "vctk_22k_dir": "[preprocessed VCTK 22k directory]", + "spk_dir": "[preprocess_spk directory]", + "ssl_dir": "[preprocess_ssl directory]", + "sr_dir": "[preprocess_sr directory]", + "hifigan_ckpt_path": "[hifigan checkpoint file path]" + }, + "model": { + "use_spk": true + }, + "train": { + "use_sr": true + } +} \ No newline at end of file diff --git a/egs/vc/FreeVC/freevc.json b/egs/vc/FreeVC/freevc.json new file mode 100644 index 00000000..c25a9f1d --- /dev/null +++ b/egs/vc/FreeVC/freevc.json @@ -0,0 +1,54 @@ +{ + "train": { + "log_interval": 200, + "eval_interval": 10000, + "seed": 1234, + "epochs": 10000, + "learning_rate": 2e-4, + "betas": [0.8, 0.99], + "eps": 1e-9, + "batch_size": 64, + "fp16_run": false, + "lr_decay": 0.999875, + "segment_size": 8960, + "init_lr_ratio": 1, + "warmup_epochs": 0, + "c_mel": 45, + "c_kl": 1.0, + "use_sr": true, + "max_speclen": 128, + "port": "8001" + }, + "data": { + "training_files":"filelists/train.txt", + "validation_files":"filelists/val.txt", + "max_wav_value": 32768.0, + "sampling_rate": 16000, + "filter_length": 1280, + "hop_length": 320, + "win_length": 1280, + "n_mel_channels": 80, + "mel_fmin": 0.0, + "mel_fmax": null + }, + "model": { + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + "upsample_rates": [10,8,2,2], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [16,16,4,4], + "n_layers_q": 3, + "use_spectral_norm": false, + "gin_channels": 256, + "ssl_dim": 1024, + "use_spk": true + } +} diff --git a/egs/vc/FreeVC/run.sh b/egs/vc/FreeVC/run.sh new file mode 100644 index 00000000..1c56d24b --- /dev/null +++ b/egs/vc/FreeVC/run.sh @@ -0,0 +1,113 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +######## Build Experiment Environment ########### +exp_dir=$(cd `dirname $0`; pwd) +work_dir=$(dirname $(dirname $(dirname $exp_dir))) + +export WORK_DIR=$work_dir +export PYTHONPATH=$work_dir +export PYTHONIOENCODING=UTF-8 + +######## Parse the Given Parameters from the Command ########### +options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,resume:,resume_from_ckpt_path:,resume_type:,ckpt:,convert:,outdir: -- "$@") +eval set -- "$options" + +while true; do + case $1 in + # Experimental Configuration File + -c | --config) shift; exp_config=$1 ; shift ;; + # Experimental Name + -n | --name) shift; exp_name=$1 ; shift ;; + # Running Stage + -s | --stage) shift; running_stage=$1 ; shift ;; + # Visible GPU machines. The default value is "0". + --gpu) shift; gpu=$1 ; shift ;; + + # [Only for Training] Resume configuration + --resume) shift; resume=$1 ; shift ;; + # [Only for Training] The specific checkpoint path that you want to resume from. + --resume_from_ckpt_path) shift; resume_from_ckpt_path=$1 ; shift ;; + # [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights. + --resume_type) shift; resume_type=$1 ; shift ;; + + # [Only for Inference] The path of saved checkpoint. + --ckpt) shift; ckpt=$1 ; shift ;; + # [Only for Inference] The path of convert file + --convert) shift; convert=$1 ; shift ;; + # [Only for Inference] The output dir to save inferred audios. + --outdir) shift; outdir=$1 ; shift ;; + # [Only for Inference] Whether to use timestamp in the output filename. + --use_timestamp) shift; use_timestamp=$1 ; shift ;; + --) shift ; break ;; + *) echo "Invalid option: $1" exit 1 ;; + esac +done + + +### Value check ### +if [ -z "$running_stage" ]; then + echo "[Error] Please specify the running stage" + exit 1 +fi + +if [ -z "$exp_config" ]; then + exp_config="${exp_dir}"/exp_config.json +fi +echo "Exprimental Configuration File: $exp_config" + +if [ -z "$gpu" ]; then + gpu="0" +fi + +######## Features Extraction ########### +if [ $running_stage -eq 1 ]; then + CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/models/vc/FreeVC/preprocess.py \ + --config $exp_config +fi + +######## Training ########### +if [ $running_stage -eq 2 ]; then + if [ -z "$exp_name" ]; then + echo "[Error] Please specify the experiments name" + exit 1 + fi + echo "Exprimental Name: $exp_name" + + # add default value + if [ -z "$resume_from_ckpt_path" ]; then + resume_from_ckpt_path="" + fi + + if [ -z "$resume_type" ]; then + resume_type="resume" + fi + + if [ "$resume" = true ]; then + echo "Resume from the existing experiment..." + CUDA_VISIBLE_DEVICES="$gpu" python "${work_dir}"/models/vc/FreeVC/train.py \ + --config "$exp_config" \ + --exp_name "$exp_name" \ + --log_level info \ + --resume \ + --resume_from_ckpt_path "$resume_from_ckpt_path" \ + --resume_type "$resume_type" + else + echo "Start a new experiment..." + CUDA_VISIBLE_DEVICES="$gpu" python "${work_dir}"/models/vc/FreeVC/train.py \ + --config "$exp_config" \ + --exp_name "$exp_name" \ + --log_level info + fi +fi + +######## Inference/Conversion ########### +if [ $running_stage -eq 3 ]; then + CUDA_VISIBLE_DEVICES=$gpu python "$work_dir"/models/vc/FreeVC/inference.py \ + --config $exp_config \ + --ckpt $ckpt \ + --convert $convert \ + --outdir $outdir +fi diff --git a/models/vc/FreeVC/__init__.py b/models/vc/FreeVC/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/vc/FreeVC/commons.py b/models/vc/FreeVC/commons.py new file mode 100644 index 00000000..6c6d8c8c --- /dev/null +++ b/models/vc/FreeVC/commons.py @@ -0,0 +1,17 @@ +from models.tts.vits.vits import ( + slice_segments, + rand_slice_segments, + get_padding, +) # noqa: F401 + +import torch + + +def rand_spec_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str diff --git a/models/vc/FreeVC/data.py b/models/vc/FreeVC/data.py new file mode 100644 index 00000000..41460ff1 --- /dev/null +++ b/models/vc/FreeVC/data.py @@ -0,0 +1,287 @@ +from .mel_processing import spectrogram_torch +from .commons import rand_spec_segments, slice_segments + +import os +import random +import torch +from torch.utils.data import Dataset, Sampler +import torchaudio + + +def read_txt_lines(path): + ans = [] + with open(str(path), "r") as f: + for line in f.readlines(): + line = line.strip() + if line != "": + ans.append(line) + return ans + + +class FreeVCDataset(Dataset): + def __init__(self, audiopaths, hparams): + self.audiopaths = read_txt_lines(audiopaths) + + self.max_wav_value = hparams.data.max_wav_value + self.sampling_rate = hparams.data.sampling_rate + self.filter_length = hparams.data.filter_length + self.hop_length = hparams.data.hop_length + self.win_length = hparams.data.win_length + self.sampling_rate = hparams.data.sampling_rate + self.use_sr = hparams.train.use_sr + self.use_spk = hparams.model.use_spk + self.spec_len = hparams.train.max_speclen + + self.vctk_16k_dir = hparams.preprocess.vctk_16k_dir + self.spk_dir = hparams.preprocess.spk_dir + self.ssl_dir = hparams.preprocess.ssl_dir + self.sr_dir = hparams.preprocess.sr_dir + + random.shuffle(self.audiopaths) + self._filter() + + def _filter(self): + """ + Filter text & store spec lengths + """ + # Store spectrogram lengths for Bucketing + # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) + # spec_length = wav_length // hop_length + + lengths = [] + for audiopath in self.audiopaths: + path = os.path.join(self.vctk_16k_dir, audiopath) + lengths.append(os.path.getsize(path) // (2 * self.hop_length)) + self.lengths = lengths + + @torch.no_grad() + def load_sample(self, filename): + filepath = os.path.join(self.vctk_16k_dir, filename) + audio, sampling_rate = torchaudio.load(filepath) + if sampling_rate != self.sampling_rate: + raise ValueError( + f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR" + ) + + audio_norm = audio / self.max_wav_value + + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ).squeeze_(0) + + if self.use_spk: + spk_path = os.path.join(self.spk_dir, filename.replace(".wav", ".pt")) + spk = torch.load(spk_path) + else: + spk = None + + if not self.use_sr: + ssl_path = os.path.join(self.ssl_dir, filename.replace(".wav", ".pt")) + ssl = torch.load(ssl_path).squeeze_(0) + else: + h = random.randint(68, 92) + ssl_path = os.path.join(self.sr_dir, filename.replace(".wav", f"_{h}.pt")) + ssl = torch.load(ssl_path).squeeze_(0) + + return ssl, spec, audio_norm, spk + + def __getitem__(self, index): + return self.load_sample(self.audiopaths[index]) + + def __len__(self): + return len(self.audiopaths) + + +class FreeVCCollate: + def __init__(self, hps): + self.hps = hps + self.use_sr = hps.train.use_sr + self.use_spk = hps.model.use_spk + + def __call__(self, batch): + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True + ) + + max_spec_len = max([x[1].size(1) for x in batch]) + max_wav_len = max([x[2].size(1) for x in batch]) + + spec_lengths = torch.LongTensor(len(batch)) + wav_lengths = torch.LongTensor(len(batch)) + if self.use_spk: + spks = torch.FloatTensor(len(batch), batch[0][3].size(0)) + else: + spks = None + + c_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len) + spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) + wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) + c_padded.zero_() + spec_padded.zero_() + wav_padded.zero_() + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + + c = row[0] + c_padded[i, :, : c.size(1)] = c + + spec = row[1] + spec_padded[i, :, : spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + wav = row[2] + wav_padded[i, :, : wav.size(1)] = wav + wav_lengths[i] = wav.size(1) + + if self.use_spk: + spks[i] = row[3] # type: ignore + + spec_seglen = ( + spec_lengths[-1] + if spec_lengths[-1] < self.hps.train.max_speclen + 1 + else self.hps.train.max_speclen + 1 + ) + wav_seglen = spec_seglen * self.hps.data.hop_length + + spec_padded, ids_slice = rand_spec_segments( + spec_padded, + spec_lengths, + spec_seglen, # type: ignore + ) + wav_padded = slice_segments( + wav_padded, ids_slice * self.hps.data.hop_length, wav_seglen + ) + + c_padded = slice_segments(c_padded, ids_slice, spec_seglen)[:, :, :-1] # type: ignore + + spec_padded = spec_padded[:, :, :-1] + wav_padded = wav_padded[:, :, : -self.hps.data.hop_length] + + return c_padded, spec_padded, wav_padded, spks + + +class BucketSampler(Sampler): + """ + Maintain similar input lengths in a batch. + Length groups are specified by boundaries. + Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. + + It removes samples which are not included in the boundaries. + Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. + """ + + def __init__( + self, + dataset, + batch_size, + boundaries, + shuffle=True, + ): + super().__init__(dataset) + self.lengths = dataset.lengths + self.batch_size = batch_size + self.boundaries = boundaries + + self.num_replicas = 1 + + self.buckets, self.num_samples_per_bucket = self._create_buckets() + self.total_size = sum(self.num_samples_per_bucket) + self.num_samples = self.total_size // self.num_replicas + + self.shuffle = shuffle + + def _create_buckets(self): + buckets = [[] for _ in range(len(self.boundaries) - 1)] + for i in range(len(self.lengths)): + length = self.lengths[i] + idx_bucket = self._bisect(length) + if idx_bucket != -1: + buckets[idx_bucket].append(i) + + for i in range(len(buckets) - 1, 0, -1): + if len(buckets[i]) == 0: + buckets.pop(i) + self.boundaries.pop(i + 1) + + num_samples_per_bucket = [] + for i in range(len(buckets)): + len_bucket = len(buckets[i]) + total_batch_size = self.num_replicas * self.batch_size + rem = ( + total_batch_size - (len_bucket % total_batch_size) + ) % total_batch_size + num_samples_per_bucket.append(len_bucket + rem) + return buckets, num_samples_per_bucket + + def __iter__(self): + # # deterministically shuffle based on epoch + # g = torch.Generator() + # g.manual_seed(self.epoch) + + indices = [] + if self.shuffle: + for bucket in self.buckets: + indices.append(torch.randperm(len(bucket)).tolist()) + else: + for bucket in self.buckets: + indices.append(list(range(len(bucket)))) + + batches = [] + for i in range(len(self.buckets)): + bucket = self.buckets[i] + len_bucket = len(bucket) + ids_bucket = indices[i] + num_samples_bucket = self.num_samples_per_bucket[i] + + # add extra samples to make it evenly divisible + rem = num_samples_bucket - len_bucket + ids_bucket = ( + ids_bucket + + ids_bucket * (rem // len_bucket) + + ids_bucket[: (rem % len_bucket)] + ) + + # # subsample + # ids_bucket = ids_bucket[self.rank :: self.num_replicas] + + # batching + for j in range(len(ids_bucket) // self.batch_size): + batch = [ + bucket[idx] + for idx in ids_bucket[ + j * self.batch_size : (j + 1) * self.batch_size + ] + ] + batches.append(batch) + + if self.shuffle: + batch_ids = torch.randperm(len(batches)).tolist() + batches = [batches[i] for i in batch_ids] + self.batches = batches + + assert len(self.batches) * self.batch_size == self.num_samples + return iter(self.batches) + + def _bisect(self, x, lo=0, hi=None): + if hi is None: + hi = len(self.boundaries) - 1 + + if hi > lo: + mid = (hi + lo) // 2 + if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: + return mid + elif x <= self.boundaries[mid]: + return self._bisect(x, lo, mid) + else: + return self._bisect(x, mid + 1, hi) + else: + return -1 + + def __len__(self): + return self.num_samples // self.batch_size diff --git a/models/vc/FreeVC/hifigan.py b/models/vc/FreeVC/hifigan.py new file mode 100644 index 00000000..62d778b6 --- /dev/null +++ b/models/vc/FreeVC/hifigan.py @@ -0,0 +1,246 @@ +# ruff: noqa: E741 + +import os + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +from omegaconf import OmegaConf + + +# hifigan vctk-v1 +def load_hifigan(ckpt_path): + config = OmegaConf.load(os.path.join(ckpt_path, "config.json")) + ckpt = torch.load(os.path.join(ckpt_path, "generator_v1")) + + vocoder = Generator(config) + vocoder.load_state_dict(ckpt["generator"]) + vocoder.eval() + vocoder.remove_weight_norm() + return vocoder, config + + +# ----------------------------------------- +# Copied from https://github.com/jik876/hifi-gan/tree/4769534d45265d52a904b850da5a622601885777 +# MIT License +# ----------------------------------------- +# COPY START + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock1 if h.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels # type:ignore + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + # print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +# COPY END +# ----------------------------------------------- diff --git a/models/vc/FreeVC/inference.py b/models/vc/FreeVC/inference.py new file mode 100644 index 00000000..e82c0161 --- /dev/null +++ b/models/vc/FreeVC/inference.py @@ -0,0 +1,91 @@ +from models.vc.FreeVC.model import SynthesizerTrn +from models.vc.FreeVC.wavlm import load_wavlm +from models.vc.FreeVC.mel_processing import mel_spectrogram_torch +from speaker_encoder.voice_encoder import SpeakerEncoder +from utils.util import load_config +from models.vc.FreeVC.train_utils import load_checkpoint +from models.vc.FreeVC.preprocess import calc_ssl_features + +import os +import argparse +import torch +import librosa +from scipy.io import wavfile +from tqdm import tqdm +from typing import Any + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str) + parser.add_argument("--ckpt", type=str, help="path to pth file") + parser.add_argument("--convert", type=str, help="path to txt file") + parser.add_argument("--outdir", type=str, help="path to output dir") + args = parser.parse_args() + + os.makedirs(args.outdir, exist_ok=True) + cfg: Any = load_config(args.config) + + print("Loading model...") + net_g = SynthesizerTrn( + cfg.data.filter_length // 2 + 1, # type:ignore + cfg.train.segment_size // cfg.data.hop_length, # type:ignore + **cfg.model, # type:ignore + ).cuda() + net_g.eval() + print("Loading checkpoint...") + _ = load_checkpoint(args.ckpt, net_g, None, True) + + print("Loading WavLM for content...") + wavlm = load_wavlm().cuda() # type:ignore + + if cfg.model.use_spk: + print("Loading speaker encoder...") + spk_path = os.path.join( + os.path.dirname(__file__), "speaker_encoder/ckpt/pretrained_bak_5805000.pt" + ) + smodel = SpeakerEncoder(spk_path) + + print("Processing text...") + titles, srcs, tgts = [], [], [] + with open(args.convert, "r") as f: + for rawline in f.readlines(): + title, src, tgt = rawline.strip().split("|") + titles.append(title) + srcs.append(src) + tgts.append(tgt) + + print("Synthesizing...") + with torch.no_grad(): + for line in tqdm(zip(titles, srcs, tgts)): + title, src, tgt = line + # tgt + wav_tgt, _ = librosa.load(tgt, sr=cfg.data.sampling_rate) + wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20) + if cfg.model.use_spk: + g_tgt = smodel.embed_utterance(wav_tgt) + g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).cuda() + else: + wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).cuda() + mel_tgt = mel_spectrogram_torch( + wav_tgt, + cfg.data.filter_length, + cfg.data.n_mel_channels, + cfg.data.sampling_rate, + cfg.data.hop_length, + cfg.data.win_length, + cfg.data.mel_fmin, + cfg.data.mel_fmax, + ) + # src + wav_src, _ = librosa.load(src, sr=cfg.data.sampling_rate) + wav_src = torch.from_numpy(wav_src).unsqueeze(0).cuda() + c = calc_ssl_features(wavlm, wav_src) + + if cfg.model.use_spk: + audio = net_g.infer(c, g=g_tgt) + else: + audio = net_g.infer(c, mel=mel_tgt) + audio = audio[0][0].data.cpu().float().numpy() + + outpath = os.path.join(args.outdir, f"{title}.wav") + wavfile.write(outpath, cfg.data.sampling_rate, audio) diff --git a/models/vc/FreeVC/losses.py b/models/vc/FreeVC/losses.py new file mode 100644 index 00000000..9b9d991c --- /dev/null +++ b/models/vc/FreeVC/losses.py @@ -0,0 +1,62 @@ +# ruff: noqa: E741 + +# Copied from https://github.com/OlaWod/FreeVC/tree/81c169cdbfc97ff07ee2f501e9b88d543fc46126 + +import torch + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + # print(logs_p) + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l diff --git a/models/vc/FreeVC/mel_processing.py b/models/vc/FreeVC/mel_processing.py new file mode 100644 index 00000000..3188aca5 --- /dev/null +++ b/models/vc/FreeVC/mel_processing.py @@ -0,0 +1,147 @@ +# Copied from https://github.com/OlaWod/FreeVC/tree/81c169cdbfc97ff07ee2f501e9b88d543fc46126 + +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=spec.dtype, device=spec.device + ) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=y.dtype, device=y.device + ) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/models/vc/FreeVC/model.py b/models/vc/FreeVC/model.py new file mode 100644 index 00000000..9034e87f --- /dev/null +++ b/models/vc/FreeVC/model.py @@ -0,0 +1,326 @@ +# ruff: noqa: E741 + +from .commons import rand_slice_segments, get_padding + +from models.tts.vits import vits + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from torch.nn import Conv1d, Conv2d +from torch.nn.utils import weight_norm, spectral_norm + + +class Encoder(vits.PosteriorEncoder): + pass + + +class SpeakerEncoder(torch.nn.Module): + def __init__( + self, + mel_n_channels=80, + model_num_layers=3, + model_hidden_size=256, + model_embedding_size=256, + ): + super().__init__() + self.lstm = nn.LSTM( + mel_n_channels, model_hidden_size, model_num_layers, batch_first=True + ) + self.linear = nn.Linear(model_hidden_size, model_embedding_size) + self.relu = nn.ReLU() + + def forward(self, mels): + self.lstm.flatten_parameters() + _, (hidden, _) = self.lstm(mels) + embeds_raw = self.relu(self.linear(hidden[-1])) + return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + def compute_partial_slices(self, total_frames, partial_frames, partial_hop): + mel_slices = [] + for i in range(0, total_frames - partial_frames, partial_hop): + mel_range = torch.arange(i, i + partial_frames) + mel_slices.append(mel_range) + + return mel_slices + + def embed_utterance(self, mel, partial_frames=128, partial_hop=64): + mel_len = mel.size(1) + last_mel = mel[:, -partial_frames:] + + if mel_len > partial_frames: + mel_slices = self.compute_partial_slices( + mel_len, partial_frames, partial_hop + ) + mels = list(mel[:, s] for s in mel_slices) + mels.append(last_mel) + mels = torch.stack(tuple(mels), 0).squeeze(1) + + with torch.no_grad(): + partial_embeds = self(mels) + embed = torch.mean(partial_embeds, dim=0).unsqueeze(0) + # embed = embed / torch.linalg.norm(embed, 2) + else: + with torch.no_grad(): + embed = self(last_mel) + + return embed + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + ssl_dim, + use_spk, + **kwargs, + ): + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + self.ssl_dim = ssl_dim + self.use_spk = use_spk + + self.enc_p = Encoder(ssl_dim, inter_channels, hidden_channels, 5, 1, 16) + self.dec = vits.Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = Encoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = vits.ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels + ) + + if not self.use_spk: + self.enc_spk = SpeakerEncoder( + model_hidden_size=gin_channels, model_embedding_size=gin_channels + ) + + def forward(self, c, spec, g=None, mel=None, c_lengths=None, spec_lengths=None): + if c_lengths is None: + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + if spec_lengths is None: + spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device) + + if not self.use_spk: + g = self.enc_spk(mel.transpose(1, 2)) # type:ignore + g = g.unsqueeze(-1) # type:ignore + + _, m_p, logs_p, _ = self.enc_p(c, c_lengths) + z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g) + z_p = self.flow(z, spec_mask, g=g) + + z_slice, ids_slice = rand_slice_segments(z, spec_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + + return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, c, g=None, mel=None, c_lengths=None): + if c_lengths is None: + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + if not self.use_spk: + g = self.enc_spk.embed_utterance(mel.transpose(1, 2)) # type:ignore + g = g.unsqueeze(-1) # type:ignore + + z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths) + z = self.flow(z_p, c_mask, g=g, reverse=True) + o = self.dec(z * c_mask, g=g) + + return o + + +# ----------------------------------------------- +# Copied from https://github.com/jaywalnut310/vits/tree/2e561ba58618d021b5b8323d3765880f7e0ecfdb +# MIT License +# ----------------------------------------------- +# COPY START + +LRELU_SLOPE = 0.1 + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for _, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# COPY END +# ----------------------------------------------- diff --git a/models/vc/FreeVC/preprocess.py b/models/vc/FreeVC/preprocess.py new file mode 100644 index 00000000..9789e5b3 --- /dev/null +++ b/models/vc/FreeVC/preprocess.py @@ -0,0 +1,319 @@ +from models.vc.FreeVC.wavlm import load_wavlm +from models.vc.FreeVC.hifigan import load_hifigan +from models.vc.FreeVC.mel_processing import mel_spectrogram_torch +from utils.util import load_config + +from speaker_encoder.voice_encoder import SpeakerEncoder +import speaker_encoder.audio + +import os +import random +from typing import Optional +import argparse + +import torch +import numpy as np +import librosa +from scipy.io import wavfile +from multiprocessing import Pool, cpu_count +from tqdm import tqdm +from glob import glob +import torchaudio +import torchvision.transforms.v2 + + +def downsample(args): + in_dir, wav_name, target = args + + speaker = wav_name[:4] + wav_path = os.path.join(in_dir, speaker, wav_name) + + if not os.path.exists(wav_path): + return + + # speaker 's5', 'p280', 'p315' are excluded, + if "_mic2.flac" not in wav_path: + return + + wav = None + + for out_dir, target_sr in target: + save_name = wav_name.replace("_mic2.flac", ".wav") + save_path = os.path.join(out_dir, speaker, save_name) + if os.path.exists(save_path): + continue + + if wav is None: + wav, src_sr = librosa.load(wav_path) + wav, _ = librosa.effects.trim(wav, top_db=20) + peak = np.abs(wav).max() + if peak > 1.0: + wav = 0.98 * wav / peak + + os.makedirs(os.path.join(out_dir, speaker), exist_ok=True) + target_wav = librosa.resample(wav, orig_sr=src_sr, target_sr=target_sr) + wavfile.write( + save_path, target_sr, (target_wav * np.iinfo(np.int16).max).astype(np.int16) + ) + + +def resample_vctk(*, vctk_dir, vctk_16k_dir, vctk_22k_dir): + print("Start resampling VCTK dataset...") + + target = [(vctk_16k_dir, 16000), (vctk_22k_dir, 22050)] + + pool = Pool(processes=cpu_count() - 2) + + in_dir = os.path.join(vctk_dir, "wav48_silence_trimmed") + + wav_names = [] + for speaker in os.listdir(in_dir): + spk_dir = os.path.join(vctk_dir, speaker) + if os.path.isdir(spk_dir): + wav_names.extend(os.listdir(spk_dir)) + + tasks = [(vctk_dir, wav_name, target) for wav_name in wav_names] + + with tqdm(total=len(tasks)) as pbar: + for _ in pool.imap_unordered(downsample, tasks): + pbar.update() + + print("Done!") + + +def generate_split(*, vctk_16k_dir, split_dir): + print("Start generating split...") + + src_dir = vctk_16k_dir + + train = [] + val = [] + test = [] + + for speaker in os.listdir(src_dir): + wav_names = os.listdir(os.path.join(src_dir, speaker)) + random.shuffle(wav_names) + train.extend(wav_names[2:-10]) + val.extend(wav_names[:2]) + test.extend(wav_names[-10:]) + + random.shuffle(train) + random.shuffle(val) + random.shuffle(test) + + train_list = os.path.join(split_dir, "train.txt") + val_list = os.path.join(split_dir, "val.txt") + test_list = os.path.join(split_dir, "test.txt") + + os.makedirs(split_dir, exist_ok=True) + + for list_path, wav_names in zip( + [train_list, val_list, test_list], [train, val, test] + ): + with open(list_path, "w") as f: + for wav_name in wav_names: + speaker = wav_name[:4] + f.write(f"{speaker}/{wav_name}" + "\n") + + print("Done!") + + +def preprocess_spk(*, vctk_16k_dir, preprocess_spk_dir): + in_dir = vctk_16k_dir + out_dir = preprocess_spk_dir + + wav_names = [] + for speaker in os.listdir(in_dir): + spk_dir = os.path.join(in_dir, speaker) + if os.path.isdir(spk_dir): + wav_names.extend(os.listdir(spk_dir)) + + pretrained_spk_ckpt_path = os.path.join( + os.path.dirname(__file__), "speaker_encoder/ckpt/pretrained_bak_5805000.pt" + ) + spk_encoder = SpeakerEncoder(pretrained_spk_ckpt_path) + + for wav_name in tqdm(wav_names): + speaker = wav_name[:4] + save_path = os.path.join(out_dir, speaker, wav_name.replace(".wav", ".pt")) + + if os.path.exists(save_path): + continue + + wav_path = os.path.join(in_dir, speaker, wav_name) + spk_wav = speaker_encoder.audio.preprocess_wav(wav_path) + spk = spk_encoder.embed_utterance(spk_wav) + spk = torch.from_numpy(spk) + + os.makedirs(os.path.join(out_dir, speaker), exist_ok=True) + torch.save(spk, save_path) + + +@torch.no_grad() +def calc_ssl_features(wavlm, wav): + return wavlm(wav).last_hidden_state.transpose(1, 2) + + +def preprocess_ssl(*, vctk_16k_dir, preprocess_ssl_dir): + print("Start preprocessing SSL features...") + + in_dir = vctk_16k_dir + out_dir = preprocess_ssl_dir + sr = 16000 + + model = load_wavlm().cuda() # type:ignore + filenames = glob(f"{in_dir}/*/*.wav", recursive=True) + + for filename in tqdm(filenames): + wav_name = os.path.basename(filename) + speaker = wav_name[:4] + + save_dir = os.path.join(out_dir, speaker) + save_path = os.path.join(save_dir, wav_name.replace(".wav", ".pt")) + if os.path.exists(save_path): + continue + + os.makedirs(save_dir, exist_ok=True) + wav, _ = librosa.load(filename, sr=sr) + wav = torch.from_numpy(wav).unsqueeze_(0).cuda() + ssl_features = calc_ssl_features(model, wav) + torch.save(ssl_features.cpu(), save_path) + + print("Done!") + + +def mel_resize(mel, height): # 68-92 + tgt = torchvision.transforms.v2.functional.resize(mel, [height, mel.size(-1)]) + if height >= mel.size(-2): + return tgt[:, : mel.size(-2), :] + else: + silence = tgt[:, -1:, :].repeat(1, mel.size(-2) - height, 1) + silence += torch.randn_like(silence) / 10 + return torch.cat((tgt, silence), 1) + + +@torch.no_grad() +def preprocess_sr( + *, + vctk_22k_dir: str, + preprocess_sr_dir: str, + hifigan_ckpt_path: str, + minh: int = 68, + maxh: int = 92, + cuda_rank: Optional[int] = None, + cuda_total: Optional[int] = None, +): + assert 68 <= minh <= maxh <= 92 + + in_dir = vctk_22k_dir + out_dir = preprocess_sr_dir + + wavlm = load_wavlm() + hifigan, hifigan_config = load_hifigan(hifigan_ckpt_path) + + device = ( + torch.device(f"cuda:{cuda_rank}") + if cuda_rank is not None + else torch.device("cuda") + ) + + wavlm = wavlm.to(device) # type:ignore + hifigan = hifigan.to(device) # type:ignore + + target_sr = 16000 + resample = torchaudio.transforms.Resample( + orig_freq=hifigan_config.sampling_rate, new_freq=target_sr + ).to(device) + + filenames = glob(f"{in_dir}/*/*.wav", recursive=True) + filenames.sort() + + if cuda_rank is not None: + assert cuda_total is not None + filenames = filenames[cuda_rank::cuda_total] + + with tqdm(total=len(filenames) * (maxh - minh + 1)) as pbar: + for filename in filenames: + wav_name = os.path.basename(filename) + speaker = wav_name[:4] + + odir = os.path.join(out_dir, speaker) + os.makedirs(odir, exist_ok=True) + + wav, sr = torchaudio.load(filename) + assert sr == hifigan_config.sampling_rate + wav = wav.to(device) + + mel = mel_spectrogram_torch( + wav, + n_fft=hifigan_config.n_fft, + num_mels=hifigan_config.num_mels, + sampling_rate=hifigan_config.sampling_rate, + hop_size=hifigan_config.hop_size, + win_size=hifigan_config.win_size, + fmin=hifigan_config.fmin, + fmax=hifigan_config.fmax, + ) + + for h in range(minh, maxh + 1): + ssl_path = os.path.join(odir, wav_name.replace(".wav", f"_{h}.pt")) + wav_path = os.path.join(odir, wav_name.replace(".wav", f"_{h}.wav")) + + if not os.path.exists(wav_path): + mel_rs = mel_resize(mel, h) + + wav_rs = hifigan(mel_rs)[0] + assert wav_rs.shape[0] == 1 + + wav_rs = resample(wav_rs) + + ssl_features = calc_ssl_features(wavlm, wav_rs) + torch.save(ssl_features.cpu(), ssl_path) + wavfile.write(wav_path, target_sr, wav_rs.cpu().numpy().squeeze(0)) + + pbar.update() + + +def preprocess(cfg, args): + resample_vctk( + vctk_dir=cfg.preprocess.vctk_dir, + vctk_16k_dir=cfg.preprocess.vctk_16k_dir, + vctk_22k_dir=cfg.preprocess.vctk_22k_dir, + ) + generate_split( + vctk_16k_dir=cfg.preprocess.vctk_16k_dir, + split_dir=cfg.preprocess.split_dir, + ) + preprocess_spk( + vctk_16k_dir=cfg.preprocess.vctk_16k_dir, + preprocess_spk_dir=cfg.preprocess.spk_dir, + ) + preprocess_ssl( + vctk_16k_dir=cfg.preprocess.vctk_16k_dir, + preprocess_ssl_dir=cfg.preprocess.ssl_dir, + ) + preprocess_sr( + vctk_22k_dir=cfg.preprocess.vctk_22k_dir, + preprocess_sr_dir=cfg.preprocess.sr_dir, + hifigan_ckpt_path=cfg.preprocess.hifigan_ckpt_path, + minh=cfg.preprocess.minh, + maxh=cfg.preprocess.maxh, + cuda_rank=args.cuda_rank, + cuda_total=args.cuda_total, + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str) + parser.add_argument("--cuda_rank", type=int, default=None) + parser.add_argument("--cuda_total", type=int, default=None) + + args = parser.parse_args() + cfg = load_config(args.config) + + preprocess(cfg, args) + + +if __name__ == "__main__": + main() diff --git a/models/vc/FreeVC/speaker_encoder/__init__.py b/models/vc/FreeVC/speaker_encoder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/vc/FreeVC/speaker_encoder/audio.py b/models/vc/FreeVC/speaker_encoder/audio.py new file mode 100644 index 00000000..9c637a2d --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/audio.py @@ -0,0 +1,115 @@ +from scipy.ndimage.morphology import binary_dilation +from speaker_encoder.params_data import * +from pathlib import Path +from typing import Optional, Union +import numpy as np +import webrtcvad +import librosa +import struct + +int16_max = (2**15) - 1 + + +def preprocess_wav( + fpath_or_wav: Union[str, Path, np.ndarray], source_sr: Optional[int] = None +): + """ + Applies the preprocessing operations used in training the Speaker Encoder to a waveform + either on disk or in memory. The waveform will be resampled to match the data hyperparameters. + + :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not + just .wav), either the waveform as a numpy array of floats. + :param source_sr: if passing an audio waveform, the sampling rate of the waveform before + preprocessing. After preprocessing, the waveform's sampling rate will match the data + hyperparameters. If passing a filepath, the sampling rate will be automatically detected and + this argument will be ignored. + """ + # Load the wav from disk if needed + if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): + wav, source_sr = librosa.load(fpath_or_wav, sr=None) + else: + wav = fpath_or_wav + + # Resample the wav if needed + if source_sr is not None and source_sr != sampling_rate: + wav = librosa.resample(wav, source_sr, sampling_rate) + + # Apply the preprocessing: normalize volume and shorten long silences + wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) + wav = trim_long_silences(wav) + + return wav + + +def wav_to_mel_spectrogram(wav): + """ + Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. + Note: this not a log-mel spectrogram. + """ + frames = librosa.feature.melspectrogram( + y=wav, + sr=sampling_rate, + n_fft=int(sampling_rate * mel_window_length / 1000), + hop_length=int(sampling_rate * mel_window_step / 1000), + n_mels=mel_n_channels, + ) + return frames.astype(np.float32).T + + +def trim_long_silences(wav): + """ + Ensures that segments without voice in the waveform remain no longer than a + threshold determined by the VAD parameters in params.py. + + :param wav: the raw waveform as a numpy array of floats + :return: the same waveform with silences trimmed away (length <= original wav length) + """ + # Compute the voice detection window size + samples_per_window = (vad_window_length * sampling_rate) // 1000 + + # Trim the end of the audio to have a multiple of the window size + wav = wav[: len(wav) - (len(wav) % samples_per_window)] + + # Convert the float waveform to 16-bit mono PCM + pcm_wave = struct.pack( + "%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16) + ) + + # Perform voice activation detection + voice_flags = [] + vad = webrtcvad.Vad(mode=3) + for window_start in range(0, len(wav), samples_per_window): + window_end = window_start + samples_per_window + voice_flags.append( + vad.is_speech( + pcm_wave[window_start * 2 : window_end * 2], sample_rate=sampling_rate + ) + ) + voice_flags = np.array(voice_flags) + + # Smooth the voice detection with a moving average + def moving_average(array, width): + array_padded = np.concatenate( + (np.zeros((width - 1) // 2), array, np.zeros(width // 2)) + ) + ret = np.cumsum(array_padded, dtype=float) + ret[width:] = ret[width:] - ret[:-width] + return ret[width - 1 :] / width + + audio_mask = moving_average(voice_flags, vad_moving_average_width) + audio_mask = np.round(audio_mask).astype(np.bool) + + # Dilate the voiced regions + audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) + audio_mask = np.repeat(audio_mask, samples_per_window) + + return wav[audio_mask == True] + + +def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): + if increase_only and decrease_only: + raise ValueError("Both increase only and decrease only are set") + dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2)) + if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): + return wav + return wav * (10 ** (dBFS_change / 20)) diff --git a/models/vc/FreeVC/speaker_encoder/ckpt/pretrained_bak_5805000.pt.txt b/models/vc/FreeVC/speaker_encoder/ckpt/pretrained_bak_5805000.pt.txt new file mode 100644 index 00000000..db714220 --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/ckpt/pretrained_bak_5805000.pt.txt @@ -0,0 +1 @@ +https://github.com/liusongxiang/ppg-vc/tree/main/speaker_encoder/ckpt \ No newline at end of file diff --git a/models/vc/FreeVC/speaker_encoder/compute_embed.py b/models/vc/FreeVC/speaker_encoder/compute_embed.py new file mode 100644 index 00000000..cb7d6499 --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/compute_embed.py @@ -0,0 +1,43 @@ +from speaker_encoder import inference as encoder +from multiprocessing.pool import Pool +from functools import partial +from pathlib import Path + +# from utils import logmmse +# from tqdm import tqdm +# import numpy as np +# import librosa + + +def embed_utterance(fpaths, encoder_model_fpath): + if not encoder.is_loaded(): + encoder.load_model(encoder_model_fpath) + + # Compute the speaker embedding of the utterance + wav_fpath, embed_fpath = fpaths + wav = np.load(wav_fpath) + wav = encoder.preprocess_wav(wav) + embed = encoder.embed_utterance(wav) + np.save(embed_fpath, embed, allow_pickle=False) + + +def create_embeddings( + outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int +): + + wav_dir = outdir_root.joinpath("audio") + metadata_fpath = synthesizer_root.joinpath("train.txt") + assert wav_dir.exists() and metadata_fpath.exists() + embed_dir = synthesizer_root.joinpath("embeds") + embed_dir.mkdir(exist_ok=True) + + # Gather the input wave filepath and the target output embed filepath + with metadata_fpath.open("r") as metadata_file: + metadata = [line.split("|") for line in metadata_file] + fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata] + + # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here. + # Embed the utterances in separate threads + func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath) + job = Pool(n_processes).imap(func, fpaths) + list(tqdm(job, "Embedding", len(fpaths), unit="utterances")) diff --git a/models/vc/FreeVC/speaker_encoder/config.py b/models/vc/FreeVC/speaker_encoder/config.py new file mode 100644 index 00000000..bde2ffb9 --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/config.py @@ -0,0 +1,27 @@ +librispeech_datasets = { + "train": { + "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"], + "other": ["LibriSpeech/train-other-500"], + }, + "test": {"clean": ["LibriSpeech/test-clean"], "other": ["LibriSpeech/test-other"]}, + "dev": {"clean": ["LibriSpeech/dev-clean"], "other": ["LibriSpeech/dev-other"]}, +} +libritts_datasets = { + "train": { + "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"], + "other": ["LibriTTS/train-other-500"], + }, + "test": {"clean": ["LibriTTS/test-clean"], "other": ["LibriTTS/test-other"]}, + "dev": {"clean": ["LibriTTS/dev-clean"], "other": ["LibriTTS/dev-other"]}, +} +voxceleb_datasets = { + "voxceleb1": {"train": ["VoxCeleb1/wav"], "test": ["VoxCeleb1/test_wav"]}, + "voxceleb2": {"train": ["VoxCeleb2/dev/aac"], "test": ["VoxCeleb2/test_wav"]}, +} + +other_datasets = [ + "LJSpeech-1.1", + "VCTK-Corpus/wav48", +] + +anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"] diff --git a/models/vc/FreeVC/speaker_encoder/data_objects/__init__.py b/models/vc/FreeVC/speaker_encoder/data_objects/__init__.py new file mode 100644 index 00000000..2f981b82 --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/data_objects/__init__.py @@ -0,0 +1,6 @@ +from speaker_encoder.data_objects.speaker_verification_dataset import ( + SpeakerVerificationDataset, +) +from speaker_encoder.data_objects.speaker_verification_dataset import ( + SpeakerVerificationDataLoader, +) diff --git a/models/vc/FreeVC/speaker_encoder/data_objects/random_cycler.py b/models/vc/FreeVC/speaker_encoder/data_objects/random_cycler.py new file mode 100644 index 00000000..b968ebd7 --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/data_objects/random_cycler.py @@ -0,0 +1,37 @@ +import random + + +class RandomCycler: + """ + Creates an internal copy of a sequence and allows access to its items in a constrained random + order. For a source sequence of n items and one or several consecutive queries of a total + of m items, the following guarantees hold (one implies the other): + - Each item will be returned between m // n and ((m - 1) // n) + 1 times. + - Between two appearances of the same item, there may be at most 2 * (n - 1) other items. + """ + + def __init__(self, source): + if len(source) == 0: + raise Exception("Can't create RandomCycler from an empty collection") + self.all_items = list(source) + self.next_items = [] + + def sample(self, count: int): + shuffle = lambda l: random.sample(l, len(l)) + + out = [] + while count > 0: + if count >= len(self.all_items): + out.extend(shuffle(list(self.all_items))) + count -= len(self.all_items) + continue + n = min(count, len(self.next_items)) + out.extend(self.next_items[:n]) + count -= n + self.next_items = self.next_items[n:] + if len(self.next_items) == 0: + self.next_items = shuffle(list(self.all_items)) + return out + + def __next__(self): + return self.sample(1)[0] diff --git a/models/vc/FreeVC/speaker_encoder/data_objects/speaker.py b/models/vc/FreeVC/speaker_encoder/data_objects/speaker.py new file mode 100644 index 00000000..bc75b0cf --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/data_objects/speaker.py @@ -0,0 +1,43 @@ +from speaker_encoder.data_objects.random_cycler import RandomCycler +from speaker_encoder.data_objects.utterance import Utterance +from pathlib import Path + + +# Contains the set of utterances of a single speaker +class Speaker: + def __init__(self, root: Path): + self.root = root + self.name = root.name + self.utterances = None + self.utterance_cycler = None + + def _load_utterances(self): + with self.root.joinpath("_sources.txt").open("r") as sources_file: + sources = [l.split(",") for l in sources_file] + sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources} + self.utterances = [ + Utterance(self.root.joinpath(f), w) for f, w in sources.items() + ] + self.utterance_cycler = RandomCycler(self.utterances) + + def random_partial(self, count, n_frames): + """ + Samples a batch of unique partial utterances from the disk in a way that all + utterances come up at least once every two cycles and in a random order every time. + + :param count: The number of partial utterances to sample from the set of utterances from + that speaker. Utterances are guaranteed not to be repeated if is not larger than + the number of utterances available. + :param n_frames: The number of frames in the partial utterance. + :return: A list of tuples (utterance, frames, range) where utterance is an Utterance, + frames are the frames of the partial utterances and range is the range of the partial + utterance with regard to the complete utterance. + """ + if self.utterances is None: + self._load_utterances() + + utterances = self.utterance_cycler.sample(count) + + a = [(u,) + u.random_partial(n_frames) for u in utterances] + + return a diff --git a/models/vc/FreeVC/speaker_encoder/data_objects/speaker_batch.py b/models/vc/FreeVC/speaker_encoder/data_objects/speaker_batch.py new file mode 100644 index 00000000..3cc5a52e --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/data_objects/speaker_batch.py @@ -0,0 +1,19 @@ +import numpy as np +from typing import List +from speaker_encoder.data_objects.speaker import Speaker + + +class SpeakerBatch: + def __init__( + self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int + ): + self.speakers = speakers + self.partials = { + s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers + } + + # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with + # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40) + self.data = np.array( + [frames for s in speakers for _, frames, _ in self.partials[s]] + ) diff --git a/models/vc/FreeVC/speaker_encoder/data_objects/speaker_verification_dataset.py b/models/vc/FreeVC/speaker_encoder/data_objects/speaker_verification_dataset.py new file mode 100644 index 00000000..1a24f2ce --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/data_objects/speaker_verification_dataset.py @@ -0,0 +1,67 @@ +from speaker_encoder.data_objects.random_cycler import RandomCycler +from speaker_encoder.data_objects.speaker_batch import SpeakerBatch +from speaker_encoder.data_objects.speaker import Speaker +from speaker_encoder.params_data import partials_n_frames +from torch.utils.data import Dataset, DataLoader +from pathlib import Path + +# TODO: improve with a pool of speakers for data efficiency + + +class SpeakerVerificationDataset(Dataset): + def __init__(self, datasets_root: Path): + self.root = datasets_root + speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] + if len(speaker_dirs) == 0: + raise Exception( + "No speakers found. Make sure you are pointing to the directory " + "containing all preprocessed speaker directories." + ) + self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs] + self.speaker_cycler = RandomCycler(self.speakers) + + def __len__(self): + return int(1e10) + + def __getitem__(self, index): + return next(self.speaker_cycler) + + def get_logs(self): + log_string = "" + for log_fpath in self.root.glob("*.txt"): + with log_fpath.open("r") as log_file: + log_string += "".join(log_file.readlines()) + return log_string + + +class SpeakerVerificationDataLoader(DataLoader): + def __init__( + self, + dataset, + speakers_per_batch, + utterances_per_speaker, + sampler=None, + batch_sampler=None, + num_workers=0, + pin_memory=False, + timeout=0, + worker_init_fn=None, + ): + self.utterances_per_speaker = utterances_per_speaker + + super().__init__( + dataset=dataset, + batch_size=speakers_per_batch, + shuffle=False, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=self.collate, + pin_memory=pin_memory, + drop_last=False, + timeout=timeout, + worker_init_fn=worker_init_fn, + ) + + def collate(self, speakers): + return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames) diff --git a/models/vc/FreeVC/speaker_encoder/data_objects/utterance.py b/models/vc/FreeVC/speaker_encoder/data_objects/utterance.py new file mode 100644 index 00000000..5b65eaa5 --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/data_objects/utterance.py @@ -0,0 +1,26 @@ +import numpy as np + + +class Utterance: + def __init__(self, frames_fpath, wave_fpath): + self.frames_fpath = frames_fpath + self.wave_fpath = wave_fpath + + def get_frames(self): + return np.load(self.frames_fpath) + + def random_partial(self, n_frames): + """ + Crops the frames into a partial utterance of n_frames + + :param n_frames: The number of frames of the partial utterance + :return: the partial utterance frames and a tuple indicating the start and end of the + partial utterance in the complete utterance. + """ + frames = self.get_frames() + if frames.shape[0] == n_frames: + start = 0 + else: + start = np.random.randint(0, frames.shape[0] - n_frames) + end = start + n_frames + return frames[start:end], (start, end) diff --git a/models/vc/FreeVC/speaker_encoder/hparams.py b/models/vc/FreeVC/speaker_encoder/hparams.py new file mode 100644 index 00000000..2c536ae1 --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/hparams.py @@ -0,0 +1,31 @@ +## Mel-filterbank +mel_window_length = 25 # In milliseconds +mel_window_step = 10 # In milliseconds +mel_n_channels = 40 + + +## Audio +sampling_rate = 16000 +# Number of spectrogram frames in a partial utterance +partials_n_frames = 160 # 1600 ms + + +## Voice Activation Detection +# Window size of the VAD. Must be either 10, 20 or 30 milliseconds. +# This sets the granularity of the VAD. Should not need to be changed. +vad_window_length = 30 # In milliseconds +# Number of frames to average together when performing the moving average smoothing. +# The larger this value, the larger the VAD variations must be to not get smoothed out. +vad_moving_average_width = 8 +# Maximum number of consecutive silent frames a segment can have. +vad_max_silence_length = 6 + + +## Audio volume normalization +audio_norm_target_dBFS = -30 + + +## Model parameters +model_hidden_size = 256 +model_embedding_size = 256 +model_num_layers = 3 diff --git a/models/vc/FreeVC/speaker_encoder/inference.py b/models/vc/FreeVC/speaker_encoder/inference.py new file mode 100644 index 00000000..deac3f8c --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/inference.py @@ -0,0 +1,190 @@ +from speaker_encoder.params_data import * +from speaker_encoder.model import SpeakerEncoder +from speaker_encoder.audio import ( + preprocess_wav, +) # We want to expose this function from here +from matplotlib import cm +from speaker_encoder import audio +from pathlib import Path +import matplotlib.pyplot as plt +import numpy as np +import torch + +_model = None # type: SpeakerEncoder +_device = None # type: torch.device + + +def load_model(weights_fpath: Path, device=None): + """ + Loads the model in memory. If this function is not explicitely called, it will be run on the + first call to embed_frames() with the default weights file. + + :param weights_fpath: the path to saved model weights. + :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The + model will be loaded and will run on this device. Outputs will however always be on the cpu. + If None, will default to your GPU if it"s available, otherwise your CPU. + """ + # TODO: I think the slow loading of the encoder might have something to do with the device it + # was saved on. Worth investigating. + global _model, _device + if device is None: + _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, str): + _device = torch.device(device) + _model = SpeakerEncoder(_device, torch.device("cpu")) + checkpoint = torch.load(weights_fpath) + _model.load_state_dict(checkpoint["model_state"]) + _model.eval() + print( + 'Loaded encoder "%s" trained to step %d' + % (weights_fpath.name, checkpoint["step"]) + ) + + +def is_loaded(): + return _model is not None + + +def embed_frames_batch(frames_batch): + """ + Computes embeddings for a batch of mel spectrogram. + + :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape + (batch_size, n_frames, n_channels) + :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size) + """ + if _model is None: + raise Exception("Model was not loaded. Call load_model() before inference.") + + frames = torch.from_numpy(frames_batch).to(_device) + embed = _model.forward(frames).detach().cpu().numpy() + return embed + + +def compute_partial_slices( + n_samples, + partial_utterance_n_frames=partials_n_frames, + min_pad_coverage=0.75, + overlap=0.5, +): + """ + Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain + partial utterances of each. Both the waveform and the mel + spectrogram slices are returned, so as to make each partial utterance waveform correspond to + its spectrogram. This function assumes that the mel spectrogram parameters used are those + defined in params_data.py. + + The returned ranges may be indexing further than the length of the waveform. It is + recommended that you pad the waveform with zeros up to wave_slices[-1].stop. + + :param n_samples: the number of samples in the waveform + :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial + utterance + :param min_pad_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered, as if we padded the audio. Otherwise, + it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial + utterance, this parameter is ignored so that the function always returns at least 1 slice. + :param overlap: by how much the partial utterance should overlap. If set to 0, the partial + utterances are entirely disjoint. + :return: the waveform slices and mel spectrogram slices as lists of array slices. Index + respectively the waveform and the mel spectrogram with these slices to obtain the partial + utterances. + """ + assert 0 <= overlap < 1 + assert 0 < min_pad_coverage <= 1 + + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) + + # Compute the slices + wav_slices, mel_slices = [], [] + steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1) + for i in range(0, steps, frame_step): + mel_range = np.array([i, i + partial_utterance_n_frames]) + wav_range = mel_range * samples_per_frame + mel_slices.append(slice(*mel_range)) + wav_slices.append(slice(*wav_range)) + + # Evaluate whether extra padding is warranted or not + last_wav_range = wav_slices[-1] + coverage = (n_samples - last_wav_range.start) / ( + last_wav_range.stop - last_wav_range.start + ) + if coverage < min_pad_coverage and len(mel_slices) > 1: + mel_slices = mel_slices[:-1] + wav_slices = wav_slices[:-1] + + return wav_slices, mel_slices + + +def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs): + """ + Computes an embedding for a single utterance. + + # TODO: handle multiple wavs to benefit from batching on GPU + :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32 + :param using_partials: if True, then the utterance is split in partial utterances of + frames and the utterance embedding is computed from their + normalized average. If False, the utterance is instead computed from feeding the entire + spectogram to the network. + :param return_partials: if True, the partial embeddings will also be returned along with the + wav slices that correspond to the partial embeddings. + :param kwargs: additional arguments to compute_partial_splits() + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If + is True, the partial utterances as a numpy array of float32 of shape + (n_partials, model_embedding_size) and the wav partials as a list of slices will also be + returned. If is simultaneously set to False, both these values will be None + instead. + """ + # Process the entire utterance if not using partials + if not using_partials: + frames = audio.wav_to_mel_spectrogram(wav) + embed = embed_frames_batch(frames[None, ...])[0] + if return_partials: + return embed, None, None + return embed + + # Compute where to split the utterance into partials and pad if necessary + wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) + max_wave_length = wave_slices[-1].stop + if max_wave_length >= len(wav): + wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") + + # Split the utterance into partials + frames = audio.wav_to_mel_spectrogram(wav) + frames_batch = np.array([frames[s] for s in mel_slices]) + partial_embeds = embed_frames_batch(frames_batch) + + # Compute the utterance embedding from the partial embeddings + raw_embed = np.mean(partial_embeds, axis=0) + embed = raw_embed / np.linalg.norm(raw_embed, 2) + + if return_partials: + return embed, partial_embeds, wave_slices + return embed + + +def embed_speaker(wavs, **kwargs): + raise NotImplemented() + + +def plot_embedding_as_heatmap( + embed, ax=None, title="", shape=None, color_range=(0, 0.30) +): + if ax is None: + ax = plt.gca() + + if shape is None: + height = int(np.sqrt(len(embed))) + shape = (height, -1) + embed = embed.reshape(shape) + + cmap = cm.get_cmap() + mappable = ax.imshow(embed, cmap=cmap) + cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04) + cbar.set_clim(*color_range) + + ax.set_xticks([]), ax.set_yticks([]) + ax.set_title(title) diff --git a/models/vc/FreeVC/speaker_encoder/model.py b/models/vc/FreeVC/speaker_encoder/model.py new file mode 100644 index 00000000..6fc462fd --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/model.py @@ -0,0 +1,144 @@ +from speaker_encoder.params_model import * +from speaker_encoder.params_data import * +from scipy.interpolate import interp1d +from sklearn.metrics import roc_curve +from torch.nn.utils import clip_grad_norm_ +from scipy.optimize import brentq +from torch import nn +import numpy as np +import torch + + +class SpeakerEncoder(nn.Module): + def __init__(self, device, loss_device): + super().__init__() + self.loss_device = loss_device + + # Network defition + self.lstm = nn.LSTM( + input_size=mel_n_channels, # 40 + hidden_size=model_hidden_size, # 256 + num_layers=model_num_layers, # 3 + batch_first=True, + ).to(device) + self.linear = nn.Linear( + in_features=model_hidden_size, out_features=model_embedding_size + ).to(device) + self.relu = torch.nn.ReLU().to(device) + + # Cosine similarity scaling (with fixed initial parameter values) + self.similarity_weight = nn.Parameter(torch.tensor([10.0])).to(loss_device) + self.similarity_bias = nn.Parameter(torch.tensor([-5.0])).to(loss_device) + + # Loss + self.loss_fn = nn.CrossEntropyLoss().to(loss_device) + + def do_gradient_ops(self): + # Gradient scale + self.similarity_weight.grad *= 0.01 + self.similarity_bias.grad *= 0.01 + + # Gradient clipping + clip_grad_norm_(self.parameters(), 3, norm_type=2) + + def forward(self, utterances, hidden_init=None): + """ + Computes the embeddings of a batch of utterance spectrograms. + + :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape + (batch_size, n_frames, n_channels) + :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, + batch_size, hidden_size). Will default to a tensor of zeros if None. + :return: the embeddings as a tensor of shape (batch_size, embedding_size) + """ + # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state + # and the final cell state. + out, (hidden, cell) = self.lstm(utterances, hidden_init) + + # We take only the hidden state of the last layer + embeds_raw = self.relu(self.linear(hidden[-1])) + + # L2-normalize it + embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + return embeds + + def similarity_matrix(self, embeds): + """ + Computes the similarity matrix according the section 2.1 of GE2E. + + :param embeds: the embeddings as a tensor of shape (speakers_per_batch, + utterances_per_speaker, embedding_size) + :return: the similarity matrix as a tensor of shape (speakers_per_batch, + utterances_per_speaker, speakers_per_batch) + """ + speakers_per_batch, utterances_per_speaker = embeds.shape[:2] + + # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation + centroids_incl = torch.mean(embeds, dim=1, keepdim=True) + centroids_incl = centroids_incl.clone() / torch.norm( + centroids_incl, dim=2, keepdim=True + ) + + # Exclusive centroids (1 per utterance) + centroids_excl = torch.sum(embeds, dim=1, keepdim=True) - embeds + centroids_excl /= utterances_per_speaker - 1 + centroids_excl = centroids_excl.clone() / torch.norm( + centroids_excl, dim=2, keepdim=True + ) + + # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot + # product of these vectors (which is just an element-wise multiplication reduced by a sum). + # We vectorize the computation for efficiency. + sim_matrix = torch.zeros( + speakers_per_batch, utterances_per_speaker, speakers_per_batch + ).to(self.loss_device) + mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int) + for j in range(speakers_per_batch): + mask = np.where(mask_matrix[j])[0] + sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2) + sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1) + + ## Even more vectorized version (slower maybe because of transpose) + # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker + # ).to(self.loss_device) + # eye = np.eye(speakers_per_batch, dtype=np.int) + # mask = np.where(1 - eye) + # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2) + # mask = np.where(eye) + # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2) + # sim_matrix2 = sim_matrix2.transpose(1, 2) + + sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias + return sim_matrix + + def loss(self, embeds): + """ + Computes the softmax loss according the section 2.1 of GE2E. + + :param embeds: the embeddings as a tensor of shape (speakers_per_batch, + utterances_per_speaker, embedding_size) + :return: the loss and the EER for this batch of embeddings. + """ + speakers_per_batch, utterances_per_speaker = embeds.shape[:2] + + # Loss + sim_matrix = self.similarity_matrix(embeds) + sim_matrix = sim_matrix.reshape( + (speakers_per_batch * utterances_per_speaker, speakers_per_batch) + ) + ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) + target = torch.from_numpy(ground_truth).long().to(self.loss_device) + loss = self.loss_fn(sim_matrix, target) + + # EER (not backpropagated) + with torch.no_grad(): + inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0] + labels = np.array([inv_argmax(i) for i in ground_truth]) + preds = sim_matrix.detach().cpu().numpy() + + # Snippet from https://yangcha.github.io/EER-ROC/ + fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) + eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) + + return loss, eer diff --git a/models/vc/FreeVC/speaker_encoder/params_data.py b/models/vc/FreeVC/speaker_encoder/params_data.py new file mode 100644 index 00000000..0619c739 --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/params_data.py @@ -0,0 +1,27 @@ +## Mel-filterbank +mel_window_length = 25 # In milliseconds +mel_window_step = 10 # In milliseconds +mel_n_channels = 40 + + +## Audio +sampling_rate = 16000 +# Number of spectrogram frames in a partial utterance +partials_n_frames = 160 # 1600 ms +# Number of spectrogram frames at inference +inference_n_frames = 80 # 800 ms + + +## Voice Activation Detection +# Window size of the VAD. Must be either 10, 20 or 30 milliseconds. +# This sets the granularity of the VAD. Should not need to be changed. +vad_window_length = 30 # In milliseconds +# Number of frames to average together when performing the moving average smoothing. +# The larger this value, the larger the VAD variations must be to not get smoothed out. +vad_moving_average_width = 8 +# Maximum number of consecutive silent frames a segment can have. +vad_max_silence_length = 6 + + +## Audio volume normalization +audio_norm_target_dBFS = -30 diff --git a/models/vc/FreeVC/speaker_encoder/params_model.py b/models/vc/FreeVC/speaker_encoder/params_model.py new file mode 100644 index 00000000..29026cc2 --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/params_model.py @@ -0,0 +1,10 @@ +## Model parameters +model_hidden_size = 256 +model_embedding_size = 256 +model_num_layers = 3 + + +## Training parameters +learning_rate_init = 1e-4 +speakers_per_batch = 64 +utterances_per_speaker = 10 diff --git a/models/vc/FreeVC/speaker_encoder/preprocess.py b/models/vc/FreeVC/speaker_encoder/preprocess.py new file mode 100644 index 00000000..a6c2c98f --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/preprocess.py @@ -0,0 +1,329 @@ +from multiprocess.pool import ThreadPool +from speaker_encoder.params_data import * +from speaker_encoder.config import librispeech_datasets, anglophone_nationalites +from datetime import datetime +from speaker_encoder import audio +from pathlib import Path +from tqdm import tqdm +import numpy as np + + +class DatasetLog: + """ + Registers metadata about the dataset in a text file. + """ + + def __init__(self, root, name): + self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w") + self.sample_data = dict() + + start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) + self.write_line("Creating dataset %s on %s" % (name, start_time)) + self.write_line("-----") + self._log_params() + + def _log_params(self): + from speaker_encoder import params_data + + self.write_line("Parameter values:") + for param_name in (p for p in dir(params_data) if not p.startswith("__")): + value = getattr(params_data, param_name) + self.write_line("\t%s: %s" % (param_name, value)) + self.write_line("-----") + + def write_line(self, line): + self.text_file.write("%s\n" % line) + + def add_sample(self, **kwargs): + for param_name, value in kwargs.items(): + if not param_name in self.sample_data: + self.sample_data[param_name] = [] + self.sample_data[param_name].append(value) + + def finalize(self): + self.write_line("Statistics:") + for param_name, values in self.sample_data.items(): + self.write_line("\t%s:" % param_name) + self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values))) + self.write_line( + "\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)) + ) + self.write_line("-----") + end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) + self.write_line("Finished on %s" % end_time) + self.text_file.close() + + +def _init_preprocess_dataset( + dataset_name, datasets_root, out_dir +) -> (Path, DatasetLog): + dataset_root = datasets_root.joinpath(dataset_name) + if not dataset_root.exists(): + print("Couldn't find %s, skipping this dataset." % dataset_root) + return None, None + return dataset_root, DatasetLog(out_dir, dataset_name) + + +def _preprocess_speaker_dirs( + speaker_dirs, dataset_name, datasets_root, out_dir, extension, skip_existing, logger +): + print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs))) + + # Function to preprocess utterances for one speaker + def preprocess_speaker(speaker_dir: Path): + # Give a name to the speaker that includes its dataset + speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) + + # Create an output directory with that name, as well as a txt file containing a + # reference to each source file. + speaker_out_dir = out_dir.joinpath(speaker_name) + speaker_out_dir.mkdir(exist_ok=True) + sources_fpath = speaker_out_dir.joinpath("_sources.txt") + + # There's a possibility that the preprocessing was interrupted earlier, check if + # there already is a sources file. + if sources_fpath.exists(): + try: + with sources_fpath.open("r") as sources_file: + existing_fnames = {line.split(",")[0] for line in sources_file} + except: + existing_fnames = {} + else: + existing_fnames = {} + + # Gather all audio files for that speaker recursively + sources_file = sources_fpath.open("a" if skip_existing else "w") + for in_fpath in speaker_dir.glob("**/*.%s" % extension): + # Check if the target output file already exists + out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) + out_fname = out_fname.replace(".%s" % extension, ".npy") + if skip_existing and out_fname in existing_fnames: + continue + + # Load and preprocess the waveform + wav = audio.preprocess_wav(in_fpath) + if len(wav) == 0: + continue + + # Create the mel spectrogram, discard those that are too short + frames = audio.wav_to_mel_spectrogram(wav) + if len(frames) < partials_n_frames: + continue + + out_fpath = speaker_out_dir.joinpath(out_fname) + np.save(out_fpath, frames) + logger.add_sample(duration=len(wav) / sampling_rate) + sources_file.write("%s,%s\n" % (out_fname, in_fpath)) + + sources_file.close() + + # Process the utterances for each speaker + with ThreadPool(8) as pool: + list( + tqdm( + pool.imap(preprocess_speaker, speaker_dirs), + dataset_name, + len(speaker_dirs), + unit="speakers", + ) + ) + logger.finalize() + print("Done preprocessing %s.\n" % dataset_name) + + +# Function to preprocess utterances for one speaker +def __preprocess_speaker( + speaker_dir: Path, + datasets_root: Path, + out_dir: Path, + extension: str, + skip_existing: bool, +): + # Give a name to the speaker that includes its dataset + speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) + + # Create an output directory with that name, as well as a txt file containing a + # reference to each source file. + speaker_out_dir = out_dir.joinpath(speaker_name) + speaker_out_dir.mkdir(exist_ok=True) + sources_fpath = speaker_out_dir.joinpath("_sources.txt") + + # There's a possibility that the preprocessing was interrupted earlier, check if + # there already is a sources file. + # if sources_fpath.exists(): + # try: + # with sources_fpath.open("r") as sources_file: + # existing_fnames = {line.split(",")[0] for line in sources_file} + # except: + # existing_fnames = {} + # else: + # existing_fnames = {} + existing_fnames = {} + # Gather all audio files for that speaker recursively + sources_file = sources_fpath.open("a" if skip_existing else "w") + + for in_fpath in speaker_dir.glob("**/*.%s" % extension): + # Check if the target output file already exists + out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) + out_fname = out_fname.replace(".%s" % extension, ".npy") + if skip_existing and out_fname in existing_fnames: + continue + + # Load and preprocess the waveform + wav = audio.preprocess_wav(in_fpath) + if len(wav) == 0: + continue + + # Create the mel spectrogram, discard those that are too short + frames = audio.wav_to_mel_spectrogram(wav) + if len(frames) < partials_n_frames: + continue + + out_fpath = speaker_out_dir.joinpath(out_fname) + np.save(out_fpath, frames) + # logger.add_sample(duration=len(wav) / sampling_rate) + sources_file.write("%s,%s\n" % (out_fname, in_fpath)) + + sources_file.close() + return len(wav) + + +def _preprocess_speaker_dirs_vox2( + speaker_dirs, dataset_name, datasets_root, out_dir, extension, skip_existing, logger +): + # from multiprocessing import Pool, cpu_count + from pathos.multiprocessing import ProcessingPool as Pool + + # Function to preprocess utterances for one speaker + def __preprocess_speaker(speaker_dir: Path): + # Give a name to the speaker that includes its dataset + speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) + + # Create an output directory with that name, as well as a txt file containing a + # reference to each source file. + speaker_out_dir = out_dir.joinpath(speaker_name) + speaker_out_dir.mkdir(exist_ok=True) + sources_fpath = speaker_out_dir.joinpath("_sources.txt") + + existing_fnames = {} + # Gather all audio files for that speaker recursively + sources_file = sources_fpath.open("a" if skip_existing else "w") + wav_lens = [] + for in_fpath in speaker_dir.glob("**/*.%s" % extension): + # Check if the target output file already exists + out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) + out_fname = out_fname.replace(".%s" % extension, ".npy") + if skip_existing and out_fname in existing_fnames: + continue + + # Load and preprocess the waveform + wav = audio.preprocess_wav(in_fpath) + if len(wav) == 0: + continue + + # Create the mel spectrogram, discard those that are too short + frames = audio.wav_to_mel_spectrogram(wav) + if len(frames) < partials_n_frames: + continue + + out_fpath = speaker_out_dir.joinpath(out_fname) + np.save(out_fpath, frames) + # logger.add_sample(duration=len(wav) / sampling_rate) + sources_file.write("%s,%s\n" % (out_fname, in_fpath)) + wav_lens.append(len(wav)) + sources_file.close() + return wav_lens + + print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs))) + # Process the utterances for each speaker + # with ThreadPool(8) as pool: + # list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs), + # unit="speakers")) + pool = Pool(processes=20) + for i, wav_lens in enumerate(pool.map(__preprocess_speaker, speaker_dirs), 1): + for wav_len in wav_lens: + logger.add_sample(duration=wav_len / sampling_rate) + print(f"{i}/{len(speaker_dirs)} \r") + + logger.finalize() + print("Done preprocessing %s.\n" % dataset_name) + + +def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False): + for dataset_name in librispeech_datasets["train"]["other"]: + # Initialize the preprocessing + dataset_root, logger = _init_preprocess_dataset( + dataset_name, datasets_root, out_dir + ) + if not dataset_root: + return + + # Preprocess all speakers + speaker_dirs = list(dataset_root.glob("*")) + _preprocess_speaker_dirs( + speaker_dirs, + dataset_name, + datasets_root, + out_dir, + "flac", + skip_existing, + logger, + ) + + +def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False): + # Initialize the preprocessing + dataset_name = "VoxCeleb1" + dataset_root, logger = _init_preprocess_dataset( + dataset_name, datasets_root, out_dir + ) + if not dataset_root: + return + + # Get the contents of the meta file + with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile: + metadata = [line.split("\t") for line in metafile][1:] + + # Select the ID and the nationality, filter out non-anglophone speakers + nationalities = {line[0]: line[3] for line in metadata} + # keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if + # nationality.lower() in anglophone_nationalites] + keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()] + print( + "VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." + % (len(keep_speaker_ids), len(nationalities)) + ) + + # Get the speaker directories for anglophone speakers only + speaker_dirs = dataset_root.joinpath("wav").glob("*") + speaker_dirs = [ + speaker_dir + for speaker_dir in speaker_dirs + if speaker_dir.name in keep_speaker_ids + ] + print( + "VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." + % (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)) + ) + + # Preprocess all speakers + _preprocess_speaker_dirs( + speaker_dirs, dataset_name, datasets_root, out_dir, "wav", skip_existing, logger + ) + + +def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False): + # Initialize the preprocessing + dataset_name = "VoxCeleb2" + dataset_root, logger = _init_preprocess_dataset( + dataset_name, datasets_root, out_dir + ) + if not dataset_root: + return + + # Get the speaker directories + # Preprocess all speakers + speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*")) + _preprocess_speaker_dirs_vox2( + speaker_dirs, dataset_name, datasets_root, out_dir, "m4a", skip_existing, logger + ) diff --git a/models/vc/FreeVC/speaker_encoder/train.py b/models/vc/FreeVC/speaker_encoder/train.py new file mode 100644 index 00000000..e1dc7457 --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/train.py @@ -0,0 +1,150 @@ +from speaker_encoder.visualizations import Visualizations +from speaker_encoder.data_objects import ( + SpeakerVerificationDataLoader, + SpeakerVerificationDataset, +) +from speaker_encoder.params_model import * +from speaker_encoder.model import SpeakerEncoder +from utils.profiler import Profiler +from pathlib import Path +import torch + + +def sync(device: torch.device): + # FIXME + return + # For correct profiling (cuda operations are async) + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def train( + run_id: str, + clean_data_root: Path, + models_dir: Path, + umap_every: int, + save_every: int, + backup_every: int, + vis_every: int, + force_restart: bool, + visdom_server: str, + no_visdom: bool, +): + # Create a dataset and a dataloader + dataset = SpeakerVerificationDataset(clean_data_root) + loader = SpeakerVerificationDataLoader( + dataset, + speakers_per_batch, # 64 + utterances_per_speaker, # 10 + num_workers=8, + ) + + # Setup the device on which to run the forward pass and the loss. These can be different, + # because the forward pass is faster on the GPU whereas the loss is often (depending on your + # hyperparameters) faster on the CPU. + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # FIXME: currently, the gradient is None if loss_device is cuda + loss_device = torch.device("cpu") + + # Create the model and the optimizer + model = SpeakerEncoder(device, loss_device) + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) + init_step = 1 + + # Configure file path for the model + state_fpath = models_dir.joinpath(run_id + ".pt") + backup_dir = models_dir.joinpath(run_id + "_backups") + + # Load any existing model + if not force_restart: + if state_fpath.exists(): + print( + 'Found existing model "%s", loading it and resuming training.' % run_id + ) + checkpoint = torch.load(state_fpath) + init_step = checkpoint["step"] + model.load_state_dict(checkpoint["model_state"]) + optimizer.load_state_dict(checkpoint["optimizer_state"]) + optimizer.param_groups[0]["lr"] = learning_rate_init + else: + print('No model "%s" found, starting training from scratch.' % run_id) + else: + print("Starting the training from scratch.") + model.train() + + # Initialize the visualization environment + vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) + vis.log_dataset(dataset) + vis.log_params() + device_name = str( + torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU" + ) + vis.log_implementation({"Device": device_name}) + + # Training loop + profiler = Profiler(summarize_every=10, disabled=False) + for step, speaker_batch in enumerate(loader, init_step): + profiler.tick("Blocking, waiting for batch (threaded)") + + # Forward pass + inputs = torch.from_numpy(speaker_batch.data).to(device) + sync(device) + profiler.tick("Data to %s" % device) + embeds = model(inputs) + sync(device) + profiler.tick("Forward pass") + embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to( + loss_device + ) + loss, eer = model.loss(embeds_loss) + sync(loss_device) + profiler.tick("Loss") + + # Backward pass + model.zero_grad() + loss.backward() + profiler.tick("Backward pass") + model.do_gradient_ops() + optimizer.step() + profiler.tick("Parameter update") + + # Update visualizations + # learning_rate = optimizer.param_groups[0]["lr"] + vis.update(loss.item(), eer, step) + + # Draw projections and save them to the backup folder + if umap_every != 0 and step % umap_every == 0: + print("Drawing and saving projections (step %d)" % step) + backup_dir.mkdir(exist_ok=True) + projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step)) + embeds = embeds.detach().cpu().numpy() + vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath) + vis.save() + + # Overwrite the latest version of the model + if save_every != 0 and step % save_every == 0: + print("Saving the model (step %d)" % step) + torch.save( + { + "step": step + 1, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, + state_fpath, + ) + + # Make a backup + if backup_every != 0 and step % backup_every == 0: + print("Making a backup (step %d)" % step) + backup_dir.mkdir(exist_ok=True) + backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step)) + torch.save( + { + "step": step + 1, + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, + backup_fpath, + ) + + profiler.tick("Extras (visualizations, saving)") diff --git a/models/vc/FreeVC/speaker_encoder/visualizations.py b/models/vc/FreeVC/speaker_encoder/visualizations.py new file mode 100644 index 00000000..fea0efb2 --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/visualizations.py @@ -0,0 +1,195 @@ +from speaker_encoder.data_objects.speaker_verification_dataset import ( + SpeakerVerificationDataset, +) +from datetime import datetime +from time import perf_counter as timer +import matplotlib.pyplot as plt +import numpy as np + +# import webbrowser +import visdom +import umap + +colormap = ( + np.array( + [ + [76, 255, 0], + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], + ], + dtype=np.float, + ) + / 255 +) + + +class Visualizations: + def __init__( + self, env_name=None, update_every=10, server="http://localhost", disabled=False + ): + # Tracking data + self.last_update_timestamp = timer() + self.update_every = update_every + self.step_times = [] + self.losses = [] + self.eers = [] + print("Updating the visualizations every %d steps." % update_every) + + # If visdom is disabled TODO: use a better paradigm for that + self.disabled = disabled + if self.disabled: + return + + # Set the environment name + now = str(datetime.now().strftime("%d-%m %Hh%M")) + if env_name is None: + self.env_name = now + else: + self.env_name = "%s (%s)" % (env_name, now) + + # Connect to visdom and open the corresponding window in the browser + try: + self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True) + except ConnectionError: + raise Exception( + 'No visdom server detected. Run the command "visdom" in your CLI to ' + "start it." + ) + # webbrowser.open("http://localhost:8097/env/" + self.env_name) + + # Create the windows + self.loss_win = None + self.eer_win = None + # self.lr_win = None + self.implementation_win = None + self.projection_win = None + self.implementation_string = "" + + def log_params(self): + if self.disabled: + return + from speaker_encoder import params_data + from speaker_encoder import params_model + + param_string = "Model parameters:
" + for param_name in (p for p in dir(params_model) if not p.startswith("__")): + value = getattr(params_model, param_name) + param_string += "\t%s: %s
" % (param_name, value) + param_string += "Data parameters:
" + for param_name in (p for p in dir(params_data) if not p.startswith("__")): + value = getattr(params_data, param_name) + param_string += "\t%s: %s
" % (param_name, value) + self.vis.text(param_string, opts={"title": "Parameters"}) + + def log_dataset(self, dataset: SpeakerVerificationDataset): + if self.disabled: + return + dataset_string = "" + dataset_string += "Speakers: %s\n" % len(dataset.speakers) + dataset_string += "\n" + dataset.get_logs() + dataset_string = dataset_string.replace("\n", "
") + self.vis.text(dataset_string, opts={"title": "Dataset"}) + + def log_implementation(self, params): + if self.disabled: + return + implementation_string = "" + for param, value in params.items(): + implementation_string += "%s: %s\n" % (param, value) + implementation_string = implementation_string.replace("\n", "
") + self.implementation_string = implementation_string + self.implementation_win = self.vis.text( + implementation_string, opts={"title": "Training implementation"} + ) + + def update(self, loss, eer, step): + # Update the tracking data + now = timer() + self.step_times.append(1000 * (now - self.last_update_timestamp)) + self.last_update_timestamp = now + self.losses.append(loss) + self.eers.append(eer) + print(".", end="") + + # Update the plots every steps + if step % self.update_every != 0: + return + time_string = "Step time: mean: %5dms std: %5dms" % ( + int(np.mean(self.step_times)), + int(np.std(self.step_times)), + ) + print( + "\nStep %6d Loss: %.4f EER: %.4f %s" + % (step, np.mean(self.losses), np.mean(self.eers), time_string) + ) + if not self.disabled: + self.loss_win = self.vis.line( + [np.mean(self.losses)], + [step], + win=self.loss_win, + update="append" if self.loss_win else None, + opts=dict( + legend=["Avg. loss"], + xlabel="Step", + ylabel="Loss", + title="Loss", + ), + ) + self.eer_win = self.vis.line( + [np.mean(self.eers)], + [step], + win=self.eer_win, + update="append" if self.eer_win else None, + opts=dict( + legend=["Avg. EER"], + xlabel="Step", + ylabel="EER", + title="Equal error rate", + ), + ) + if self.implementation_win is not None: + self.vis.text( + self.implementation_string + ("%s" % time_string), + win=self.implementation_win, + opts={"title": "Training implementation"}, + ) + + # Reset the tracking + self.losses.clear() + self.eers.clear() + self.step_times.clear() + + def draw_projections( + self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10 + ): + max_speakers = min(max_speakers, len(colormap)) + embeds = embeds[: max_speakers * utterances_per_speaker] + + n_speakers = len(embeds) // utterances_per_speaker + ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker) + colors = [colormap[i] for i in ground_truth] + + reducer = umap.UMAP() + projected = reducer.fit_transform(embeds) + plt.scatter(projected[:, 0], projected[:, 1], c=colors) + plt.gca().set_aspect("equal", "datalim") + plt.title("UMAP projection (step %d)" % step) + if not self.disabled: + self.projection_win = self.vis.matplot(plt, win=self.projection_win) + if out_fpath is not None: + plt.savefig(out_fpath) + plt.clf() + + def save(self): + if not self.disabled: + self.vis.save([self.env_name]) diff --git a/models/vc/FreeVC/speaker_encoder/voice_encoder.py b/models/vc/FreeVC/speaker_encoder/voice_encoder.py new file mode 100644 index 00000000..93f70f72 --- /dev/null +++ b/models/vc/FreeVC/speaker_encoder/voice_encoder.py @@ -0,0 +1,193 @@ +from speaker_encoder.hparams import * +from speaker_encoder import audio +from pathlib import Path +from typing import Union, List +from torch import nn +from time import perf_counter as timer +import numpy as np +import torch + + +class SpeakerEncoder(nn.Module): + def __init__( + self, weights_fpath, device: Union[str, torch.device] = None, verbose=True + ): + """ + :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). + If None, defaults to cuda if it is available on your machine, otherwise the model will + run on cpu. Outputs are always returned on the cpu, as numpy arrays. + """ + super().__init__() + + # Define the network + self.lstm = nn.LSTM( + mel_n_channels, model_hidden_size, model_num_layers, batch_first=True + ) + self.linear = nn.Linear(model_hidden_size, model_embedding_size) + self.relu = nn.ReLU() + + # Get the target device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, str): + device = torch.device(device) + self.device = device + + # Load the pretrained model'speaker weights + # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt") + # if not weights_fpath.exists(): + # raise Exception("Couldn't find the voice encoder pretrained model at %s." % + # weights_fpath) + + start = timer() + checkpoint = torch.load(weights_fpath, map_location="cpu") + + self.load_state_dict(checkpoint["model_state"], strict=False) + self.to(device) + + if verbose: + print( + "Loaded the voice encoder model on %s in %.2f seconds." + % (device.type, timer() - start) + ) + + def forward(self, mels: torch.FloatTensor): + """ + Computes the embeddings of a batch of utterance spectrograms. + :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape + (batch_size, n_frames, n_channels) + :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size). + Embeddings are positive and L2-normed, thus they lay in the range [0, 1]. + """ + # Pass the input through the LSTM layers and retrieve the final hidden state of the last + # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings. + _, (hidden, _) = self.lstm(mels) + embeds_raw = self.relu(self.linear(hidden[-1])) + return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + @staticmethod + def compute_partial_slices(n_samples: int, rate, min_coverage): + """ + Computes where to split an utterance waveform and its corresponding mel spectrogram to + obtain partial utterances of each. Both the waveform and the + mel spectrogram slices are returned, so as to make each partial utterance waveform + correspond to its spectrogram. + + The returned ranges may be indexing further than the length of the waveform. It is + recommended that you pad the waveform with zeros up to wav_slices[-1].stop. + + :param n_samples: the number of samples in the waveform + :param rate: how many partial utterances should occur per second. Partial utterances must + cover the span of the entire utterance, thus the rate should not be lower than the inverse + of the duration of a partial utterance. By default, partial utterances are 1.6s long and + the minimum rate is thus 0.625. + :param min_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered by zero-padding the audio. Otherwise, + it will be discarded. If there aren't enough frames for one partial utterance, + this parameter is ignored so that the function always returns at least one slice. + :return: the waveform slices and mel spectrogram slices as lists of array slices. Index + respectively the waveform and the mel spectrogram with these slices to obtain the partial + utterances. + """ + assert 0 < min_coverage <= 1 + + # Compute how many frames separate two partial utterances + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = int(np.round((sampling_rate / rate) / samples_per_frame)) + assert 0 < frame_step, "The rate is too high" + assert ( + frame_step <= partials_n_frames + ), "The rate is too low, it should be %f at least" % ( + sampling_rate / (samples_per_frame * partials_n_frames) + ) + + # Compute the slices + wav_slices, mel_slices = [], [] + steps = max(1, n_frames - partials_n_frames + frame_step + 1) + for i in range(0, steps, frame_step): + mel_range = np.array([i, i + partials_n_frames]) + wav_range = mel_range * samples_per_frame + mel_slices.append(slice(*mel_range)) + wav_slices.append(slice(*wav_range)) + + # Evaluate whether extra padding is warranted or not + last_wav_range = wav_slices[-1] + coverage = (n_samples - last_wav_range.start) / ( + last_wav_range.stop - last_wav_range.start + ) + if coverage < min_coverage and len(mel_slices) > 1: + mel_slices = mel_slices[:-1] + wav_slices = wav_slices[:-1] + + return wav_slices, mel_slices + + def embed_utterance( + self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75 + ): + """ + Computes an embedding for a single utterance. The utterance is divided in partial + utterances and an embedding is computed for each. The complete utterance embedding is the + L2-normed average embedding of the partial utterances. + + TODO: independent batched version of this function + + :param wav: a preprocessed utterance waveform as a numpy array of float32 + :param return_partials: if True, the partial embeddings will also be returned along with + the wav slices corresponding to each partial utterance. + :param rate: how many partial utterances should occur per second. Partial utterances must + cover the span of the entire utterance, thus the rate should not be lower than the inverse + of the duration of a partial utterance. By default, partial utterances are 1.6s long and + the minimum rate is thus 0.625. + :param min_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered by zero-padding the audio. Otherwise, + it will be discarded. If there aren't enough frames for one partial utterance, + this parameter is ignored so that the function always returns at least one slice. + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If + is True, the partial utterances as a numpy array of float32 of shape + (n_partials, model_embedding_size) and the wav partials as a list of slices will also be + returned. + """ + # Compute where to split the utterance into partials and pad the waveform with zeros if + # the partial utterances cover a larger range. + wav_slices, mel_slices = self.compute_partial_slices( + len(wav), rate, min_coverage + ) + max_wave_length = wav_slices[-1].stop + if max_wave_length >= len(wav): + wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") + + # Split the utterance into partials and forward them through the model + mel = audio.wav_to_mel_spectrogram(wav) + mels = np.array([mel[s] for s in mel_slices]) + with torch.no_grad(): + mels = torch.from_numpy(mels).to(self.device) + partial_embeds = self(mels).cpu().numpy() + + # Compute the utterance embedding from the partial embeddings + raw_embed = np.mean(partial_embeds, axis=0) + embed = raw_embed / np.linalg.norm(raw_embed, 2) + + if return_partials: + return embed, partial_embeds, wav_slices + return embed + + def embed_speaker(self, wavs: List[np.ndarray], **kwargs): + """ + Compute the embedding of a collection of wavs (presumably from the same speaker) by + averaging their embedding and L2-normalizing it. + + :param wavs: list of wavs a numpy arrays of float32. + :param kwargs: extra arguments to embed_utterance() + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). + """ + raw_embed = np.mean( + [ + self.embed_utterance(wav, return_partials=False, **kwargs) + for wav in wavs + ], + axis=0, + ) + return raw_embed / np.linalg.norm(raw_embed, 2) diff --git a/models/vc/FreeVC/train.py b/models/vc/FreeVC/train.py new file mode 100644 index 00000000..87dd591d --- /dev/null +++ b/models/vc/FreeVC/train.py @@ -0,0 +1,372 @@ +from models.vc.FreeVC.model import SynthesizerTrn, MultiPeriodDiscriminator +from models.vc.FreeVC.data import FreeVCDataset, FreeVCCollate, BucketSampler +from models.vc.FreeVC.mel_processing import mel_spectrogram_torch, spec_to_mel_torch +from models.vc.FreeVC.commons import slice_segments +from models.vc.FreeVC.train_utils import ( + get_logger, + load_checkpoint, + latest_checkpoint_path, + summarize, + plot_spectrogram_to_numpy, + save_checkpoint, +) +from models.vc.FreeVC.losses import ( + generator_loss, + discriminator_loss, + feature_loss, + kl_loss, +) +from utils.util import clip_grad_value_ +from utils.util import load_config + + +import os +import argparse + +import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.tensorboard.writer import SummaryWriter +from torch.cuda.amp import autocast, GradScaler + + +torch.backends.cudnn.benchmark = True +global_step = 0 + + +def main(cfg, args): + global global_step + + assert torch.cuda.is_available(), "CPU training is not allowed." + + logger = get_logger(args.log_dir) + logger.info(cfg) + writer = SummaryWriter(log_dir=args.log_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(args.log_dir, "eval")) + + torch.manual_seed(cfg.train.seed) + + train_dataset = FreeVCDataset( + os.path.join(cfg.preprocess.split_dir, "train.txt"), cfg + ) + train_sampler = BucketSampler( + train_dataset, + cfg.train.batch_size, + [32, 300, 400, 500, 600, 700, 800, 900, 1000], + shuffle=True, + ) + collate_fn = FreeVCCollate(cfg) + train_loader = DataLoader( + train_dataset, + num_workers=cfg.train.num_workers, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + ) + + eval_dataset = FreeVCDataset(os.path.join(cfg.preprocess.split_dir, "val.txt"), cfg) + eval_loader = DataLoader( + eval_dataset, + num_workers=cfg.train.num_workers, + shuffle=True, + batch_size=cfg.train.batch_size, + pin_memory=False, + drop_last=False, + collate_fn=collate_fn, + ) + + net_g = SynthesizerTrn( + cfg.data.filter_length // 2 + 1, + cfg.train.segment_size // cfg.data.hop_length, + **cfg.model, + ).cuda() + + net_d = MultiPeriodDiscriminator(cfg.model.use_spectral_norm).cuda() + optim_g = torch.optim.AdamW( + net_g.parameters(), + cfg.train.learning_rate, + betas=cfg.train.betas, + eps=cfg.train.eps, + ) + optim_d = torch.optim.AdamW( + net_d.parameters(), + cfg.train.learning_rate, + betas=cfg.train.betas, + eps=cfg.train.eps, + ) + + try: + _, _, _, epoch_str = load_checkpoint( + latest_checkpoint_path(args.log_dir, "G_*.pth"), net_g, optim_g + ) + _, _, _, epoch_str = load_checkpoint( + latest_checkpoint_path(args.log_dir, "D_*.pth"), net_d, optim_d + ) + global_step = (epoch_str - 1) * len(train_loader) + except Exception: + epoch_str = 1 + global_step = 0 + print(f"global_step: {global_step}") + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + optim_g, gamma=cfg.train.lr_decay, last_epoch=epoch_str - 2 + ) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + optim_d, gamma=cfg.train.lr_decay, last_epoch=epoch_str - 2 + ) + + scaler = GradScaler(enabled=cfg.train.fp16_run) + + for epoch in range(epoch_str, cfg.train.epochs + 1): + train_and_evaluate( + epoch, + cfg, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, eval_loader], + logger, + [writer, writer_eval], + ) + scheduler_g.step() + scheduler_d.step() + + +def train_and_evaluate( + epoch, cfg, nets, optims, schedulers, scaler, loaders, logger, writers +): + net_g, net_d = nets + optim_g, optim_d = optims + scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + # train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + net_d.train() + for batch_idx, items in enumerate(train_loader): + if cfg.model.use_spk: + c, spec, y, spk = items + g = spk.cuda(non_blocking=True) + else: + c, spec, y = items + g = None + spec, y = spec.cuda(non_blocking=True), y.cuda(non_blocking=True) + c = c.cuda(non_blocking=True) + mel = spec_to_mel_torch( + spec, + cfg.data.filter_length, + cfg.data.n_mel_channels, + cfg.data.sampling_rate, + cfg.data.mel_fmin, + cfg.data.mel_fmax, + ) + + with autocast(enabled=cfg.train.fp16_run): + y_hat, ids_slice, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = net_g( + c, spec, g=g, mel=mel + ) + + y_mel = slice_segments( + mel, ids_slice, cfg.train.segment_size // cfg.data.hop_length + ) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + cfg.data.filter_length, + cfg.data.n_mel_channels, + cfg.data.sampling_rate, + cfg.data.hop_length, + cfg.data.win_length, + cfg.data.mel_fmin, + cfg.data.mel_fmax, + ) + y = slice_segments( + y, ids_slice * cfg.data.hop_length, cfg.train.segment_size + ) # slice + + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + with autocast(enabled=False): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( + y_d_hat_r, y_d_hat_g + ) + loss_disc_all = loss_disc + optim_d.zero_grad() + scaler.scale(loss_disc_all).backward() + scaler.unscale_(optim_d) + grad_norm_d = clip_grad_value_(net_d.parameters(), None) + scaler.step(optim_d) + + with autocast(enabled=cfg.train.fp16_run): + # Generator + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) + with autocast(enabled=False): + loss_mel = F.l1_loss(y_mel, y_hat_mel) * cfg.train.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * cfg.train.c_kl + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if global_step % cfg.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] + logger.info( + "Train Epoch: {} [{:.0f}%]".format( + epoch, 100.0 * batch_idx / len(train_loader) + ) + ) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = { + "loss/g/total": loss_gen_all, + "loss/d/total": loss_disc_all, + "learning_rate": lr, + "grad_norm_d": grad_norm_d, + "grad_norm_g": grad_norm_g, + } + scalar_dict.update( + {"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl} + ) + + scalar_dict.update( + {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} + ) + scalar_dict.update( + {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)} + ) + scalar_dict.update( + {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} + ) + image_dict = { + "slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), + "slice/mel_gen": plot_spectrogram_to_numpy( + y_hat_mel[0].data.cpu().numpy() + ), + "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), + } + summarize( + writer=writer, + global_step=global_step, + images=image_dict, + scalars=scalar_dict, + ) + + if global_step % cfg.train.eval_interval == 0: + evaluate(cfg, net_g, eval_loader, writer_eval) + save_checkpoint( + net_g, + optim_g, + cfg.train.learning_rate, + epoch, + os.path.join(args.log_dir, "G_{}.pth".format(global_step)), + ) + save_checkpoint( + net_d, + optim_d, + cfg.train.learning_rate, + epoch, + os.path.join(args.log_dir, "D_{}.pth".format(global_step)), + ) + global_step += 1 + + logger.info("====> Epoch: {}".format(epoch)) + + +def evaluate(cfg, generator, eval_loader, writer_eval): + generator.eval() + with torch.no_grad(): + for batch_idx, items in enumerate(eval_loader): + if cfg.model.use_spk: + c, spec, y, spk = items + g = spk[:1].cuda(0) + else: + c, spec, y = items + g = None + spec, y = spec[:1].cuda(0), y[:1].cuda(0) + c = c[:1].cuda(0) + break + mel = spec_to_mel_torch( + spec, + cfg.data.filter_length, + cfg.data.n_mel_channels, + cfg.data.sampling_rate, + cfg.data.mel_fmin, + cfg.data.mel_fmax, + ) + y_hat = generator.infer(c, g=g, mel=mel) + + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1).float(), + cfg.data.filter_length, + cfg.data.n_mel_channels, + cfg.data.sampling_rate, + cfg.data.hop_length, + cfg.data.win_length, + cfg.data.mel_fmin, + cfg.data.mel_fmax, + ) + image_dict = { + "gen/mel": plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()), + "gt/mel": plot_spectrogram_to_numpy(mel[0].cpu().numpy()), + } + audio_dict = {"gen/audio": y_hat[0], "gt/audio": y[0]} + summarize( + writer=writer_eval, + global_step=global_step, + images=image_dict, + audios=audio_dict, + audio_sampling_rate=cfg.data.sampling_rate, + ) + generator.train() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + default="config.json", + help="json files for configurations.", + required=True, + ) + parser.add_argument( + "--exp_name", + type=str, + default="exp_name", + help="A specific name to note the experiment", + required=True, + ) + parser.add_argument( + "--resume", + action="store_true", + help="If specified, to resume from the existing checkpoint.", + ) + parser.add_argument( + "--resume_from_ckpt_path", + type=str, + default="", + help="The specific checkpoint path that you want to resume from.", + ) + parser.add_argument( + "--resume_type", + type=str, + default="", + help="`resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights", + ) + parser.add_argument( + "--log_level", default="warning", help="logging level (debug, info, warning)" + ) + args = parser.parse_args() + args.log_dir = f"ckpts/vc/FreeVC/{args.exp_name}" + cfg = load_config(args.config) + main(cfg, args) diff --git a/models/vc/FreeVC/train_utils.py b/models/vc/FreeVC/train_utils.py new file mode 100644 index 00000000..c21ebdda --- /dev/null +++ b/models/vc/FreeVC/train_utils.py @@ -0,0 +1,167 @@ +import os +import glob +import logging +import sys +import numpy as np +import torch + +MATPLOTLIB_FLAG = False + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger = logging + + +def load_checkpoint(checkpoint_path, model, optimizer=None, strict=False): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + iteration = checkpoint_dict["iteration"] + learning_rate = checkpoint_dict["learning_rate"] + if optimizer is not None: + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + saved_state_dict = checkpoint_dict["model"] + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + if strict: + assert ( + state_dict.keys() == saved_state_dict.keys() + ), "Mismatched model config and checkpoint." + new_state_dict = {} + for k, v in state_dict.items(): + try: + new_state_dict[k] = saved_state_dict[k] + except Exception: + logger.info("%s is not in the checkpoint" % k) + new_state_dict[k] = v + if hasattr(model, "module"): + model.module.load_state_dict(new_state_dict) + else: + model.load_state_dict(new_state_dict) + logger.info( + "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration) + ) + return model, optimizer, learning_rate, iteration + + +def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): + logger.info( + "Saving model and optimizer state at iteration {} to {}".format( + iteration, checkpoint_path + ) + ) + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save( + { + "model": state_dict, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) + + +def summarize( + writer, + global_step, + scalars={}, + histograms={}, + images={}, + audios={}, + audio_sampling_rate=22050, +): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats="HWC") + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) + + +def latest_checkpoint_path(dir_path, regex="G_*.pth"): + f_list = glob.glob(os.path.join(dir_path, regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + x = f_list[-1] + print(x) + return x + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring( + fig.canvas.tostring_rgb(), dtype=np.uint8, sep="" + ) # type:ignore + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow( + alignment.transpose(), aspect="auto", origin="lower", interpolation="none" + ) + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + if info is not None: + xlabel += "\n\n" + info + plt.xlabel(xlabel) + plt.ylabel("Encoder timestep") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring( + fig.canvas.tostring_rgb(), dtype=np.uint8, sep="" + ) # type:ignore + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.DEBUG) + h.setFormatter(formatter) + logger.addHandler(h) + return logger diff --git a/models/vc/FreeVC/wavlm.py b/models/vc/FreeVC/wavlm.py new file mode 100644 index 00000000..f2bf703a --- /dev/null +++ b/models/vc/FreeVC/wavlm.py @@ -0,0 +1,48 @@ +import os + +from huggingface_hub import try_to_load_from_cache, snapshot_download +import torch +from transformers import WavLMModel + +REPO_ID = "microsoft/wavlm-large" + + +def rename_state_key(state_dict, key, new_key): + state_dict[new_key] = state_dict.pop(key) + + +def load_wavlm(): + # https://github.com/huggingface/transformers/issues/30469 + bin_name = "pytorch_model.bin" + bin_path = try_to_load_from_cache(repo_id=REPO_ID, filename=bin_name) + if bin_path is None: + download_wavlm() + bin_path = try_to_load_from_cache(repo_id=REPO_ID, filename=bin_name) + assert bin_path is not None + + # https://github.com/pytorch/pytorch/issues/102999 + # https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html + state_dict = torch.load(bin_path) + rename_state_key( + state_dict, + "encoder.pos_conv_embed.conv.weight_g", + "encoder.pos_conv_embed.conv.parametrizations.weight.original0", + ) + rename_state_key( + state_dict, + "encoder.pos_conv_embed.conv.weight_v", + "encoder.pos_conv_embed.conv.parametrizations.weight.original1", + ) + + model = WavLMModel.from_pretrained(os.path.dirname(bin_path), state_dict=state_dict) + assert isinstance(model, WavLMModel) + return model + + +def download_wavlm(): + snapshot_download(repo_id=REPO_ID, repo_type="model", resume_download=True) + + +if __name__ == "__main__": + download_wavlm() + print(load_wavlm()) From 9c356e9c3f9dba457b57fbdf9765bde92dbf560e Mon Sep 17 00:00:00 2001 From: Nugine Date: Tue, 7 May 2024 23:06:36 +0800 Subject: [PATCH 2/8] fix --- egs/vc/FreeVC/README.md | 4 +++ egs/vc/FreeVC/freevc.json | 54 ------------------------------------- models/vc/FreeVC/commons.py | 6 ++--- 3 files changed, 7 insertions(+), 57 deletions(-) delete mode 100644 egs/vc/FreeVC/freevc.json diff --git a/egs/vc/FreeVC/README.md b/egs/vc/FreeVC/README.md index af4eb02e..47033ca8 100644 --- a/egs/vc/FreeVC/README.md +++ b/egs/vc/FreeVC/README.md @@ -46,6 +46,10 @@ The code will automatically download pretrained [WavLM-Large](https://huggingfac huggingface-cli download microsoft/wavlm-large ``` +The pretrained speaker encoder is available at: + +The weight should be put in `models/vc/FreeVC/speaker_encoder/ckpt/` since it is excluded from the git history. + ### Configuration Specify the data path and the checkpoint path for saving the processed data in `exp_config.json`: diff --git a/egs/vc/FreeVC/freevc.json b/egs/vc/FreeVC/freevc.json deleted file mode 100644 index c25a9f1d..00000000 --- a/egs/vc/FreeVC/freevc.json +++ /dev/null @@ -1,54 +0,0 @@ -{ - "train": { - "log_interval": 200, - "eval_interval": 10000, - "seed": 1234, - "epochs": 10000, - "learning_rate": 2e-4, - "betas": [0.8, 0.99], - "eps": 1e-9, - "batch_size": 64, - "fp16_run": false, - "lr_decay": 0.999875, - "segment_size": 8960, - "init_lr_ratio": 1, - "warmup_epochs": 0, - "c_mel": 45, - "c_kl": 1.0, - "use_sr": true, - "max_speclen": 128, - "port": "8001" - }, - "data": { - "training_files":"filelists/train.txt", - "validation_files":"filelists/val.txt", - "max_wav_value": 32768.0, - "sampling_rate": 16000, - "filter_length": 1280, - "hop_length": 320, - "win_length": 1280, - "n_mel_channels": 80, - "mel_fmin": 0.0, - "mel_fmax": null - }, - "model": { - "inter_channels": 192, - "hidden_channels": 192, - "filter_channels": 768, - "n_heads": 2, - "n_layers": 6, - "kernel_size": 3, - "p_dropout": 0.1, - "resblock": "1", - "resblock_kernel_sizes": [3,7,11], - "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], - "upsample_rates": [10,8,2,2], - "upsample_initial_channel": 512, - "upsample_kernel_sizes": [16,16,4,4], - "n_layers_q": 3, - "use_spectral_norm": false, - "gin_channels": 256, - "ssl_dim": 1024, - "use_spk": true - } -} diff --git a/models/vc/FreeVC/commons.py b/models/vc/FreeVC/commons.py index 6c6d8c8c..ac3d4a41 100644 --- a/models/vc/FreeVC/commons.py +++ b/models/vc/FreeVC/commons.py @@ -1,8 +1,8 @@ from models.tts.vits.vits import ( slice_segments, - rand_slice_segments, - get_padding, -) # noqa: F401 + rand_slice_segments, # noqa: F401 + get_padding, # noqa: F401 +) import torch From 97a1e5a8c8b854efdc596b362e046cfd004d6d5b Mon Sep 17 00:00:00 2001 From: Nugine Date: Tue, 7 May 2024 23:59:39 +0800 Subject: [PATCH 3/8] fix --- config/freevc.json | 1 - models/vc/FreeVC/data.py | 7 ++----- models/vc/FreeVC/train.py | 9 +++++---- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/config/freevc.json b/config/freevc.json index 085c95b3..4bef623a 100644 --- a/config/freevc.json +++ b/config/freevc.json @@ -12,7 +12,6 @@ "maxh": 92 }, "data": { - "max_wav_value": 32768.0, "sampling_rate": 16000, "filter_length": 1280, "hop_length": 320, diff --git a/models/vc/FreeVC/data.py b/models/vc/FreeVC/data.py index 41460ff1..ea5bcd56 100644 --- a/models/vc/FreeVC/data.py +++ b/models/vc/FreeVC/data.py @@ -22,7 +22,6 @@ class FreeVCDataset(Dataset): def __init__(self, audiopaths, hparams): self.audiopaths = read_txt_lines(audiopaths) - self.max_wav_value = hparams.data.max_wav_value self.sampling_rate = hparams.data.sampling_rate self.filter_length = hparams.data.filter_length self.hop_length = hparams.data.hop_length @@ -63,10 +62,8 @@ def load_sample(self, filename): f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR" ) - audio_norm = audio / self.max_wav_value - spec = spectrogram_torch( - audio_norm, + audio, self.filter_length, self.sampling_rate, self.hop_length, @@ -88,7 +85,7 @@ def load_sample(self, filename): ssl_path = os.path.join(self.sr_dir, filename.replace(".wav", f"_{h}.pt")) ssl = torch.load(ssl_path).squeeze_(0) - return ssl, spec, audio_norm, spk + return ssl, spec, audio, spk def __getitem__(self, index): return self.load_sample(self.audiopaths[index]) diff --git a/models/vc/FreeVC/train.py b/models/vc/FreeVC/train.py index 87dd591d..ab81d43d 100644 --- a/models/vc/FreeVC/train.py +++ b/models/vc/FreeVC/train.py @@ -150,14 +150,16 @@ def train_and_evaluate( net_g.train() net_d.train() for batch_idx, items in enumerate(train_loader): + c, spec, y, spk = items if cfg.model.use_spk: - c, spec, y, spk = items g = spk.cuda(non_blocking=True) else: - c, spec, y = items g = None spec, y = spec.cuda(non_blocking=True), y.cuda(non_blocking=True) c = c.cuda(non_blocking=True) + + torch.cuda.synchronize() + mel = spec_to_mel_torch( spec, cfg.data.filter_length, @@ -287,11 +289,10 @@ def evaluate(cfg, generator, eval_loader, writer_eval): generator.eval() with torch.no_grad(): for batch_idx, items in enumerate(eval_loader): + c, spec, y, spk = items if cfg.model.use_spk: - c, spec, y, spk = items g = spk[:1].cuda(0) else: - c, spec, y = items g = None spec, y = spec[:1].cuda(0), y[:1].cuda(0) c = c[:1].cuda(0) From 26348552ba895d2a0062c6a6e56a827b4fa78a05 Mon Sep 17 00:00:00 2001 From: EricLaw Date: Wed, 8 May 2024 11:50:34 +0800 Subject: [PATCH 4/8] Add WER PCC NER evaluation file in models --- models/vc/FreeVC/F0_PCC.py | 68 ++++++++++++++++++++++++++++++++++++++ models/vc/FreeVC/get_gt.py | 29 ++++++++++++++++ models/vc/FreeVC/wer.py | 66 ++++++++++++++++++++++++++++++++++++ 3 files changed, 163 insertions(+) create mode 100644 models/vc/FreeVC/F0_PCC.py create mode 100644 models/vc/FreeVC/get_gt.py create mode 100644 models/vc/FreeVC/wer.py diff --git a/models/vc/FreeVC/F0_PCC.py b/models/vc/FreeVC/F0_PCC.py new file mode 100644 index 00000000..5865bda7 --- /dev/null +++ b/models/vc/FreeVC/F0_PCC.py @@ -0,0 +1,68 @@ +from tqdm import tqdm +import numpy as np +import pyworld as pw +import argparse +import librosa +import os + + +def get_f0(x, fs=16000, n_shift=160): + x = x.astype(np.float64) + frame_period = n_shift / fs * 1000 + f0, timeaxis = pw.dio(x, fs, frame_period=frame_period) + f0 = pw.stonemask(x, f0, timeaxis, fs) + return f0 + + +def compute_f0(wav, sr=16000, frame_period=10.0): + wav = wav.astype(np.float64) + f0, timeaxis = pw.harvest( + wav, sr, frame_period=frame_period, f0_floor=20.0, f0_ceil=600.0) + return f0 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # parser.add_argument("--txtpath", type=str, default="samples.txt", help="path to txt file") + parser.add_argument("--src_path", type=str, default=r"data\VCTK\test_data", help="path to src audio files") + parser.add_argument("--tgt_path", type=str, default=r"data\VCTK\test_output", help="path to output audio files") + parser.add_argument("--output_path", type=str, default=r"result\pcc.txt", help="path tot the pcc output file") + args = parser.parse_args() + + pccs = [] + i = 0 + src_files = [f for f in os.listdir(args.src_path) if f.endswith('.wav')] + + for filename in src_files: + # print(filename) + path_of_src_audio = os.path.join(args.src_path, filename) + path_of_res_audio = os.path.join(args.tgt_path, os.path.splitext(filename)[0]+"_sync.wav") + # print(os.path.exists(path_of_res_audio)) + + if os.path.exists(path_of_res_audio): + # print(path_of_res_audio) + # 加载音频文件 + src = librosa.load(path_of_src_audio, sr=16000)[0] + src_f0 = get_f0(src) + tgt = librosa.load(path_of_res_audio, sr=16000)[0] + tgt_f0 = get_f0(tgt) + if sum(src_f0) == 0: + src_f0 = compute_f0(src) + tgt_f0 = compute_f0(tgt) + # print(rawline) + pcc = np.corrcoef(src_f0[:tgt_f0.shape[-1]], tgt_f0[:src_f0.shape[-1]])[0, 1] + print(pcc) + #print(i, pcc) + if not np.isnan(pcc.item()): + pccs.append(pcc.item()) + + else: + print(f"Warning: No matching file for {filename} in result folder") + + with open(args.output_path, "w") as f: + for pcc in pccs: + f.write(f"{pcc}\n") + pcc = sum(pccs) / len(pccs) + f.write(f"mean pcc: {pcc}") + + print("mean: ", pcc) \ No newline at end of file diff --git a/models/vc/FreeVC/get_gt.py b/models/vc/FreeVC/get_gt.py new file mode 100644 index 00000000..ab06968e --- /dev/null +++ b/models/vc/FreeVC/get_gt.py @@ -0,0 +1,29 @@ +from transformers import Wav2Vec2Processor, HubertForCTC +import os +import argparse +import torch +import librosa +from tqdm import tqdm +from glob import glob + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--txtpath", type=str, default="gt.txt", help="path to tgt txt file") + parser.add_argument("--wavdir", type=str, default=r"data\VCTK\test_data") + args = parser.parse_args() + + # load model and processor + model_text = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").cuda() + processor_text = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") + + # get transcriptions + wavs = glob(f'{args.wavdir}/*.wav') + wavs.sort() + with open(f"{args.txtpath}", "w") as f: + for path in tqdm(wavs): + wav = [librosa.load(path, sr=16000)[0]] + input_values = processor_text(wav, return_tensors="pt").input_values.cuda() # text + logits = model_text(input_values).logits + predicted_ids = torch.argmax(logits, dim=-1) + text = processor_text.batch_decode(predicted_ids)[0] + f.write(f"{path}|{text}\n") \ No newline at end of file diff --git a/models/vc/FreeVC/wer.py b/models/vc/FreeVC/wer.py new file mode 100644 index 00000000..eb35151b --- /dev/null +++ b/models/vc/FreeVC/wer.py @@ -0,0 +1,66 @@ +from transformers import Wav2Vec2Processor, HubertForCTC +import os +import argparse +import torch +import librosa +from tqdm import tqdm +from glob import glob +from jiwer import wer, cer + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--wavdir", type=str, default=r"data\VCTK\test_output") + parser.add_argument("--outdir", type=str, default="result", help="path to output dir") + parser.add_argument("--use_cuda", default=True, action="store_true") + args = parser.parse_args() + + os.makedirs(args.outdir, exist_ok=True) + + # load model and processor + model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft") + if args.use_cuda: + model = model.cuda() + processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") + + # gt + gt_dict = {} + with open("gt.txt", "r") as f: + for line in f.readlines(): + path, text = line.strip().split("|") + title = os.path.basename(path)[:-4] + gt_dict[title] = text + + # get transcriptions + wavs = glob(f'{args.wavdir}/*.wav') + wavs.sort() + trans_dict = {} + + with open(f"{args.outdir}/text.txt", "w") as f: + for path in tqdm(wavs): + wav = [librosa.load(path, sr=16000)[0]] + input_values = processor(wav, return_tensors="pt").input_values + if args.use_cuda: + input_values = input_values.cuda() + logits = model(input_values).logits + predicted_ids = torch.argmax(logits, dim=-1) + text = processor.batch_decode(predicted_ids)[0] + f.write(f"{path}|{text}\n") + title = os.path.basename(path)[:-4] + trans_dict[title] = text + + # calc + gts, trans = [], [] + for key in trans_dict.keys(): + text = trans_dict[key] + trans.append(text) + # gttext = gt_dict[key.split("-")[0]] + gttext = gt_dict[key[:8] ] + gts.append(gttext) + + wer = wer(gts, trans) + cer = cer(gts, trans) + with open(f"{args.outdir}/wer.txt", "w") as f: + f.write(f"wer: {wer}\n") + f.write(f"cer: {cer}\n") + print("WER:", wer) + print("CER:", cer) \ No newline at end of file From 8c5e3413ec162b6307e3fe2eec46c771b664fbf7 Mon Sep 17 00:00:00 2001 From: Nugine Date: Wed, 8 May 2024 12:06:03 +0800 Subject: [PATCH 5/8] format --- models/vc/FreeVC/F0_PCC.py | 55 +++++++++++++++++++++++++------------- models/vc/FreeVC/get_gt.py | 21 +++++++++------ models/vc/FreeVC/wer.py | 32 ++++++++++++---------- 3 files changed, 68 insertions(+), 40 deletions(-) diff --git a/models/vc/FreeVC/F0_PCC.py b/models/vc/FreeVC/F0_PCC.py index 5865bda7..c42554a2 100644 --- a/models/vc/FreeVC/F0_PCC.py +++ b/models/vc/FreeVC/F0_PCC.py @@ -1,4 +1,3 @@ -from tqdm import tqdm import numpy as np import pyworld as pw import argparse @@ -9,34 +8,52 @@ def get_f0(x, fs=16000, n_shift=160): x = x.astype(np.float64) frame_period = n_shift / fs * 1000 - f0, timeaxis = pw.dio(x, fs, frame_period=frame_period) - f0 = pw.stonemask(x, f0, timeaxis, fs) + f0, timeaxis = pw.dio(x, fs, frame_period=frame_period) # type:ignore + f0 = pw.stonemask(x, f0, timeaxis, fs) # type:ignore return f0 - - + + def compute_f0(wav, sr=16000, frame_period=10.0): wav = wav.astype(np.float64) - f0, timeaxis = pw.harvest( - wav, sr, frame_period=frame_period, f0_floor=20.0, f0_ceil=600.0) + f0, timeaxis = pw.harvest( # type:ignore + wav, sr, frame_period=frame_period, f0_floor=20.0, f0_ceil=600.0 + ) return f0 - + if __name__ == "__main__": parser = argparse.ArgumentParser() # parser.add_argument("--txtpath", type=str, default="samples.txt", help="path to txt file") - parser.add_argument("--src_path", type=str, default=r"data\VCTK\test_data", help="path to src audio files") - parser.add_argument("--tgt_path", type=str, default=r"data\VCTK\test_output", help="path to output audio files") - parser.add_argument("--output_path", type=str, default=r"result\pcc.txt", help="path tot the pcc output file") + parser.add_argument( + "--src_path", + type=str, + default=r"data\VCTK\test_data", + help="path to src audio files", + ) + parser.add_argument( + "--tgt_path", + type=str, + default=r"data\VCTK\test_output", + help="path to output audio files", + ) + parser.add_argument( + "--output_path", + type=str, + default=r"result\pcc.txt", + help="path tot the pcc output file", + ) args = parser.parse_args() - + pccs = [] i = 0 - src_files = [f for f in os.listdir(args.src_path) if f.endswith('.wav')] + src_files = [f for f in os.listdir(args.src_path) if f.endswith(".wav")] for filename in src_files: # print(filename) path_of_src_audio = os.path.join(args.src_path, filename) - path_of_res_audio = os.path.join(args.tgt_path, os.path.splitext(filename)[0]+"_sync.wav") + path_of_res_audio = os.path.join( + args.tgt_path, os.path.splitext(filename)[0] + "_sync.wav" + ) # print(os.path.exists(path_of_res_audio)) if os.path.exists(path_of_res_audio): @@ -50,19 +67,21 @@ def compute_f0(wav, sr=16000, frame_period=10.0): src_f0 = compute_f0(src) tgt_f0 = compute_f0(tgt) # print(rawline) - pcc = np.corrcoef(src_f0[:tgt_f0.shape[-1]], tgt_f0[:src_f0.shape[-1]])[0, 1] + pcc = np.corrcoef(src_f0[: tgt_f0.shape[-1]], tgt_f0[: src_f0.shape[-1]])[ + 0, 1 + ] print(pcc) - #print(i, pcc) + # print(i, pcc) if not np.isnan(pcc.item()): pccs.append(pcc.item()) else: print(f"Warning: No matching file for {filename} in result folder") - + with open(args.output_path, "w") as f: for pcc in pccs: f.write(f"{pcc}\n") pcc = sum(pccs) / len(pccs) f.write(f"mean pcc: {pcc}") - print("mean: ", pcc) \ No newline at end of file + print("mean: ", pcc) diff --git a/models/vc/FreeVC/get_gt.py b/models/vc/FreeVC/get_gt.py index ab06968e..f385f387 100644 --- a/models/vc/FreeVC/get_gt.py +++ b/models/vc/FreeVC/get_gt.py @@ -1,5 +1,4 @@ from transformers import Wav2Vec2Processor, HubertForCTC -import os import argparse import torch import librosa @@ -8,22 +7,28 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--txtpath", type=str, default="gt.txt", help="path to tgt txt file") + parser.add_argument( + "--txtpath", type=str, default="gt.txt", help="path to tgt txt file" + ) parser.add_argument("--wavdir", type=str, default=r"data\VCTK\test_data") args = parser.parse_args() # load model and processor - model_text = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").cuda() + model_text = HubertForCTC.from_pretrained( + "facebook/hubert-large-ls960-ft" + ).cuda() # type:ignore processor_text = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") - + # get transcriptions - wavs = glob(f'{args.wavdir}/*.wav') + wavs = glob(f"{args.wavdir}/*.wav") wavs.sort() with open(f"{args.txtpath}", "w") as f: for path in tqdm(wavs): wav = [librosa.load(path, sr=16000)[0]] - input_values = processor_text(wav, return_tensors="pt").input_values.cuda() # text + input_values = processor_text( + wav, return_tensors="pt" + ).input_values.cuda() # text # type:ignore logits = model_text(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) - text = processor_text.batch_decode(predicted_ids)[0] - f.write(f"{path}|{text}\n") \ No newline at end of file + text = processor_text.batch_decode(predicted_ids)[0] # type:ignore + f.write(f"{path}|{text}\n") diff --git a/models/vc/FreeVC/wer.py b/models/vc/FreeVC/wer.py index eb35151b..88741f24 100644 --- a/models/vc/FreeVC/wer.py +++ b/models/vc/FreeVC/wer.py @@ -10,18 +10,20 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--wavdir", type=str, default=r"data\VCTK\test_output") - parser.add_argument("--outdir", type=str, default="result", help="path to output dir") + parser.add_argument( + "--outdir", type=str, default="result", help="path to output dir" + ) parser.add_argument("--use_cuda", default=True, action="store_true") args = parser.parse_args() - + os.makedirs(args.outdir, exist_ok=True) # load model and processor model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft") if args.use_cuda: - model = model.cuda() + model = model.cuda() # type:ignore processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") - + # gt gt_dict = {} with open("gt.txt", "r") as f: @@ -29,38 +31,40 @@ path, text = line.strip().split("|") title = os.path.basename(path)[:-4] gt_dict[title] = text - + # get transcriptions - wavs = glob(f'{args.wavdir}/*.wav') + wavs = glob(f"{args.wavdir}/*.wav") wavs.sort() trans_dict = {} - + with open(f"{args.outdir}/text.txt", "w") as f: for path in tqdm(wavs): wav = [librosa.load(path, sr=16000)[0]] - input_values = processor(wav, return_tensors="pt").input_values + input_values = processor( + wav, return_tensors="pt" + ).input_values # type:ignore if args.use_cuda: input_values = input_values.cuda() - logits = model(input_values).logits + logits = model(input_values).logits # type:ignore predicted_ids = torch.argmax(logits, dim=-1) - text = processor.batch_decode(predicted_ids)[0] + text = processor.batch_decode(predicted_ids)[0] # type:ignore f.write(f"{path}|{text}\n") title = os.path.basename(path)[:-4] trans_dict[title] = text - + # calc gts, trans = [], [] for key in trans_dict.keys(): text = trans_dict[key] trans.append(text) # gttext = gt_dict[key.split("-")[0]] - gttext = gt_dict[key[:8] ] + gttext = gt_dict[key[:8]] gts.append(gttext) - + wer = wer(gts, trans) cer = cer(gts, trans) with open(f"{args.outdir}/wer.txt", "w") as f: f.write(f"wer: {wer}\n") f.write(f"cer: {cer}\n") print("WER:", wer) - print("CER:", cer) \ No newline at end of file + print("CER:", cer) From 4b8ef5021ea1539627d908d40aec95d13189ba15 Mon Sep 17 00:00:00 2001 From: Nugine Date: Wed, 8 May 2024 14:32:13 +0800 Subject: [PATCH 6/8] fix readme --- egs/vc/FreeVC/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/vc/FreeVC/README.md b/egs/vc/FreeVC/README.md index 47033ca8..cb5a1fe5 100644 --- a/egs/vc/FreeVC/README.md +++ b/egs/vc/FreeVC/README.md @@ -137,9 +137,9 @@ Then you should run `run.sh`, you need to specify the following configurations: For example: ```bash -sh egs/svc/VitsSVC/run.sh --stage 3 \ +sh egs/vc/FreeVC/run.sh --stage 3 \ --config egs/vc/FreeVC/exp_config.json \ - --ckpt ckpts/vc/FreeVC/[YourExptName]/G_5.ckpt \ + --ckpt ckpts/vc/FreeVC/[YourExptName]/G_100000.ckpt \ --convert ckpts/vc/FreeVC/[YourExptName] \ --outdir ckpts/vc/FreeVC/[YourExptName]/result \ -``` \ No newline at end of file +``` From 74203ebe58c9844537eed71403e92f570ffb42ea Mon Sep 17 00:00:00 2001 From: Nugine Date: Tue, 21 May 2024 22:52:31 +0800 Subject: [PATCH 7/8] reuse hifigan --- models/vc/FreeVC/hifigan.py | 239 ++---------------------------------- 1 file changed, 10 insertions(+), 229 deletions(-) diff --git a/models/vc/FreeVC/hifigan.py b/models/vc/FreeVC/hifigan.py index 62d778b6..161e1b69 100644 --- a/models/vc/FreeVC/hifigan.py +++ b/models/vc/FreeVC/hifigan.py @@ -1,13 +1,10 @@ # ruff: noqa: E741 +from models.vocoders.gan.generator.hifigan import HiFiGAN + import os import torch -import torch.nn.functional as F -import torch.nn as nn -from torch.nn import Conv1d, ConvTranspose1d -from torch.nn.utils import weight_norm, remove_weight_norm - from omegaconf import OmegaConf @@ -16,231 +13,15 @@ def load_hifigan(ckpt_path): config = OmegaConf.load(os.path.join(ckpt_path, "config.json")) ckpt = torch.load(os.path.join(ckpt_path, "generator_v1")) - vocoder = Generator(config) + vocoder = HiFiGAN( + OmegaConf.create( + { + "model": {"hifigan": config}, + "preprocess": {"n_mel": config.num_mels}, + } + ) + ) vocoder.load_state_dict(ckpt["generator"]) vocoder.eval() vocoder.remove_weight_norm() return vocoder, config - - -# ----------------------------------------- -# Copied from https://github.com/jik876/hifi-gan/tree/4769534d45265d52a904b850da5a622601885777 -# MIT License -# ----------------------------------------- -# COPY START - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -LRELU_SLOPE = 0.1 - - -class ResBlock1(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): - super(ResBlock1, self).__init__() - self.h = h - self.convs1 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - self.convs1.apply(init_weights) - - self.convs2 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - ] - ) - self.convs2.apply(init_weights) - - def forward(self, x): - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - xt = c2(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - - -class ResBlock2(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): - super(ResBlock2, self).__init__() - self.h = h - self.convs = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - ] - ) - self.convs.apply(init_weights) - - def forward(self, x): - for c in self.convs: - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs: - remove_weight_norm(l) - - -class Generator(torch.nn.Module): - def __init__(self, h): - super(Generator, self).__init__() - self.h = h - self.num_kernels = len(h.resblock_kernel_sizes) - self.num_upsamples = len(h.upsample_rates) - self.conv_pre = weight_norm( - Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) - ) - resblock = ResBlock1 if h.resblock == "1" else ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): - self.ups.append( - weight_norm( - ConvTranspose1d( - h.upsample_initial_channel // (2**i), - h.upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = h.upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate( - zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) - ): - self.resblocks.append(resblock(h, ch, k, d)) - - self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) - self.ups.apply(init_weights) - self.conv_post.apply(init_weights) - - def forward(self, x): - x = self.conv_pre(x) - for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels # type:ignore - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - # print("Removing weight norm...") - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) - - -# COPY END -# ----------------------------------------------- From d42e665c9d61fd8efe28b567022608975e82f2bc Mon Sep 17 00:00:00 2001 From: Nugine Date: Tue, 21 May 2024 23:29:26 +0800 Subject: [PATCH 8/8] reuse mel processing fn --- models/vc/FreeVC/mel_processing.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/models/vc/FreeVC/mel_processing.py b/models/vc/FreeVC/mel_processing.py index 3188aca5..7a1047df 100644 --- a/models/vc/FreeVC/mel_processing.py +++ b/models/vc/FreeVC/mel_processing.py @@ -1,20 +1,11 @@ # Copied from https://github.com/OlaWod/FreeVC/tree/81c169cdbfc97ff07ee2f501e9b88d543fc46126 +from utils.mel import spectral_normalize_torch + import torch import torch.utils.data from librosa.filters import mel as librosa_mel_fn -MAX_WAV_VALUE = 32768.0 - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - """ - PARAMS - ------ - C: compression factor - """ - return torch.log(torch.clamp(x, min=clip_val) * C) - def dynamic_range_decompression_torch(x, C=1): """ @@ -25,11 +16,6 @@ def dynamic_range_decompression_torch(x, C=1): return torch.exp(x) / C -def spectral_normalize_torch(magnitudes): - output = dynamic_range_compression_torch(magnitudes) - return output - - def spectral_de_normalize_torch(magnitudes): output = dynamic_range_decompression_torch(magnitudes) return output @@ -39,6 +25,7 @@ def spectral_de_normalize_torch(magnitudes): hann_window = {} +# TODO: merge with `utils.mel.extract_linear_features` def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): if torch.min(y) < -1.0: print("min value is ", torch.min(y)) @@ -94,6 +81,7 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): return spec +# TODO: merge with `utils.mel.extract_mel_features` def mel_spectrogram_torch( y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False ):