-
Notifications
You must be signed in to change notification settings - Fork 591
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
314 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,7 +56,6 @@ processed_data | |
data | ||
model_ckpt | ||
logs | ||
*.ipynb | ||
*.lst | ||
source_audio | ||
result | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,314 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import numpy as np\n", | ||
"import librosa\n", | ||
"import safetensors\n", | ||
"from utils.util import load_config\n", | ||
"\n", | ||
"from models.codec.kmeans.repcodec_model import RepCodec\n", | ||
"from models.tts.maskgct.maskgct_s2a import MaskGCT_S2A\n", | ||
"from models.tts.maskgct.maskgct_t2s import MaskGCT_T2S\n", | ||
"from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder\n", | ||
"from transformers import Wav2Vec2BertModel\n", | ||
"\n", | ||
"from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import SeamlessM4TFeatureExtractor\n", | ||
"processor = SeamlessM4TFeatureExtractor.from_pretrained(\"facebook/w2v-bert-2.0\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def g2p_(text, language):\n", | ||
" if language in [\"zh\", \"en\"]:\n", | ||
" return chn_eng_g2p(text)\n", | ||
" else:\n", | ||
" return g2p(text, sentence=None, language=language)\n", | ||
"\n", | ||
"def build_t2s_model(cfg, device):\n", | ||
" t2s_model = MaskGCT_T2S(cfg=cfg)\n", | ||
" t2s_model.eval()\n", | ||
" t2s_model.to(device)\n", | ||
" return t2s_model\n", | ||
"\n", | ||
"def build_s2a_model(cfg, device):\n", | ||
" soundstorm_model = MaskGCT_S2A(cfg=cfg)\n", | ||
" soundstorm_model.eval()\n", | ||
" soundstorm_model.to(device)\n", | ||
" return soundstorm_model\n", | ||
"\n", | ||
"def build_semantic_model(device):\n", | ||
" semantic_model = Wav2Vec2BertModel.from_pretrained(\"facebook/w2v-bert-2.0\")\n", | ||
" semantic_model.eval()\n", | ||
" semantic_model.to(device)\n", | ||
" stat_mean_var = torch.load(\"./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt\")\n", | ||
" semantic_mean = stat_mean_var[\"mean\"]\n", | ||
" semantic_std = torch.sqrt(stat_mean_var[\"var\"])\n", | ||
" semantic_mean = semantic_mean.to(device)\n", | ||
" semantic_std = semantic_std.to(device)\n", | ||
" return semantic_model, semantic_mean, semantic_std\n", | ||
"\n", | ||
"def build_semantic_codec(cfg, device):\n", | ||
" semantic_codec = RepCodec(cfg=cfg)\n", | ||
" semantic_codec.eval()\n", | ||
" semantic_codec.to(device)\n", | ||
" return semantic_codec\n", | ||
"\n", | ||
"def build_acoustic_codec(cfg, device):\n", | ||
" codec_encoder = CodecEncoder(cfg=cfg.encoder)\n", | ||
" codec_decoder = CodecDecoder(cfg=cfg.decoder)\n", | ||
" codec_encoder.eval()\n", | ||
" codec_decoder.eval()\n", | ||
" codec_encoder.to(device)\n", | ||
" codec_decoder.to(device)\n", | ||
" return codec_encoder, codec_decoder" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"@torch.no_grad()\n", | ||
"def extract_features(speech, processor):\n", | ||
" inputs = processor(speech, sampling_rate=16000, return_tensors=\"pt\")\n", | ||
" input_features = inputs[\"input_features\"][0]\n", | ||
" attention_mask = inputs[\"attention_mask\"][0]\n", | ||
" return input_features, attention_mask\n", | ||
"\n", | ||
"@torch.no_grad()\n", | ||
"def extract_semantic_code(semantic_mean, semantic_std, input_features, attention_mask):\n", | ||
" vq_emb = semantic_model(\n", | ||
" input_features=input_features,\n", | ||
" attention_mask=attention_mask,\n", | ||
" output_hidden_states=True,\n", | ||
" )\n", | ||
" feat = vq_emb.hidden_states[17] # (B, T, C)\n", | ||
" feat = (feat - semantic_mean.to(feat)) / semantic_std.to(feat)\n", | ||
"\n", | ||
" semantic_code, rec_feat = semantic_codec.quantize(feat) # (B, T)\n", | ||
" return semantic_code, rec_feat\n", | ||
"\n", | ||
"@torch.no_grad()\n", | ||
"def extract_acoustic_code(speech):\n", | ||
" vq_emb = codec_encoder(speech.unsqueeze(1))\n", | ||
" _, vq, _, _, _ = codec_decoder.quantizer(vq_emb)\n", | ||
" acoustic_code = vq.permute(\n", | ||
" 1, 2, 0\n", | ||
" )\n", | ||
" return acoustic_code\n", | ||
"\n", | ||
"@torch.no_grad()\n", | ||
"def text2semantic(prompt_speech, prompt_text, prompt_language, target_text, target_language, target_len=None, n_timesteps=50, cfg=2.5, rescale_cfg=0.75):\n", | ||
" \n", | ||
" prompt_phone_id = g2p_(prompt_text, prompt_language)[1]\n", | ||
"\n", | ||
" target_phone_id = g2p_(target_text, target_language)[1]\n", | ||
"\n", | ||
" if target_len is None:\n", | ||
" target_len = int((len(prompt_speech) * len(target_phone_id) / len(prompt_phone_id)) / 16000 * 50)\n", | ||
" else:\n", | ||
" target_len = int(target_len * 50)\n", | ||
"\n", | ||
" prompt_phone_id = torch.tensor(prompt_phone_id, dtype=torch.long).to(device)\n", | ||
" target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device)\n", | ||
"\n", | ||
" phone_id = torch.cat([prompt_phone_id, target_phone_id]) \n", | ||
"\n", | ||
" input_fetures, attention_mask = extract_features(prompt_speech, processor)\n", | ||
" input_fetures = input_fetures.unsqueeze(0).to(device)\n", | ||
" attention_mask = attention_mask.unsqueeze(0).to(device)\n", | ||
" semantic_code, _ = extract_semantic_code(semantic_mean, semantic_std, input_fetures, attention_mask)\n", | ||
"\n", | ||
" predict_semantic = t2s_model.reverse_diffusion(semantic_code[:, :], target_len, phone_id.unsqueeze(0), n_timesteps=n_timesteps, cfg=cfg, rescale_cfg=rescale_cfg)\n", | ||
"\n", | ||
" print(\"predict semantic shape\", predict_semantic.shape)\n", | ||
"\n", | ||
" combine_semantic_code = torch.cat([semantic_code[:,:], predict_semantic], dim=-1)\n", | ||
" prompt_semantic_code = semantic_code\n", | ||
"\n", | ||
" return combine_semantic_code, prompt_semantic_code\n", | ||
"\n", | ||
"@torch.no_grad()\n", | ||
"def semantic2acoustic(combine_semantic_code, acoustic_code, n_timesteps=[25,10,1,1,1,1,1,1,1,1,1,1], cfg=2.5, rescale_cfg=0.75):\n", | ||
"\n", | ||
" semantic_code = combine_semantic_code\n", | ||
" \n", | ||
" cond = s2a_model_1layer.cond_emb(semantic_code)\n", | ||
" prompt = acoustic_code[:,:,:]\n", | ||
" predict_1layer = s2a_model_1layer.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=n_timesteps[:1], cfg=cfg, rescale_cfg=rescale_cfg)\n", | ||
"\n", | ||
" cond = s2a_model_full.cond_emb(semantic_code)\n", | ||
" prompt = acoustic_code[:,:,:]\n", | ||
" predict_full = s2a_model_full.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=n_timesteps, cfg=cfg, rescale_cfg=rescale_cfg, gt_code=predict_1layer)\n", | ||
" \n", | ||
" vq_emb = codec_decoder.vq2emb(predict_full.permute(2,0,1), n_quantizers=12)\n", | ||
" recovered_audio = codec_decoder(vq_emb)\n", | ||
" prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2,0,1), n_quantizers=12)\n", | ||
" recovered_prompt_audio = codec_decoder(prompt_vq_emb)\n", | ||
" recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy()\n", | ||
" recovered_audio = recovered_audio[0][0].cpu().numpy()\n", | ||
" combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])\n", | ||
"\n", | ||
" return combine_audio, recovered_audio" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def maskgct_inference(prompt_speech_path, prompt_text, target_text, language=\"en\", target_language=\"en\", target_len=None, n_timesteps=25, cfg=2.5, rescale_cfg=0.75, n_timesteps_s2a=[25,10,1,1,1,1,1,1,1,1,1,1], cfg_s2a=2.5, rescale_cfg_s2a=0.75):\n", | ||
" speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]\n", | ||
" speech = librosa.load(prompt_speech_path, sr=24000)[0]\n", | ||
"\n", | ||
" combine_semantic_code, _ = text2semantic(speech_16k, prompt_text, language, target_text, target_language, target_len, n_timesteps, cfg, rescale_cfg)\n", | ||
" acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device))\n", | ||
" _, recovered_audio = semantic2acoustic(combine_semantic_code, acoustic_code, n_timesteps=n_timesteps_s2a, cfg=cfg_s2a, rescale_cfg=rescale_cfg_s2a)\n", | ||
"\n", | ||
" return recovered_audio" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Build Model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"device = torch.device(\"cuda:0\")\n", | ||
"cfg_path = \"./models/tts/maskgct/config/maskgct.json\"\n", | ||
"cfg = load_config(cfg_path)\n", | ||
"\n", | ||
"# 1. build semantic model (w2v-bert-2.0)\n", | ||
"semantic_model, semantic_mean, semantic_std = build_semantic_model(device)\n", | ||
"# 2. build semantic codec\n", | ||
"semantic_codec = build_semantic_codec(cfg.model.semantic_codec, device)\n", | ||
"# 3. build acoustic codec\n", | ||
"codec_encoder, codec_decoder = build_acoustic_codec(cfg.model.acoustic_codec, device)\n", | ||
"# 4. build t2s model\n", | ||
"t2s_model = build_t2s_model(cfg.model.t2s_model, device)\n", | ||
"# 5. build s2a model\n", | ||
"s2a_model_1layer = build_s2a_model(cfg.model.s2a_model.s2a_1layer, device)\n", | ||
"s2a_model_full = build_s2a_model(cfg.model.s2a_model.s2a_full, device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Load Checkpoints" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from huggingface_hub import hf_hub_download\n", | ||
"\n", | ||
"# download semantic codec ckpt\n", | ||
"semantic_code_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"semantic_codec/model.safetensors\")\n", | ||
"# download acoustic codec ckpt\n", | ||
"codec_encoder_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"acoustic_codec/model.safetensors\")\n", | ||
"codec_decoder_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"acoustic_codec/model_1.safetensors\")\n", | ||
"# download t2s model ckpt\n", | ||
"t2s_model_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"t2s_model/model.safetensors\")\n", | ||
"# download s2a model ckpt\n", | ||
"s2a_1layer_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"s2a_model/s2a_model_1layer/model.safetensors\")\n", | ||
"s2a_full_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"s2a_model/s2a_model_full/model.safetensors\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# load semantic codec\n", | ||
"safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)\n", | ||
"# load acoustic codec\n", | ||
"safetensors.torch.load_model(codec_encoder, codec_encoder_ckpt)\n", | ||
"safetensors.torch.load_model(codec_decoder, codec_decoder_ckpt)\n", | ||
"# load t2s model\n", | ||
"safetensors.torch.load_model(t2s_model, t2s_model_ckpt)\n", | ||
"# load s2a model\n", | ||
"safetensors.torch.load_model(s2a_model_1layer, s2a_1layer_ckpt)\n", | ||
"safetensors.torch.load_model(s2a_model_full, s2a_full_ckpt)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"prompt_wav_path = \"./models/tts/maskgct/wav/prompt.wav\"\n", | ||
"prompt_text = \" We do not break. We never give in. We never back down.\"\n", | ||
"target_text = \"In this paper, we introduce MaskGCT, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision.\"\n", | ||
"target_len = 18 # Specify the target duration (in seconds). If target_len = None, we use a simple rule to predict the target duration.\n", | ||
"recovered_audio = maskgct_inference(prompt_wav_path, prompt_text, target_text, \"en\", \"en\", target_len=target_len)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from IPython.display import Audio\n", | ||
"Audio(recovered_audio, rate=24000)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"fileId": "8353ad98-61bb-49ea-b655-c8f6a3264cc3", | ||
"filePath": "/opt/tiger/SpeechGeneration2/models/tts/maskgct/maskgct_demo.ipynb", | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |