Skip to content

Commit 211e1d4

Browse files
authored
Training & Inference code for FAcodec (#229)
* Training & Inference code for FAcodec * Update vocoder_trainer.py * Added copyright statements & code source (where necessary) * reformatted files with black formatter * reformat * reformat reformat * reformat reformat reformat
1 parent a17f139 commit 211e1d4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+7442
-2
lines changed

bins/codec/inference.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) 2023 Amphion.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import argparse
7+
from argparse import ArgumentParser
8+
import os
9+
10+
from models.codec.facodec.facodec_inference import FAcodecInference
11+
from utils.util import load_config
12+
import torch
13+
14+
15+
def build_inference(args, cfg):
16+
supported_inference = {
17+
"FAcodec": FAcodecInference,
18+
}
19+
20+
inference_class = supported_inference[cfg.model_type]
21+
inference = inference_class(args, cfg)
22+
return inference
23+
24+
25+
def cuda_relevant(deterministic=False):
26+
torch.cuda.empty_cache()
27+
# TF32 on Ampere and above
28+
torch.backends.cuda.matmul.allow_tf32 = True
29+
torch.backends.cudnn.enabled = True
30+
torch.backends.cudnn.allow_tf32 = True
31+
# Deterministic
32+
torch.backends.cudnn.deterministic = deterministic
33+
torch.backends.cudnn.benchmark = not deterministic
34+
torch.use_deterministic_algorithms(deterministic)
35+
36+
37+
def build_parser():
38+
parser = argparse.ArgumentParser()
39+
40+
parser.add_argument(
41+
"--config",
42+
type=str,
43+
required=True,
44+
help="JSON/YAML file for configurations.",
45+
)
46+
parser.add_argument(
47+
"--checkpoint_path",
48+
type=str,
49+
default=None,
50+
help="Acoustic model checkpoint directory. If a directory is given, "
51+
"search for the latest checkpoint dir in the directory. If a specific "
52+
"checkpoint dir is given, directly load the checkpoint.",
53+
)
54+
parser.add_argument(
55+
"--source",
56+
type=str,
57+
required=True,
58+
help="Path to the source audio file",
59+
)
60+
parser.add_argument(
61+
"--reference",
62+
type=str,
63+
default=None,
64+
help="Path to the reference audio file, passing an",
65+
)
66+
parser.add_argument(
67+
"--output_dir",
68+
type=str,
69+
default=None,
70+
help="Output dir for saving generated results",
71+
)
72+
return parser
73+
74+
75+
def main():
76+
# Parse arguments
77+
parser = build_parser()
78+
args = parser.parse_args()
79+
print(args)
80+
81+
# Parse config
82+
cfg = load_config(args.config)
83+
84+
# CUDA settings
85+
cuda_relevant()
86+
87+
# Build inference
88+
inferencer = build_inference(args, cfg)
89+
90+
# Run inference
91+
_ = inferencer.inference(args.source, args.output_dir)
92+
93+
# Run voice conversion
94+
if args.reference is not None:
95+
_ = inferencer.voice_conversion(args.source, args.reference, args.output_dir)
96+
97+
98+
if __name__ == "__main__":
99+
main()

bins/codec/train.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2023 Amphion.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import argparse
7+
8+
import torch
9+
10+
from models.codec.facodec.facodec_trainer import FAcodecTrainer
11+
12+
from utils.util import load_config
13+
14+
15+
def build_trainer(args, cfg):
16+
supported_trainer = {
17+
"FAcodec": FAcodecTrainer,
18+
}
19+
20+
trainer_class = supported_trainer[cfg.model_type]
21+
trainer = trainer_class(args, cfg)
22+
return trainer
23+
24+
25+
def cuda_relevant(deterministic=False):
26+
torch.cuda.empty_cache()
27+
# TF32 on Ampere and above
28+
torch.backends.cuda.matmul.allow_tf32 = True
29+
torch.backends.cudnn.enabled = True
30+
torch.backends.cudnn.benchmark = False
31+
torch.backends.cudnn.allow_tf32 = True
32+
# Deterministic
33+
torch.backends.cudnn.deterministic = deterministic
34+
torch.backends.cudnn.benchmark = not deterministic
35+
torch.use_deterministic_algorithms(deterministic)
36+
37+
38+
def main():
39+
parser = argparse.ArgumentParser()
40+
parser.add_argument(
41+
"--config",
42+
default="config.json",
43+
help="json files for configurations.",
44+
required=True,
45+
)
46+
parser.add_argument(
47+
"--exp_name",
48+
type=str,
49+
default="exp_name",
50+
help="A specific name to note the experiment",
51+
required=True,
52+
)
53+
parser.add_argument(
54+
"--resume_type",
55+
type=str,
56+
help="resume for continue to train, finetune for finetuning",
57+
)
58+
parser.add_argument(
59+
"--checkpoint",
60+
type=str,
61+
help="checkpoint to resume",
62+
)
63+
parser.add_argument(
64+
"--log_level", default="warning", help="logging level (debug, info, warning)"
65+
)
66+
args = parser.parse_args()
67+
cfg = load_config(args.config)
68+
69+
# CUDA settings
70+
cuda_relevant()
71+
72+
# Build trainer
73+
trainer = build_trainer(args, cfg)
74+
75+
trainer.train_loop()
76+
77+
78+
if __name__ == "__main__":
79+
main()

