Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VC Noro model #247

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions README.md
kenxxxxx marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
In addition to the specific generation tasks, Amphion includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks. Moreover, Amphion is dedicated to advancing audio generation in real-world applications, such as building **large-scale datasets** for speech synthesis.

## 🚀 News
- **2024/09/01**: [Amphion](https://arxiv.org/abs/2312.09911) and [Emilia](https://arxiv.org/abs/2407.05361) got accepted by IEEE SLT 2024! 🤗
- **2024/08/28**: Welcome to join Amphion's [Discord channel](https://discord.com/invite/ZxxREr3Y) to stay connected and engage with our community!
- **2024/08/20**: [SingVisio](https://arxiv.org/abs/2402.12660) got accepted by Computers & Graphics, [available here](https://www.sciencedirect.com/science/article/pii/S0097849324001936)! 🎉
- **2024/08/27**: *The Emilia dataset is now publicly available!* Discover the most extensive and diverse speech generation dataset with 101k hours of in-the-wild speech data now at [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia-Dataset) or [![OpenDataLab](https://img.shields.io/badge/OpenDataLab-Dataset-blue)](https://opendatalab.com/Amphion/Emilia)! 👑👑👑
- **2024/07/01**: Amphion now releases **Emilia**, the first open-source multilingual in-the-wild dataset for speech generation with over 101k hours of speech data, and the **Emilia-Pipe**, the first open-source preprocessing pipeline designed to transform in-the-wild speech data into high-quality training data with annotations for speech generation! [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2407.05361) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia) [![demo](https://img.shields.io/badge/WebPage-Demo-red)](https://emilia-dataset.github.io/Emilia-Demo-Page/) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](preprocessors/Emilia/README.md)
- **2024/06/17**: Amphion has a new release for its **VALL-E** model! It uses Llama as its underlying architecture and has better model performance, faster training speed, and more readable codes compared to our first version. [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](egs/tts/VALLE_V2/README.md)
- **2024/03/12**: Amphion now support **NaturalSpeech3 FACodec** and release pretrained checkpoints. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2403.03100) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/naturalspeech3_facodec) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/codec/ns3_codec/README.md)
Expand Down
82 changes: 82 additions & 0 deletions bins/vc/Noro/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse

import torch
from models.vc.Noro.noro_trainer import NoroTrainer
from utils.util import load_config

def build_trainer(args, cfg):
supported_trainer = {
"VC": NoroTrainer,
}
trainer_class = supported_trainer[cfg.model_type]
trainer = trainer_class(args, cfg)
return trainer


def cuda_relevant(deterministic=False):
torch.cuda.empty_cache()
# TF32 on Ampere and above
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
# Deterministic
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
torch.use_deterministic_algorithms(deterministic)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="config.json",
help="json files for configurations.",
required=True,
)
parser.add_argument(
"--exp_name",
type=str,
default="exp_name",
help="A specific name to note the experiment",
required=True,
)
parser.add_argument(
"--resume", action="store_true", help="The model name to restore"
)
parser.add_argument(
"--log_level", default="warning", help="logging level (debug, info, warning)"
)
parser.add_argument(
"--resume_type",
type=str,
default="resume",
help="Resume training or finetuning.",
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help="Checkpoint for resume training or finetuning.",
)
NoroTrainer.add_arguments(parser)
args = parser.parse_args()
cfg = load_config(args.config)
print("experiment name: ", args.exp_name)
# # CUDA settings
cuda_relevant()
# Build trainer
print(f"Building {cfg.model_type} trainer")
trainer = build_trainer(args, cfg)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
print(f"Start training {cfg.model_type} model")
trainer.train_loop()


if __name__ == "__main__":
main()
76 changes: 76 additions & 0 deletions config/noro.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
{
"base_config": "config/base.json",
"model_type": "VC",
"dataset": ["mls"],
"model": {
"reference_encoder": {
"encoder_layer": 6,
"encoder_hidden": 512,
"encoder_head": 8,
"conv_filter_size": 2048,
"conv_kernel_size": 9,
"encoder_dropout": 0.2,
"use_skip_connection": false,
"use_new_ffn": true,
"ref_in_dim": 80,
"ref_out_dim": 512,
"use_query_emb": true,
"num_query_emb": 32
},
"diffusion": {
"beta_min": 0.05,
"beta_max": 20,
"sigma": 1.0,
"noise_factor": 1.0,
"ode_solve_method": "euler",
"diff_model_type": "WaveNet",
"diff_wavenet":{
"input_size": 80,
"hidden_size": 512,
"out_size": 80,
"num_layers": 47,
"cross_attn_per_layer": 3,
"dilation_cycle": 2,
"attn_head": 8,
"drop_out": 0.2
}
},
"prior_encoder": {
"encoder_layer": 6,
"encoder_hidden": 512,
"encoder_head": 8,
"conv_filter_size": 2048,
"conv_kernel_size": 9,
"encoder_dropout": 0.2,
"use_skip_connection": false,
"use_new_ffn": true,
"vocab_size": 256,
"cond_dim": 512,
"duration_predictor": {
"input_size": 512,
"filter_size": 512,
"kernel_size": 3,
"conv_layers": 30,
"cross_attn_per_layer": 3,
"attn_head": 8,
"drop_out": 0.2
},
"pitch_predictor": {
"input_size": 512,
"filter_size": 512,
"kernel_size": 5,
"conv_layers": 30,
"cross_attn_per_layer": 3,
"attn_head": 8,
"drop_out": 0.5
},
"pitch_min": 50,
"pitch_max": 1100,
"pitch_bins_num": 512
},
"vc_feature": {
"content_feature_dim": 768,
"hidden_dim": 512
}
}
}
2 changes: 1 addition & 1 deletion config/tts.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"add_blank": true
},
"model": {
"text_token_num": 512,
"text_token_num": 512
}

}
122 changes: 122 additions & 0 deletions egs/vc/Noro/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Noro: A Noise-Robust One-shot Voice Conversion System

