Skip to content

Commit

Permalink
Add MaskGCT demo ipynb (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
babysor authored Oct 30, 2024
1 parent c37307d commit 2c42b47
Show file tree
Hide file tree
Showing 2 changed files with 314 additions and 1 deletion.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ processed_data
data
model_ckpt
logs
*.ipynb
*.lst
source_audio
result
Expand Down
314 changes: 314 additions & 0 deletions models/tts/maskgct/maskgct_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"import librosa\n",
"import safetensors\n",
"from utils.util import load_config\n",
"\n",
"from models.codec.kmeans.repcodec_model import RepCodec\n",
"from models.tts.maskgct.maskgct_s2a import MaskGCT_S2A\n",
"from models.tts.maskgct.maskgct_t2s import MaskGCT_T2S\n",
"from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder\n",
"from transformers import Wav2Vec2BertModel\n",
"\n",
"from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import SeamlessM4TFeatureExtractor\n",
"processor = SeamlessM4TFeatureExtractor.from_pretrained(\"facebook/w2v-bert-2.0\")"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"def g2p_(text, language):\n",
" if language in [\"zh\", \"en\"]:\n",
" return chn_eng_g2p(text)\n",
" else:\n",
" return g2p(text, sentence=None, language=language)\n",
"\n",
"def build_t2s_model(cfg, device):\n",
" t2s_model = MaskGCT_T2S(cfg=cfg)\n",
" t2s_model.eval()\n",
" t2s_model.to(device)\n",
" return t2s_model\n",
"\n",
"def build_s2a_model(cfg, device):\n",
" soundstorm_model = MaskGCT_S2A(cfg=cfg)\n",
" soundstorm_model.eval()\n",
" soundstorm_model.to(device)\n",
" return soundstorm_model\n",
"\n",
"def build_semantic_model(device):\n",
" semantic_model = Wav2Vec2BertModel.from_pretrained(\"facebook/w2v-bert-2.0\")\n",
" semantic_model.eval()\n",
" semantic_model.to(device)\n",
" stat_mean_var = torch.load(\"./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt\")\n",
" semantic_mean = stat_mean_var[\"mean\"]\n",
" semantic_std = torch.sqrt(stat_mean_var[\"var\"])\n",
" semantic_mean = semantic_mean.to(device)\n",
" semantic_std = semantic_std.to(device)\n",
" return semantic_model, semantic_mean, semantic_std\n",
"\n",
"def build_semantic_codec(cfg, device):\n",
" semantic_codec = RepCodec(cfg=cfg)\n",
" semantic_codec.eval()\n",
" semantic_codec.to(device)\n",
" return semantic_codec\n",
"\n",
"def build_acoustic_codec(cfg, device):\n",
" codec_encoder = CodecEncoder(cfg=cfg.encoder)\n",
" codec_decoder = CodecDecoder(cfg=cfg.decoder)\n",
" codec_encoder.eval()\n",
" codec_decoder.eval()\n",
" codec_encoder.to(device)\n",
" codec_decoder.to(device)\n",
" return codec_encoder, codec_decoder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def extract_features(speech, processor):\n",
" inputs = processor(speech, sampling_rate=16000, return_tensors=\"pt\")\n",
" input_features = inputs[\"input_features\"][0]\n",
" attention_mask = inputs[\"attention_mask\"][0]\n",
" return input_features, attention_mask\n",
"\n",
"@torch.no_grad()\n",
"def extract_semantic_code(semantic_mean, semantic_std, input_features, attention_mask):\n",
" vq_emb = semantic_model(\n",
" input_features=input_features,\n",
" attention_mask=attention_mask,\n",
" output_hidden_states=True,\n",
" )\n",
" feat = vq_emb.hidden_states[17] # (B, T, C)\n",
" feat = (feat - semantic_mean.to(feat)) / semantic_std.to(feat)\n",
"\n",
" semantic_code, rec_feat = semantic_codec.quantize(feat) # (B, T)\n",
" return semantic_code, rec_feat\n",
"\n",
"@torch.no_grad()\n",
"def extract_acoustic_code(speech):\n",
" vq_emb = codec_encoder(speech.unsqueeze(1))\n",
" _, vq, _, _, _ = codec_decoder.quantizer(vq_emb)\n",
" acoustic_code = vq.permute(\n",
" 1, 2, 0\n",
" )\n",
" return acoustic_code\n",
"\n",
"@torch.no_grad()\n",
"def text2semantic(prompt_speech, prompt_text, prompt_language, target_text, target_language, target_len=None, n_timesteps=50, cfg=2.5, rescale_cfg=0.75):\n",
" \n",
" prompt_phone_id = g2p_(prompt_text, prompt_language)[1]\n",
"\n",
" target_phone_id = g2p_(target_text, target_language)[1]\n",
"\n",
" if target_len is None:\n",
" target_len = int((len(prompt_speech) * len(target_phone_id) / len(prompt_phone_id)) / 16000 * 50)\n",
" else:\n",
" target_len = int(target_len * 50)\n",
"\n",
" prompt_phone_id = torch.tensor(prompt_phone_id, dtype=torch.long).to(device)\n",
" target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device)\n",
"\n",
" phone_id = torch.cat([prompt_phone_id, target_phone_id]) \n",
"\n",
" input_fetures, attention_mask = extract_features(prompt_speech, processor)\n",
" input_fetures = input_fetures.unsqueeze(0).to(device)\n",
" attention_mask = attention_mask.unsqueeze(0).to(device)\n",
" semantic_code, _ = extract_semantic_code(semantic_mean, semantic_std, input_fetures, attention_mask)\n",
"\n",
" predict_semantic = t2s_model.reverse_diffusion(semantic_code[:, :], target_len, phone_id.unsqueeze(0), n_timesteps=n_timesteps, cfg=cfg, rescale_cfg=rescale_cfg)\n",
"\n",
" print(\"predict semantic shape\", predict_semantic.shape)\n",
"\n",
" combine_semantic_code = torch.cat([semantic_code[:,:], predict_semantic], dim=-1)\n",
" prompt_semantic_code = semantic_code\n",
"\n",
" return combine_semantic_code, prompt_semantic_code\n",
"\n",
"@torch.no_grad()\n",
"def semantic2acoustic(combine_semantic_code, acoustic_code, n_timesteps=[25,10,1,1,1,1,1,1,1,1,1,1], cfg=2.5, rescale_cfg=0.75):\n",
"\n",
" semantic_code = combine_semantic_code\n",
" \n",
" cond = s2a_model_1layer.cond_emb(semantic_code)\n",
" prompt = acoustic_code[:,:,:]\n",
" predict_1layer = s2a_model_1layer.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=n_timesteps[:1], cfg=cfg, rescale_cfg=rescale_cfg)\n",
"\n",
" cond = s2a_model_full.cond_emb(semantic_code)\n",
" prompt = acoustic_code[:,:,:]\n",
" predict_full = s2a_model_full.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=n_timesteps, cfg=cfg, rescale_cfg=rescale_cfg, gt_code=predict_1layer)\n",
" \n",
" vq_emb = codec_decoder.vq2emb(predict_full.permute(2,0,1), n_quantizers=12)\n",
" recovered_audio = codec_decoder(vq_emb)\n",
" prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2,0,1), n_quantizers=12)\n",
" recovered_prompt_audio = codec_decoder(prompt_vq_emb)\n",
" recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy()\n",
" recovered_audio = recovered_audio[0][0].cpu().numpy()\n",
" combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])\n",
"\n",
" return combine_audio, recovered_audio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def maskgct_inference(prompt_speech_path, prompt_text, target_text, language=\"en\", target_language=\"en\", target_len=None, n_timesteps=25, cfg=2.5, rescale_cfg=0.75, n_timesteps_s2a=[25,10,1,1,1,1,1,1,1,1,1,1], cfg_s2a=2.5, rescale_cfg_s2a=0.75):\n",
" speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]\n",
" speech = librosa.load(prompt_speech_path, sr=24000)[0]\n",
"\n",
" combine_semantic_code, _ = text2semantic(speech_16k, prompt_text, language, target_text, target_language, target_len, n_timesteps, cfg, rescale_cfg)\n",
" acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device))\n",
" _, recovered_audio = semantic2acoustic(combine_semantic_code, acoustic_code, n_timesteps=n_timesteps_s2a, cfg=cfg_s2a, rescale_cfg=rescale_cfg_s2a)\n",
"\n",
" return recovered_audio"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Build Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\")\n",
"cfg_path = \"./models/tts/maskgct/config/maskgct.json\"\n",
"cfg = load_config(cfg_path)\n",
"\n",
"# 1. build semantic model (w2v-bert-2.0)\n",
"semantic_model, semantic_mean, semantic_std = build_semantic_model(device)\n",
"# 2. build semantic codec\n",
"semantic_codec = build_semantic_codec(cfg.model.semantic_codec, device)\n",
"# 3. build acoustic codec\n",
"codec_encoder, codec_decoder = build_acoustic_codec(cfg.model.acoustic_codec, device)\n",
"# 4. build t2s model\n",
"t2s_model = build_t2s_model(cfg.model.t2s_model, device)\n",
"# 5. build s2a model\n",
"s2a_model_1layer = build_s2a_model(cfg.model.s2a_model.s2a_1layer, device)\n",
"s2a_model_full = build_s2a_model(cfg.model.s2a_model.s2a_full, device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load Checkpoints"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import hf_hub_download\n",
"\n",
"# download semantic codec ckpt\n",
"semantic_code_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"semantic_codec/model.safetensors\")\n",
"# download acoustic codec ckpt\n",
"codec_encoder_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"acoustic_codec/model.safetensors\")\n",
"codec_decoder_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"acoustic_codec/model_1.safetensors\")\n",
"# download t2s model ckpt\n",
"t2s_model_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"t2s_model/model.safetensors\")\n",
"# download s2a model ckpt\n",
"s2a_1layer_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"s2a_model/s2a_model_1layer/model.safetensors\")\n",
"s2a_full_ckpt = hf_hub_download(\"amphion/MaskGCT\", filename=\"s2a_model/s2a_model_full/model.safetensors\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load semantic codec\n",
"safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)\n",
"# load acoustic codec\n",
"safetensors.torch.load_model(codec_encoder, codec_encoder_ckpt)\n",
"safetensors.torch.load_model(codec_decoder, codec_decoder_ckpt)\n",
"# load t2s model\n",
"safetensors.torch.load_model(t2s_model, t2s_model_ckpt)\n",
"# load s2a model\n",
"safetensors.torch.load_model(s2a_model_1layer, s2a_1layer_ckpt)\n",
"safetensors.torch.load_model(s2a_model_full, s2a_full_ckpt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt_wav_path = \"./models/tts/maskgct/wav/prompt.wav\"\n",
"prompt_text = \" We do not break. We never give in. We never back down.\"\n",
"target_text = \"In this paper, we introduce MaskGCT, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision.\"\n",
"target_len = 18 # Specify the target duration (in seconds). If target_len = None, we use a simple rule to predict the target duration.\n",
"recovered_audio = maskgct_inference(prompt_wav_path, prompt_text, target_text, \"en\", \"en\", target_len=target_len)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Audio\n",
"Audio(recovered_audio, rate=24000)"
]
}
],
"metadata": {
"fileId": "8353ad98-61bb-49ea-b655-c8f6a3264cc3",
"filePath": "/opt/tiger/SpeechGeneration2/models/tts/maskgct/maskgct_demo.ipynb",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 2c42b47

Please sign in to comment.