config/facodec.json

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
{
2+
"exp_name": "facodec",
3+
"model_type": "FAcodec",
4+
"log_dir": "./runs/",
5+
"log_interval": 10,
6+
"save_interval": 1000,
7+
"device": "cuda",
8+
"epochs": 1000,
9+
"batch_size": 4,
10+
"batch_length": 100,
11+
"max_len": 80,
12+
"pretrained_model": "",
13+
"load_only_params": false,
14+
"F0_path": "modules/JDC/bst.t7",
15+
"dataset": "dummy",
16+
"preprocess_params": {
17+
"sr": 24000,
18+
"frame_rate": 80,
19+
"duration_range": [1.0, 25.0],
20+
"spect_params": {
21+
"n_fft": 2048,
22+
"win_length": 1200,
23+
"hop_length": 300,
24+
"n_mels": 80,
25+
},
26+
},
27+
"train": {
28+
"gradient_accumulation_step": 1,
29+
"batch_size": 1,
30+
"save_checkpoint_stride": [20],
31+
"random_seed": 1234,
32+
"max_epoch": -1,
33+
"max_frame_len": 80,
34+
"tracker": ["tensorboard"],
35+
"run_eval": [false],
36+
"sampler": {"holistic_shuffle": true, "drop_last": true},
37+
"dataloader": {"num_worker": 0, "pin_memory": true},
38+
},
39+
"model_params": {
40+
"causal": true,
41+
"lstm": 2,
42+
"norm_f0": true,
43+
"use_gr_content_f0": false,
44+
"use_gr_prosody_phone": false,
45+
"use_gr_timbre_prosody": false,
46+
"separate_prosody_encoder": true,
47+
"n_c_codebooks": 2,
48+
"timbre_norm": true,
49+
"use_gr_content_global_f0": true,
50+
"DAC": {
51+
"encoder_dim": 64,
52+
"encoder_rates": [2, 5, 5, 6],
53+
"decoder_dim": 1536,
54+
"decoder_rates": [6, 5, 5, 2],
55+
"sr": 24000,
56+
},
57+
},
58+
"loss_params": {
59+
"base_lr": 0.0001,
60+
"warmup_steps": 200,
61+
"discriminator_iter_start": 2000,
62+
"lambda_spk": 1.0,
63+
"lambda_mel": 45,
64+
"lambda_f0": 1.0,
65+
"lambda_uv": 1.0,
66+
},
67+
}