<br>
<div align="center">
<img src="../../../imgs/vc/NoroVC.png" width="85%">
</div>
<br>

This is the official implementation of the paper: NORO: A Noise-Robust One-Shot Voice Conversion System with Hidden Speaker Representation Capabilities.

- The semantic extractor is from [Hubert](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert).
- The vocoder is [BigVGAN](https://github.com/NVIDIA/BigVGAN) architecture.

## Project Overview
Noro is a noise-robust one-shot voice conversion (VC) system designed to convert the timbre of speech from a source speaker to a target speaker using only a single reference speech sample, while preserving the semantic content of the original speech. Noro introduces innovative components tailored for VC using noisy reference speeches, including a dual-branch reference encoding module and a noise-agnostic contrastive speaker loss.

## Features
- **Noise-Robust Voice Conversion**: Utilizes a dual-branch reference encoding module and noise-agnostic contrastive speaker loss to maintain high-quality voice conversion in noisy environments.
- **One-shot Voice Conversion**: Achieves timbre conversion using only one reference speech sample.
- **Speaker Representation Learning**: Explores the potential of the reference encoder as a self-supervised speaker encoder.

## Installation Requirement

Set up your environment as in Amphion README (you'll need a conda environment, and we recommend using Linux).

### Prepare Hubert Model

Humbert checkpoint and kmeans can be downloaded [here](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert).
Set the downloded model path at `egs/vc/Noro/exp_config_base.json`.


## Usage

### Download pretrained weights
You need to download our pretrained weights from [Google Drive](https://drive.google.com/drive/folders/1NPzSIuSKO8o87g5ImNzpw_BgbhsZaxNg?usp=drive_link).

### Inference
1. Configure inference parameters:
Modify the pretrained checkpoint path, source voice path and reference voice path at `egs/vc/Noro/noro_inference.sh` file.
Currently it's at line 35.
```
checkpoint_path="path/to/checkpoint/model.safetensors"
output_dir="path/to/output/directory"
source_path="path/to/source/audio.wav"
reference_path="path/to/reference/audio.wav"
```
2. Start inference:
```bash
bash path/to/Amphion/egs/vc/noro_inference.sh
```

3. You got the reconstructed mel spectrum saved to the output direction.
Then use the [BigVGAN](https://github.com/NVIDIA/BigVGAN) to construct the wav file.

## Training from Scratch

### Data Preparation

We use the LibriLight dataset for training and evaluation. You can download it using the following commands:
```bash
wget https://dl.fbaipublicfiles.com/librilight/data/large.tar
wget https://dl.fbaipublicfiles.com/librilight/data/medium.tar
wget https://dl.fbaipublicfiles.com/librilight/data/small.tar
```

### Training the Model with Clean Reference Voice

Configure training parameters:
Our configuration file for training clean Noro model is at "egs/vc/Noro/exp_config_clean.json", and Nosiy Noro model at "egs/vc/Noro/exp_config_noisy.json".

To train your model, you need to modify the `dataset` variable in the json configurations.
Currently it's at line 40, you should modify the "data_dir" to your dataset's root directory.
```
"directory_list": [
"path/to/your/training_data_directory1",
"path/to/your/training_data_directory2",
"path/to/your/training_data_directory3"
],
```

If you want to train for the noisy noro model, you also need to set the direction path for the noisy data at "egs/vc/Noro/exp_config_noisy.json".
```
"noise_dir": "path/to/your/noise/train/directory",
"test_noise_dir": "path/to/your/noise/test/directory"
```

You can change other experiment settings in the config flies such as the learning rate, optimizer and the dataset.

**Set smaller batch_size if you are out of memory😢😢**

I used max_tokens = 3200000 to successfully run on a single card, if you'r out of memory, try smaller.

```json
"max_tokens": 3200000
```
### Resume from existing checkpoint
Our framework supports resuming from existing checkpoint.
If this is a new experiment, use the following command:
```
CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 26667 --mixed_precision fp16 \
"${work_dir}/bins/vc/train.py" \
--config $exp_config \
--exp_name $exp_name \
--log_level debug
```
To resume training or fine-tune from a checkpoint, use the following command:
Ensure the options `--resume`, `--resume_type resume`, and `--checkpoint_path` are set.

### Run the command to Train model
Start clean training:
```bash
bash path/to/Amphion/egs/vc/noro_train_clean.sh
```


Start noisy training:
```bash
bash path/to/Amphion/egs/vc/noro_train_noisy.sh
```



61 changes: 61 additions & 0 deletions egs/vc/Noro/exp_config_base.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{
"base_config": "config/noro.json",
"model_type": "VC",
"dataset": [
"mls"
],
"sample_rate": 16000,
"n_fft": 1024,
"n_mel": 80,
"hop_size": 200,
"win_size": 800,
"fmin": 0,
"fmax": 8000,
"preprocess": {
"kmeans_model_path": "path/to/kmeans_model",
"hubert_model_path": "path/to/hubert_model",
"sample_rate": 16000,
"hop_size": 200,
"f0_min": 50,
"f0_max": 500,
"frame_period": 12.5
},
"model": {
"reference_encoder": {
"encoder_layer": 6,
"encoder_hidden": 512,
"encoder_head": 8,
"conv_filter_size": 2048,
"conv_kernel_size": 9,
"encoder_dropout": 0.2,
"use_skip_connection": false,
"use_new_ffn": true,
"ref_in_dim": 80,
"ref_out_dim": 512,
"use_query_emb": true,
"num_query_emb": 32
},
"diffusion": {
"beta_min": 0.05,
"beta_max": 20,
"sigma": 1.0,
"noise_factor": 1.0,
"ode_solve_method": "euler",
"diff_model_type": "WaveNet",
"diff_wavenet":{
"input_size": 80,
"hidden_size": 512,
"out_size": 80,
"num_layers": 47,
"cross_attn_per_layer": 3,
"dilation_cycle": 2,
"attn_head": 8,
"drop_out": 0.2
}
},
"vc_feature": {
"content_feature_dim": 768,
"hidden_dim": 512
}
}
}
Loading
Loading