egs/codec/FAcodec/README.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# FAcodec
2+
3+
Pytorch implementation for the training of FAcodec, which was proposed in paper [NaturalSpeech 3: Zero-Shot Speech Synthesis
4+
with Factorized Codec and Diffusion Models](https://arxiv.org/pdf/2403.03100)
5+
6+
A dedicated repository for the FAcodec model can also be find [here](https://github.com/Plachtaa/FAcodec).
7+
8+
This implementation made some key improvements to the training pipeline, so that the requirements of any form of annotations, including
9+
transcripts, phoneme alignments, and speaker labels, are eliminated. All you need are simply raw speech files.
10+
With the new training pipeline, it is possible to train the model on more languages with more diverse timbre distributions.
11+
We release the code for training and inference, including a pretrained checkpoint on 50k hours speech data with over 1 million speakers.
12+
13+
## Model storage
14+
We provide pretrained checkpoints on 50k hours speech data.
15+
16+
| Model type | Link |
17+
|-------------------|----------------------------------------------------------------------------------------------------------------------------------------|
18+
| FAcodec | [![Hugging Face](https://img.shields.io/badge/🤗%20Hugging%20Face-FAcodec-blue)](https://huggingface.co/Plachta/FAcodec) |
19+
20+
## Demo
21+
Try our model on [![Hugging Face](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-blue)](https://huggingface.co/spaces/Plachta/FAcodecV2)!
22+
23+
## Training
24+
Prepare your data and put them under one folder, internal file structure does not matter.
25+
Then, change the `dataset` in `./egs/codec/FAcodec/exp_custom_data.json` to the path of your data folder.
26+
Finally, run the following command:
27+
```bash
28+
sh ./egs/codec/FAcodec/train.sh
29+
```
30+
31+
## Inference
32+
To reconstruct a speech file, run:
33+
```bash
34+
python ./bins/codec/inference.py --source <source_wav> --output_dir <output_dir> --checkpoint_path <checkpoint_path>
35+
```
36+
To use zero-shot voice conversion, run:
37+
```bash
38+
python ./bins/codec/inference.py --source <source_wav> --reference <reference_wav> --output_dir <output_dir> --checkpoint_path <checkpoint_path>
39+
```
40+
41+
## Feature extraction
42+
When running `./bins/codec/inference.py`, check the returned results of the `FAcodecInference` class: a tuple of `(quantized, codes)`
43+
- `quantized` is the quantized representation of the input speech file.
44+
- `quantized[0]` is the quantized representation of prosody
45+
- `quantized[1]` is the quantized representation of content
46+
47+
- `codes` is the discrete code representation of the input speech file.
48+
- `codes[0]` is the discrete code representation of prosody
49+
- `codes[1]` is the discrete code representation of content
50+
51+
For the most clean content representation without any timbre, we suggest to use `codes[1][:, 0, :]`, which is the first layer of content codebooks.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
{
2+
"exp_name": "facodec",
3+
"model_type": "FAcodec",
4+
5+
"log_dir": "./runs/",
6+
"log_interval": 10,
7+
"save_interval": 1000,
8+
"device": "cuda",
9+
"epochs": 1000,
10+
"batch_size": 4,
11+
"batch_length": 100,
12+
"max_len": 80,
13+
"pretrained_model": "",
14+
"load_only_params": false,
15+
"F0_path": "modules/JDC/bst.t7",
16+
"dataset": "/path/to/dataset",
17+
"preprocess_params": {
18+
"sr": 24000,
19+
"frame_rate": 80,
20+
"duration_range": [1.0, 25.0],
21+
"spect_params": {
22+
"n_fft": 2048,
23+
"win_length": 1200,
24+
"hop_length": 300,
25+
"n_mels": 80
26+
}
27+
},
28+
"train": {
29+
"gradient_accumulation_step": 1,
30+
"batch_size": 1,
31+
"save_checkpoint_stride": [
32+
20
33+
],
34+
"random_seed": 1234,
35+
"max_epoch": -1,
36+
"max_frame_len": 80,
37+
"tracker": [
38+
"tensorboard"
39+
],
40+
"run_eval": [
41+
false
42+
],
43+
"sampler": {
44+
"holistic_shuffle": true,
45+
"drop_last": true
46+
},
47+
"dataloader": {
48+
"num_worker": 0,
49+
"pin_memory": true
50+
}
51+
},
52+
"model_params": {
53+
"causal": true,
54+
"lstm": 2,
55+
"norm_f0": true,
56+
"use_gr_content_f0": false,
57+
"use_gr_prosody_phone": false,
58+
"use_gr_timbre_prosody": false,
59+
"separate_prosody_encoder": true,
60+
"n_c_codebooks": 2,
61+
"timbre_norm": true,
62+
"use_gr_content_global_f0": true,
63+
"DAC": {
64+
"encoder_dim": 64,
65+
"encoder_rates": [2, 5, 5, 6],
66+
"decoder_dim": 1536,
67+
"decoder_rates": [6, 5, 5, 2],
68+
"sr": 24000
69+
}
70+
},
71+
"loss_params": {
72+
"base_lr": 0.0001,
73+
"warmup_steps": 200,
74+
"discriminator_iter_start": 2000,
75+
"lambda_spk": 1.0,
76+
"lambda_mel": 45,
77+
"lambda_f0": 1.0,
78+
"lambda_uv": 1.0
79+
}
80+
}

0 commit comments

Comments
 (0)