From 760a65b1f868f31584150f429de640ac44a8967d Mon Sep 17 00:00:00 2001 From: hehaorui Date: Fri, 25 Oct 2024 23:38:38 +0800 Subject: [PATCH 1/8] Add Debatts Some Code --- ...c_repcodec_8192_1q_large_101k_fix_new.json | 117 ++ .../87_SPEAKER01_2_part03.json | 9 + models/tts/debatts/t2s_model_new.py | 478 ++++++ models/tts/debatts/t2s_sft_dataset_new.py | 273 ++++ .../debatts/try_inference_small_samples.py | 466 ++++++ models/tts/debatts/utils/__init__.py | 0 models/tts/debatts/utils/audio.py | 74 + models/tts/debatts/utils/audio_slicer.py | 476 ++++++ models/tts/debatts/utils/cut_by_vad.py | 105 ++ models/tts/debatts/utils/data_utils.py | 588 +++++++ models/tts/debatts/utils/distribution.py | 270 ++++ models/tts/debatts/utils/dsp.py | 97 ++ models/tts/debatts/utils/duration.py | 86 ++ models/tts/debatts/utils/f0.py | 275 ++++ models/tts/debatts/utils/g2p/__init__.py | 139 ++ models/tts/debatts/utils/g2p/bpe_317.json | 639 ++++++++ models/tts/debatts/utils/g2p/bpe_553.json | 1221 +++++++++++++++ models/tts/debatts/utils/g2p/bpe_613.json | 1366 +++++++++++++++++ models/tts/debatts/utils/g2p/cleaners.py | 62 + models/tts/debatts/utils/g2p/english.py | 139 ++ models/tts/debatts/utils/g2p/french.py | 178 +++ models/tts/debatts/utils/g2p/german.py | 122 ++ models/tts/debatts/utils/g2p/japanese.py | 128 ++ models/tts/debatts/utils/g2p/korean.py | 167 ++ models/tts/debatts/utils/g2p/mandarin.py | 270 ++++ .../tts/debatts/utils/g2p_liwei/__init__.py | 66 + .../tts/debatts/utils/g2p_liwei/cleaners.py | 25 + models/tts/debatts/utils/g2p_liwei/english.py | 166 ++ models/tts/debatts/utils/g2p_liwei/french.py | 164 ++ .../tts/debatts/utils/g2p_liwei/g2p_liwei.py | 7 + models/tts/debatts/utils/g2p_liwei/german.py | 108 ++ .../tts/debatts/utils/g2p_liwei/japanese.py | 154 ++ models/tts/debatts/utils/g2p_liwei/korean.py | 149 ++ .../tts/debatts/utils/g2p_liwei/mandarin.py | 191 +++ .../utils/g2p_liwei/text_tokenizers.py | 80 + models/tts/debatts/utils/g2p_liwei/vacab.json | 372 +++++ models/tts/debatts/utils/hparam.py | 659 ++++++++ models/tts/debatts/utils/hubert.py | 155 ++ models/tts/debatts/utils/io.py | 182 +++ models/tts/debatts/utils/io_optim.py | 123 ++ models/tts/debatts/utils/logger.py | 43 + models/tts/debatts/utils/mel.py | 280 ++++ models/tts/debatts/utils/mert.py | 139 ++ models/tts/debatts/utils/mfa_prepare.py | 116 ++ models/tts/debatts/utils/model_summary.py | 74 + models/tts/debatts/utils/prompt_preparer.py | 68 + models/tts/debatts/utils/ssim.py | 80 + models/tts/debatts/utils/stft.py | 278 ++++ models/tts/debatts/utils/symbol_table.py | 317 ++++ models/tts/debatts/utils/tokenizer.py | 150 ++ models/tts/debatts/utils/tool.py | 84 + models/tts/debatts/utils/topk_sampling.py | 72 + models/tts/debatts/utils/trainer_utils.py | 16 + models/tts/debatts/utils/util.py | 687 +++++++++ .../debatts/utils/whisper_transcription.py | 122 ++ models/tts/debatts/utils/world.py | 92 ++ 56 files changed, 12964 insertions(+) create mode 100644 models/tts/debatts/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json create mode 100644 models/tts/debatts/speech_examples/87_SPEAKER01_2_part03.json create mode 100644 models/tts/debatts/t2s_model_new.py create mode 100644 models/tts/debatts/t2s_sft_dataset_new.py create mode 100644 models/tts/debatts/try_inference_small_samples.py create mode 100644 models/tts/debatts/utils/__init__.py create mode 100644 models/tts/debatts/utils/audio.py create mode 100644 models/tts/debatts/utils/audio_slicer.py create mode 100644 models/tts/debatts/utils/cut_by_vad.py create mode 100644 models/tts/debatts/utils/data_utils.py create mode 100644 models/tts/debatts/utils/distribution.py create mode 100644 models/tts/debatts/utils/dsp.py create mode 100644 models/tts/debatts/utils/duration.py create mode 100644 models/tts/debatts/utils/f0.py create mode 100644 models/tts/debatts/utils/g2p/__init__.py create mode 100644 models/tts/debatts/utils/g2p/bpe_317.json create mode 100644 models/tts/debatts/utils/g2p/bpe_553.json create mode 100644 models/tts/debatts/utils/g2p/bpe_613.json create mode 100644 models/tts/debatts/utils/g2p/cleaners.py create mode 100644 models/tts/debatts/utils/g2p/english.py create mode 100644 models/tts/debatts/utils/g2p/french.py create mode 100644 models/tts/debatts/utils/g2p/german.py create mode 100644 models/tts/debatts/utils/g2p/japanese.py create mode 100644 models/tts/debatts/utils/g2p/korean.py create mode 100644 models/tts/debatts/utils/g2p/mandarin.py create mode 100644 models/tts/debatts/utils/g2p_liwei/__init__.py create mode 100644 models/tts/debatts/utils/g2p_liwei/cleaners.py create mode 100644 models/tts/debatts/utils/g2p_liwei/english.py create mode 100644 models/tts/debatts/utils/g2p_liwei/french.py create mode 100644 models/tts/debatts/utils/g2p_liwei/g2p_liwei.py create mode 100644 models/tts/debatts/utils/g2p_liwei/german.py create mode 100644 models/tts/debatts/utils/g2p_liwei/japanese.py create mode 100644 models/tts/debatts/utils/g2p_liwei/korean.py create mode 100644 models/tts/debatts/utils/g2p_liwei/mandarin.py create mode 100644 models/tts/debatts/utils/g2p_liwei/text_tokenizers.py create mode 100644 models/tts/debatts/utils/g2p_liwei/vacab.json create mode 100644 models/tts/debatts/utils/hparam.py create mode 100644 models/tts/debatts/utils/hubert.py create mode 100644 models/tts/debatts/utils/io.py create mode 100644 models/tts/debatts/utils/io_optim.py create mode 100644 models/tts/debatts/utils/logger.py create mode 100644 models/tts/debatts/utils/mel.py create mode 100644 models/tts/debatts/utils/mert.py create mode 100644 models/tts/debatts/utils/mfa_prepare.py create mode 100644 models/tts/debatts/utils/model_summary.py create mode 100644 models/tts/debatts/utils/prompt_preparer.py create mode 100644 models/tts/debatts/utils/ssim.py create mode 100644 models/tts/debatts/utils/stft.py create mode 100644 models/tts/debatts/utils/symbol_table.py create mode 100644 models/tts/debatts/utils/tokenizer.py create mode 100644 models/tts/debatts/utils/tool.py create mode 100644 models/tts/debatts/utils/topk_sampling.py create mode 100644 models/tts/debatts/utils/trainer_utils.py create mode 100644 models/tts/debatts/utils/util.py create mode 100644 models/tts/debatts/utils/whisper_transcription.py create mode 100644 models/tts/debatts/utils/world.py diff --git a/models/tts/debatts/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json b/models/tts/debatts/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json new file mode 100644 index 00000000..a193cdb0 --- /dev/null +++ b/models/tts/debatts/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json @@ -0,0 +1,117 @@ +{ + "model_type": "T2S", + "dataset": ["not"], + "preprocess": { + "hop_size": 320, + "sample_rate": 16000, + "processed_dir": "TODO", + "valid_file": "valid.json", + "train_file": "train.json", + "use_phone_cond": false, + "min_dur": 3, + "max_dur": 40, + "use_emilia_101k": true + }, + "model": { + "t2sllama": { + "phone_vocab_size": 1024, + "target_vocab_size": 8192, + "hidden_size": 2048, + "intermediate_size": 8192, + "num_hidden_layer": 8, + "num_attention_head": 16, + "pad_token_id": 9216, + "bos_target_id": 9217, + "eos_target_id": 9218, + "bos_phone_id": 9219, + "eos_phone_id": 9220, + "bos_prompt0_id": 9221, + "eos_prompt0_id": 9222, + "use_lang_emb": false + }, + "kmeans": { + "type": "repcodec", + "stat_mean_var_path":"./ckpt/emilia_wav2vec2bert_stats_10k.pt", + "repcodec": { + "codebook_size": 8192, + "hidden_size": 1024, + "codebook_dim": 8, + "vocos_dim": 384, + "vocos_intermediate_dim": 2048, + "vocos_num_layers": 12 + }, + "pretrained_path":"./ckpt/repcodec/emilia_50k_8192_norm_8d_86k_steps_model.safetensors" + }, + "codec": { + "encoder": { + "d_model": 96, + "up_ratios": [4, 4, 4, 5], + "out_channels": 256, + "use_tanh": false, + "pretrained_path":"./ckpt/codec_16K_320_8/pytorch_model.bin" + }, + "decoder": { + "in_channel": 256, + "upsample_initial_channel": 1536, + "up_ratios": [5, 4, 4, 4], + "num_quantizers": 8, + "codebook_size": 1024, + "codebook_dim": 8, + "quantizer_type": "fvq", + "quantizer_dropout": 0.5, + "commitment": 0.25, + "codebook_loss_weight": 1.0, + "use_l2_normlize": true, + "codebook_type": "euclidean", + "kmeans_init": false, + "kmeans_iters": 10, + "decay": 0.8, + "eps": 0.5, + "threshold_ema_dead_code": 2, + "weight_init": false, + "use_vocos": true, + "vocos_dim": 512, + "vocos_intermediate_dim": 4096, + "vocos_num_layers": 24, + "n_fft": 1280, + "hop_size": 320, + "padding": "same", + "pretrained_path": "./ckpt/codec_16K_320_8/pytorch_model_1.bin" + } + } + }, + "log_dir": "TODO", + "train": { + "max_epoch": 0, + "use_dynamic_batchsize": true, + "max_tokens": 3000, + "max_sentences": 20, + "lr_warmup_steps": 3200, + "lr_scheduler": "inverse_sqrt", + "num_train_steps": 800, + "adam": { + "lr": 1e-5 + }, + "ddp": false, + "random_seed": 114, + "batch_size": 1, + "epochs": 500, + "max_steps": 10000, + "total_training_steps": 8000, + "save_summary_steps": 500, + "save_checkpoints_steps": 300, + "valid_interval": 2000, + "keep_checkpoint_max": 100, + "gradient_accumulation_step": 10, + "tracker": ["tensorboard"], + "save_checkpoint_stride": [1], + "keep_last": [15], + "run_eval": [true], + "dataloader": { + "num_worker": 4, + "pin_memory": true + }, + "use_emilia_dataset": false + } +} + diff --git a/models/tts/debatts/speech_examples/87_SPEAKER01_2_part03.json b/models/tts/debatts/speech_examples/87_SPEAKER01_2_part03.json new file mode 100644 index 00000000..b6b0db11 --- /dev/null +++ b/models/tts/debatts/speech_examples/87_SPEAKER01_2_part03.json @@ -0,0 +1,9 @@ +{ + "key": "87_SPEAKER01_2_part03", + "text": "你方要论证的是, 他在技术上完全突破了壁垒, 在现实里真的可以落地, 而不是举出一些跟越宇宙可能没有关, 没有关系, 实际关系不大的东西告诉我, 这叫做越宇宙。我方告诉你, 不只是算力和芯片的问题, 还有包括VR的问题, 硬件上算力有问题, 题片有问题, 软软。", + "duration": 15.34, + "language": "zh", + "wav_path": "./87_SPEAKER01_2_part03.wav", + "chenci_prompt_wav_path": "./87_SPEAKER01_2_part03_213_chenci_prompt_6s.wav", + "prompt0_wav_path": "./87_SPEAKER00_1_part01.wav" +} \ No newline at end of file diff --git a/models/tts/debatts/t2s_model_new.py b/models/tts/debatts/t2s_model_new.py new file mode 100644 index 00000000..ab20e468 --- /dev/null +++ b/models/tts/debatts/t2s_model_new.py @@ -0,0 +1,478 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from multiprocessing.sharedctypes import Value +from re import T +from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel +import torch +import torch.nn.functional as F +import numpy as np +import os +import json +import torch.nn as nn +import tqdm +from einops import rearrange +os.chdir('./models/tts/debatts') +import sys +sys.path.append('./models/tts/debatts') +from utils.topk_sampling import top_k_top_p_filtering +import pickle + + +class T2SLlama_new(nn.Module): + def __init__( + self, + phone_vocab_size=1024, + target_vocab_size=2048, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=12, + num_attention_heads=16, + pad_token_id=3072, + bos_target_id=3073, + eos_target_id=3074, + bos_phone_id=3075, + eos_phone_id=3076, + bos_prompt0_id=3077, + eos_prompt0_id=3078, + use_lang_emb=False, + cfg=None, + ): + super().__init__() + + phone_vocab_size = ( + cfg.phone_vocab_size + if cfg is not None and hasattr(cfg, "phone_vocab_size") + else phone_vocab_size + ) + target_vocab_size = ( + cfg.target_vocab_size + if cfg is not None and hasattr(cfg, "target_vocab_size") + else target_vocab_size + ) + hidden_size = ( + cfg.hidden_size + if cfg is not None and hasattr(cfg, "hidden_size") + else hidden_size + ) + intermediate_size = ( + cfg.intermediate_size + if cfg is not None and hasattr(cfg, "intermediate_size") + else intermediate_size + ) + num_hidden_layers = ( + cfg.num_hidden_layers + if cfg is not None and hasattr(cfg, "num_hidden_layers") + else num_hidden_layers + ) + num_attention_heads = ( + cfg.num_attention_heads + if cfg is not None and hasattr(cfg, "num_attention_heads") + else num_attention_heads + ) + pad_token_id = ( + cfg.pad_token_id + if cfg is not None and hasattr(cfg, "pad_token_id") + else pad_token_id + ) + bos_target_id = ( + cfg.bos_target_id + if cfg is not None and hasattr(cfg, "bos_target_id") + else bos_target_id + ) + eos_target_id = ( + cfg.eos_target_id + if cfg is not None and hasattr(cfg, "eos_target_id") + else eos_target_id + ) + bos_phone_id = ( + cfg.bos_phone_id + if cfg is not None and hasattr(cfg, "bos_phone_id") + else bos_phone_id + ) + eos_phone_id = ( + cfg.eos_phone_id + if cfg is not None and hasattr(cfg, "eos_phone_id") + else eos_phone_id + ) + use_lang_emb = ( + cfg.use_lang_emb + if cfg is not None and hasattr(cfg, "use_lang_emb") + else use_lang_emb + ) + bos_prompt0_id = ( + cfg.bos_prompt0_id + if cfg is not None and hasattr(cfg, "bos_prompt0_id") + else bos_prompt0_id + ) + eos_prompt0_id = ( + cfg.eos_prompt0_id + if cfg is not None and hasattr(cfg, "eos_prompt0_id") + else eos_prompt0_id + ) + + + self.config = LlamaConfig( + vocab_size=phone_vocab_size + target_vocab_size + 20, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + pad_token_id=pad_token_id, + bos_token_id=bos_target_id, + eos_token_id=eos_target_id, + bos_prompt0_id=bos_prompt0_id, + eos_prompt0_id=eos_prompt0_id + ) + self.phone_vocab_size = phone_vocab_size + self.target_vocab_size = target_vocab_size + self.hidden_size = hidden_size + self.pad_token_id = pad_token_id + self.bos_target_id = bos_target_id + self.eos_target_id = eos_target_id + self.bos_phone_id = bos_phone_id + self.eos_phone_id = eos_phone_id + self.use_lang_emb = use_lang_emb + self.bos_prompt0_id = bos_prompt0_id + self.eos_prompt0_id = eos_prompt0_id + + self.model = LlamaForCausalLM(self.config) + + if self.use_lang_emb: + self.lang_emb = nn.Embedding(25, hidden_size, padding_idx=0) + torch.nn.init.normal_(self.lang_emb.weight, mean=0.0, std=0.02) + + def forward( + self, prompt0_ids, prompt0_mask, phone_ids, phone_mask, target_ids, target_mask, lang_id=None, + ): + prompt0_ids, prompt0_mask, prompt0_label, prompt0_lang_mask = self.add_phone_eos_bos_label( + prompt0_ids, + prompt0_mask, + self.eos_prompt0_id, + self.bos_prompt0_id, + self.pad_token_id, + label="prompt0_id" + ) + phone_ids, phone_mask, phone_label, lang_mask = self.add_phone_eos_bos_label( + phone_ids, + phone_mask, + self.eos_phone_id, + self.bos_phone_id, + self.pad_token_id, + label="phone_id" + ) + target_ids, target_mask, target_label = self.add_target_eos_bos_label( + target_ids, + target_mask, + self.eos_target_id, + self.bos_target_id, + self.pad_token_id, + ) + + input_token_ids = torch.cat([prompt0_ids, phone_ids, target_ids], dim=-1) + attention_mask = torch.cat([prompt0_mask, phone_mask, target_mask], dim=-1) + + labels = torch.cat([prompt0_label, phone_label, target_label], dim=-1) + + # lang_id: (B,); lang_mask: (B, T) + if self.use_lang_emb: + lang_embedding = self.lang_emb(lang_id).unsqueeze(1) # (B, d) -> (B, 1, d) + lang_embedding = lang_embedding * torch.cat([prompt0_lang_mask, lang_mask, torch.zeros_like(target_mask)], dim=-1).unsqueeze(-1) # (B, T, d) + input_token_embedding = self.model.model.embed_tokens(input_token_ids) # (B, T, d) + inputs_embeds = input_token_embedding + lang_embedding + + out = self.model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + labels=labels, + return_dict=True, + ) + + else: + out = self.model( + input_token_ids, + attention_mask=attention_mask, + labels=labels, + return_dict=True, + ) + + return out + + def add_phone_eos_bos_label( + self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id, label + ): + # phone_ids: [B, T] + # phone_mask: [B, T] + + # add 0 in the left + lang_mask = F.pad(phone_mask, (1, 0), value=0) + # add 0 in the right + lang_mask = F.pad(lang_mask, (0, 1), value=0) + + if label == "phone_id": + phone_ids = phone_ids + self.target_vocab_size * phone_mask + + phone_ids = phone_ids * phone_mask + """Step-by-Step Computation: + + Pad phone_ids: + + After padding: [[101, 102, 103, 0]] + Invert and Pad phone_mask: + + Inverted mask: [[0, 0, 0]] + Padded inverted mask: [[0, 0, 0, 1]] + Calculate EOS Insertion: + + Multiply with phone_eos_id: [[0, 0, 0, 200]] + Combine: + + Combined result: [[101, 102, 103, 200]] + """ + phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad( + 1 - phone_mask, (0, 1), value=1 + ) # make pad token eos token, add eos token at the end + phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask + phone_ids = phone_ids * phone_mask + pad_token_id * (1 - phone_mask) # restore pad token ids + phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token + phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask + phone_label = -100 * torch.ones_like(phone_ids) # loss for entire phone is not computed (passed to llama) + return phone_ids, phone_mask, phone_label, lang_mask + + def add_target_eos_bos_label( + self, target_ids, target_mask, target_eos_id, target_bos_id, pad_token_id + ): + # target_ids: [B, T] + # target_mask: [B, T] + target_ids = target_ids * target_mask + target_ids = F.pad(target_ids, (0, 1), value=0) + target_eos_id * F.pad( + 1 - target_mask, (0, 1), value=1 + ) + target_mask = F.pad(target_mask, (1, 0), value=1) + target_ids = target_ids * target_mask + pad_token_id * (1 - target_mask) + target_ids = F.pad(target_ids, (1, 0), value=target_bos_id) + target_mask = F.pad(target_mask, (1, 0), value=1) + target_label = target_ids * target_mask + (-100) * (1 - target_mask) # loss for target is computed on unmasked tokens + return target_ids, target_mask, target_label + + def add_phone_middle_label( + self, prompt0_ids, prompt0_mask, eos_prompt0_id, pad_token_id + ): + # prompt0_ids: [B, T] + # prompt0_mask: [B, T] + + prompt0_ids = prompt0_ids * prompt0_mask + prompt0_ids = F.pad(prompt0_ids, (0, 1), value=0) + eos_prompt0_id * F.pad( + 1 - prompt0_mask, (0, 1), value=1 + ) # Add eos_prompt0_id at the positions transitioning to padding + prompt0_mask = F.pad(prompt0_mask, (1, 0), value=1) # Pad the mask for the new eos_prompt0_id + prompt0_ids = prompt0_ids * prompt0_mask + pad_token_id * (1 - prompt0_mask) # Restore pad tokens + prompt0_ids = F.pad(prompt0_ids, (1, 0), value=eos_prompt0_id) # Add eos_prompt0_id at the beginning + prompt0_mask = F.pad(prompt0_mask, (1, 0), value=1) # Adjust the mask for the added eos_prompt0_id + prompt0_label = prompt0_ids * prompt0_mask + (-100) * (1 - prompt0_mask) # Set up labels for loss computation + + return prompt0_ids, prompt0_mask, prompt0_label + + + @torch.no_grad() + def sample_hf( + self, + phone_ids, # the phones of prompt and target should be concatenated together。在实际使用中,phone_ids是文本的token输入 + prompt_ids, + prompt0_ids=None, + max_length=100000, + temperature=0.3, + top_k=30, + top_p=0.7, + repeat_penalty=3.5, + lang_ids=None + ): + if prompt0_ids is not None: + phone_mask = torch.ones_like(phone_ids) + prompt_mask = torch.ones_like(prompt_ids) + + prompt_mask_prompt0 = torch.ones_like(prompt0_ids) + + # downsample = DownsampleWithMask(downsample_factor=2) + # prompt0_ids, prompt_mask_prompt0 = downsample(prompt0_ids, prompt_mask_prompt0) + + phone_ids, _, _, _ = self.add_phone_eos_bos_label( + phone_ids, + phone_mask, + self.eos_phone_id, + self.bos_phone_id, + self.pad_token_id, + label="phone_id" + ) + prompt_ids, _, _ = self.add_target_eos_bos_label( + prompt_ids, + prompt_mask, + self.eos_target_id, + self.bos_target_id, + self.pad_token_id, + ) + prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode + + prompt0_ids, _, _ = self.add_target_eos_bos_label( + prompt0_ids, + prompt_mask_prompt0, + self.eos_prompt0_id, + self.bos_prompt0_id, + self.pad_token_id, + ) + + input_token_ids = torch.cat([prompt0_ids, phone_ids, prompt_ids], dim=-1) + input_length = input_token_ids.shape[1] + + if lang_ids != None and self.use_lang_emb: + lang_ids = F.pad(F.pad(lang_ids, (1, 0), value=0), (0, 1), value=0) + + input_token_embedding = self.model.model.embed_tokens(input_token_ids) # (B, T, d) + # lang_ids: [1,1,1,1,1,1,2,2,2,2] which means ['en','en','en','en','en','en','zh','zh','zh','zh'] + lang_mask = torch.ones_like(phone_ids) + lang_mask[:,0] = 0 + lang_mask[:,-1] = 0 + lang_embedding = torch.cat([self.lang_emb(lang_ids), self.lang_emb(lang_ids), torch.zeros(lang_ids.shape[0], input_token_ids.shape[1] - lang_ids.shape[1], self.hidden_size).to(input_token_ids.device)], dim=1) * torch.cat([lang_mask, torch.zeros_like(prompt_ids)], dim=-1).unsqueeze(-1) + + inputs_embeds = input_token_embedding + lang_embedding + + # if prosody_features is not None: + # + # prosody_features = prosody_features.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1) + # inputs_embeds = inputs_embeds + prosody_features + + generated_ids = self.model.generate( + # input wav phone token ids + text token ids + inputs_embeds=inputs_embeds, + do_sample=True, + max_length=max_length, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_target_id, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repeat_penalty, + min_new_tokens=50, + ) + gen_tokens = generated_ids[:,:-1] + else: + + input_token_embedding = self.model.model.embed_tokens(input_token_ids) + + generated_ids = self.model.generate( + input_token_ids, + do_sample=True, + max_length=max_length, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_target_id, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repeat_penalty, + min_new_tokens=50, + ) + gen_tokens = generated_ids[:, input_length:-1] + + return gen_tokens + + else: + phone_mask = torch.ones_like(phone_ids) + prompt_mask = torch.ones_like(prompt_ids) + phone_ids, _, _, _ = self.add_phone_eos_bos_label( + phone_ids, + phone_mask, + self.eos_phone_id, + self.bos_phone_id, + self.pad_token_id, + label="phone_ids" + ) + prompt_ids, _, _ = self.add_target_eos_bos_label( + prompt_ids, + prompt_mask, + self.eos_target_id, + self.bos_target_id, + self.pad_token_id, + ) + prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode + + input_token_ids = torch.cat([phone_ids, prompt_ids], dim=-1) + input_length = input_token_ids.shape[1] + + if lang_ids != None and self.use_lang_emb: + lang_ids = F.pad(F.pad(lang_ids, (1, 0), value=0), (0, 1), value=0) + # token to vector + input_token_embedding = self.model.model.embed_tokens(input_token_ids) # (B, T, d) + # lang_ids: [1,1,1,1,1,1,2,2,2,2] which means ['en','en','en','en','en','en','zh','zh','zh','zh'] + lang_mask = torch.ones_like(phone_ids) + lang_mask[:,0] = 0 + lang_mask[:,-1] = 0 + lang_embedding = torch.cat([self.lang_emb(lang_ids), torch.zeros(lang_ids.shape[0], input_token_ids.shape[1] - lang_ids.shape[1], self.hidden_size).to(input_token_ids.device)], dim=1) * torch.cat([lang_mask, torch.zeros_like(prompt_ids)], dim=-1).unsqueeze(-1) + + + inputs_embeds = input_token_embedding + lang_embedding + + generated_ids = self.model.generate( + # input wav phone token ids + text token ids + inputs_embeds=inputs_embeds, + do_sample=True, + max_length=max_length, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_target_id, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repeat_penalty, + min_new_tokens=50, + ) + # assert generated_ids.size(1) > input_length, f"Generated tokens length {generated_ids.size(1)} is less than input length {input_length}, generated ids is {generated_ids}" + gen_tokens = generated_ids[:,:-1] + + else: + + input_token_embedding = self.model.model.embed_tokens(input_token_ids) + # if prosody_features is not None: + # + # prosody_features = prosody_features.unsqueeze(1).expand(-1, input_token_embedding.size(1), -1) + # inputs_embeds = input_token_embedding + prosody_features + # generated_ids = self.model.generate( + # inputs_embeds=inputs_embeds, + generated_ids = self.model.generate( + input_token_ids, + do_sample=True, + max_length=max_length, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_target_id, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repeat_penalty, + min_new_tokens=50 + ) + # assert generated_ids.size(1) > input_length, f"Generated tokens length {generated_ids.size(1)} is less than input length {input_length}, generated ids is {generated_ids}" + + gen_tokens = generated_ids[:, input_length:-1] + return gen_tokens + + +class DownsampleWithMask(nn.Module): + def __init__(self, downsample_factor=2): + super(DownsampleWithMask, self).__init__() + self.downsample_factor = downsample_factor + + def forward(self, x, mask): + # x shape: (batch_size, seq_len) + # mask shape: (batch_size, seq_len) + + x = x.float() + x = x.unsqueeze(1) # add channel dimension: (batch_size, 1, seq_len) + x = F.avg_pool1d(x, kernel_size=self.downsample_factor, stride=self.downsample_factor) + x = x.squeeze(1) # remove channel dimension: (batch_size, seq_len // downsample_factor) + x = x.long() + + # average pooling + mask = mask.float() # convert mask to float for pooling + mask = mask.unsqueeze(1) # add channel dimension: (batch_size, 1, seq_len) + mask = F.avg_pool1d(mask, kernel_size=self.downsample_factor, stride=self.downsample_factor) + \ No newline at end of file diff --git a/models/tts/debatts/t2s_sft_dataset_new.py b/models/tts/debatts/t2s_sft_dataset_new.py new file mode 100644 index 00000000..0e104491 --- /dev/null +++ b/models/tts/debatts/t2s_sft_dataset_new.py @@ -0,0 +1,273 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from cmath import inf +import io +import librosa +import torch +import json +import tqdm +import numpy as np +import logging +import pickle +import os + +import time +from torch.utils.data import Dataset +import torch.nn as nn +import torch.nn.functional as F +from multiprocessing import Pool +import concurrent.futures +from pathlib import Path +from transformers import SeamlessM4TFeatureExtractor +from transformers import Wav2Vec2BertModel +os.chdir('./models/tts/debatts') +import sys +sys.path.append('./models/tts/debatts') +from utils.g2p_new.g2p import phonemizer_g2p +from utils.g2p_liwei.g2p_liwei import liwei_g2p +from torch.nn.utils.rnn import pad_sequence + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class WarningFilter(logging.Filter): + def filter(self, record): + + if record.name == "phonemizer" and record.levelno == logging.WARNING: + return False + if record.name == "qcloud_cos.cos_client" and record.levelno == logging.INFO: + return False + if record.name == "jieba" and record.levelno == logging.DEBUG: + return False + return True + +filter = WarningFilter() +logging.getLogger("phonemizer").addFilter(filter) +logging.getLogger("qcloud_cos.cos_client").addFilter(filter) +logging.getLogger("jieba").addFilter(filter) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class T2SDataset(torch.utils.data.Dataset): + def __init__( + self, + cfg=None, + ): + self.cfg = cfg + + self.meta_info_path = "Debatts-Data Summary Json" + with open(self.meta_info_path, "r") as f: + self.meta_info_data = json.load(f) + + self.wav_paths = [] + self.prompt0_paths = [] # Add prompt0 paths + self.wav_path_index2duration = [] + self.wav_path_index2phonelen = [] + self.wav_path_index2spkid = [] + self.wav_path_index2phoneid = [] + self.index2num_frames = [] + self.index2lang = [] + self.lang2id = {"en": 1, "zh": 2, "ja": 3, "fr": 4, "ko": 5, "de": 6} + + for info in self.meta_info_data: + if info["prompt0_wav_path"] == None: + continue + self.wav_paths.append(info["wav_path"]) + self.prompt0_paths.append(info["prompt0_wav_path"]) # Add prompt0 path + self.wav_path_index2duration.append(info["duration"]) + self.wav_path_index2phonelen.append(info["phone_count"]) + self.wav_path_index2spkid.append(info["speaker_id"]) + self.wav_path_index2phoneid.append(info["phone_id"]) + self.index2num_frames.append(info["duration"] * 50 + len(info["phone_id"])) + lang_id = self.lang2id[info['language']] + self.index2lang.append(lang_id) + + # self.index2num_frames.append(info["duration"] * self.cfg.preprocess.sample_rate) + + self.num_frame_indices = np.array( + sorted( + range(len(self.index2num_frames)), + key=lambda k: self.index2num_frames[k], + ) + ) + + self.processor = SeamlessM4TFeatureExtractor.from_pretrained( + "./w2v-bert-2" + ) + + def liwei_g2p(self, text, language): + return liwei_g2p(text, language) + + def __len__(self): + return self.wav_paths.__len__() + + def get_num_frames(self, index): + return self.wav_path_index2duration[index] * 50 + self.wav_path_index2phonelen[index] + + def __getitem__(self, idx): + wav_path = self.wav_paths[idx] + speech, sr = librosa.load(wav_path, sr=self.cfg.preprocess.sample_rate) + speech = np.pad(speech, (0, self.cfg.preprocess.hop_size - len(speech) % self.cfg.preprocess.hop_size), mode="constant") + # resample the speech to 16k for feature extraction + if self.cfg.preprocess.sample_rate != 16000: + speech_16k = librosa.resample( + speech, orig_sr=self.cfg.preprocess.sample_rate, target_sr=16000 + ) + else: + speech_16k = speech + inputs = self.processor(speech_16k, sampling_rate=16000) + # wav 2 bert convert to useful feature + input_features = inputs["input_features"][0] + attention_mask = inputs["attention_mask"][0] + + prompt0_wav_path = self.prompt0_paths[idx] # Get prompt0 path + speech_prompt0, sr_prompt0 = librosa.load(prompt0_wav_path, sr=self.cfg.preprocess.sample_rate) + speech_prompt0 = np.pad(speech_prompt0, (0, self.cfg.preprocess.hop_size - len(speech_prompt0) % self.cfg.preprocess.hop_size), mode="constant") + # resample the speech to 16k for feature extraction + if self.cfg.preprocess.sample_rate != 16000: + speech_16k_prompt0 = librosa.resample( + speech_prompt0, orig_sr=self.cfg.preprocess.sample_rate, target_sr=16000 + ) + else: + speech_16k_prompt0 = speech_prompt0 + + inputs_prompt0 = self.processor(speech_16k_prompt0, sampling_rate=16000) + + input_features_prompt0 = inputs_prompt0["input_features"][0] + attention_mask_prompt0 = inputs_prompt0["attention_mask"][0] + + # get speech mask + speech_frames = len(speech) // self.cfg.preprocess.hop_size + mask = np.ones(speech_frames) + + speech_frames_prompt0 = len(speech_prompt0) // self.cfg.preprocess.hop_size + mask_prompt0 = np.ones(speech_frames_prompt0) + + del speech, speech_16k, speech_prompt0, speech_16k_prompt0 + + lang_id = self.index2lang[idx] + phone_id = self.wav_path_index2phoneid[idx] + phone_id = torch.tensor(phone_id, dtype=torch.long) + phone_mask = np.ones(len(phone_id)) + + single_feature = dict() + + spk_id = self.wav_path_index2spkid[idx] + + single_feature.update({"spk_id": spk_id}) + single_feature.update({"lang_id": lang_id}) + + single_feature.update({"phone_id": phone_id}) + single_feature.update({"phone_mask": phone_mask}) + + single_feature.update( + { + "input_features": input_features, + "attention_mask": attention_mask, + "mask": mask, + "input_features_prompt0": input_features_prompt0, + "attention_mask_prompt0": attention_mask_prompt0, + "mask_prompt0":mask_prompt0 + } + ) + + return single_feature + + + +class T2SCollator(object): + def __init__(self, cfg): + self.cfg = cfg + + def __call__(self, batch): + packed_batch_features = dict() + + for key in batch[0].keys(): + if "input_features" in key: + packed_batch_features[key] = pad_sequence( + [utt[key].float() if isinstance(utt[key], torch.Tensor) else torch.tensor(utt[key]).float() for utt in batch], + batch_first=True + ) + if "attention_mask" in key: + packed_batch_features[key] = pad_sequence( + [utt[key].float() if isinstance(utt[key], torch.Tensor) else torch.tensor(utt[key]).float() for utt in batch], + batch_first=True + ) + if "mask" in key: + packed_batch_features[key] = pad_sequence( + [utt[key].long() if isinstance(utt[key], torch.Tensor) else torch.tensor(utt[key]).long() for utt in batch], + batch_first=True + ) + if "semantic_code" in key: + packed_batch_features[key] = pad_sequence( + [utt[key].float() if isinstance(utt[key], torch.Tensor) else torch.tensor(utt[key]).float() for utt in batch], + batch_first=True + ) + if key == "phone_id": + packed_batch_features[key] = pad_sequence( + [utt[key].long() for utt in batch], batch_first=True, padding_value=1023, # phone vocab size is 1024 + ) + if key == "phone_mask": + packed_batch_features[key] = pad_sequence( + [torch.tensor(utt[key]).long() for utt in batch], batch_first=True + ) + if key == "lang_id": + packed_batch_features[key] = torch.tensor([utt[key] for utt in batch]).long() + if key == "spk_id": + packed_batch_features[key] = torch.tensor([utt[key] for utt in batch]).long() + if key == "spk_emb_input_features": + packed_batch_features[key] = pad_sequence( + [torch.tensor(utt[key]).float() for utt in batch], batch_first=True + ) + if key == "spk_emb_attention_mask": + packed_batch_features[key] = pad_sequence( + [torch.tensor(utt[key]).long() for utt in batch], batch_first=True + ) + else: + pass + + return packed_batch_features + + + +class DownsampleWithMask(nn.Module): + def __init__(self, downsample_factor=2): + super(DownsampleWithMask, self).__init__() + self.downsample_factor = downsample_factor + + def forward(self, x, mask): + # input from numpy.ndarray to torch.Tensor + if isinstance(x, np.ndarray): + x = torch.tensor(x, dtype=torch.float32) + if isinstance(mask, np.ndarray): + mask = torch.tensor(mask, dtype=torch.float32) + + # print(f"################## x size original {x.shape}################################") + + x = x.float() + x = x.permute(1, 0) # to (feature_dim, timestep) + x = x.unsqueeze(1) # add channel dimension: (timestep, 1, feature_dim) + + if x.size(-1) < self.downsample_factor: + raise ValueError("Input size must be larger than downsample factor") + + # print(f"################## x size before {x.shape}################################") + x = F.avg_pool1d(x, kernel_size=self.downsample_factor) + x = x.squeeze(1) # remove channel dimension: (timestep, feature_dim // downsample_factor) + x = x.long() + x = x.permute(1, 0) # to (feature_dim, timestep) + + mask = mask.float() # convert mask to float for pooling + mask = mask.unsqueeze(0).unsqueeze(0) # add channel dimension: (timestep, 1, feature_dim) + + if mask.size(-1) < self.downsample_factor: + raise ValueError("Mask size must be larger than downsample factor") + + mask = F.avg_pool1d(mask, kernel_size=self.downsample_factor, stride=self.downsample_factor) + mask = mask.squeeze(0).squeeze(0) # remove channel dimension: (timestep, feature_dim // downsample_factor) + mask = (mask >= 0.5).long() # if average > 0.5 --> 1, else 0 + + return x, mask diff --git a/models/tts/debatts/try_inference_small_samples.py b/models/tts/debatts/try_inference_small_samples.py new file mode 100644 index 00000000..0e21965f --- /dev/null +++ b/models/tts/debatts/try_inference_small_samples.py @@ -0,0 +1,466 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import os +os.chdir('./models/tts/debatts') +sys.path.append('./models/tts/debatts') +from utils.g2p_liwei.g2p_liwei import liwei_g2p + +from transformers import Wav2Vec2Model +from cgitb import text +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import librosa +import os +from IPython.display import Audio +import matplotlib.pyplot as plt +import soundfile as sf +import pickle +import math +import json +import accelerate +from IPython.display import Audio + +from models.codec.kmeans.kmeans_model import KMeans, KMeansEMA +from models.codec.kmeans.repcodec_model import RepCodec +from models.tts.soundstorm.soundstorm_model import SoundStorm +from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder +from transformers import Wav2Vec2BertModel +import safetensors +from utils.util import load_config +from tqdm import tqdm + +from transformers import SeamlessM4TFeatureExtractor +processor = SeamlessM4TFeatureExtractor.from_pretrained("/mntcephfs/lab_data/lijiaqi/debate/gluster-tts/w2v-bert-2") + +from transformers import AutoProcessor, AutoModel + +from models.tts.text2semantic.t2s_model import T2SLlama +from models.tts.text2semantic.t2s_model_new import T2SLlama_new +from utils.g2p_liwei.g2p_liwei import liwei_g2p +from models.tts.text2semantic.t2s_sft_dataset_new import DownsampleWithMask + +def liwei_g2p_(text, language): + return liwei_g2p(text, language) + +def build_t2s_model_new(cfg, device): + t2s_model = T2SLlama_new(phone_vocab_size=1024, + target_vocab_size=8192, + hidden_size=2048, + intermediate_size=8192, + pad_token_id=9216, + bos_target_id=9217, + eos_target_id=9218, + bos_phone_id=9219, + eos_phone_id=9220, + bos_prompt0_id=9221, + eos_prompt0_id=9222, + use_lang_emb=False) + t2s_model.eval() + t2s_model.to(device) + t2s_model.half() + return t2s_model + +def build_soundstorm(cfg, device): + soundstorm_model = SoundStorm(cfg=cfg.model.soundstorm) + soundstorm_model.eval() + soundstorm_model.to(device) + return soundstorm_model + +def build_kmeans_model(cfg, device): + if cfg.model.kmeans.type == "kmeans": + kmeans_model = KMeans(cfg=cfg.model.kmeans.kmeans) + elif cfg.model.kmeans.type == "kmeans_ema": + kmeans_model = KMeansEMA(cfg=cfg.model.kmeans.kmeans) + elif cfg.model.kmeans.type == "repcodec": + kmeans_model = RepCodec(cfg=cfg.model.kmeans.repcodec) + kmeans_model.eval() + pretrained_path =cfg.model.kmeans.pretrained_path + if ".bin" in pretrained_path: + kmeans_model.load_state_dict(torch.load(pretrained_path)) + elif ".safetensors" in pretrained_path: + safetensors.torch.load_model(kmeans_model, pretrained_path) + kmeans_model.to(device) + return kmeans_model + +def build_semantic_model(cfg, device): + semantic_model = Wav2Vec2BertModel.from_pretrained("/mntcephfs/lab_data/lijiaqi/debate/gluster-tts/w2v-bert-2") + semantic_model.eval() + semantic_model.to(device) + + layer_idx = 15 + output_idx = 17 + stat_mean_var = torch.load(cfg.model.kmeans.stat_mean_var_path) + semantic_mean = stat_mean_var["mean"] + semantic_std = torch.sqrt(stat_mean_var["var"]) + semantic_mean = semantic_mean.to(device) + semantic_std = semantic_std.to(device) + + return semantic_model, semantic_mean, semantic_std + +def build_codec_model(cfg, device): + codec_encoder = CodecEncoder(cfg=cfg.model.codec.encoder) + codec_decoder = CodecDecoder(cfg=cfg.model.codec.decoder) + if ".bin" in cfg.model.codec.encoder.pretrained_path: + codec_encoder.load_state_dict( + torch.load(cfg.model.codec.encoder.pretrained_path) + ) + codec_decoder.load_state_dict( + torch.load(cfg.model.codec.decoder.pretrained_path) + ) + else: + accelerate.load_checkpoint_and_dispatch(codec_encoder, cfg.model.codec.encoder.pretrained_path) + accelerate.load_checkpoint_and_dispatch(codec_decoder, cfg.model.codec.decoder.pretrained_path) + codec_encoder.eval() + codec_decoder.eval() + codec_encoder.to(device) + codec_decoder.to(device) + return codec_encoder, codec_decoder + +@torch.no_grad() +def extract_acoustic_code(speech): + vq_emb = codec_encoder(speech.unsqueeze(1)) + _, vq, _, _, _ = codec_decoder.quantizer(vq_emb) + acoustic_code = vq.permute( + 1, 2, 0 + ) # (num_quantizer, T, C) -> (T, C, num_quantizer) + return acoustic_code + +@torch.no_grad() +def extract_semantic_code(semantic_mean, semantic_std, input_features, attention_mask): + vq_emb = semantic_model( + input_features=input_features, + attention_mask=attention_mask, + output_hidden_states=True, + ) + feat = vq_emb.hidden_states[17] # (B, T, C) + feat = (feat - semantic_mean.to(feat)) / semantic_std.to(feat) + + semantic_code, _ = kmeans_model.quantize(feat) # (B, T) + return semantic_code + +@torch.no_grad() +def extract_features(speech, processor): + inputs = processor(speech, sampling_rate=16000, return_tensors="pt") + input_features = inputs["input_features"][0] + attention_mask = inputs["attention_mask"][0] + return input_features, attention_mask + +@torch.no_grad() +def text2semantic(prompt0_speech, prompt0_text, prompt_speech, prompt_text, prompt_language, target_text, target_language, use_prompt_text=True, temp=1.0, top_k=1000, top_p=0.85, infer_mode = "ori"): + if use_prompt_text: + if infer_mode == "new" and prompt0_speech is not None and prompt0_speech.any(): + prompt0_phone_id = liwei_g2p_(prompt0_text, prompt_language)[1] + prompt0_phone_id = torch.tensor(prompt0_phone_id, dtype=torch.long).to(device) + + prompt_phone_id = liwei_g2p_(prompt_text, prompt_language)[1] + prompt_phone_id = torch.tensor(prompt_phone_id, dtype=torch.long).to(device) + + target_phone_id = liwei_g2p_(target_text, target_language)[1] + target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device) + + phone_id = torch.cat([prompt_phone_id, torch.LongTensor([4]).to(device), target_phone_id]) + + else: + target_phone_id = liwei_g2p_(target_text, target_language)[1] + target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device) + phone_id = target_phone_id + + input_fetures, attention_mask = extract_features(prompt_speech, processor) + input_fetures = input_fetures.unsqueeze(0).to(device) + attention_mask = attention_mask.unsqueeze(0).to(device) + semantic_code = extract_semantic_code(semantic_mean, semantic_std, input_fetures, attention_mask) + + + if infer_mode == "new": + input_fetures_prompt0, attention_mask_prompt0 = extract_features(prompt0_speech, processor) + input_fetures_prompt0 = input_fetures_prompt0.unsqueeze(0).to(device) + attention_mask_prompt0 = attention_mask_prompt0.unsqueeze(0).to(device) + attention_mask_prompt0 = attention_mask_prompt0.float() + semantic_code_prompt0 = extract_semantic_code(semantic_mean, semantic_std, input_fetures_prompt0, attention_mask_prompt0) + + if use_prompt_text: + if infer_mode =="ori": + predict_semantic = t2s_model.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :], temperature=temp, top_k=top_k, top_p=top_p) + elif infer_mode == "tune": + predict_semantic = t2s_model_tune.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :], temperature=temp, top_k=top_k, top_p=top_p) + elif infer_mode == "new": + predict_semantic = t2s_model_new.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :], prompt0_ids=semantic_code_prompt0[:, :], temperature=temp, top_k=top_k, top_p=top_p) + + else: + if infer_mode == "ori": + predict_semantic = t2s_model.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :1], temperature=temp, top_k=top_k, top_p=top_p) + elif infer_mode == "tune": + predict_semantic = t2s_model_tune.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :1], temperature=temp, top_k=top_k, top_p=top_p) + elif infer_mode == "new": + predict_semantic = t2s_model_new.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :1], prompt0_ids=semantic_code_prompt0[:, :1], temperature=temp, top_k=top_k, top_p=top_p) + + + combine_semantic_code = torch.cat([semantic_code[:,:], predict_semantic], dim=-1) + prompt_semantic_code = semantic_code + + # max_com_semantic_value = torch.max(combine_semantic_code).item() + # max_prompt_semantic_value = torch.max(prompt_semantic_code).item() + + # print(f"Max token value in com semantic: {max_com_semantic_value}, shape is {combine_semantic_code.shape}") + # print(f"Max token value in prompt semantic: {max_prompt_semantic_value}, shape is {prompt_semantic_code.shape}") + # print(f"combine semantic_code of t2s new is {combine_semantic_code}, shape is {combine_semantic_code.shape}") + + return combine_semantic_code, prompt_semantic_code + +@torch.no_grad() +def semantic2acoustic(combine_semantic_code, acoustic_code): + + semantic_code = combine_semantic_code + + if soundstorm_1layer.cond_code_layers == 1: + cond = soundstorm_1layer.cond_emb(semantic_code) + else: + cond = soundstorm_1layer.cond_emb[0](semantic_code[0,:,:]) + for i in range(1, soundstorm_1layer.cond_code_layers): + cond += soundstorm_1layer.cond_emb[i](semantic_code[i,:,:]) + cond = cond / math.sqrt(soundstorm_1layer.cond_code_layers) + + prompt = acoustic_code[:,:,:] + predict_1layer = soundstorm_1layer.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=[40], cfg=1.0, rescale_cfg=1.0) + + if soundstorm_full.cond_code_layers == 1: + cond = soundstorm_full.cond_emb(semantic_code) + else: + cond = soundstorm_full.cond_emb[0](semantic_code[0,:,:]) + for i in range(1, soundstorm_full.cond_code_layers): + cond += soundstorm_full.cond_emb[i](semantic_code[i,:,:]) + cond = cond / math.sqrt(soundstorm_full.cond_code_layers) + + prompt = acoustic_code[:,:,:] + predict_full = soundstorm_full.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=[40,16,10,10,10,10,10,10,10,10,10,10], cfg=1.0, rescale_cfg=1.0, gt_code=predict_1layer) + vq_emb = codec_decoder.vq2emb(predict_full.permute(2,0,1), n_quantizers=12) + recovered_audio = codec_decoder(vq_emb) + prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2,0,1), n_quantizers=12) + recovered_prompt_audio = codec_decoder(prompt_vq_emb) + recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy() + recovered_audio = recovered_audio[0][0].cpu().numpy() + combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio]) + + return combine_audio, recovered_audio + +device = torch.device("cuda:0") +cfg_soundstorm_1layer = load_config("./egs/tts/SoundStorm/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json") +cfg_soundstorm_full = load_config("./models/tts/debatts/egs/tts/SoundStorm/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json") + +soundstorm_1layer = build_soundstorm(cfg_soundstorm_1layer, device) +soundstorm_full = build_soundstorm(cfg_soundstorm_full, device) + +semantic_model, semantic_mean, semantic_std = build_semantic_model(cfg_soundstorm_full, device) +kmeans_model = build_kmeans_model(cfg_soundstorm_full, device) + +codec_encoder, codec_decoder = build_codec_model(cfg_soundstorm_full, device) + +semantic_model, semantic_mean, semantic_std = build_semantic_model(cfg_soundstorm_full, device) +kmeans_model = build_kmeans_model(cfg_soundstorm_full, device) + + +soundstorm_1layer_path = "./s2a_model/emilia_50k_8192_331k_model.safetensors" +soundstorm_full_path = "./s2a_model/emilia_50k_8192_519k_model.safetensors" +safetensors.torch.load_model(soundstorm_1layer, soundstorm_1layer_path) +safetensors.torch.load_model(soundstorm_full, soundstorm_full_path) + +t2s_cfg = load_config("./exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json") +t2s_model = build_t2s_model(t2s_cfg, device) +t2s_model_ckpt_path = "/mntcephfs/lab_data/lijiaqi/debate/gluster-tts/ckpt/t2s/t2s_625ksteps_model.safetensors" +safetensors.torch.load_model(t2s_model, t2s_model_ckpt_path) +print(t2s_model.bos_target_id, t2s_model.eos_target_id, t2s_model.bos_phone_id, t2s_model.eos_phone_id, t2s_model.pad_token_id) + +t2s_cfg = load_config("./egs/tts/Text2Semantic/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json") +t2s_model_tune = build_t2s_model(t2s_cfg, device) +t2s_model_tune_ckpt_path = "/mntcephfs/data/wuzhizheng/debate_/ckpt_ori_tune/epoch-0021_step-0005000_loss-4.354165/model.safetensors" +safetensors.torch.load_model(t2s_model_tune, t2s_model_tune_ckpt_path) + +t2s_cfg = load_config("./egs/tts/Text2Semantic/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json") +t2s_model_new = build_t2s_model_new(t2s_cfg, device) +t2s_model_new_ckpt_path = "./s2a_model/model.safetensors" # 1900(02), 1906 +safetensors.torch.load_model(t2s_model_new, t2s_model_new_ckpt_path) + +from funasr import AutoModel +print("Loading ASR model...") +asr_model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", vad_kwargs={"max_single_segment_time": 60000}, punc_model="ct-punc", device="cuda:0") + +def adjust_punctuation(text): + """ + Adjust the punctuation so that the comma is followed + by a space and the rest of the punctuation uses the + full Angle symbol. + """ + text = text.replace(',', ', ') + + punct_mapping = { + '。': '。', + '?': '?', + '!': '!', + ':': ':', + ';': ';', + '“': '“', + '”': '”', + '‘': '‘', + '’': '’' + } + for punct, full_punct in punct_mapping.items(): + text = text.replace(punct, full_punct) + return text + +import random +import zhconv +def generate_text_data(wav_file): + idx = random.randint(0, 7000) + speech = librosa.load(wav_file, sr=16000)[0] + txt_json_path = wav_file.replace(".wav", ".json") + txt_json_param_path = wav_file.replace(".wav", "_asr_param.json") + if os.path.exists(txt_json_path): + + with open(txt_json_path, 'r', encoding='utf-8') as file: + json_data = json.load(file) + + if "text" in json_data: + txt = json_data["text"] + txt = adjust_punctuation(txt) + + elif os.path.exists(txt_json_param_path): + with open(txt_json_param_path, 'r', encoding='utf-8') as file: + json_data = json.load(file) + if "text" in json_data: + txt = json_data["text"] + txt = adjust_punctuation(txt) + + else: + res = asr_model.generate(input=wav_file, batch_size_s=300) + txt = res[0]["text"] + txt = zhconv.convert(txt, 'zh-cn') + txt = adjust_punctuation(txt) + + json_data["text"] = txt + with open(txt_json_path, 'w', encoding='utf-8') as file: + json.dump(json_data, file, ensure_ascii=False, indent=4) + + # If no JSON file is found, generate new text and save it to a new JSON file + else: + res = asr_model.generate(input=wav_file, batch_size_s=300) + txt = res[0]["text"] + txt = zhconv.convert(txt, 'zh-cn') + txt = adjust_punctuation(txt) + # txt = re.sub(" ", "", txt) + + json_data = {"text": txt} + with open(txt_json_path, 'w', encoding='utf-8') as file: + json.dump(json_data, file, ensure_ascii=False, indent=4) + + return wav_file, txt, wav_file + + +def infer(speech_path, prompt_text, target_wav_path, target_text, target_language='zh', speech_path_prompt0=None, prompt0_text=None, temperature=0.2, top_k=20, top_p=0.9, concat_prompt=False, infer_mode="ori", idx = 0, epoch=0, spk_prompt_type=""): + if idx != 0: + save_dir = os.path.join("The Path to Store Generated Speech", f"{infer_mode}/{spk_prompt_type}") + if not os.path.exists(save_dir): + os.mkdir(save_dir) + save_path = os.path.join(save_dir, f"{os.path.splitext(os.path.basename(target_wav_path))[0]}_infer_{infer_mode}_{idx}_epoch_{epoch}_{spk_prompt_type}.wav") + else: + save_dir = os.path.join("The Path to Store Generated Speech", f"{infer_mode}/{spk_prompt_type}") + if not os.path.exists(save_dir): + os.mkdir(save_dir) + save_path = os.path.join(save_dir, f"{os.path.splitext(os.path.basename(target_wav_path))[0]}_infer_{infer_mode}_epoch_{epoch}_{spk_prompt_type}.wav") + + if os.path.exists(save_path): + return save_path + + print(f"HERE COMES INFER!!! {infer_mode}") + print(f"IN INFER PROMPT text is {prompt_text}") + print(f"IN INFER Target text is {target_text}") + speech_16k = librosa.load(speech_path, sr=16000)[0] + speech = librosa.load(speech_path, sr=cfg_soundstorm_1layer.preprocess.sample_rate)[0] + + if infer_mode == "new": + speech_16k_prompt0 = librosa.load(speech_path_prompt0, sr=16000)[0] + speech_prompt0 = librosa.load(speech_path_prompt0, sr=cfg_soundstorm_1layer.preprocess.sample_rate)[0] + # combine_semantic_code, _ = text2semantic_new(speech_16k_prompt0, prompt0_text, speech_16k, prompt_text, target_language, target_text, target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode=infer_mode) + combine_semantic_code, _ = text2semantic(prompt0_speech=speech_16k_prompt0, prompt0_text=prompt0_text, prompt_speech=speech_16k, prompt_text=prompt_text, prompt_language=target_language, target_text=target_text, target_language=target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode = infer_mode) + + else: + combine_semantic_code, _ = text2semantic(prompt0_speech=None, prompt0_text=None, prompt_speech=speech_16k, prompt_text=prompt_text, prompt_language = target_language, target_text=target_text, target_language=target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode=infer_mode) + acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device)) + # print(acoustic_code.shape) + combine_audio, recovered_audio = semantic2acoustic(combine_semantic_code, acoustic_code) + + + if not concat_prompt: + combine_audio = combine_audio[speech.shape[-1]:] + # sf.write(os.path.join(save_path, "{}.wav".format(uid)), recovered_audio, samplerate=cfg_soundstorm_1layer.preprocess.sample_rate) + sf.write(save_path, combine_audio, samplerate=cfg_soundstorm_1layer.preprocess.sample_rate) + return save_path + +def infer_small(speech_path, prompt_text, target_text, target_language='zh', speech_path_prompt0=None, prompt0_text=None, temperature=0.2, top_k=20, top_p=0.9, concat_prompt=False, infer_mode="ori", save_path=None): + + if os.path.exists(save_path): + return save_path + + speech_16k = librosa.load(speech_path, sr=16000)[0] + speech = librosa.load(speech_path, sr=cfg_soundstorm_1layer.preprocess.sample_rate)[0] + + if infer_mode == "new": + speech_16k_prompt0 = librosa.load(speech_path_prompt0, sr=16000)[0] + speech_prompt0 = librosa.load(speech_path_prompt0, sr=cfg_soundstorm_1layer.preprocess.sample_rate)[0] + # combine_semantic_code, _ = text2semantic_new(speech_16k_prompt0, prompt0_text, speech_16k, prompt_text, target_language, target_text, target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode=infer_mode) + combine_semantic_code, _ = text2semantic(prompt0_speech=speech_16k_prompt0, prompt0_text=prompt0_text, prompt_speech=speech_16k, prompt_text=prompt_text, prompt_language=target_language, target_text=target_text, target_language=target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode = infer_mode) + + else: + combine_semantic_code, _ = text2semantic(prompt0_speech=None, prompt0_text=None, prompt_speech=speech_16k, prompt_text=prompt_text, prompt_language = target_language, target_text=target_text, target_language=target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode=infer_mode) + acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device)) + combine_audio, recovered_audio = semantic2acoustic(combine_semantic_code, acoustic_code) + + + if not concat_prompt: + combine_audio = combine_audio[speech.shape[-1]:] + # sf.write(os.path.join(save_path, "{}.wav".format(uid)), recovered_audio, samplerate=cfg_soundstorm_1layer.preprocess.sample_rate) + sf.write(save_path, combine_audio, samplerate=cfg_soundstorm_1layer.preprocess.sample_rate) + return save_path + +def get_prompt0_wav_path(wav_path): + base_path = wav_path.split('chenci')[0].rstrip('_') + json_path = f"{base_path}.json" + + try: + with open(json_path, 'r', encoding='utf-8') as file: + data = json.load(file) + prompt0_wav_path = data['prompt0_wav_path'] + basename = os.path.basename(prompt0_wav_path) + prompt0_wav_path = os.path.join("Debatts-Data Test Directory", basename) + return prompt0_wav_path + except FileNotFoundError: + print(f"File not found: {json_path}") + except KeyError: + print(f"Cannot find 'prompt0_wav_path' in json") + except json.JSONDecodeError as e: + print(f"Error in reading json: {e}") + +##################################### EVALUATION ################################################################ +from funasr import AutoModel +import torch.nn.functional as F +import torch + +from models.tts.soundstorm.try_inference_new import evaluation +from models.tts.soundstorm.try_inference_new import evaluation_new +from models.tts.soundstorm.try_inference_new import extract_emotion_similarity + +prompt0_wav_path = "./models/tts/debatts/speech_examples/87_SPEAKER01_2_part03_213.wav" +prompt0_text = generate_text_data(prompt0_wav_path)[1] + +spk_prompt_wav_path = "The Speaker Identity Path" +spk_prompt_text = generate_text_data(spk_prompt_wav_path)[1] + +save_path_dir = "The Path to Save Generated Speech" +wav_filename = "The Filename of Generated Speech" +save_path = os.path.join(save_path_infer_dir, wav_filename) +save_path = infer_small(speech_path=spk_prompt_wav_path, prompt_text = spk_prompt_text, target_text=target_text, speech_path_prompt0 = prompt0_wav_path, prompt0_text = prompt0_text, infer_mode = "new", save_path = save_path) diff --git a/models/tts/debatts/utils/__init__.py b/models/tts/debatts/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/tts/debatts/utils/audio.py b/models/tts/debatts/utils/audio.py new file mode 100644 index 00000000..374d5091 --- /dev/null +++ b/models/tts/debatts/utils/audio.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +from numpy import linalg as LA +import librosa +import soundfile as sf +import librosa.filters + + +def load_audio_torch(wave_file, fs): + """Load audio data into torch tensor + + Args: + wave_file (str): path to wave file + fs (int): sample rate + + Returns: + audio (tensor): audio data in tensor + fs (int): sample rate + """ + + audio, sample_rate = librosa.load(wave_file, sr=fs, mono=True) + # audio: (T,) + assert len(audio) > 2 + + # Check the audio type (for soundfile loading backbone) - float, 8bit or 16bit + if np.issubdtype(audio.dtype, np.integer): + max_mag = -np.iinfo(audio.dtype).min + else: + max_mag = max(np.amax(audio), -np.amin(audio)) + max_mag = ( + (2**31) + 1 + if max_mag > (2**15) + else ((2**15) + 1 if max_mag > 1.01 else 1.0) + ) + + # Normalize the audio + audio = torch.FloatTensor(audio.astype(np.float32)) / max_mag + + if (torch.isnan(audio) | torch.isinf(audio)).any(): + return [], sample_rate or fs or 48000 + + # Resample the audio to our target samplerate + if fs is not None and fs != sample_rate: + audio = torch.from_numpy( + librosa.core.resample(audio.numpy(), orig_sr=sample_rate, target_sr=fs) + ) + sample_rate = fs + + return audio, fs + + +def _stft(y, cfg): + return librosa.stft( + y=y, n_fft=cfg.n_fft, hop_length=cfg.hop_size, win_length=cfg.win_size + ) + + +def energy(wav, cfg): + D = _stft(wav, cfg) + magnitudes = np.abs(D).T # [F, T] + return LA.norm(magnitudes, axis=1) + + +def get_energy_from_tacotron(audio, _stft): + audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) + audio = torch.autograd.Variable(audio, requires_grad=False) + mel, energy = _stft.mel_spectrogram(audio) + energy = torch.squeeze(energy, 0).numpy().astype(np.float32) + return mel, energy diff --git a/models/tts/debatts/utils/audio_slicer.py b/models/tts/debatts/utils/audio_slicer.py new file mode 100644 index 00000000..28474596 --- /dev/null +++ b/models/tts/debatts/utils/audio_slicer.py @@ -0,0 +1,476 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import json +import numpy as np +from tqdm import tqdm +import torch +import torchaudio + +from utils.io import save_audio +from utils.audio import load_audio_torch + + +# This function is obtained from librosa. +def get_rms( + y, + *, + frame_length=2048, + hop_length=512, + pad_mode="constant", +): + padding = (int(frame_length // 2), int(frame_length // 2)) + y = np.pad(y, padding, mode=pad_mode) + + axis = -1 + # put our new within-frame axis at the end for now + out_strides = y.strides + tuple([y.strides[axis]]) + # Reduce the shape on the framing axis + x_shape_trimmed = list(y.shape) + x_shape_trimmed[axis] -= frame_length - 1 + out_shape = tuple(x_shape_trimmed) + tuple([frame_length]) + xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides) + if axis < 0: + target_axis = axis - 1 + else: + target_axis = axis + 1 + xw = np.moveaxis(xw, -1, target_axis) + # Downsample along the target axis + slices = [slice(None)] * xw.ndim + slices[axis] = slice(0, None, hop_length) + x = xw[tuple(slices)] + + # Calculate power + power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True) + + return np.sqrt(power) + + +class Slicer: + """ + Copy from: https://github.com/openvpi/audio-slicer/blob/main/slicer2.py + """ + + def __init__( + self, + sr: int, + threshold: float = -40.0, + min_length: int = 5000, + min_interval: int = 300, + hop_size: int = 10, + max_sil_kept: int = 5000, + ): + if not min_length >= min_interval >= hop_size: + raise ValueError( + "The following condition must be satisfied: min_length >= min_interval >= hop_size" + ) + if not max_sil_kept >= hop_size: + raise ValueError( + "The following condition must be satisfied: max_sil_kept >= hop_size" + ) + min_interval = sr * min_interval / 1000 + self.threshold = 10 ** (threshold / 20.0) + self.hop_size = round(sr * hop_size / 1000) + self.win_size = min(round(min_interval), 4 * self.hop_size) + self.min_length = round(sr * min_length / 1000 / self.hop_size) + self.min_interval = round(min_interval / self.hop_size) + self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) + + def _apply_slice(self, waveform, begin, end): + begin = begin * self.hop_size + if len(waveform.shape) > 1: + end = min(waveform.shape[1], end * self.hop_size) + return waveform[:, begin:end], begin, end + else: + end = min(waveform.shape[0], end * self.hop_size) + return waveform[begin:end], begin, end + + # @timeit + def slice(self, waveform, return_chunks_positions=False): + if len(waveform.shape) > 1: + # (#channle, wave_len) -> (wave_len) + samples = waveform.mean(axis=0) + else: + samples = waveform + if samples.shape[0] <= self.min_length: + return [waveform] + rms_list = get_rms( + y=samples, frame_length=self.win_size, hop_length=self.hop_size + ).squeeze(0) + sil_tags = [] + silence_start = None + clip_start = 0 + for i, rms in enumerate(rms_list): + # Keep looping while frame is silent. + if rms < self.threshold: + # Record start of silent frames. + if silence_start is None: + silence_start = i + continue + # Keep looping while frame is not silent and silence start has not been recorded. + if silence_start is None: + continue + # Clear recorded silence start if interval is not enough or clip is too short + is_leading_silence = silence_start == 0 and i > self.max_sil_kept + need_slice_middle = ( + i - silence_start >= self.min_interval + and i - clip_start >= self.min_length + ) + if not is_leading_silence and not need_slice_middle: + silence_start = None + continue + # Need slicing. Record the range of silent frames to be removed. + if i - silence_start <= self.max_sil_kept: + pos = rms_list[silence_start : i + 1].argmin() + silence_start + if silence_start == 0: + sil_tags.append((0, pos)) + else: + sil_tags.append((pos, pos)) + clip_start = pos + elif i - silence_start <= self.max_sil_kept * 2: + pos = rms_list[ + i - self.max_sil_kept : silence_start + self.max_sil_kept + 1 + ].argmin() + pos += i - self.max_sil_kept + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + clip_start = pos_r + else: + sil_tags.append((min(pos_l, pos), max(pos_r, pos))) + clip_start = max(pos_r, pos) + else: + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + else: + sil_tags.append((pos_l, pos_r)) + clip_start = pos_r + silence_start = None + # Deal with trailing silence. + total_frames = rms_list.shape[0] + if ( + silence_start is not None + and total_frames - silence_start >= self.min_interval + ): + silence_end = min(total_frames, silence_start + self.max_sil_kept) + pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start + sil_tags.append((pos, total_frames + 1)) + # Apply and return slices. + if len(sil_tags) == 0: + return [waveform] + else: + chunks = [] + chunks_pos_of_waveform = [] + + if sil_tags[0][0] > 0: + chunk, begin, end = self._apply_slice(waveform, 0, sil_tags[0][0]) + chunks.append(chunk) + chunks_pos_of_waveform.append((begin, end)) + + for i in range(len(sil_tags) - 1): + chunk, begin, end = self._apply_slice( + waveform, sil_tags[i][1], sil_tags[i + 1][0] + ) + chunks.append(chunk) + chunks_pos_of_waveform.append((begin, end)) + + if sil_tags[-1][1] < total_frames: + chunk, begin, end = self._apply_slice( + waveform, sil_tags[-1][1], total_frames + ) + chunks.append(chunk) + chunks_pos_of_waveform.append((begin, end)) + + return ( + chunks + if not return_chunks_positions + else ( + chunks, + chunks_pos_of_waveform, + ) + ) + + +def split_utterances_from_audio( + wav_file, + output_dir, + max_duration_of_utterance=10.0, + min_interval=300, + db_threshold=-40, +): + """ + Split a long audio into utterances accoring to the silence (VAD). + + max_duration_of_utterance (second): + The maximum duration of every utterance (seconds) + min_interval (millisecond): + The smaller min_interval is, the more sliced audio clips this script is likely to generate. + """ + print("File:", wav_file.split("/")[-1]) + waveform, fs = torchaudio.load(wav_file) + + slicer = Slicer(sr=fs, min_interval=min_interval, threshold=db_threshold) + chunks, positions = slicer.slice(waveform, return_chunks_positions=True) + + durations = [(end - begin) / fs for begin, end in positions] + print( + "Slicer's min silence part is {}ms, min and max duration of sliced utterances is {}s and {}s".format( + min_interval, min(durations), max(durations) + ) + ) + + res_chunks, res_positions = [], [] + for i, chunk in enumerate(chunks): + if len(chunk.shape) == 1: + chunk = chunk[None, :] + + begin, end = positions[i] + assert end - begin == chunk.shape[-1] + + max_wav_len = max_duration_of_utterance * fs + if chunk.shape[-1] <= max_wav_len: + res_chunks.append(chunk) + res_positions.append(positions[i]) + else: + # TODO: to reserve overlapping and conduct fade-in, fade-out + + # Get segments number + number = 2 + while chunk.shape[-1] // number >= max_wav_len: + number += 1 + seg_len = chunk.shape[-1] // number + + # Split + for num in range(number): + s = seg_len * num + t = min(s + seg_len, chunk.shape[-1]) + + seg_begin = begin + s + seg_end = begin + t + + res_chunks.append(chunk[:, s:t]) + res_positions.append((seg_begin, seg_end)) + + # Save utterances + os.makedirs(output_dir, exist_ok=True) + res = {"fs": int(fs)} + for i, chunk in enumerate(res_chunks): + filename = "{:04d}.wav".format(i) + res[filename] = [int(p) for p in res_positions[i]] + save_audio(os.path.join(output_dir, filename), chunk, fs) + + # Save positions + with open(os.path.join(output_dir, "positions.json"), "w") as f: + json.dump(res, f, indent=4, ensure_ascii=False) + return res + + +def is_silence( + wavform, + fs, + threshold=-40.0, + min_interval=300, + hop_size=10, + min_length=5000, +): + """ + Detect whether the given wavform is a silence + + wavform: (T, ) + """ + threshold = 10 ** (threshold / 20.0) + + hop_size = round(fs * hop_size / 1000) + win_size = min(round(min_interval), 4 * hop_size) + min_length = round(fs * min_length / 1000 / hop_size) + + if wavform.shape[0] <= min_length: + return True + + # (#Frame,) + rms_array = get_rms(y=wavform, frame_length=win_size, hop_length=hop_size).squeeze( + 0 + ) + return (rms_array < threshold).all() + + +def split_audio( + wav_file, target_sr, output_dir, max_duration_of_segment=10.0, overlap_duration=1.0 +): + """ + Split a long audio into segments. + + target_sr: + The target sampling rate to save the segments. + max_duration_of_utterance (second): + The maximum duration of every utterance (second) + overlap_duraion: + Each segment has "overlap duration" (second) overlap with its previous and next segment + """ + # (#channel, T) -> (T,) + waveform, fs = torchaudio.load(wav_file) + waveform = torchaudio.functional.resample( + waveform, orig_freq=fs, new_freq=target_sr + ) + waveform = torch.mean(waveform, dim=0) + + # waveform, _ = load_audio_torch(wav_file, target_sr) + assert len(waveform.shape) == 1 + + assert overlap_duration < max_duration_of_segment + length = int(max_duration_of_segment * target_sr) + stride = int((max_duration_of_segment - overlap_duration) * target_sr) + chunks = [] + for i in range(0, len(waveform), stride): + # (length,) + chunks.append(waveform[i : i + length]) + if i + length >= len(waveform): + break + + # Save segments + os.makedirs(output_dir, exist_ok=True) + results = [] + for i, chunk in enumerate(chunks): + uid = "{:04d}".format(i) + filename = os.path.join(output_dir, "{}.wav".format(uid)) + results.append( + {"Uid": uid, "Path": filename, "Duration": len(chunk) / target_sr} + ) + save_audio( + filename, + chunk, + target_sr, + turn_up=not is_silence(chunk, target_sr), + add_silence=False, + ) + + return results + + +def merge_segments_torchaudio(wav_files, fs, output_path, overlap_duration=1.0): + """Merge the given wav_files (may have overlaps) into a long audio + + fs: + The sampling rate of the wav files. + output_path: + The output path to save the merged audio. + overlap_duration (float, optional): + Each segment has "overlap duration" (second) overlap with its previous and next segment. Defaults to 1.0. + """ + + waveforms = [] + for file in wav_files: + # (T,) + waveform, _ = load_audio_torch(file, fs) + waveforms.append(waveform) + + if len(waveforms) == 1: + save_audio(output_path, waveforms[0], fs, add_silence=False, turn_up=False) + return + + overlap_len = int(overlap_duration * fs) + fade_out = torchaudio.transforms.Fade(fade_out_len=overlap_len) + fade_in = torchaudio.transforms.Fade(fade_in_len=overlap_len) + fade_in_and_out = torchaudio.transforms.Fade(fade_out_len=overlap_len) + + segments_lens = [len(wav) for wav in waveforms] + merged_waveform_len = sum(segments_lens) - overlap_len * (len(waveforms) - 1) + merged_waveform = torch.zeros(merged_waveform_len) + + start = 0 + for index, wav in enumerate( + tqdm(waveforms, desc="Merge for {}".format(output_path)) + ): + wav_len = len(wav) + + if index == 0: + wav = fade_out(wav) + elif index == len(waveforms) - 1: + wav = fade_in(wav) + else: + wav = fade_in_and_out(wav) + + merged_waveform[start : start + wav_len] = wav + start += wav_len - overlap_len + + save_audio(output_path, merged_waveform, fs, add_silence=False, turn_up=True) + + +def merge_segments_encodec(wav_files, fs, output_path, overlap_duration=1.0): + """Merge the given wav_files (may have overlaps) into a long audio + + fs: + The sampling rate of the wav files. + output_path: + The output path to save the merged audio. + overlap_duration (float, optional): + Each segment has "overlap duration" (second) overlap with its previous and next segment. Defaults to 1.0. + """ + + waveforms = [] + for file in wav_files: + # (T,) + waveform, _ = load_audio_torch(file, fs) + waveforms.append(waveform) + + if len(waveforms) == 1: + save_audio(output_path, waveforms[0], fs, add_silence=False, turn_up=False) + return + + device = waveforms[0].device + dtype = waveforms[0].dtype + shape = waveforms[0].shape[:-1] + + overlap_len = int(overlap_duration * fs) + segments_lens = [len(wav) for wav in waveforms] + merged_waveform_len = sum(segments_lens) - overlap_len * (len(waveforms) - 1) + + sum_weight = torch.zeros(merged_waveform_len, device=device, dtype=dtype) + out = torch.zeros(*shape, merged_waveform_len, device=device, dtype=dtype) + offset = 0 + + for frame in waveforms: + frame_length = frame.size(-1) + t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=torch.float32)[ + 1:-1 + ] + weight = 0.5 - (t - 0.5).abs() + weighted_frame = frame * weight + + cur = out[..., offset : offset + frame_length] + cur += weighted_frame[..., : cur.size(-1)] + out[..., offset : offset + frame_length] = cur + + cur = sum_weight[offset : offset + frame_length] + cur += weight[..., : cur.size(-1)] + sum_weight[offset : offset + frame_length] = cur + + offset += frame_length - overlap_len + + assert sum_weight.min() > 0 + merged_waveform = out / sum_weight + save_audio(output_path, merged_waveform, fs, add_silence=False, turn_up=True) diff --git a/models/tts/debatts/utils/cut_by_vad.py b/models/tts/debatts/utils/cut_by_vad.py new file mode 100644 index 00000000..0d41a4a1 --- /dev/null +++ b/models/tts/debatts/utils/cut_by_vad.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" This code is modified from https://github.com/facebookresearch/libri-light/blob/main/data_preparation/cut_by_vad.py""" +import pathlib +import soundfile as sf +import numpy as np +import json +import multiprocessing +import tqdm + + +def save(seq, fname, index, extension): + """save audio sequences to file""" + output = np.hstack(seq) + file_name = fname.parent / (fname.stem + f"_{index:04}{extension}") + fname.parent.mkdir(exist_ok=True, parents=True) + sf.write(file_name, output, samplerate=16000) + + +def cut_sequence(path, vad, path_out, target_len_sec, out_extension): + """cut audio sequences based on VAD""" + data, samplerate = sf.read(path) + + assert len(data.shape) == 1 + assert samplerate == 16000 + + to_stitch = [] + length_accumulated = 0.0 + + i = 0 + # Iterate over VAD segments + for start, end in vad: + start_index = int(start * samplerate) + end_index = int(end * samplerate) + slice = data[start_index:end_index] + + # Save slices that exceed the target length or if there's already accumulated audio + if ( + length_accumulated + (end - start) > target_len_sec + and length_accumulated > 0 + ): + save(to_stitch, path_out, i, out_extension) + to_stitch = [] + i += 1 + length_accumulated = 0 + + # Add the current slice to the list to be stitched + to_stitch.append(slice) + length_accumulated += end - start + + # Save any remaining slices + if to_stitch: + save(to_stitch, path_out, i, out_extension) + + +def cut_book(task): + """process each book in the dataset""" + path_book, root_out, target_len_sec, extension = task + + speaker = pathlib.Path(path_book.parent.name) + + for i, meta_file_path in enumerate(path_book.glob("*.json")): + with open(meta_file_path, "r") as f: + meta = json.loads(f.read()) + book_id = meta["book_meta"]["id"] + vad = meta["voice_activity"] + + sound_file = meta_file_path.parent / (meta_file_path.stem + ".flac") + + path_out = root_out / speaker / book_id / (meta_file_path.stem) + cut_sequence(sound_file, vad, path_out, target_len_sec, extension) + + +def cut_segments( + input_dir, output_dir, target_len_sec=30, n_process=32, out_extension=".wav" +): + """Main function to cut segments from audio files""" + + pathlib.Path(output_dir).mkdir(exist_ok=True, parents=True) + list_dir = pathlib.Path(input_dir).glob("*/*") + list_dir = [x for x in list_dir if x.is_dir()] + + print(f"{len(list_dir)} directories detected") + print(f"Launching {n_process} processes") + + # Create tasks for multiprocessing + tasks = [ + (path_book, output_dir, target_len_sec, out_extension) for path_book in list_dir + ] + + # Process tasks in parallel using multiprocessing + with multiprocessing.Pool(processes=n_process) as pool: + for _ in tqdm.tqdm(pool.imap_unordered(cut_book, tasks), total=len(tasks)): + pass + + +if __name__ == "__main__": + input_dir = "/path/to/input_dir" + output_dir = "/path/to/output_dir" + target_len_sec = 10 + n_process = 16 + cut_segments(input_dir, output_dir, target_len_sec, n_process) diff --git a/models/tts/debatts/utils/data_utils.py b/models/tts/debatts/utils/data_utils.py new file mode 100644 index 00000000..8c0bc2ff --- /dev/null +++ b/models/tts/debatts/utils/data_utils.py @@ -0,0 +1,588 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os + +import numpy as np +from scipy.interpolate import interp1d +from tqdm import tqdm +from sklearn.preprocessing import StandardScaler + + +def intersperse(lst, item): + """ + Insert an item in between any two consecutive elements of the given list, including beginning and end of list + + Example: + >>> intersperse(0, [1, 74, 5, 31]) + [0, 1, 0, 74, 0, 5, 0, 31, 0] + """ + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def load_content_feature_path(meta_data, processed_dir, feat_dir): + utt2feat_path = {} + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + feat_path = os.path.join( + processed_dir, utt_info["Dataset"], feat_dir, f'{utt_info["Uid"]}.npy' + ) + utt2feat_path[utt] = feat_path + + return utt2feat_path + + +def load_source_content_feature_path(meta_data, feat_dir): + utt2feat_path = {} + for utt in meta_data: + feat_path = os.path.join(feat_dir, f"{utt}.npy") + utt2feat_path[utt] = feat_path + + return utt2feat_path + + +def get_spk_map(spk2id_path, utt2spk_path): + utt2spk = {} + with open(spk2id_path, "r") as spk2id_file: + spk2id = json.load(spk2id_file) + with open(utt2spk_path, encoding="utf-8") as f: + for line in f.readlines(): + utt, spk = line.strip().split("\t") + utt2spk[utt] = spk + return spk2id, utt2spk + + +def get_target_f0_median(f0_dir): + total_f0 = [] + for utt in os.listdir(f0_dir): + if not utt.endswith(".npy"): + continue + f0_feat_path = os.path.join(f0_dir, utt) + f0 = np.load(f0_feat_path) + total_f0 += f0.tolist() + + total_f0 = np.array(total_f0) + voiced_position = np.where(total_f0 != 0) + return np.median(total_f0[voiced_position]) + + +def get_conversion_f0_factor(source_f0, target_median, source_median=None): + """Align the median between source f0 and target f0 + + Note: Here we use multiplication, whose factor is target_median/source_median + + Reference: Frequency and pitch interval + http://blog.ccyg.studio/article/be12c2ee-d47c-4098-9782-ca76da3035e4/ + """ + if source_median is None: + voiced_position = np.where(source_f0 != 0) + source_median = np.median(source_f0[voiced_position]) + factor = target_median / source_median + return source_median, factor + + +def transpose_key(frame_pitch, trans_key): + # Transpose by user's argument + print("Transpose key = {} ...\n".format(trans_key)) + + transed_pitch = frame_pitch * 2 ** (trans_key / 12) + return transed_pitch + + +def pitch_shift_to_target(frame_pitch, target_pitch_median, source_pitch_median=None): + # Loading F0 Base (median) and shift + source_pitch_median, factor = get_conversion_f0_factor( + frame_pitch, target_pitch_median, source_pitch_median + ) + print( + "Auto transposing: source f0 median = {:.1f}, target f0 median = {:.1f}, factor = {:.2f}".format( + source_pitch_median, target_pitch_median, factor + ) + ) + transed_pitch = frame_pitch * factor + return transed_pitch + + +def load_frame_pitch( + meta_data, + processed_dir, + pitch_dir, + use_log_scale=False, + return_norm=False, + interoperate=False, + utt2spk=None, +): + utt2pitch = {} + utt2uv = {} + if utt2spk is None: + pitch_scaler = StandardScaler() + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + pitch_path = os.path.join( + processed_dir, utt_info["Dataset"], pitch_dir, f'{utt_info["Uid"]}.npy' + ) + pitch = np.load(pitch_path) + assert len(pitch) > 0 + uv = pitch != 0 + utt2uv[utt] = uv + if use_log_scale: + nonzero_idxes = np.where(pitch != 0)[0] + pitch[nonzero_idxes] = np.log(pitch[nonzero_idxes]) + utt2pitch[utt] = pitch + pitch_scaler.partial_fit(pitch.reshape(-1, 1)) + + mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0] + if return_norm: + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + pitch = utt2pitch[utt] + normalized_pitch = (pitch - mean) / std + utt2pitch[utt] = normalized_pitch + pitch_statistic = {"mean": mean, "std": std} + else: + spk2utt = {} + pitch_statistic = [] + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + if not utt2spk[utt] in spk2utt: + spk2utt[utt2spk[utt]] = [] + spk2utt[utt2spk[utt]].append(utt) + + for spk in spk2utt: + pitch_scaler = StandardScaler() + for utt in spk2utt[spk]: + dataset = utt.split("_")[0] + uid = "_".join(utt.split("_")[1:]) + pitch_path = os.path.join( + processed_dir, dataset, pitch_dir, f"{uid}.npy" + ) + pitch = np.load(pitch_path) + assert len(pitch) > 0 + uv = pitch != 0 + utt2uv[utt] = uv + if use_log_scale: + nonzero_idxes = np.where(pitch != 0)[0] + pitch[nonzero_idxes] = np.log(pitch[nonzero_idxes]) + utt2pitch[utt] = pitch + pitch_scaler.partial_fit(pitch.reshape(-1, 1)) + + mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0] + if return_norm: + for utt in spk2utt[spk]: + pitch = utt2pitch[utt] + normalized_pitch = (pitch - mean) / std + utt2pitch[utt] = normalized_pitch + pitch_statistic.append({"spk": spk, "mean": mean, "std": std}) + + return utt2pitch, utt2uv, pitch_statistic + + +# discard +def load_phone_pitch( + meta_data, + processed_dir, + pitch_dir, + utt2dur, + use_log_scale=False, + return_norm=False, + interoperate=True, + utt2spk=None, +): + print("Load Phone Pitch") + utt2pitch = {} + utt2uv = {} + if utt2spk is None: + pitch_scaler = StandardScaler() + for utt_info in tqdm(meta_data): + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + pitch_path = os.path.join( + processed_dir, utt_info["Dataset"], pitch_dir, f'{utt_info["Uid"]}.npy' + ) + frame_pitch = np.load(pitch_path) + assert len(frame_pitch) > 0 + uv = frame_pitch != 0 + utt2uv[utt] = uv + phone_pitch = phone_average_pitch(frame_pitch, utt2dur[utt], interoperate) + if use_log_scale: + nonzero_idxes = np.where(phone_pitch != 0)[0] + phone_pitch[nonzero_idxes] = np.log(phone_pitch[nonzero_idxes]) + utt2pitch[utt] = phone_pitch + pitch_scaler.partial_fit(remove_outlier(phone_pitch).reshape(-1, 1)) + + mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0] + max_value = np.finfo(np.float64).min + min_value = np.finfo(np.float64).max + if return_norm: + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + pitch = utt2pitch[utt] + normalized_pitch = (pitch - mean) / std + max_value = max(max_value, max(normalized_pitch)) + min_value = min(min_value, min(normalized_pitch)) + utt2pitch[utt] = normalized_pitch + phone_normalized_pitch_path = os.path.join( + processed_dir, + utt_info["Dataset"], + "phone_level_" + pitch_dir, + f'{utt_info["Uid"]}.npy', + ) + pitch_statistic = { + "mean": mean, + "std": std, + "min_value": min_value, + "max_value": max_value, + } + else: + spk2utt = {} + pitch_statistic = [] + for utt_info in tqdm(meta_data): + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + if not utt2spk[utt] in spk2utt: + spk2utt[utt2spk[utt]] = [] + spk2utt[utt2spk[utt]].append(utt) + + for spk in spk2utt: + pitch_scaler = StandardScaler() + for utt in spk2utt[spk]: + dataset = utt.split("_")[0] + uid = "_".join(utt.split("_")[1:]) + pitch_path = os.path.join( + processed_dir, dataset, pitch_dir, f"{uid}.npy" + ) + frame_pitch = np.load(pitch_path) + assert len(frame_pitch) > 0 + uv = frame_pitch != 0 + utt2uv[utt] = uv + phone_pitch = phone_average_pitch( + frame_pitch, utt2dur[utt], interoperate + ) + if use_log_scale: + nonzero_idxes = np.where(phone_pitch != 0)[0] + phone_pitch[nonzero_idxes] = np.log(phone_pitch[nonzero_idxes]) + utt2pitch[utt] = phone_pitch + pitch_scaler.partial_fit(remove_outlier(phone_pitch).reshape(-1, 1)) + + mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0] + max_value = np.finfo(np.float64).min + min_value = np.finfo(np.float64).max + + if return_norm: + for utt in spk2utt[spk]: + pitch = utt2pitch[utt] + normalized_pitch = (pitch - mean) / std + max_value = max(max_value, max(normalized_pitch)) + min_value = min(min_value, min(normalized_pitch)) + utt2pitch[utt] = normalized_pitch + pitch_statistic.append( + { + "spk": spk, + "mean": mean, + "std": std, + "min_value": min_value, + "max_value": max_value, + } + ) + + return utt2pitch, utt2uv, pitch_statistic + + +def phone_average_pitch(pitch, dur, interoperate=False): + pos = 0 + + if interoperate: + nonzero_ids = np.where(pitch != 0)[0] + interp_fn = interp1d( + nonzero_ids, + pitch[nonzero_ids], + fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), + bounds_error=False, + ) + pitch = interp_fn(np.arange(0, len(pitch))) + phone_pitch = np.zeros(len(dur)) + + for i, d in enumerate(dur): + d = int(d) + if d > 0 and pos < len(pitch): + phone_pitch[i] = np.mean(pitch[pos : pos + d]) + else: + phone_pitch[i] = 0 + pos += d + return phone_pitch + + +def load_energy( + meta_data, + processed_dir, + energy_dir, + use_log_scale=False, + return_norm=False, + utt2spk=None, +): + utt2energy = {} + if utt2spk is None: + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + energy_path = os.path.join( + processed_dir, utt_info["Dataset"], energy_dir, f'{utt_info["Uid"]}.npy' + ) + if not os.path.exists(energy_path): + continue + energy = np.load(energy_path) + assert len(energy) > 0 + + if use_log_scale: + nonzero_idxes = np.where(energy != 0)[0] + energy[nonzero_idxes] = np.log(energy[nonzero_idxes]) + utt2energy[utt] = energy + + if return_norm: + with open( + os.path.join( + processed_dir, utt_info["Dataset"], energy_dir, "statistics.json" + ) + ) as f: + stats = json.load(f) + mean, std = ( + stats[utt_info["Dataset"] + "_" + utt_info["Singer"]][ + "voiced_positions" + ]["mean"], + stats["LJSpeech_LJSpeech"]["voiced_positions"]["std"], + ) + for utt in utt2energy.keys(): + energy = utt2energy[utt] + normalized_energy = (energy - mean) / std + utt2energy[utt] = normalized_energy + + energy_statistic = {"mean": mean, "std": std} + else: + spk2utt = {} + energy_statistic = [] + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + if not utt2spk[utt] in spk2utt: + spk2utt[utt2spk[utt]] = [] + spk2utt[utt2spk[utt]].append(utt) + + for spk in spk2utt: + energy_scaler = StandardScaler() + for utt in spk2utt[spk]: + dataset = utt.split("_")[0] + uid = "_".join(utt.split("_")[1:]) + energy_path = os.path.join( + processed_dir, dataset, energy_dir, f"{uid}.npy" + ) + if not os.path.exists(energy_path): + continue + frame_energy = np.load(energy_path) + assert len(frame_energy) > 0 + + if use_log_scale: + nonzero_idxes = np.where(frame_energy != 0)[0] + frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes]) + utt2energy[utt] = frame_energy + energy_scaler.partial_fit(frame_energy.reshape(-1, 1)) + + mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0] + if return_norm: + for utt in spk2utt[spk]: + energy = utt2energy[utt] + normalized_energy = (energy - mean) / std + utt2energy[utt] = normalized_energy + energy_statistic.append({"spk": spk, "mean": mean, "std": std}) + + return utt2energy, energy_statistic + + +def load_frame_energy( + meta_data, + processed_dir, + energy_dir, + use_log_scale=False, + return_norm=False, + interoperate=False, + utt2spk=None, +): + utt2energy = {} + if utt2spk is None: + energy_scaler = StandardScaler() + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + energy_path = os.path.join( + processed_dir, utt_info["Dataset"], energy_dir, f'{utt_info["Uid"]}.npy' + ) + frame_energy = np.load(energy_path) + assert len(frame_energy) > 0 + + if use_log_scale: + nonzero_idxes = np.where(frame_energy != 0)[0] + frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes]) + utt2energy[utt] = frame_energy + energy_scaler.partial_fit(frame_energy.reshape(-1, 1)) + + mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0] + if return_norm: + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + energy = utt2energy[utt] + normalized_energy = (energy - mean) / std + utt2energy[utt] = normalized_energy + energy_statistic = {"mean": mean, "std": std} + + else: + spk2utt = {} + energy_statistic = [] + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + if not utt2spk[utt] in spk2utt: + spk2utt[utt2spk[utt]] = [] + spk2utt[utt2spk[utt]].append(utt) + + for spk in spk2utt: + energy_scaler = StandardScaler() + for utt in spk2utt[spk]: + dataset = utt.split("_")[0] + uid = "_".join(utt.split("_")[1:]) + energy_path = os.path.join( + processed_dir, dataset, energy_dir, f"{uid}.npy" + ) + frame_energy = np.load(energy_path) + assert len(frame_energy) > 0 + + if use_log_scale: + nonzero_idxes = np.where(frame_energy != 0)[0] + frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes]) + utt2energy[utt] = frame_energy + energy_scaler.partial_fit(frame_energy.reshape(-1, 1)) + + mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0] + if return_norm: + for utt in spk2utt[spk]: + energy = utt2energy[utt] + normalized_energy = (energy - mean) / std + utt2energy[utt] = normalized_energy + energy_statistic.append({"spk": spk, "mean": mean, "std": std}) + + return utt2energy, energy_statistic + + +def align_length(feature, target_len, pad_value=0.0): + feature_len = feature.shape[-1] + dim = len(feature.shape) + # align 1-D data + if dim == 2: + if target_len > feature_len: + feature = np.pad( + feature, + ((0, 0), (0, target_len - feature_len)), + constant_values=pad_value, + ) + else: + feature = feature[:, :target_len] + # align 2-D data + elif dim == 1: + if target_len > feature_len: + feature = np.pad( + feature, (0, target_len - feature_len), constant_values=pad_value + ) + else: + feature = feature[:target_len] + else: + raise NotImplementedError + return feature + + +def align_whisper_feauture_length( + feature, target_len, fast_mapping=True, source_hop=320, target_hop=256 +): + factor = np.gcd(source_hop, target_hop) + source_hop //= factor + target_hop //= factor + # print( + # "Mapping source's {} frames => target's {} frames".format( + # target_hop, source_hop + # ) + # ) + + max_source_len = 1500 + target_len = min(target_len, max_source_len * source_hop // target_hop) + + width = feature.shape[-1] + + if fast_mapping: + source_len = target_len * target_hop // source_hop + 1 + feature = feature[:source_len] + + else: + source_len = max_source_len + + # const ~= target_len * target_hop + const = source_len * source_hop // target_hop * target_hop + + # (source_len * source_hop, dim) + up_sampling_feats = np.repeat(feature, source_hop, axis=0) + # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) + down_sampling_feats = np.average( + up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 + ) + assert len(down_sampling_feats) >= target_len + + # (target_len, dim) + feat = down_sampling_feats[:target_len] + + return feat + + +def align_content_feature_length(feature, target_len, source_hop=320, target_hop=256): + factor = np.gcd(source_hop, target_hop) + source_hop //= factor + target_hop //= factor + # print( + # "Mapping source's {} frames => target's {} frames".format( + # target_hop, source_hop + # ) + # ) + + # (source_len, 256) + source_len, width = feature.shape + + # const ~= target_len * target_hop + const = source_len * source_hop // target_hop * target_hop + + # (source_len * source_hop, dim) + up_sampling_feats = np.repeat(feature, source_hop, axis=0) + # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) + down_sampling_feats = np.average( + up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 + ) + + err = abs(target_len - len(down_sampling_feats)) + if err > 4: ## why 4 not 3? + print("target_len:", target_len) + print("raw feature:", feature.shape) + print("up_sampling:", up_sampling_feats.shape) + print("down_sampling_feats:", down_sampling_feats.shape) + exit() + if len(down_sampling_feats) < target_len: + # (1, dim) -> (err, dim) + end = down_sampling_feats[-1][None, :].repeat(err, axis=0) + down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) + + # (target_len, dim) + feat = down_sampling_feats[:target_len] + + return feat + + +def remove_outlier(values): + values = np.array(values) + p25 = np.percentile(values, 25) + p75 = np.percentile(values, 75) + lower = p25 - 1.5 * (p75 - p25) + upper = p75 + 1.5 * (p75 - p25) + normal_indices = np.logical_and(values > lower, values < upper) + return values[normal_indices] diff --git a/models/tts/debatts/utils/distribution.py b/models/tts/debatts/utils/distribution.py new file mode 100644 index 00000000..de3000e9 --- /dev/null +++ b/models/tts/debatts/utils/distribution.py @@ -0,0 +1,270 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn.functional as F + +from torch.distributions import Normal + + +def log_sum_exp(x): + """numerically stable log_sum_exp implementation that prevents overflow""" + # TF ordering + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) + + +def discretized_mix_logistic_loss( + y_hat, y, num_classes=256, log_scale_min=-7.0, reduce=True +): + """Discretized mixture of logistic distributions loss + + Note that it is assumed that input is scaled to [-1, 1]. + + Args: + y_hat (Tensor): Predicted output (B x C x T) + y (Tensor): Target (B x T x 1). + num_classes (int): Number of classes + log_scale_min (float): Log scale minimum value + reduce (bool): If True, the losses are averaged or summed for each + minibatch. + + Returns + Tensor: loss + """ + assert y_hat.dim() == 3 + assert y_hat.size(1) % 3 == 0 + nr_mix = y_hat.size(1) // 3 + + # (B x T x C) + y_hat = y_hat.transpose(1, 2) + + # unpack parameters. (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix : 2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min) + + # B x T x 1 -> B x T x num_mixtures + y = y.expand_as(means) + + centered_y = y - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1)) + cdf_plus = torch.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1)) + cdf_min = torch.sigmoid(min_in) + + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(torch.sigmoid(plus_in)) + log_cdf_plus = plus_in - F.softplus(plus_in) + + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - torch.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) + + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + # (not actually used in our code) + log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) + + # tf equivalent + """ + log_probs = tf.where(x < -0.999, log_cdf_plus, + tf.where(x > 0.999, log_one_minus_cdf_min, + tf.where(cdf_delta > 1e-5, + tf.log(tf.maximum(cdf_delta, 1e-12)), + log_pdf_mid - np.log(127.5)))) + """ + # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value + # for num_classes=65536 case? 1e-7? not sure.. + inner_inner_cond = (cdf_delta > 1e-5).float() + + inner_inner_out = inner_inner_cond * torch.log( + torch.clamp(cdf_delta, min=1e-12) + ) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) + inner_cond = (y > 0.999).float() + inner_out = ( + inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out + ) + cond = (y < -0.999).float() + log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out + + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if reduce: + return -torch.sum(log_sum_exp(log_probs)) + else: + return -log_sum_exp(log_probs).unsqueeze(-1) + + +def to_one_hot(tensor, n, fill_with=1.0): + # we perform one hot encore with respect to the last axis + one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() + if tensor.is_cuda: + one_hot = one_hot.cuda() + one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) + return one_hot + + +def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0, clamp_log_scale=False): + """ + Sample from discretized mixture of logistic distributions + + Args: + y (Tensor): B x C x T + log_scale_min (float): Log scale minimum value + + Returns: + Tensor: sample in range of [-1, 1]. + """ + assert y.size(1) % 3 == 0 + nr_mix = y.size(1) // 3 + + # B x T x C + y = y.transpose(1, 2) + logit_probs = y[:, :, :nr_mix] + + # sample mixture indicator from softmax + temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) + temp = logit_probs.data - torch.log(-torch.log(temp)) + _, argmax = temp.max(dim=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = to_one_hot(argmax, nr_mix) + # select logistic parameters + means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1) + if clamp_log_scale: + log_scales = torch.clamp(log_scales, min=log_scale_min) + # sample from logistic & clip to interval + # we don't actually round to the nearest 8bit value when sampling + u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) + x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u)) + + x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0) + + return x + + +# we can easily define discretized version of the gaussian loss, however, +# use continuous version as same as the https://clarinet-demo.github.io/ +def mix_gaussian_loss(y_hat, y, log_scale_min=-7.0, reduce=True): + """Mixture of continuous gaussian distributions loss + + Note that it is assumed that input is scaled to [-1, 1]. + + Args: + y_hat (Tensor): Predicted output (B x C x T) + y (Tensor): Target (B x T x 1). + log_scale_min (float): Log scale minimum value + reduce (bool): If True, the losses are averaged or summed for each + minibatch. + Returns + Tensor: loss + """ + assert y_hat.dim() == 3 + C = y_hat.size(1) + if C == 2: + nr_mix = 1 + else: + assert y_hat.size(1) % 3 == 0 + nr_mix = y_hat.size(1) // 3 + + # (B x T x C) + y_hat = y_hat.transpose(1, 2) + + # unpack parameters. + if C == 2: + # special case for C == 2, just for compatibility + logit_probs = None + means = y_hat[:, :, 0:1] + log_scales = torch.clamp(y_hat[:, :, 1:2], min=log_scale_min) + else: + # (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix : 2 * nr_mix] + log_scales = torch.clamp( + y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min + ) + + # B x T x 1 -> B x T x num_mixtures + y = y.expand_as(means) + + centered_y = y - means + dist = Normal(loc=0.0, scale=torch.exp(log_scales)) + # do we need to add a trick to avoid log(0)? + log_probs = dist.log_prob(centered_y) + + if nr_mix > 1: + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if reduce: + if nr_mix == 1: + return -torch.sum(log_probs) + else: + return -torch.sum(log_sum_exp(log_probs)) + else: + if nr_mix == 1: + return -log_probs + else: + return -log_sum_exp(log_probs).unsqueeze(-1) + + +def sample_from_mix_gaussian(y, log_scale_min=-7.0): + """ + Sample from (discretized) mixture of gaussian distributions + Args: + y (Tensor): B x C x T + log_scale_min (float): Log scale minimum value + Returns: + Tensor: sample in range of [-1, 1]. + """ + C = y.size(1) + if C == 2: + nr_mix = 1 + else: + assert y.size(1) % 3 == 0 + nr_mix = y.size(1) // 3 + + # B x T x C + y = y.transpose(1, 2) + + if C == 2: + logit_probs = None + else: + logit_probs = y[:, :, :nr_mix] + + if nr_mix > 1: + # sample mixture indicator from softmax + temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) + temp = logit_probs.data - torch.log(-torch.log(temp)) + _, argmax = temp.max(dim=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = to_one_hot(argmax, nr_mix) + + # Select means and log scales + means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1) + else: + if C == 2: + means, log_scales = y[:, :, 0], y[:, :, 1] + elif C == 3: + means, log_scales = y[:, :, 1], y[:, :, 2] + else: + assert False, "shouldn't happen" + + scales = torch.exp(log_scales) + dist = Normal(loc=means, scale=scales) + x = dist.sample() + + x = torch.clamp(x, min=-1.0, max=1.0) + return x diff --git a/models/tts/debatts/utils/dsp.py b/models/tts/debatts/utils/dsp.py new file mode 100644 index 00000000..18f9466f --- /dev/null +++ b/models/tts/debatts/utils/dsp.py @@ -0,0 +1,97 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +# ZERO = 1e-12 + + +def gaussian_normalize_mel_channel(mel, mu, sigma): + """ + Shift to Standorm Normal Distribution + + Args: + mel: (n_mels, frame_len) + mu: (n_mels,), mean value + sigma: (n_mels,), sd value + Return: + Tensor like mel + """ + mu = np.expand_dims(mu, -1) + sigma = np.expand_dims(sigma, -1) + return (mel - mu) / sigma + + +def de_gaussian_normalize_mel_channel(mel, mu, sigma): + """ + + Args: + mel: (n_mels, frame_len) + mu: (n_mels,), mean value + sigma: (n_mels,), sd value + Return: + Tensor like mel + """ + mu = np.expand_dims(mu, -1) + sigma = np.expand_dims(sigma, -1) + return sigma * mel + mu + + +def decompress(audio_compressed, bits): + mu = 2**bits - 1 + audio = np.sign(audio_compressed) / mu * ((1 + mu) ** np.abs(audio_compressed) - 1) + return audio + + +def compress(audio, bits): + mu = 2**bits - 1 + audio_compressed = np.sign(audio) * np.log(1 + mu * np.abs(audio)) / np.log(mu + 1) + return audio_compressed + + +def label_to_audio(quant, bits): + classes = 2**bits + audio = 2 * quant / (classes - 1.0) - 1.0 + return audio + + +def audio_to_label(audio, bits): + """Normalized audio data tensor to digit array + + Args: + audio (tensor): audio data + bits (int): data bits + + Returns: + array: digit array of audio data + """ + classes = 2**bits + # initialize an increasing array with values from -1 to 1 + bins = np.linspace(-1, 1, classes) + # change value in audio tensor to digits + quant = np.digitize(audio, bins) - 1 + return quant + + +def label_to_onehot(x, bits): + """Converts a class vector (integers) to binary class matrix. + Args: + x: class vector to be converted into a matrix + (integers from 0 to num_classes). + num_classes: total number of classes. + Returns: + A binary matrix representation of the input. The classes axis + is placed last. + """ + classes = 2**bits + + result = torch.zeros((x.shape[0], classes), dtype=torch.float32) + for i in range(x.shape[0]): + result[i, x[i]] = 1 + + output_shape = x.shape + (classes,) + output = torch.reshape(result, output_shape) + return output diff --git a/models/tts/debatts/utils/duration.py b/models/tts/debatts/utils/duration.py new file mode 100644 index 00000000..c9544b40 --- /dev/null +++ b/models/tts/debatts/utils/duration.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import os +import tgt + + +def get_alignment(tier, cfg): + sample_rate = cfg["sample_rate"] + hop_size = cfg["hop_size"] + + sil_phones = ["sil", "sp", "spn"] + + phones = [] + durations = [] + start_time = 0 + end_time = 0 + end_idx = 0 + + for t in tier._objects: + s, e, p = t.start_time, t.end_time, t.text + + # Trim leading silences + if phones == []: + if p in sil_phones: + continue + else: + start_time = s + + if p not in sil_phones: + # For ordinary phones + phones.append(p) + end_time = e + end_idx = len(phones) + else: + # For silent phones + phones.append(p) + + durations.append( + int( + np.round(e * sample_rate / hop_size) + - np.round(s * sample_rate / hop_size) + ) + ) + + # Trim tailing silences + phones = phones[:end_idx] + durations = durations[:end_idx] + + return phones, durations, start_time, end_time + + +def get_duration(utt, wav, cfg): + speaker = utt["Singer"] + basename = utt["Uid"] + dataset = utt["Dataset"] + sample_rate = cfg["sample_rate"] + + # print(cfg.processed_dir, dataset, speaker, basename) + wav_path = os.path.join( + cfg.processed_dir, dataset, "raw_data", speaker, "{}.wav".format(basename) + ) + text_path = os.path.join( + cfg.processed_dir, dataset, "raw_data", speaker, "{}.lab".format(basename) + ) + tg_path = os.path.join( + cfg.processed_dir, dataset, "TextGrid", speaker, "{}.TextGrid".format(basename) + ) + + # Read raw text + with open(text_path, "r") as f: + raw_text = f.readline().strip("\n") + + # Get alignments + textgrid = tgt.io.read_textgrid(tg_path) + phone, duration, start, end = get_alignment( + textgrid.get_tier_by_name("phones"), cfg + ) + text = "{" + " ".join(phone) + "}" + if start >= end: + return None + + return duration, text, int(sample_rate * start), int(sample_rate * end) diff --git a/models/tts/debatts/utils/f0.py b/models/tts/debatts/utils/f0.py new file mode 100644 index 00000000..169b1403 --- /dev/null +++ b/models/tts/debatts/utils/f0.py @@ -0,0 +1,275 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import librosa +import numpy as np +import torch +import parselmouth +import torchcrepe +import pyworld as pw + + +def f0_to_coarse(f0, pitch_bin, f0_min, f0_max): + """ + Convert f0 (Hz) to pitch (mel scale), and then quantize the mel-scale pitch to the + range from [1, 2, 3, ..., pitch_bin-1] + + Reference: https://en.wikipedia.org/wiki/Mel_scale + + Args: + f0 (array or Tensor): Hz + pitch_bin (int): the vocabulary size + f0_min (int): the minimum f0 (Hz) + f0_max (int): the maximum f0 (Hz) + + Returns: + quantized f0 (array or Tensor) + """ + f0_mel_min = 1127 * np.log(1 + f0_min / 700) + f0_mel_max = 1127 * np.log(1 + f0_max / 700) + + is_torch = isinstance(f0, torch.Tensor) + f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (pitch_bin - 2) / ( + f0_mel_max - f0_mel_min + ) + 1 + + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > pitch_bin - 1] = pitch_bin - 1 + f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int32) + assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, ( + f0_coarse.max(), + f0_coarse.min(), + ) + return f0_coarse + + +def interpolate(f0): + """Interpolate the unvoiced part. Thus the f0 can be passed to a subtractive synthesizer. + Args: + f0: A numpy array of shape (seq_len,) + Returns: + f0: Interpolated f0 of shape (seq_len,) + uv: Unvoiced part of shape (seq_len,) + """ + uv = f0 == 0 + if len(f0[~uv]) > 0: + # interpolate the unvoiced f0 + f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) + uv = uv.astype("float") + uv = np.min(np.array([uv[:-2], uv[1:-1], uv[2:]]), axis=0) + uv = np.pad(uv, (1, 1)) + return f0, uv + + +def get_log_f0(f0): + f0[np.where(f0 == 0)] = 1 + log_f0 = np.log(f0) + return log_f0 + + +def get_f0_features_using_pyin(audio, cfg): + """Using pyin to extract the f0 feature. + Args: + audio + fs + win_length + hop_length + f0_min + f0_max + Returns: + f0: numpy array of shape (frame_len,) + """ + f0, voiced_flag, voiced_probs = librosa.pyin( + y=audio, + fmin=cfg.f0_min, + fmax=cfg.f0_max, + sr=cfg.sample_rate, + win_length=cfg.win_size, + hop_length=cfg.hop_size, + ) + # Set nan to 0 + f0[voiced_flag == False] = 0 + return f0 + + +def get_f0_features_using_parselmouth(audio, cfg, speed=1): + """Using parselmouth to extract the f0 feature. + Args: + audio + mel_len + hop_length + fs + f0_min + f0_max + speed(default=1) + Returns: + f0: numpy array of shape (frame_len,) + pitch_coarse: numpy array of shape (frame_len,) + """ + hop_size = int(np.round(cfg.hop_size * speed)) + + # Calculate the time step for pitch extraction + time_step = hop_size / cfg.sample_rate * 1000 + + f0 = ( + parselmouth.Sound(audio, cfg.sample_rate) + .to_pitch_ac( + time_step=time_step / 1000, + voicing_threshold=0.6, + pitch_floor=cfg.f0_min, + pitch_ceiling=cfg.f0_max, + ) + .selected_array["frequency"] + ) + return f0 + + +def get_f0_features_using_dio(audio, cfg): + """Using dio to extract the f0 feature. + Args: + audio + mel_len + fs + hop_length + f0_min + f0_max + Returns: + f0: numpy array of shape (frame_len,) + """ + # Get the raw f0 + _f0, t = pw.dio( + audio.astype("double"), + cfg.sample_rate, + f0_floor=cfg.f0_min, + f0_ceil=cfg.f0_max, + channels_in_octave=2, + frame_period=(1000 * cfg.hop_size / cfg.sample_rate), + ) + # Get the f0 + f0 = pw.stonemask(audio.astype("double"), _f0, t, cfg.sample_rate) + return f0 + + +def get_f0_features_using_harvest(audio, mel_len, fs, hop_length, f0_min, f0_max): + """Using harvest to extract the f0 feature. + Args: + audio + mel_len + fs + hop_length + f0_min + f0_max + Returns: + f0: numpy array of shape (frame_len,) + """ + f0, _ = pw.harvest( + audio.astype("double"), + fs, + f0_floor=f0_min, + f0_ceil=f0_max, + frame_period=(1000 * hop_length / fs), + ) + f0 = f0.astype("float")[:mel_len] + return f0 + + +def get_f0_features_using_crepe( + audio, mel_len, fs, hop_length, hop_length_new, f0_min, f0_max, threshold=0.3 +): + """Using torchcrepe to extract the f0 feature. + Args: + audio + mel_len + fs + hop_length + hop_length_new + f0_min + f0_max + threshold(default=0.3) + Returns: + f0: numpy array of shape (frame_len,) + """ + # Currently, crepe only supports 16khz audio + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + audio_16k = librosa.resample(audio, orig_sr=fs, target_sr=16000) + audio_16k_torch = torch.FloatTensor(audio_16k).unsqueeze(0).to(device) + + # Get the raw pitch + f0, pd = torchcrepe.predict( + audio_16k_torch, + 16000, + hop_length_new, + f0_min, + f0_max, + pad=True, + model="full", + batch_size=1024, + device=device, + return_periodicity=True, + ) + + # Filter, de-silence, set up threshold for unvoiced part + pd = torchcrepe.filter.median(pd, 3) + pd = torchcrepe.threshold.Silence(-60.0)(pd, audio_16k_torch, 16000, hop_length_new) + f0 = torchcrepe.threshold.At(threshold)(f0, pd) + f0 = torchcrepe.filter.mean(f0, 3) + + # Convert unvoiced part to 0hz + f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0) + + # Interpolate f0 + nzindex = torch.nonzero(f0[0]).squeeze() + f0 = torch.index_select(f0[0], dim=0, index=nzindex).cpu().numpy() + time_org = 0.005 * nzindex.cpu().numpy() + time_frame = np.arange(mel_len) * hop_length / fs + f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) + return f0 + + +def get_f0(audio, cfg, use_interpolate=False, return_uv=False): + if cfg.pitch_extractor == "dio": + f0 = get_f0_features_using_dio(audio, cfg) + elif cfg.pitch_extractor == "pyin": + f0 = get_f0_features_using_pyin(audio, cfg) + elif cfg.pitch_extractor == "parselmouth": + f0 = get_f0_features_using_parselmouth(audio, cfg) + + if use_interpolate: + f0, uv = interpolate(f0) + else: + uv = f0 == 0 + + if return_uv: + return f0, uv + + return f0 + + +def get_cents(f0_hz): + """ + F_{cent} = 1200 * log2 (F/440) + + Reference: + APSIPA'17, Perceptual Evaluation of Singing Quality + """ + voiced_f0 = f0_hz[f0_hz != 0] + return 1200 * np.log2(voiced_f0 / 440) + + +def get_pitch_derivatives(f0_hz): + """ + f0_hz: (,T) + """ + f0_cent = get_cents(f0_hz) + return f0_cent[1:] - f0_cent[:-1] + + +def get_pitch_sub_median(f0_hz): + """ + f0_hz: (,T) + """ + f0_cent = get_cents(f0_hz) + return f0_cent - np.median(f0_cent) diff --git a/models/tts/debatts/utils/g2p/__init__.py b/models/tts/debatts/utils/g2p/__init__.py new file mode 100644 index 00000000..3a2b3988 --- /dev/null +++ b/models/tts/debatts/utils/g2p/__init__.py @@ -0,0 +1,139 @@ +""" from https://github.com/keithito/tacotron """ + +import sys +import utils.g2p.cleaners +from tokenizers import Tokenizer +import json +import re + + +class PhonemeBpeTokenizer: + def __init__(self, tokenizer_path="./utils/g2p/bpe_613.json"): + self.tokenizer = Tokenizer.from_file(tokenizer_path) + self.tokenizer_path = tokenizer_path + + with open(tokenizer_path, "r") as f: + json_data = f.read() + data = json.loads(json_data) + self.vocab = data["model"]["vocab"] + + def tokenize(self, text, language): + # 1. convert text to phoneme + phonemes = _clean_text(text, ["cje_cleaners"]) + # print('clean text: ', phonemes) + + # 2. replace blank space " " with "_" + phonemes = phonemes.replace(" ", "_") + + # 3. tokenize phonemes + phoneme_tokens = self.tokenizer.encode(phonemes).ids + # print('encode: ', phoneme_tokens) + + # 4. connect single phoneme because of "`" or "⁼" + if language == "zh": + phoneme_tokens = _connect_phone(phoneme_tokens) + # print('encode phoneme: ', phoneme_tokens) + + # 5. connect tones with previous phoneme + phoneme_tokens = _connect_tone(phoneme_tokens, self.vocab) + # print('connect tones: ', phoneme_tokens) + + # 6. decode tokens [optional] + # decoded_text = self.tokenizer.decode(phoneme_tokens) + # print('decoded: ', decoded_text) + + # if not len(phoneme_tokens): + # raise ValueError("Empty text is given") + + return phonemes, phoneme_tokens + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(utils.g2p.cleaners, name) + if not cleaner: + raise Exception("Unknown cleaner: %s" % name) + text = cleaner(text) + + return text + + +def _connect_phone(phoneme_tokens): + sublist = [ + [32, 66, 67], # "p⁼wo" + [32, 66], # "p⁼" + [34, 66], # "t⁼" + [27, 66], # "k⁼" + [78, 66], # "tʃ⁼" + [81, 17, 66, 55, 17], # "ts`⁼ɹ`" + [81, 17, 66], # "ts`⁼" + [81, 17, 61, 55, 17], # "ts`ʰɹ`" + [81, 17, 61], # "ts`ʰ" + [33, 17, 55, 17], # "s`ɹ`" + [33, 17], # "s`" + [55, 17, 55, 17], # "ɹ`ɹ`" + [55, 17], # "ɹ`" + [81, 66, 55], # "ts⁼ɹ" + [81, 66], # "ts⁼" + [48, 55, 17], # "əɹ`" + ] + value = [ + 70, # "p⁼wo" + 68, # "p⁼" + 74, # "t⁼" + 76, # "k⁼" + 79, # "tʃ⁼" + 91, # "ts`⁼ɹ`" + 85, # "ts`⁼" + 92, # "ts`ʰɹ`" + 86, # "ts`ʰ" + 89, # "s`ɹ`" + 87, # "s`" + 90, # "ɹ`ɹ`" + 88, # "ɹ`" + 93, # "ts⁼ɹ" + 82, # "ts⁼" + 113, # "əɹ`" + ] + token_str = ",".join(map(str, phoneme_tokens)) + new_lst_str = [] + for idx, sub in enumerate(sublist): + sub_str = "," + ",".join(map(str, sub)) + "," + if sub_str in token_str: + replace_str = "," + str(value[idx]) + "," + token_str = token_str.replace(sub_str, replace_str) + + new_lst = list(map(int, token_str.split(","))) + return new_lst + + +def _connect_tone(phoneme_tokens, vocab): + tone_list = ["→", "↑", "↓↑", "↓"] + tone_token = [] + last_single_token = 0 + base = 0 + pattern = r"\[[^\[\]]*\]" # Exclude "[" and "]" + for tone, idx in vocab.items(): + if re.match(pattern, tone): + base = idx + 1 + if tone in tone_list: + tone_token.append(idx) + last_single_token = idx + + pre_token = None + cur_token = None + res_token = [] + for t in phoneme_tokens: + cur_token = t + if t in tone_token: + cur_token = ( + last_single_token + + (pre_token - base) * len(tone_list) + + tone_token.index(t) + + 1 + ) + res_token.pop() + res_token.append(cur_token) + pre_token = t + + return res_token diff --git a/models/tts/debatts/utils/g2p/bpe_317.json b/models/tts/debatts/utils/g2p/bpe_317.json new file mode 100644 index 00000000..f244ce16 --- /dev/null +++ b/models/tts/debatts/utils/g2p/bpe_317.json @@ -0,0 +1,639 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "[UNK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": "[CLS]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 2, + "content": "[SEP]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 3, + "content": "[PAD]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 4, + "content": "[MASK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Whitespace" + }, + "post_processor": null, + "decoder": null, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": "[UNK]", + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "vocab": { + "[UNK]": 0, + "[CLS]": 1, + "[SEP]": 2, + "[PAD]": 3, + "[MASK]": 4, + "!": 5, + "#": 6, + "*": 7, + ",": 8, + "-": 9, + ".": 10, + "=": 11, + "?": 12, + "N": 13, + "Q": 14, + "^": 15, + "_": 16, + "`": 17, + "a": 18, + "b": 19, + "d": 20, + "e": 21, + "f": 22, + "g": 23, + "h": 24, + "i": 25, + "j": 26, + "k": 27, + "l": 28, + "m": 29, + "n": 30, + "o": 31, + "p": 32, + "s": 33, + "t": 34, + "u": 35, + "v": 36, + "w": 37, + "x": 38, + "y": 39, + "z": 40, + "~": 41, + "æ": 42, + "ç": 43, + "ð": 44, + "ŋ": 45, + "ɑ": 46, + "ɔ": 47, + "ə": 48, + "ɛ": 49, + "ɥ": 50, + "ɪ": 51, + "ɫ": 52, + "ɯ": 53, + "ɸ": 54, + "ɹ": 55, + "ɾ": 56, + "ʃ": 57, + "ʊ": 58, + "ʑ": 59, + "ʒ": 60, + "ʰ": 61, + "ˈ": 62, + "ˌ": 63, + "θ": 64, + "…": 65, + "⁼": 66, + "→": 67, + "↑": 68, + "↓↑": 69, + "↓": 70, + "!→": 71, + "!↑": 72, + "!↓↑": 73, + "!↓": 74, + "#→": 75, + "#↑": 76, + "#↓↑": 77, + "#↓": 78, + "*→": 79, + "*↑": 80, + "*↓↑": 81, + "*↓": 82, + ",→": 83, + ",↑": 84, + ",↓↑": 85, + ",↓": 86, + "-→": 87, + "-↑": 88, + "-↓↑": 89, + "-↓": 90, + ".→": 91, + ".↑": 92, + ".↓↑": 93, + ".↓": 94, + "=→": 95, + "=↑": 96, + "=↓↑": 97, + "=↓": 98, + "?→": 99, + "?↑": 100, + "?↓↑": 101, + "?↓": 102, + "N→": 103, + "N↑": 104, + "N↓↑": 105, + "N↓": 106, + "Q→": 107, + "Q↑": 108, + "Q↓↑": 109, + "Q↓": 110, + "^→": 111, + "^↑": 112, + "^↓↑": 113, + "^↓": 114, + "_→": 115, + "_↑": 116, + "_↓↑": 117, + "_↓": 118, + "`→": 119, + "`↑": 120, + "`↓↑": 121, + "`↓": 122, + "a→": 123, + "a↑": 124, + "a↓↑": 125, + "a↓": 126, + "b→": 127, + "b↑": 128, + "b↓↑": 129, + "b↓": 130, + "d→": 131, + "d↑": 132, + "d↓↑": 133, + "d↓": 134, + "e→": 135, + "e↑": 136, + "e↓↑": 137, + "e↓": 138, + "f→": 139, + "f↑": 140, + "f↓↑": 141, + "f↓": 142, + "g→": 143, + "g↑": 144, + "g↓↑": 145, + "g↓": 146, + "h→": 147, + "h↑": 148, + "h↓↑": 149, + "h↓": 150, + "i→": 151, + "i↑": 152, + "i↓↑": 153, + "i↓": 154, + "j→": 155, + "j↑": 156, + "j↓↑": 157, + "j↓": 158, + "k→": 159, + "k↑": 160, + "k↓↑": 161, + "k↓": 162, + "l→": 163, + "l↑": 164, + "l↓↑": 165, + "l↓": 166, + "m→": 167, + "m↑": 168, + "m↓↑": 169, + "m↓": 170, + "n→": 171, + "n↑": 172, + "n↓↑": 173, + "n↓": 174, + "o→": 175, + "o↑": 176, + "o↓↑": 177, + "o↓": 178, + "p→": 179, + "p↑": 180, + "p↓↑": 181, + "p↓": 182, + "s→": 183, + "s↑": 184, + "s↓↑": 185, + "s↓": 186, + "t→": 187, + "t↑": 188, + "t↓↑": 189, + "t↓": 190, + "u→": 191, + "u↑": 192, + "u↓↑": 193, + "u↓": 194, + "v→": 195, + "v↑": 196, + "v↓↑": 197, + "v↓": 198, + "w→": 199, + "w↑": 200, + "w↓↑": 201, + "w↓": 202, + "x→": 203, + "x↑": 204, + "x↓↑": 205, + "x↓": 206, + "y→": 207, + "y↑": 208, + "y↓↑": 209, + "y↓": 210, + "z→": 211, + "z↑": 212, + "z↓↑": 213, + "z↓": 214, + "~→": 215, + "~↑": 216, + "~↓↑": 217, + "~↓": 218, + "æ→": 219, + "æ↑": 220, + "æ↓↑": 221, + "æ↓": 222, + "ç→": 223, + "ç↑": 224, + "ç↓↑": 225, + "ç↓": 226, + "ð→": 227, + "ð↑": 228, + "ð↓↑": 229, + "ð↓": 230, + "ŋ→": 231, + "ŋ↑": 232, + "ŋ↓↑": 233, + "ŋ↓": 234, + "ɑ→": 235, + "ɑ↑": 236, + "ɑ↓↑": 237, + "ɑ↓": 238, + "ɔ→": 239, + "ɔ↑": 240, + "ɔ↓↑": 241, + "ɔ↓": 242, + "ə→": 243, + "ə↑": 244, + "ə↓↑": 245, + "ə↓": 246, + "ɛ→": 247, + "ɛ↑": 248, + "ɛ↓↑": 249, + "ɛ↓": 250, + "ɥ→": 251, + "ɥ↑": 252, + "ɥ↓↑": 253, + "ɥ↓": 254, + "ɪ→": 255, + "ɪ↑": 256, + "ɪ↓↑": 257, + "ɪ↓": 258, + "ɫ→": 259, + "ɫ↑": 260, + "ɫ↓↑": 261, + "ɫ↓": 262, + "ɯ→": 263, + "ɯ↑": 264, + "ɯ↓↑": 265, + "ɯ↓": 266, + "ɸ→": 267, + "ɸ↑": 268, + "ɸ↓↑": 269, + "ɸ↓": 270, + "ɹ→": 271, + "ɹ↑": 272, + "ɹ↓↑": 273, + "ɹ↓": 274, + "ɾ→": 275, + "ɾ↑": 276, + "ɾ↓↑": 277, + "ɾ↓": 278, + "ʃ→": 279, + "ʃ↑": 280, + "ʃ↓↑": 281, + "ʃ↓": 282, + "ʊ→": 283, + "ʊ↑": 284, + "ʊ↓↑": 285, + "ʊ↓": 286, + "ʑ→": 287, + "ʑ↑": 288, + "ʑ↓↑": 289, + "ʑ↓": 290, + "ʒ→": 291, + "ʒ↑": 292, + "ʒ↓↑": 293, + "ʒ↓": 294, + "ʰ→": 295, + "ʰ↑": 296, + "ʰ↓↑": 297, + "ʰ↓": 298, + "ˈ→": 299, + "ˈ↑": 300, + "ˈ↓↑": 301, + "ˈ↓": 302, + "ˌ→": 303, + "ˌ↑": 304, + "ˌ↓↑": 305, + "ˌ↓": 306, + "θ→": 306, + "θ↑": 307, + "θ↓↑": 308, + "θ↓": 309, + "…→": 310, + "…↑": 311, + "…↓↑": 312, + "…↓": 313, + "⁼→": 314, + "⁼↑": 315, + "⁼↓↑": 316, + "⁼↓": 317 + }, + "merges": [ + "↓ ↑", + "! →", + "! ↑", + "! ↓↑", + "! ↓", + "# →", + "# ↑", + "# ↓↑", + "# ↓", + "* →", + "* ↑", + "* ↓↑", + "* ↓", + ", →", + ", ↑", + ", ↓↑", + ", ↓", + "- →", + "- ↑", + "- ↓↑", + "- ↓", + ". →", + ". ↑", + ". ↓↑", + ". ↓", + "= →", + "= ↑", + "= ↓↑", + "= ↓", + "? →", + "? ↑", + "? ↓↑", + "? ↓", + "N →", + "N ↑", + "N ↓↑", + "N ↓", + "Q →", + "Q ↑", + "Q ↓↑", + "Q ↓", + "^ →", + "^ ↑", + "^ ↓↑", + "^ ↓", + "_ →", + "_ ↑", + "_ ↓↑", + "_ ↓", + "` →", + "` ↑", + "` ↓↑", + "` ↓", + "a →", + "a ↑", + "a ↓↑", + "a ↓", + "b →", + "b ↑", + "b ↓↑", + "b ↓", + "d →", + "d ↑", + "d ↓↑", + "d ↓", + "e →", + "e ↑", + "e ↓↑", + "e ↓", + "f →", + "f ↑", + "f ↓↑", + "f ↓", + "g →", + "g ↑", + "g ↓↑", + "g ↓", + "h →", + "h ↑", + "h ↓↑", + "h ↓", + "i →", + "i ↑", + "i ↓↑", + "i ↓", + "j →", + "j ↑", + "j ↓↑", + "j ↓", + "k →", + "k ↑", + "k ↓↑", + "k ↓", + "l →", + "l ↑", + "l ↓↑", + "l ↓", + "m →", + "m ↑", + "m ↓↑", + "m ↓", + "n →", + "n ↑", + "n ↓↑", + "n ↓", + "o →", + "o ↑", + "o ↓↑", + "o ↓", + "p →", + "p ↑", + "p ↓↑", + "p ↓", + "s →", + "s ↑", + "s ↓↑", + "s ↓", + "t →", + "t ↑", + "t ↓↑", + "t ↓", + "u →", + "u ↑", + "u ↓↑", + "u ↓", + "v →", + "v ↑", + "v ↓↑", + "v ↓", + "w →", + "w ↑", + "w ↓↑", + "w ↓", + "x →", + "x ↑", + "x ↓↑", + "x ↓", + "y →", + "y ↑", + "y ↓↑", + "y ↓", + "z →", + "z ↑", + "z ↓↑", + "z ↓", + "~ →", + "~ ↑", + "~ ↓↑", + "~ ↓", + "æ →", + "æ ↑", + "æ ↓↑", + "æ ↓", + "ç →", + "ç ↑", + "ç ↓↑", + "ç ↓", + "ð →", + "ð ↑", + "ð ↓↑", + "ð ↓", + "ŋ →", + "ŋ ↑", + "ŋ ↓↑", + "ŋ ↓", + "ɑ →", + "ɑ ↑", + "ɑ ↓↑", + "ɑ ↓", + "ɔ →", + "ɔ ↑", + "ɔ ↓↑", + "ɔ ↓", + "ə →", + "ə ↑", + "ə ↓↑", + "ə ↓", + "ɛ →", + "ɛ ↑", + "ɛ ↓↑", + "ɛ ↓", + "ɥ →", + "ɥ ↑", + "ɥ ↓↑", + "ɥ ↓", + "ɪ →", + "ɪ ↑", + "ɪ ↓↑", + "ɪ ↓", + "ɫ →", + "ɫ ↑", + "ɫ ↓↑", + "ɫ ↓", + "ɯ →", + "ɯ ↑", + "ɯ ↓↑", + "ɯ ↓", + "ɸ →", + "ɸ ↑", + "ɸ ↓↑", + "ɸ ↓", + "ɹ →", + "ɹ ↑", + "ɹ ↓↑", + "ɹ ↓", + "ɾ →", + "ɾ ↑", + "ɾ ↓↑", + "ɾ ↓", + "ʃ →", + "ʃ ↑", + "ʃ ↓↑", + "ʃ ↓", + "ʊ →", + "ʊ ↑", + "ʊ ↓↑", + "ʊ ↓", + "ʑ →", + "ʑ ↑", + "ʑ ↓↑", + "ʑ ↓", + "ʒ →", + "ʒ ↑", + "ʒ ↓↑", + "ʒ ↓", + "ʰ →", + "ʰ ↑", + "ʰ ↓↑", + "ʰ ↓", + "ˈ →", + "ˈ ↑", + "ˈ ↓↑", + "ˈ ↓", + "ˌ →", + "ˌ ↑", + "ˌ ↓↑", + "ˌ ↓", + "θ →", + "θ ↑", + "θ ↓↑", + "θ ↓", + "… →", + "… ↑", + "… ↓↑", + "… ↓", + "⁼ →", + "⁼ ↑", + "⁼ ↓↑", + "⁼ ↓" + ] + } +} \ No newline at end of file diff --git a/models/tts/debatts/utils/g2p/bpe_553.json b/models/tts/debatts/utils/g2p/bpe_553.json new file mode 100644 index 00000000..6cff14d3 --- /dev/null +++ b/models/tts/debatts/utils/g2p/bpe_553.json @@ -0,0 +1,1221 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "[UNK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": "[CLS]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 2, + "content": "[SEP]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 3, + "content": "[PAD]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 4, + "content": "[MASK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Whitespace" + }, + "post_processor": null, + "decoder": null, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": "[UNK]", + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "vocab": { + "[UNK]": 0, + "[CLS]": 1, + "[SEP]": 2, + "[PAD]": 3, + "[MASK]": 4, + "!": 5, + "#": 6, + "*": 7, + ",": 8, + "-": 9, + ".": 10, + "=": 11, + "?": 12, + "N": 13, + "Q": 14, + "^": 15, + "_": 16, + "`": 17, + "a": 18, + "b": 19, + "d": 20, + "e": 21, + "f": 22, + "g": 23, + "h": 24, + "i": 25, + "j": 26, + "k": 27, + "l": 28, + "m": 29, + "n": 30, + "o": 31, + "p": 32, + "s": 33, + "t": 34, + "u": 35, + "v": 36, + "w": 37, + "x": 38, + "y": 39, + "z": 40, + "~": 41, + "æ": 42, + "ç": 43, + "ð": 44, + "ŋ": 45, + "ɑ": 46, + "ɔ": 47, + "ə": 48, + "ɛ": 49, + "ɥ": 50, + "ɪ": 51, + "ɫ": 52, + "ɯ": 53, + "ɸ": 54, + "ɹ": 55, + "ɾ": 56, + "ʃ": 57, + "ʊ": 58, + "ʑ": 59, + "ʒ": 60, + "ʰ": 61, + "ˈ": 62, + "ˌ": 63, + "θ": 64, + "…": 65, + "⁼": 66, + + "wo": 67, + "p⁼": 68, + "pʰ": 69, + "p⁼wo": 70, + "pʰwo": 71, + "mwo": 72, + "fwo": 73, + + "t⁼": 74, + "tʰ": 75, + "k⁼": 76, + "kʰ": 77, + + "tʃ": 78, + "tʃ⁼": 79, + "tʃʰ": 80, + + "ts": 81, + "ts⁼": 82, + "ts`": 83, + "tsʰ" :84, + + "ts`⁼": 85, + "ts`ʰ": 86, + + "s`": 87, + "ɹ`": 88, + "s`ɹ`": 89, + "ɹ`ɹ`": 90, + "ts`⁼ɹ`": 91, + "ts`ʰɹ`": 92, + + "ts⁼ɹ": 93, + "tsʰɹ": 94, + "sɹ": 95, + + "aɪ": 96, + "eɪ": 97, + "ɑʊ": 98, + "oʊ": 99, + + "jɛ": 100, + "jɛn": 101, + "ɥæ": 102, + "ɥæn": 103, + + "an": 104, + "in": 105, + "ɥn": 106, + "ən": 107, + "ɑŋ": 108, + "iŋ": 109, + "ʊŋ": 110, + "jʊŋ": 110, + "əŋ": 112, + "əɹ`": 113, + + + "→": 114, + "↑": 115, + "↓↑": 116, + "↓": 117, + "!→": 118, + "!↑": 119, + "!↓↑": 120, + "!↓": 121, + "#→": 122, + "#↑": 123, + "#↓↑": 124, + "#↓": 125, + "*→": 126, + "*↑": 127, + "*↓↑": 128, + "*↓": 129, + ",→": 130, + ",↑": 131, + ",↓↑": 132, + ",↓": 133, + "-→": 134, + "-↑": 135, + "-↓↑": 136, + "-↓": 137, + ".→": 138, + ".↑": 139, + ".↓↑": 140, + ".↓": 141, + "=→": 142, + "=↑": 143, + "=↓↑": 144, + "=↓": 145, + "?→": 146, + "?↑": 147, + "?↓↑": 148, + "?↓": 149, + "N→": 150, + "N↑": 151, + "N↓↑": 152, + "N↓": 153, + "Q→": 154, + "Q↑": 155, + "Q↓↑": 156, + "Q↓": 157, + "^→": 158, + "^↑": 159, + "^↓↑": 160, + "^↓": 161, + "_→": 162, + "_↑": 163, + "_↓↑": 164, + "_↓": 165, + "`→": 166, + "`↑": 167, + "`↓↑": 168, + "`↓": 169, + "a→": 170, + "a↑": 171, + "a↓↑": 172, + "a↓": 173, + "b→": 174, + "b↑": 175, + "b↓↑": 176, + "b↓": 177, + "d→": 178, + "d↑": 179, + "d↓↑": 180, + "d↓": 181, + "e→": 182, + "e↑": 183, + "e↓↑": 184, + "e↓": 185, + "f→": 186, + "f↑": 187, + "f↓↑": 188, + "f↓": 189, + "g→": 190, + "g↑": 191, + "g↓↑": 192, + "g↓": 193, + "h→": 194, + "h↑": 195, + "h↓↑": 196, + "h↓": 197, + "i→": 198, + "i↑": 199, + "i↓↑": 200, + "i↓": 201, + "j→": 202, + "j↑": 203, + "j↓↑": 204, + "j↓": 205, + "k→": 206, + "k↑": 207, + "k↓↑": 208, + "k↓": 209, + "l→": 210, + "l↑": 211, + "l↓↑": 212, + "l↓": 213, + "m→": 214, + "m↑": 215, + "m↓↑": 216, + "m↓": 217, + "n→": 218, + "n↑": 219, + "n↓↑": 220, + + "n↓": 221, + "o→": 222, + "o↑": 223, + "o↓↑": 224, + "o↓": 225, + "p→": 226, + "p↑": 227, + "p↓↑": 228, + "p↓": 229, + "s→": 230, + "s↑": 231, + "s↓↑": 232, + "s↓": 233, + "t→": 234, + "t↑": 235, + "t↓↑": 236, + "t↓": 237, + "u→": 238, + "u↑": 239, + "u↓↑": 240, + "u↓": 241, + "v→": 242, + "v↑": 243, + "v↓↑": 244, + "v↓": 245, + "w→": 246, + "w↑": 247, + "w↓↑": 248, + "w↓": 249, + "x→": 250, + "x↑": 251, + "x↓↑": 252, + "x↓": 253, + "y→": 254, + "y↑": 255, + "y↓↑": 256, + "y↓": 257, + "z→": 258, + "z↑": 258, + "z↓↑": 260, + "z↓": 261, + "~→": 262, + "~↑": 263, + "~↓↑": 264, + "~↓": 265, + "æ→": 266, + "æ↑": 267, + "æ↓↑": 268, + "æ↓": 269, + "ç→": 270, + "ç↑": 271, + "ç↓↑": 272, + "ç↓": 273, + "ð→": 274, + "ð↑": 275, + "ð↓↑": 276, + "ð↓": 277, + "ŋ→": 278, + "ŋ↑": 279, + "ŋ↓↑": 280, + "ŋ↓": 281, + "ɑ→": 282, + "ɑ↑": 283, + "ɑ↓↑": 284, + "ɑ↓": 285, + "ɔ→": 286, + "ɔ↑": 287, + "ɔ↓↑": 288, + "ɔ↓": 289, + "ə→": 290, + "ə↑": 291, + "ə↓↑": 292, + "ə↓": 293, + "ɛ→": 294, + "ɛ↑": 295, + "ɛ↓↑": 296, + "ɛ↓": 297, + "ɥ→": 298, + "ɥ↑": 299, + "ɥ↓↑": 300, + "ɥ↓": 301, + "ɪ→": 302, + "ɪ↑": 303, + "ɪ↓↑": 304, + "ɪ↓": 305, + "ɫ→": 306, + "ɫ↑": 307, + "ɫ↓↑": 308, + "ɫ↓": 309, + "ɯ→": 310, + "ɯ↑": 311, + "ɯ↓↑": 312, + "ɯ↓": 313, + "ɸ→": 314, + "ɸ↑": 315, + "ɸ↓↑": 316, + "ɸ↓": 317, + "ɹ→": 318, + "ɹ↑": 319, + "ɹ↓↑": 320, + "ɹ↓": 321, + "ɾ→": 322, + "ɾ↑": 323, + "ɾ↓↑": 324, + "ɾ↓": 325, + "ʃ→": 326, + "ʃ↑": 327, + "ʃ↓↑": 328, + "ʃ↓": 329, + "ʊ→": 330, + "ʊ↑": 331, + "ʊ↓↑": 332, + "ʊ↓": 333, + "ʑ→": 334, + "ʑ↑": 335, + "ʑ↓↑": 336, + "ʑ↓": 337, + "ʒ→": 338, + "ʒ↑": 339, + "ʒ↓↑": 340, + "ʒ↓": 341, + "ʰ→": 342, + "ʰ↑": 343, + "ʰ↓↑": 344, + "ʰ↓": 345, + "ˈ→": 346, + "ˈ↑": 347, + "ˈ↓↑": 348, + "ˈ↓": 349, + "ˌ→": 350, + "ˌ↑": 351, + "ˌ↓↑": 352, + "ˌ↓": 353, + "θ→": 354, + "θ↑": 355, + "θ↓↑": 356, + "θ↓": 357, + "…→": 358, + "…↑": 359, + "…↓↑": 360, + "…↓": 361, + "⁼→": 352, + "⁼↑": 363, + "⁼↓↑": 364, + "⁼↓": 365, + + "wo→": 366, + "wo↑": 367, + "wo↓↑": 368, + "wo↓": 369, + + "p⁼→": 370, + "p⁼↑": 370, + "p⁼↓↑": 372, + "p⁼↓": 373, + + "pʰ→": 374, + "pʰ↑": 375, + "pʰ↓↑": 376, + "pʰ↓": 377, + + "p⁼wo→": 378, + "p⁼wo↑": 379, + "p⁼wo↓↑": 380, + "p⁼wo↓": 381, + + "pʰwo→": 382, + "pʰwo↑": 383, + "pʰwo↓↑": 384, + "pʰwo↓": 385, + + "mwo→": 386, + "mwo↑": 387, + "mwo↓↑": 388, + "mwo↓": 389, + + "fwo→": 390, + "fwo↑": 391, + "fwo↓↑": 392, + "fwo↓": 393, + + "t⁼→": 394, + "t⁼↑": 395, + "t⁼↓↑": 396, + "t⁼↓": 397, + + "tʰ→": 398, + "tʰ↑": 399, + "tʰ↓↑": 400, + "tʰ↓": 401, + + "k⁼→": 402, + "k⁼↑": 403, + "k⁼↓↑": 404, + "k⁼↓": 405, + + "kʰ→": 406, + "kʰ↑": 407, + "kʰ↓↑": 408, + "kʰ↓": 409, + + "tʃ→": 410, + "tʃ↑": 411, + "tʃ↓↑": 412, + "tʃ↓": 413, + + "tʃ⁼→": 414, + "tʃ⁼↑": 415, + "tʃ⁼↓↑": 416, + "tʃ⁼↓": 417, + + "tʃʰ→": 418, + "tʃʰ↑": 419, + "tʃʰ↓↑": 420, + "tʃʰ↓": 421, + + "ts→": 422, + "ts↑": 423, + "ts↓↑": 424, + "ts↓": 425, + + "ts⁼→": 426, + "ts⁼↑": 427, + "ts⁼↓↑": 428, + "ts⁼↓": 429, + + "ts`→": 430, + "ts`↑": 431, + "ts`↓↑": 432, + "ts`↓": 433, + + "tsʰ→": 434, + "tsʰ↑": 435, + "tsʰ↓↑": 436, + "tsʰ↓": 437, + + "ts`⁼→": 438, + "ts`⁼↑": 439, + "ts`⁼↓↑": 440, + "ts`⁼↓": 441, + + "ts`ʰ→": 442, + "ts`ʰ↑": 443, + "ts`ʰ↓↑": 444, + "ts`ʰ↓": 445, + + "s`→": 446, + "s`↑": 447, + "s`↓↑": 448, + "s`↓": 449, + + "ɹ`→": 450, + "ɹ`↑": 451, + "ɹ`↓↑": 452, + "ɹ`↓": 453, + + "s`ɹ`→": 454, + "s`ɹ`↑": 455, + "s`ɹ`↓↑": 456, + "s`ɹ`↓": 457, + + "ɹ`ɹ`→": 458, + "ɹ`ɹ`↑": 459, + "ɹ`ɹ`↓↑": 460, + "ɹ`ɹ`↓": 461, + + "ts`⁼ɹ`→": 462, + "ts`⁼ɹ`↑": 463, + "ts`⁼ɹ`↓↑": 464, + "ts`⁼ɹ`↓": 465, + + "ts`ʰɹ`→": 466, + "ts`ʰɹ`↑": 467, + "ts`ʰɹ`↓↑": 468, + "ts`ʰɹ`↓": 469, + + "ts⁼ɹ→": 470, + "ts⁼ɹ↑": 471, + "ts⁼ɹ↓↑": 472, + "ts⁼ɹ↓": 473, + + "tsʰɹ→": 474, + "tsʰɹ↑": 475, + "tsʰɹ↓↑": 476, + "tsʰɹ↓": 477, + + "sɹ→": 478, + "sɹ↑": 479, + "sɹ↓↑": 480, + "sɹ↓": 481, + + "aɪ→": 482, + "aɪ↑": 483, + "aɪ↓↑": 484, + "aɪ↓": 485, + + "eɪ→": 486, + "eɪ↑": 487, + "eɪ↓↑": 488, + "eɪ↓": 489, + + "ɑʊ→": 490, + "ɑʊ↑": 491, + "ɑʊ↓↑": 492, + "ɑʊ↓": 493, + + "oʊ→": 494, + "oʊ↑": 495, + "oʊ↓↑": 496, + "oʊ↓": 497, + + "jɛ→": 498, + "jɛ↑": 499, + "jɛ↓↑": 500, + "jɛ↓": 501, + + "jɛn→": 502, + "jɛn↑": 503, + "jɛn↓↑": 504, + "jɛn↓": 505, + + "ɥæ→": 506, + "ɥæ↑": 507, + "ɥæ↓↑": 508, + "ɥæ↓": 509, + + "ɥæn→": 510, + "ɥæn↑": 511, + "ɥæn↓↑": 512, + "ɥæn↓": 513, + + "an→": 514, + "an↑": 515, + "an↓↑": 516, + "an↓": 517, + + "in→": 518, + "in↑": 519, + "in↓↑": 520, + "in↓": 521, + + "ɥn→": 522, + "ɥn↑": 523, + "ɥn↓↑": 524, + "ɥn↓": 525, + + "ən→": 526, + "ən↑": 527, + "ən↓↑": 528, + "ən↓": 529, + + "ɑŋ→": 530, + "ɑŋ↑": 531, + "ɑŋ↓↑": 532, + "ɑŋ↓": 533, + + "iŋ→": 534, + "iŋ↑": 535, + "iŋ↓↑": 536, + "iŋ↓": 537, + + "ʊŋ→": 538, + "ʊŋ↑": 539, + "ʊŋ↓↑": 540, + "ʊŋ↓": 541, + + "jʊŋ→": 542, + "jʊŋ↑": 543, + "jʊŋ↓↑": 544, + "jʊŋ↓": 545, + + "əŋ→": 546, + "əŋ↑": 547, + "əŋ↓↑": 548, + "əŋ↓": 549, + + "əɹ`→": 550, + "əɹ`↑": 551, + "əɹ`↓↑": 552, + "əɹ`↓": 553 + }, + "merges": [ + "↓ ↑", + "w o", + "p ⁼", + "p ʰ", + "p⁼ wo", + "pʰ wo", + "m wo", + "f wo", + + "t ⁼", + "t ʰ", + "k ⁼", + "k ʰ", + + "t ʃ", + "tʃ ⁼", + "tʃ ʰ", + + "t s", + "ts ⁼", + "ts `", + "ts ʰ", + + "ts` ⁼", + "ts` ʰ", + + "s `", + "ɹ `", + "s` ɹ`", + "ɹ` ɹ`", + "ts`⁼ ɹ`", + "ts`ʰ ɹ`", + + "ts⁼ ɹ", + "tsʰ ɹ", + "s ɹ", + + "a ɪ", + "e ɪ", + "ɑ ʊ", + "o ʊ", + + "j ɛ", + "jɛ n", + "ɥ æ", + "ɥæ n", + + "a n", + "i n", + "ɥ n", + "ə n", + "ɑ ŋ", + "i ŋ", + "ʊ ŋ", + "j ʊŋ", + "ə ŋ", + "ə ɹ`", + + "! →", + "! ↑", + "! ↓↑", + "! ↓", + "# →", + "# ↑", + "# ↓↑", + "# ↓", + "* →", + "* ↑", + "* ↓↑", + "* ↓", + ", →", + ", ↑", + ", ↓↑", + ", ↓", + "- →", + "- ↑", + "- ↓↑", + "- ↓", + ". →", + ". ↑", + ". ↓↑", + ". ↓", + "= →", + "= ↑", + "= ↓↑", + "= ↓", + "? →", + "? ↑", + "? ↓↑", + "? ↓", + "N →", + "N ↑", + "N ↓↑", + "N ↓", + "Q →", + "Q ↑", + "Q ↓↑", + "Q ↓", + "^ →", + "^ ↑", + "^ ↓↑", + "^ ↓", + "_ →", + "_ ↑", + "_ ↓↑", + "_ ↓", + "a →", + "a ↑", + "a ↓↑", + "a ↓", + "b →", + "b ↑", + "b ↓↑", + "b ↓", + "d →", + "d ↑", + "d ↓↑", + "d ↓", + "e →", + "e ↑", + "e ↓↑", + "e ↓", + "f →", + "f ↑", + "f ↓↑", + "f ↓", + "g →", + "g ↑", + "g ↓↑", + "g ↓", + "h →", + "h ↑", + "h ↓↑", + "h ↓", + "i →", + "i ↑", + "i ↓↑", + "i ↓", + "j →", + "j ↑", + "j ↓↑", + "j ↓", + "k →", + "k ↑", + "k ↓↑", + "k ↓", + "l →", + "l ↑", + "l ↓↑", + "l ↓", + "m →", + "m ↑", + "m ↓↑", + "m ↓", + "n →", + "n ↑", + "n ↓↑", + "n ↓", + "o →", + "o ↑", + "o ↓↑", + "o ↓", + "p →", + "p ↑", + "p ↓↑", + "p ↓", + "s →", + "s ↑", + "s ↓↑", + "s ↓", + "t →", + "t ↑", + "t ↓↑", + "t ↓", + "u →", + "u ↑", + "u ↓↑", + "u ↓", + "v →", + "v ↑", + "v ↓↑", + "v ↓", + "w →", + "w ↑", + "w ↓↑", + "w ↓", + "x →", + "x ↑", + "x ↓↑", + "x ↓", + "y →", + "y ↑", + "y ↓↑", + "y ↓", + "z →", + "z ↑", + "z ↓↑", + "z ↓", + "~ →", + "~ ↑", + "~ ↓↑", + "~ ↓", + "æ →", + "æ ↑", + "æ ↓↑", + "æ ↓", + "ç →", + "ç ↑", + "ç ↓↑", + "ç ↓", + "ð →", + "ð ↑", + "ð ↓↑", + "ð ↓", + "ŋ →", + "ŋ ↑", + "ŋ ↓↑", + "ŋ ↓", + "ɑ →", + "ɑ ↑", + "ɑ ↓↑", + "ɑ ↓", + "ɔ →", + "ɔ ↑", + "ɔ ↓↑", + "ɔ ↓", + "ə →", + "ə ↑", + "ə ↓↑", + "ə ↓", + "ɛ →", + "ɛ ↑", + "ɛ ↓↑", + "ɛ ↓", + "ɥ →", + "ɥ ↑", + "ɥ ↓↑", + "ɥ ↓", + "ɪ →", + "ɪ ↑", + "ɪ ↓↑", + "ɪ ↓", + "ɫ →", + "ɫ ↑", + "ɫ ↓↑", + "ɫ ↓", + "ɯ →", + "ɯ ↑", + "ɯ ↓↑", + "ɯ ↓", + "ɸ →", + "ɸ ↑", + "ɸ ↓↑", + "ɸ ↓", + "ɹ →", + "ɹ ↑", + "ɹ ↓↑", + "ɹ ↓", + "ɾ →", + "ɾ ↑", + "ɾ ↓↑", + "ɾ ↓", + "ʃ →", + "ʃ ↑", + "ʃ ↓↑", + "ʃ ↓", + "ʊ →", + "ʊ ↑", + "ʊ ↓↑", + "ʊ ↓", + "ʑ →", + "ʑ ↑", + "ʑ ↓↑", + "ʑ ↓", + "ʒ →", + "ʒ ↑", + "ʒ ↓↑", + "ʒ ↓", + "ʰ →", + "ʰ ↑", + "ʰ ↓↑", + "ʰ ↓", + "ˈ →", + "ˈ ↑", + "ˈ ↓↑", + "ˈ ↓", + "ˌ →", + "ˌ ↑", + "ˌ ↓↑", + "ˌ ↓", + "θ →", + "θ ↑", + "θ ↓↑", + "θ ↓", + "… →", + "… ↑", + "… ↓↑", + "… ↓", + "⁼ →", + "⁼ ↑", + "⁼ ↓↑", + "⁼ ↓", + "wo →", + "wo ↑", + "wo ↓↑", + "wo ↓", + + "p⁼ →", + "p⁼ ↑", + "p⁼ ↓↑", + "p⁼ ↓", + + "pʰ →", + "pʰ ↑", + "pʰ ↓↑", + "pʰ ↓", + + "p⁼wo →", + "p⁼wo ↑", + "p⁼wo ↓↑", + "p⁼wo ↓", + + "pʰwo →", + "pʰwo ↑", + "pʰwo ↓↑", + "pʰwo ↓", + + "mwo →", + "mwo ↑", + "mwo ↓↑", + "mwo ↓", + + "fwo →", + "fwo ↑", + "fwo ↓↑", + "fwo ↓", + + "t⁼ →", + "t⁼ ↑", + "t⁼ ↓↑", + "t⁼ ↓", + + "tʰ →", + "tʰ ↑", + "tʰ ↓↑", + "tʰ ↓", + + "k⁼ →", + "k⁼ ↑", + "k⁼ ↓↑", + "k⁼ ↓", + + "kʰ →", + "kʰ ↑", + "kʰ ↓↑", + "kʰ ↓", + + "tʃ →", + "tʃ ↑", + "tʃ ↓↑", + "tʃ ↓", + + "tʃ⁼ →", + "tʃ⁼ ↑", + "tʃ⁼ ↓↑", + "tʃ⁼ ↓", + + "tʃʰ →", + "tʃʰ ↑", + "tʃʰ ↓↑", + "tʃʰ ↓", + + "ts →", + "ts ↑", + "ts ↓↑", + "ts ↓", + + "ts⁼ →", + "ts⁼ ↑", + "ts⁼ ↓↑", + "ts⁼ ↓", + + "ts` →", + "ts` ↑", + "ts` ↓↑", + "ts` ↓", + + "tsʰ →", + "tsʰ ↑", + "tsʰ ↓↑", + "tsʰ ↓", + + "ts`⁼ →", + "ts`⁼ ↑", + "ts`⁼ ↓↑", + "ts`⁼ ↓", + + "ts`ʰ →", + "ts`ʰ ↑", + "ts`ʰ ↓↑", + "ts`ʰ ↓", + + "s` →", + "s` ↑", + "s` ↓↑", + "s` ↓", + + "ɹ` →", + "ɹ` ↑", + "ɹ` ↓↑", + "ɹ` ↓", + + "s`ɹ` →", + "s`ɹ` ↑", + "s`ɹ` ↓↑", + "s`ɹ` ↓", + + "ɹ`ɹ` →", + "ɹ`ɹ` ↑", + "ɹ`ɹ` ↓↑", + "ɹ`ɹ` ↓", + + "ts`⁼ɹ` →", + "ts`⁼ɹ` ↑", + "ts`⁼ɹ` ↓↑", + "ts`⁼ɹ` ↓", + + "ts`ʰɹ` →", + "ts`ʰɹ` ↑", + "ts`ʰɹ` ↓↑", + "ts`ʰɹ` ↓", + + "ts⁼ɹ →", + "ts⁼ɹ ↑", + "ts⁼ɹ ↓↑", + "ts⁼ɹ ↓", + + "tsʰɹ →", + "tsʰɹ ↑", + "tsʰɹ ↓↑", + "tsʰɹ ↓", + + "sɹ →", + "sɹ ↑", + "sɹ ↓↑", + "sɹ ↓", + + "aɪ →", + "aɪ ↑", + "aɪ ↓↑", + "aɪ ↓", + + "eɪ →", + "eɪ ↑", + "eɪ ↓↑", + "eɪ ↓", + + "ɑʊ →", + "ɑʊ ↑", + "ɑʊ ↓↑", + "ɑʊ ↓", + + "oʊ →", + "oʊ ↑", + "oʊ ↓↑", + "oʊ ↓", + + "jɛ →", + "jɛ ↑", + "jɛ ↓↑", + "jɛ ↓", + + "jɛn →", + "jɛn ↑", + "jɛn ↓↑", + "jɛn ↓", + + "ɥæ →", + "ɥæ ↑", + "ɥæ ↓↑", + "ɥæ ↓", + + "ɥæn →", + "ɥæn ↑", + "ɥæn ↓↑", + "ɥæn ↓", + + "an →", + "an ↑", + "an ↓↑", + "an ↓", + + "in →", + "in ↑", + "in ↓↑", + "in ↓", + + "ɥn →", + "ɥn ↑", + "ɥn ↓↑", + "ɥn ↓", + + "ən →", + "ən ↑", + "ən ↓↑", + "ən ↓", + + "ɑŋ →", + "ɑŋ ↑", + "ɑŋ ↓↑", + "ɑŋ ↓", + + "iŋ →", + "iŋ ↑", + "iŋ ↓↑", + "iŋ ↓", + + "ʊŋ →", + "ʊŋ ↑", + "ʊŋ ↓↑", + "ʊŋ ↓", + + "jʊŋ →", + "jʊŋ ↑", + "jʊŋ ↓↑", + "jʊŋ ↓", + + "əŋ →", + "əŋ ↑", + "əŋ ↓↑", + "əŋ ↓", + + "əɹ` →", + "əɹ` ↑", + "əɹ` ↓↑", + "əɹ` ↓" + ] + } +} \ No newline at end of file diff --git a/models/tts/debatts/utils/g2p/bpe_613.json b/models/tts/debatts/utils/g2p/bpe_613.json new file mode 100644 index 00000000..ff6063f8 --- /dev/null +++ b/models/tts/debatts/utils/g2p/bpe_613.json @@ -0,0 +1,1366 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "[UNK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": "[CLS]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 2, + "content": "[SEP]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 3, + "content": "[PAD]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 4, + "content": "[MASK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Whitespace" + }, + "post_processor": null, + "decoder": null, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": "[UNK]", + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "vocab": { + "[UNK]": 0, + "[CLS]": 1, + "[SEP]": 2, + "[PAD]": 3, + "[MASK]": 4, + "!": 5, + "#": 6, + "*": 7, + ",": 8, + "-": 9, + ".": 10, + "=": 11, + "?": 12, + "N": 13, + "Q": 14, + "^": 15, + "_": 16, + "`": 17, + "a": 18, + "b": 19, + "d": 20, + "e": 21, + "f": 22, + "g": 23, + "h": 24, + "i": 25, + "j": 26, + "k": 27, + "l": 28, + "m": 29, + "n": 30, + "o": 31, + "p": 32, + "s": 33, + "t": 34, + "u": 35, + "v": 36, + "w": 37, + "x": 38, + "y": 39, + "z": 40, + "~": 41, + "æ": 42, + "ç": 43, + "ð": 44, + "ŋ": 45, + "ɑ": 46, + "ɔ": 47, + "ə": 48, + "ɛ": 49, + "ɥ": 50, + "ɪ": 51, + "ɫ": 52, + "ɯ": 53, + "ɸ": 54, + "ɹ": 55, + "ɾ": 56, + "ʃ": 57, + "ʊ": 58, + "ʑ": 59, + "ʒ": 60, + "ʰ": 61, + "ˈ": 62, + "ˌ": 63, + "θ": 64, + "…": 65, + "⁼": 66, + + "wo": 67, + "p⁼": 68, + "pʰ": 69, + "p⁼wo": 70, + "pʰwo": 71, + "mwo": 72, + "fwo": 73, + + "t⁼": 74, + "tʰ": 75, + "k⁼": 76, + "kʰ": 77, + + "tʃ": 78, + "tʃ⁼": 79, + "tʃʰ": 80, + + "ts": 81, + "ts⁼": 82, + "ts`": 83, + "tsʰ" :84, + + "ts`⁼": 85, + "ts`ʰ": 86, + + "s`": 87, + "ɹ`": 88, + "s`ɹ`": 89, + "ɹ`ɹ`": 90, + "ts`⁼ɹ`": 91, + "ts`ʰɹ`": 92, + + "ts⁼ɹ": 93, + "tsʰɹ": 94, + "sɹ": 95, + + "aɪ": 96, + "eɪ": 97, + "ɑʊ": 98, + "oʊ": 99, + + "jɛ": 100, + "jɛn": 101, + "ɥæ": 102, + "ɥæn": 103, + + "an": 104, + "in": 105, + "ɥn": 106, + "ən": 107, + "ɑŋ": 108, + "iŋ": 109, + "ʊŋ": 110, + "jʊŋ": 111, + "əŋ": 112, + "əɹ`": 113, + + "ja": 114, + "iɛ": 115, + "iɑʊ": 116, + "joʊ": 117, + "iɑŋ": 118, + "wa": 119, + "waɪ": 120, + "weɪ": 121, + "wan": 122, + "wən": 123, + "uɑŋ": 124, + "ɥɛ": 125, + + "→": 126, + "↑": 127, + "↓↑": 128, + "↓": 129, + "!→": 130, + "!↑": 131, + "!↓↑": 132, + "!↓": 133, + "#→": 134, + "#↑": 135, + "#↓↑": 136, + "#↓": 137, + "*→": 138, + "*↑": 139, + "*↓↑": 140, + "*↓": 141, + ",→": 142, + ",↑": 143, + ",↓↑": 144, + ",↓": 145, + "-→": 146, + "-↑": 147, + "-↓↑": 148, + "-↓": 149, + ".→": 150, + ".↑": 151, + ".↓↑": 152, + ".↓": 153, + "=→": 154, + "=↑": 155, + "=↓↑": 156, + "=↓": 157, + "?→": 158, + "?↑": 159, + "?↓↑": 160, + "?↓": 161, + "N→": 162, + "N↑": 163, + "N↓↑": 164, + "N↓": 165, + "Q→": 166, + "Q↑": 167, + "Q↓↑": 168, + "Q↓": 169, + "^→": 170, + "^↑": 171, + "^↓↑": 172, + "^↓": 173, + "_→": 174, + "_↑": 175, + "_↓↑": 176, + "_↓": 177, + "`→": 178, + "`↑": 179, + "`↓↑": 180, + "`↓": 181, + "a→": 182, + "a↑": 183, + "a↓↑": 184, + "a↓": 185, + "b→": 186, + "b↑": 187, + "b↓↑": 188, + "b↓": 189, + "d→": 190, + "d↑": 191, + "d↓↑": 192, + "d↓": 193, + "e→": 194, + "e↑": 195, + "e↓↑": 196, + "e↓": 197, + "f→": 198, + "f↑": 199, + "f↓↑": 200, + "f↓": 201, + "g→": 202, + "g↑": 203, + "g↓↑": 204, + "g↓": 205, + "h→": 206, + "h↑": 207, + "h↓↑": 208, + "h↓": 209, + "i→": 210, + "i↑": 211, + "i↓↑": 212, + "i↓": 213, + "j→": 214, + "j↑": 215, + "j↓↑": 216, + "j↓": 217, + "k→": 218, + "k↑": 219, + "k↓↑": 220, + "k↓": 221, + "l→": 222, + "l↑": 223, + "l↓↑": 224, + "l↓": 225, + "m→": 226, + "m↑": 227, + "m↓↑": 228, + "m↓": 229, + "n→": 230, + "n↑": 231, + "n↓↑": 232, + + "n↓": 233, + "o→": 234, + "o↑": 235, + "o↓↑": 236, + "o↓": 237, + "p→": 238, + "p↑": 239, + "p↓↑": 240, + "p↓": 241, + "s→": 242, + "s↑": 243, + "s↓↑": 244, + "s↓": 245, + "t→": 246, + "t↑": 247, + "t↓↑": 248, + "t↓": 249, + "u→": 250, + "u↑": 251, + "u↓↑": 252, + "u↓": 253, + "v→": 254, + "v↑": 255, + "v↓↑": 256, + "v↓": 257, + "w→": 258, + "w↑": 259, + "w↓↑": 260, + "w↓": 261, + "x→": 262, + "x↑": 263, + "x↓↑": 264, + "x↓": 265, + "y→": 266, + "y↑": 267, + "y↓↑": 268, + "y↓": 269, + "z→": 270, + "z↑": 271, + "z↓↑": 272, + "z↓": 273, + "~→": 274, + "~↑": 275, + "~↓↑": 276, + "~↓": 277, + "æ→": 278, + "æ↑": 279, + "æ↓↑": 280, + "æ↓": 281, + "ç→": 282, + "ç↑": 283, + "ç↓↑": 284, + "ç↓": 285, + "ð→": 286, + "ð↑": 287, + "ð↓↑": 288, + "ð↓": 289, + "ŋ→": 290, + "ŋ↑": 291, + "ŋ↓↑": 292, + "ŋ↓": 293, + "ɑ→": 294, + "ɑ↑": 295, + "ɑ↓↑": 296, + "ɑ↓": 297, + "ɔ→": 298, + "ɔ↑": 299, + "ɔ↓↑": 300, + "ɔ↓": 301, + "ə→": 302, + "ə↑": 303, + "ə↓↑": 304, + "ə↓": 305, + "ɛ→": 306, + "ɛ↑": 307, + "ɛ↓↑": 308, + "ɛ↓": 309, + "ɥ→": 310, + "ɥ↑": 311, + "ɥ↓↑": 312, + "ɥ↓": 313, + "ɪ→": 314, + "ɪ↑": 315, + "ɪ↓↑": 316, + "ɪ↓": 317, + "ɫ→": 318, + "ɫ↑": 319, + "ɫ↓↑": 320, + "ɫ↓": 321, + "ɯ→": 322, + "ɯ↑": 323, + "ɯ↓↑": 324, + "ɯ↓": 325, + "ɸ→": 326, + "ɸ↑": 327, + "ɸ↓↑": 328, + "ɸ↓": 329, + "ɹ→": 330, + "ɹ↑": 331, + "ɹ↓↑": 332, + "ɹ↓": 333, + "ɾ→": 334, + "ɾ↑": 335, + "ɾ↓↑": 336, + "ɾ↓": 337, + "ʃ→": 338, + "ʃ↑": 339, + "ʃ↓↑": 340, + "ʃ↓": 341, + "ʊ→": 342, + "ʊ↑": 343, + "ʊ↓↑": 344, + "ʊ↓": 345, + "ʑ→": 346, + "ʑ↑": 347, + "ʑ↓↑": 348, + "ʑ↓": 349, + "ʒ→": 350, + "ʒ↑": 351, + "ʒ↓↑": 352, + "ʒ↓": 353, + "ʰ→": 354, + "ʰ↑": 355, + "ʰ↓↑": 356, + "ʰ↓": 357, + "ˈ→": 358, + "ˈ↑": 359, + "ˈ↓↑": 360, + "ˈ↓": 361, + "ˌ→": 362, + "ˌ↑": 363, + "ˌ↓↑": 364, + "ˌ↓": 365, + "θ→": 366, + "θ↑": 367, + "θ↓↑": 368, + "θ↓": 369, + "…→": 370, + "…↑": 371, + "…↓↑": 372, + "…↓": 373, + "⁼→": 374, + "⁼↑": 375, + "⁼↓↑": 376, + "⁼↓": 377, + + "wo→": 378, + "wo↑": 379, + "wo↓↑": 380, + "wo↓": 381, + + "p⁼→": 382, + "p⁼↑": 383, + "p⁼↓↑": 384, + "p⁼↓": 385, + + "pʰ→": 386, + "pʰ↑": 387, + "pʰ↓↑": 388, + "pʰ↓": 389, + + "p⁼wo→": 390, + "p⁼wo↑": 391, + "p⁼wo↓↑": 392, + "p⁼wo↓": 393, + + "pʰwo→": 394, + "pʰwo↑": 395, + "pʰwo↓↑": 396, + "pʰwo↓": 397, + + "mwo→": 398, + "mwo↑": 399, + "mwo↓↑": 400, + "mwo↓": 401, + + "fwo→": 402, + "fwo↑": 403, + "fwo↓↑": 404, + "fwo↓": 405, + + "t⁼→": 406, + "t⁼↑": 407, + "t⁼↓↑": 408, + "t⁼↓": 409, + + "tʰ→": 410, + "tʰ↑": 411, + "tʰ↓↑": 412, + "tʰ↓": 413, + + "k⁼→": 414, + "k⁼↑": 415, + "k⁼↓↑": 416, + "k⁼↓": 417, + + "kʰ→": 418, + "kʰ↑": 419, + "kʰ↓↑": 420, + "kʰ↓": 421, + + "tʃ→": 422, + "tʃ↑": 423, + "tʃ↓↑": 424, + "tʃ↓": 425, + + "tʃ⁼→": 426, + "tʃ⁼↑": 427, + "tʃ⁼↓↑": 428, + "tʃ⁼↓": 429, + + "tʃʰ→": 430, + "tʃʰ↑": 431, + "tʃʰ↓↑": 432, + "tʃʰ↓": 433, + + "ts→": 434, + "ts↑": 435, + "ts↓↑": 436, + "ts↓": 437, + + "ts⁼→": 438, + "ts⁼↑": 439, + "ts⁼↓↑": 440, + "ts⁼↓": 441, + + "ts`→": 442, + "ts`↑": 443, + "ts`↓↑": 444, + "ts`↓": 445, + + "tsʰ→": 446, + "tsʰ↑": 447, + "tsʰ↓↑": 448, + "tsʰ↓": 449, + + "ts`⁼→": 450, + "ts`⁼↑": 451, + "ts`⁼↓↑": 452, + "ts`⁼↓": 453, + + "ts`ʰ→": 454, + "ts`ʰ↑": 455, + "ts`ʰ↓↑": 456, + "ts`ʰ↓": 457, + + "s`→": 458, + "s`↑": 459, + "s`↓↑": 460, + "s`↓": 461, + + "ɹ`→": 462, + "ɹ`↑": 463, + "ɹ`↓↑": 464, + "ɹ`↓": 465, + + "s`ɹ`→": 466, + "s`ɹ`↑": 467, + "s`ɹ`↓↑": 468, + "s`ɹ`↓": 469, + + "ɹ`ɹ`→": 470, + "ɹ`ɹ`↑": 471, + "ɹ`ɹ`↓↑": 472, + "ɹ`ɹ`↓": 473, + + "ts`⁼ɹ`→": 474, + "ts`⁼ɹ`↑": 475, + "ts`⁼ɹ`↓↑": 476, + "ts`⁼ɹ`↓": 477, + + "ts`ʰɹ`→": 478, + "ts`ʰɹ`↑": 479, + "ts`ʰɹ`↓↑": 480, + "ts`ʰɹ`↓": 481, + + "ts⁼ɹ→": 482, + "ts⁼ɹ↑": 483, + "ts⁼ɹ↓↑": 484, + "ts⁼ɹ↓": 485, + + "tsʰɹ→": 486, + "tsʰɹ↑": 487, + "tsʰɹ↓↑": 488, + "tsʰɹ↓": 489, + + "sɹ→": 490, + "sɹ↑": 491, + "sɹ↓↑": 492, + "sɹ↓": 493, + + "aɪ→": 494, + "aɪ↑": 495, + "aɪ↓↑": 496, + "aɪ↓": 497, + + "eɪ→": 498, + "eɪ↑": 499, + "eɪ↓↑": 500, + "eɪ↓": 501, + + "ɑʊ→": 502, + "ɑʊ↑": 503, + "ɑʊ↓↑": 504, + "ɑʊ↓": 505, + + "oʊ→": 506, + "oʊ↑": 507, + "oʊ↓↑": 508, + "oʊ↓": 509, + + "jɛ→": 510, + "jɛ↑": 511, + "jɛ↓↑": 512, + "jɛ↓": 513, + + "jɛn→": 514, + "jɛn↑": 515, + "jɛn↓↑": 516, + "jɛn↓": 517, + + "ɥæ→": 518, + "ɥæ↑": 519, + "ɥæ↓↑": 520, + "ɥæ↓": 521, + + "ɥæn→": 522, + "ɥæn↑": 523, + "ɥæn↓↑": 524, + "ɥæn↓": 525, + + "an→": 526, + "an↑": 527, + "an↓↑": 528, + "an↓": 529, + + "in→": 530, + "in↑": 531, + "in↓↑": 532, + "in↓": 533, + + "ɥn→": 534, + "ɥn↑": 535, + "ɥn↓↑": 536, + "ɥn↓": 537, + + "ən→": 538, + "ən↑": 539, + "ən↓↑": 540, + "ən↓": 541, + + "ɑŋ→": 542, + "ɑŋ↑": 543, + "ɑŋ↓↑": 544, + "ɑŋ↓": 545, + + "iŋ→": 546, + "iŋ↑": 547, + "iŋ↓↑": 548, + "iŋ↓": 549, + + "ʊŋ→": 550, + "ʊŋ↑": 551, + "ʊŋ↓↑": 552, + "ʊŋ↓": 553, + + "jʊŋ→": 554, + "jʊŋ↑": 555, + "jʊŋ↓↑": 556, + "jʊŋ↓": 557, + + "əŋ→": 558, + "əŋ↑": 559, + "əŋ↓↑": 560, + "əŋ↓": 561, + + "əɹ`→": 562, + "əɹ`↑": 563, + "əɹ`↓↑": 564, + "əɹ`↓": 565, + + "ja→": 566, + "ja↑": 567, + "ja↓↑": 568, + "ja↓": 569, + + "iɛ→": 570, + "iɛ↑": 571, + "iɛ↓↑": 572, + "iɛ↓": 573, + + "iɑʊ→": 574, + "iɑʊ↑": 575, + "iɑʊ↓↑": 576, + "iɑʊ↓": 577, + + "joʊ→": 578, + "joʊ↑": 579, + "joʊ↓↑": 580, + "joʊ↓": 581, + + "iɑŋ→": 582, + "iɑŋ↑": 583, + "iɑŋ↓↑": 584, + "iɑŋ↓": 585, + + "wa→": 586, + "wa↑": 587, + "wa↓↑": 588, + "wa↓": 589, + + "waɪ→": 590, + "waɪ↑": 591, + "waɪ↓↑": 592, + "waɪ↓": 593, + + "weɪ→": 594, + "weɪ↑": 595, + "weɪ↓↑": 596, + "weɪ↓": 597, + + "wan→": 598, + "wan↑": 599, + "wan↓↑": 600, + "wan↓": 601, + + "wən→": 602, + "wən↑": 603, + "wən↓↑": 604, + "wən↓": 605, + + "uɑŋ→": 606, + "uɑŋ↑": 607, + "uɑŋ↓↑": 608, + "uɑŋ↓": 609, + + "ɥɛ→": 610, + "ɥɛ↑": 611, + "ɥɛ↓↑": 612, + "ɥɛ↓": 613, + + "ɐ": 614, + "ɕ": 615, + "ʌ": 616, + "c": 617, + "q": 618, + "r": 619, + "ɡ": 620, + "jɐ": 621, + "jʌ": 622, + "jo": 623, + "ju": 624, + "je": 625, + "we": 626, + "wi": 627, + "ɯj": 628, + "wɐ": 629, + "wɛ": 630, + "wʌ": 631, + "tɕ": 632, + "tɕʰ": 633, + + + "ʁ": 634, + "ɲ": 635, + "œ": 636, + ":": 637, + "jˈ": 638, + "ˈn": 639, + "wˈ": 640, + + + "tˈ": 641, + "ˈŋ": 642, + "ˈɪ": 643 + }, + "merges": [ + "↓ ↑", + "w o", + "p ⁼", + "p ʰ", + "p⁼ wo", + "pʰ wo", + "m wo", + "f wo", + + "t ⁼", + "t ʰ", + "k ⁼", + "k ʰ", + + "t ʃ", + "tʃ ⁼", + "tʃ ʰ", + + "t s", + "ts ⁼", + "ts `", + "ts ʰ", + + "ts` ⁼", + "ts` ʰ", + + "s `", + "ɹ `", + "s` ɹ`", + "ɹ` ɹ`", + "ts`⁼ ɹ`", + "ts`ʰ ɹ`", + + "ts⁼ ɹ", + "tsʰ ɹ", + "s ɹ", + + "a ɪ", + "e ɪ", + "ɑ ʊ", + "o ʊ", + + "j ɛ", + "jɛ n", + "ɥ æ", + "ɥæ n", + + "a n", + "i n", + "ɥ n", + "ə n", + "ɑ ŋ", + "i ŋ", + "ʊ ŋ", + "j ʊŋ", + "ə ŋ", + "ə ɹ`", + + "j a", + "i ɛ", + "i ɑʊ", + "jo ʊ", + "i ɑŋ", + "w a", + "w aɪ", + "we ɪ", + "w an", + "w ən", + "u ɑŋ", + "ɥ ɛ", + + "! →", + "! ↑", + "! ↓↑", + "! ↓", + "# →", + "# ↑", + "# ↓↑", + "# ↓", + "* →", + "* ↑", + "* ↓↑", + "* ↓", + ", →", + ", ↑", + ", ↓↑", + ", ↓", + "- →", + "- ↑", + "- ↓↑", + "- ↓", + ". →", + ". ↑", + ". ↓↑", + ". ↓", + "= →", + "= ↑", + "= ↓↑", + "= ↓", + "? →", + "? ↑", + "? ↓↑", + "? ↓", + "N →", + "N ↑", + "N ↓↑", + "N ↓", + "Q →", + "Q ↑", + "Q ↓↑", + "Q ↓", + "^ →", + "^ ↑", + "^ ↓↑", + "^ ↓", + "_ →", + "_ ↑", + "_ ↓↑", + "_ ↓", + "a →", + "a ↑", + "a ↓↑", + "a ↓", + "b →", + "b ↑", + "b ↓↑", + "b ↓", + "d →", + "d ↑", + "d ↓↑", + "d ↓", + "e →", + "e ↑", + "e ↓↑", + "e ↓", + "f →", + "f ↑", + "f ↓↑", + "f ↓", + "g →", + "g ↑", + "g ↓↑", + "g ↓", + "h →", + "h ↑", + "h ↓↑", + "h ↓", + "i →", + "i ↑", + "i ↓↑", + "i ↓", + "j →", + "j ↑", + "j ↓↑", + "j ↓", + "k →", + "k ↑", + "k ↓↑", + "k ↓", + "l →", + "l ↑", + "l ↓↑", + "l ↓", + "m →", + "m ↑", + "m ↓↑", + "m ↓", + "n →", + "n ↑", + "n ↓↑", + "n ↓", + "o →", + "o ↑", + "o ↓↑", + "o ↓", + "p →", + "p ↑", + "p ↓↑", + "p ↓", + "s →", + "s ↑", + "s ↓↑", + "s ↓", + "t →", + "t ↑", + "t ↓↑", + "t ↓", + "u →", + "u ↑", + "u ↓↑", + "u ↓", + "v →", + "v ↑", + "v ↓↑", + "v ↓", + "w →", + "w ↑", + "w ↓↑", + "w ↓", + "x →", + "x ↑", + "x ↓↑", + "x ↓", + "y →", + "y ↑", + "y ↓↑", + "y ↓", + "z →", + "z ↑", + "z ↓↑", + "z ↓", + "~ →", + "~ ↑", + "~ ↓↑", + "~ ↓", + "æ →", + "æ ↑", + "æ ↓↑", + "æ ↓", + "ç →", + "ç ↑", + "ç ↓↑", + "ç ↓", + "ð →", + "ð ↑", + "ð ↓↑", + "ð ↓", + "ŋ →", + "ŋ ↑", + "ŋ ↓↑", + "ŋ ↓", + "ɑ →", + "ɑ ↑", + "ɑ ↓↑", + "ɑ ↓", + "ɔ →", + "ɔ ↑", + "ɔ ↓↑", + "ɔ ↓", + "ə →", + "ə ↑", + "ə ↓↑", + "ə ↓", + "ɛ →", + "ɛ ↑", + "ɛ ↓↑", + "ɛ ↓", + "ɥ →", + "ɥ ↑", + "ɥ ↓↑", + "ɥ ↓", + "ɪ →", + "ɪ ↑", + "ɪ ↓↑", + "ɪ ↓", + "ɫ →", + "ɫ ↑", + "ɫ ↓↑", + "ɫ ↓", + "ɯ →", + "ɯ ↑", + "ɯ ↓↑", + "ɯ ↓", + "ɸ →", + "ɸ ↑", + "ɸ ↓↑", + "ɸ ↓", + "ɹ →", + "ɹ ↑", + "ɹ ↓↑", + "ɹ ↓", + "ɾ →", + "ɾ ↑", + "ɾ ↓↑", + "ɾ ↓", + "ʃ →", + "ʃ ↑", + "ʃ ↓↑", + "ʃ ↓", + "ʊ →", + "ʊ ↑", + "ʊ ↓↑", + "ʊ ↓", + "ʑ →", + "ʑ ↑", + "ʑ ↓↑", + "ʑ ↓", + "ʒ →", + "ʒ ↑", + "ʒ ↓↑", + "ʒ ↓", + "ʰ →", + "ʰ ↑", + "ʰ ↓↑", + "ʰ ↓", + "ˈ →", + "ˈ ↑", + "ˈ ↓↑", + "ˈ ↓", + "ˌ →", + "ˌ ↑", + "ˌ ↓↑", + "ˌ ↓", + "θ →", + "θ ↑", + "θ ↓↑", + "θ ↓", + "… →", + "… ↑", + "… ↓↑", + "… ↓", + "⁼ →", + "⁼ ↑", + "⁼ ↓↑", + "⁼ ↓", + "wo →", + "wo ↑", + "wo ↓↑", + "wo ↓", + + "p⁼ →", + "p⁼ ↑", + "p⁼ ↓↑", + "p⁼ ↓", + + "pʰ →", + "pʰ ↑", + "pʰ ↓↑", + "pʰ ↓", + + "p⁼wo →", + "p⁼wo ↑", + "p⁼wo ↓↑", + "p⁼wo ↓", + + "pʰwo →", + "pʰwo ↑", + "pʰwo ↓↑", + "pʰwo ↓", + + "mwo →", + "mwo ↑", + "mwo ↓↑", + "mwo ↓", + + "fwo →", + "fwo ↑", + "fwo ↓↑", + "fwo ↓", + + "t⁼ →", + "t⁼ ↑", + "t⁼ ↓↑", + "t⁼ ↓", + + "tʰ →", + "tʰ ↑", + "tʰ ↓↑", + "tʰ ↓", + + "k⁼ →", + "k⁼ ↑", + "k⁼ ↓↑", + "k⁼ ↓", + + "kʰ →", + "kʰ ↑", + "kʰ ↓↑", + "kʰ ↓", + + "tʃ →", + "tʃ ↑", + "tʃ ↓↑", + "tʃ ↓", + + "tʃ⁼ →", + "tʃ⁼ ↑", + "tʃ⁼ ↓↑", + "tʃ⁼ ↓", + + "tʃʰ →", + "tʃʰ ↑", + "tʃʰ ↓↑", + "tʃʰ ↓", + + "ts →", + "ts ↑", + "ts ↓↑", + "ts ↓", + + "ts⁼ →", + "ts⁼ ↑", + "ts⁼ ↓↑", + "ts⁼ ↓", + + "ts` →", + "ts` ↑", + "ts` ↓↑", + "ts` ↓", + + "tsʰ →", + "tsʰ ↑", + "tsʰ ↓↑", + "tsʰ ↓", + + "ts`⁼ →", + "ts`⁼ ↑", + "ts`⁼ ↓↑", + "ts`⁼ ↓", + + "ts`ʰ →", + "ts`ʰ ↑", + "ts`ʰ ↓↑", + "ts`ʰ ↓", + + "s` →", + "s` ↑", + "s` ↓↑", + "s` ↓", + + "ɹ` →", + "ɹ` ↑", + "ɹ` ↓↑", + "ɹ` ↓", + + "s`ɹ` →", + "s`ɹ` ↑", + "s`ɹ` ↓↑", + "s`ɹ` ↓", + + "ɹ`ɹ` →", + "ɹ`ɹ` ↑", + "ɹ`ɹ` ↓↑", + "ɹ`ɹ` ↓", + + "ts`⁼ɹ` →", + "ts`⁼ɹ` ↑", + "ts`⁼ɹ` ↓↑", + "ts`⁼ɹ` ↓", + + "ts`ʰɹ` →", + "ts`ʰɹ` ↑", + "ts`ʰɹ` ↓↑", + "ts`ʰɹ` ↓", + + "ts⁼ɹ →", + "ts⁼ɹ ↑", + "ts⁼ɹ ↓↑", + "ts⁼ɹ ↓", + + "tsʰɹ →", + "tsʰɹ ↑", + "tsʰɹ ↓↑", + "tsʰɹ ↓", + + "sɹ →", + "sɹ ↑", + "sɹ ↓↑", + "sɹ ↓", + + "aɪ →", + "aɪ ↑", + "aɪ ↓↑", + "aɪ ↓", + + "eɪ →", + "eɪ ↑", + "eɪ ↓↑", + "eɪ ↓", + + "ɑʊ →", + "ɑʊ ↑", + "ɑʊ ↓↑", + "ɑʊ ↓", + + "oʊ →", + "oʊ ↑", + "oʊ ↓↑", + "oʊ ↓", + + "jɛ →", + "jɛ ↑", + "jɛ ↓↑", + "jɛ ↓", + + "jɛn →", + "jɛn ↑", + "jɛn ↓↑", + "jɛn ↓", + + "ɥæ →", + "ɥæ ↑", + "ɥæ ↓↑", + "ɥæ ↓", + + "ɥæn →", + "ɥæn ↑", + "ɥæn ↓↑", + "ɥæn ↓", + + "an →", + "an ↑", + "an ↓↑", + "an ↓", + + "in →", + "in ↑", + "in ↓↑", + "in ↓", + + "ɥn →", + "ɥn ↑", + "ɥn ↓↑", + "ɥn ↓", + + "ən →", + "ən ↑", + "ən ↓↑", + "ən ↓", + + "ɑŋ →", + "ɑŋ ↑", + "ɑŋ ↓↑", + "ɑŋ ↓", + + "iŋ →", + "iŋ ↑", + "iŋ ↓↑", + "iŋ ↓", + + "ʊŋ →", + "ʊŋ ↑", + "ʊŋ ↓↑", + "ʊŋ ↓", + + "jʊŋ →", + "jʊŋ ↑", + "jʊŋ ↓↑", + "jʊŋ ↓", + + "əŋ →", + "əŋ ↑", + "əŋ ↓↑", + "əŋ ↓", + + "əɹ` →", + "əɹ` ↑", + "əɹ` ↓↑", + "əɹ` ↓", + + + "j ɐ", + "j ʌ", + "j o", + "j u", + "j e", + "w e", + "w i", + "ɯ j", + "w ɐ", + "w ɛ", + "w ʌ", + "t ɕ", + "tɕ ʰ", + + + "j ˈ", + "ˈ n", + "w ˈ", + + + "t ˈ", + "ˈ ŋ", + "ˈ ɪ" + ] + } +} \ No newline at end of file diff --git a/models/tts/debatts/utils/g2p/cleaners.py b/models/tts/debatts/utils/g2p/cleaners.py new file mode 100644 index 00000000..7d96e84a --- /dev/null +++ b/models/tts/debatts/utils/g2p/cleaners.py @@ -0,0 +1,62 @@ +import re +from utils.g2p.japanese import japanese_to_ipa +from utils.g2p.mandarin import chinese_to_ipa +from utils.g2p.english import english_to_ipa +from utils.g2p.french import french_to_ipa +from utils.g2p.korean import korean_to_ipa +from utils.g2p.german import german_to_ipa + +patterns = [ + r"\[EN\](.*?)\[EN\]", + r"\[ZH\](.*?)\[ZH\]", + r"\[JA\](.*?)\[JA\]", + r"\[FR\](.*?)\[FR\]", + r"\[KR\](.*?)\[KR\]", + r"\[DE\](.*?)\[DE\]", +] + + +def cje_cleaners(text): + matches = [] + for pattern in patterns: + matches.extend(re.finditer(pattern, text)) + + matches.sort(key=lambda x: x.start()) # Sort matches by their start positions + + outputs = "" + for match in matches: + text_segment = text[match.start() : match.end()] + phone = clean_one(text_segment) + outputs += phone + + return outputs + + +def clean_one(text): + if text.find("[ZH]") != -1: + text = re.sub( + r"\[ZH\](.*?)\[ZH\]", lambda x: chinese_to_ipa(x.group(1)) + " ", text + ) + if text.find("[JA]") != -1: + text = re.sub( + r"\[JA\](.*?)\[JA\]", lambda x: japanese_to_ipa(x.group(1)) + " ", text + ) + if text.find("[EN]") != -1: + text = re.sub( + r"\[EN\](.*?)\[EN\]", lambda x: english_to_ipa(x.group(1)) + " ", text + ) + if text.find("[FR]") != -1: + text = re.sub( + r"\[FR\](.*?)\[FR\]", lambda x: french_to_ipa(x.group(1)) + " ", text + ) + if text.find("[KR]") != -1: + text = re.sub( + r"\[KR\](.*?)\[KR\]", lambda x: korean_to_ipa(x.group(1)) + " ", text + ) + if text.find("[DE]") != -1: + text = re.sub( + r"\[DE\](.*?)\[DE\]", lambda x: german_to_ipa(x.group(1)) + " ", text + ) + text = re.sub(r"\s+$", "", text) + text = re.sub(r"([^\.,!\?\-…~])$", r"\1.", text) + return text diff --git a/models/tts/debatts/utils/g2p/english.py b/models/tts/debatts/utils/g2p/english.py new file mode 100644 index 00000000..a24a86a5 --- /dev/null +++ b/models/tts/debatts/utils/g2p/english.py @@ -0,0 +1,139 @@ +""" from https://github.com/keithito/tacotron """ + +import re +from unidecode import unidecode +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] + +# List of (ipa, ipa2) pairs +_ipa_to_ipa2 = [ + (re.compile("%s" % x[0]), x[1]) for x in [("r", "ɹ"), ("ʤ", "dʒ"), ("ʧ", "tʃ")] +] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def collapse_whitespace(text): + return re.sub(r"\s+", " ", text) + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) + else: + return "zero dollars" + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words( + num, andword="", zero="oh", group=2 + ).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +# Normalize numbers pronunciation +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text + + +# Mark 'ɫ' after 'l' +def mark_dark_l(text): + return re.sub(r"l([^aeiouæɑɔəɛɪʊ ]*(?: |$))", lambda x: "ɫ" + x.group(1), text) + + +def _english_to_ipa(text): + import eng_to_ipa as ipa + + text = unidecode(text).lower() + text = expand_abbreviations(text) + text = normalize_numbers(text) + phonemes = ipa.convert(text) + phonemes = collapse_whitespace(phonemes) + return phonemes + + +# Add some special operation +def english_to_ipa(text): + text = _english_to_ipa(text) + text = mark_dark_l(text) + for regex, replacement in _ipa_to_ipa2: + text = re.sub(regex, replacement, text) + return text.replace("...", "…") diff --git a/models/tts/debatts/utils/g2p/french.py b/models/tts/debatts/utils/g2p/french.py new file mode 100644 index 00000000..9c059e08 --- /dev/null +++ b/models/tts/debatts/utils/g2p/french.py @@ -0,0 +1,178 @@ +"""https://github.com/bootphon/phonemizer""" + +import re +from phonemizer import phonemize +from phonemizer.separator import Separator + +# List of (regular expression, replacement) pairs for abbreviations in french: +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("M", "monsieur"), + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ("N.B", "nota bene"), + ("M", "monsieur"), + ("p.c.q", "parce que"), + ("Pr", "professeur"), + ("qqch", "quelque chose"), + ("rdv", "rendez-vous"), + ("max", "maximum"), + ("min", "minimum"), + ("no", "numéro"), + ("adr", "adresse"), + ("dr", "docteur"), + ("st", "saint"), + ("co", "companie"), + ("jr", "junior"), + ("sgt", "sergent"), + ("capt", "capitain"), + ("col", "colonel"), + ("av", "avenue"), + ("av. J.-C", "avant Jésus-Christ"), + ("apr. J.-C", "après Jésus-Christ"), + ("art", "article"), + ("boul", "boulevard"), + ("c.-à-d", "c’est-à-dire"), + ("etc", "et cetera"), + ("ex", "exemple"), + ("excl", "exclusivement"), + ("boul", "boulevard"), + ] +] + [ + (re.compile("\\b%s" % x[0]), x[1]) + for x in [ + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ] +] + +rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": ".", + "…": ".", + "$": ".", + "“": "", + "”": "", + "‘": "", + "’": "", + "(": "", + ")": "", + "(": "", + ")": "", + "《": "", + "》": "", + "【": "", + "】": "", + "[": "", + "]": "", + "—": "", + "~": "-", + "~": "-", + "「": "", + "」": "", + "¿": "", + "¡": "", +} + +_special_map = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("ø", "ɸ"), + ("ː", ":"), + ("j", "jˈ"), # To avoid incorrect connect + ("n", "ˈn"), # To avoid incorrect connect + ("w", "wˈ"), # To avoid incorrect connect + ("ã", "a~"), + ("ɑ̃", "ɑ~"), + ("ɔ̃", "ɔ~"), + ("ɛ̃", "ɛ~"), + ("œ̃", "œ~"), + ] +] + + +def collapse_whitespace(text): + # Regular expression matching whitespace: + _whitespace_re = re.compile(r"\s+") + return re.sub(_whitespace_re, " ", text).strip() + + +def remove_punctuation_at_begin(text): + return re.sub(r"^[,.!?]+", "", text) + + +def remove_aux_symbols(text): + text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) + return text + + +def replace_symbols(text): + text = text.replace(";", ",") + text = text.replace("-", " ") + text = text.replace(":", ",") + text = text.replace("&", " et ") + return text + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def replace_punctuation(text): + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + return replaced_text + + +def text_normalize(text): + text = expand_abbreviations(text) + text = replace_punctuation(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) + text = remove_punctuation_at_begin(text) + text = collapse_whitespace(text) + text = re.sub(r"([^\.,!\?\-…])$", r"\1.", text) + return text + + +# special map +def special_map(text): + for regex, replacement in _special_map: + text = re.sub(regex, replacement, text) + return text + + +def french_to_ipa(text): + text = text_normalize(text) + + ipa = phonemize( + text.strip(), + language="fr-fr", + backend="espeak", + separator=Separator(phone=None, word=" ", syllable="|"), + strip=True, + preserve_punctuation=True, + njobs=4, + ) + + # remove "(en)" and "(fr)" tag + ipa = ipa.replace("(en)", "").replace("(fr)", "") + + ipa = special_map(ipa) + + return ipa diff --git a/models/tts/debatts/utils/g2p/german.py b/models/tts/debatts/utils/g2p/german.py new file mode 100644 index 00000000..3f9259a3 --- /dev/null +++ b/models/tts/debatts/utils/g2p/german.py @@ -0,0 +1,122 @@ +"""https://github.com/bootphon/phonemizer""" + +import re +from phonemizer import phonemize +from phonemizer.separator import Separator + +rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": ".", + "…": ".", + "$": ".", + "“": "", + "”": "", + "‘": "", + "’": "", + "(": "", + ")": "", + "(": "", + ")": "", + "《": "", + "》": "", + "【": "", + "】": "", + "[": "", + "]": "", + "—": "", + "~": "-", + "~": "-", + "「": "", + "」": "", + "¿": "", + "¡": "", +} + +_special_map = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("ø", "ɸ"), + ("ː", ":"), + ("ɜ", "ʒ"), + ("ɑ̃", "ɑ~"), + ("j", "jˈ"), # To avoid incorrect connect + ("n", "ˈn"), # To avoid incorrect connect + ("t", "tˈ"), # To avoid incorrect connect + ("ŋ", "ˈŋ"), # To avoid incorrect connect + ("ɪ", "ˈɪ"), # To avoid incorrect connect + ] +] + + +def collapse_whitespace(text): + # Regular expression matching whitespace: + _whitespace_re = re.compile(r"\s+") + return re.sub(_whitespace_re, " ", text).strip() + + +def remove_punctuation_at_begin(text): + return re.sub(r"^[,.!?]+", "", text) + + +def remove_aux_symbols(text): + text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) + return text + + +def replace_symbols(text): + text = text.replace(";", ",") + text = text.replace("-", " ") + text = text.replace(":", ",") + return text + + +def replace_punctuation(text): + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + return replaced_text + + +def text_normalize(text): + text = replace_punctuation(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) + text = remove_punctuation_at_begin(text) + text = collapse_whitespace(text) + text = re.sub(r"([^\.,!\?\-…])$", r"\1.", text) + return text + + +# special map +def special_map(text): + for regex, replacement in _special_map: + text = re.sub(regex, replacement, text) + return text + + +def german_to_ipa(text): + text = text_normalize(text) + + ipa = phonemize( + text.strip(), + language="de", + backend="espeak", + separator=Separator(phone=None, word=" ", syllable="|"), + strip=True, + preserve_punctuation=True, + njobs=4, + ) + + # remove "(en)" and "(fr)" tag + ipa = ipa.replace("(en)", "").replace("(de)", "") + + ipa = special_map(ipa) + + return ipa diff --git a/models/tts/debatts/utils/g2p/japanese.py b/models/tts/debatts/utils/g2p/japanese.py new file mode 100644 index 00000000..43d17d27 --- /dev/null +++ b/models/tts/debatts/utils/g2p/japanese.py @@ -0,0 +1,128 @@ +"""from https://github.com/Plachtaa/VALL-E-X/g2p""" + +import re +from unidecode import unidecode + +# Regular expression matching Japanese without punctuation marks: +_japanese_characters = re.compile( + r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" +) + +# Regular expression matching non-Japanese characters or punctuation marks: +_japanese_marks = re.compile( + r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" +) + +# List of (symbol, Japanese) pairs for marks: +_symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("%", "パーセント")]] + +# List of (romaji, ipa2) pairs for marks: +_romaji_to_ipa2 = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("u", "ɯ"), + ("ʧ", "tʃ"), + ("j", "dʑ"), + ("y", "j"), + ("ni", "n^i"), + ("nj", "n^"), + ("hi", "çi"), + ("hj", "ç"), + ("f", "ɸ"), + ("I", "i*"), + ("U", "ɯ*"), + ("r", "ɾ"), + ] +] + +# List of (consonant, sokuon) pairs: +_real_sokuon = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + (r"Q([↑↓]*[kg])", r"k#\1"), + (r"Q([↑↓]*[tdjʧ])", r"t#\1"), + (r"Q([↑↓]*[sʃ])", r"s\1"), + (r"Q([↑↓]*[pb])", r"p#\1"), + ] +] + +# List of (consonant, hatsuon) pairs: +_real_hatsuon = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + (r"N([↑↓]*[pbm])", r"m\1"), + (r"N([↑↓]*[ʧʥj])", r"n^\1"), + (r"N([↑↓]*[tdn])", r"n\1"), + (r"N([↑↓]*[kg])", r"ŋ\1"), + ] +] + + +def symbols_to_japanese(text): + for regex, replacement in _symbols_to_japanese: + text = re.sub(regex, replacement, text) + return text + + +def japanese_to_romaji_with_accent(text): + """Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html""" + import pyopenjtalk + + text = symbols_to_japanese(text) + sentences = re.split(_japanese_marks, text) + marks = re.findall(_japanese_marks, text) + text = "" + for i, sentence in enumerate(sentences): + if re.match(_japanese_characters, sentence): + if text != "": + text += " " + labels = pyopenjtalk.extract_fullcontext(sentence) + for n, label in enumerate(labels): + phoneme = re.search(r"\-([^\+]*)\+", label).group(1) + if phoneme not in ["sil", "pau"]: + text += ( + phoneme.replace("ch", "ʧ").replace("sh", "ʃ").replace("cl", "Q") + ) + else: + continue + # n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) + a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) + a2 = int(re.search(r"\+(\d+)\+", label).group(1)) + a3 = int(re.search(r"\+(\d+)/", label).group(1)) + if re.search(r"\-([^\+]*)\+", labels[n + 1]).group(1) in ["sil", "pau"]: + a2_next = -1 + else: + a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) + # Accent phrase boundary + if a3 == 1 and a2_next == 1: + text += " " + # Falling + elif a1 == 0 and a2_next == a2 + 1: + text += "↓" + # Rising + elif a2 == 1 and a2_next == 2: + text += "↑" + if i < len(marks): + text += unidecode(marks[i]).replace(" ", "") + return text + + +def get_real_sokuon(text): + for regex, replacement in _real_sokuon: + text = re.sub(regex, replacement, text) + return text + + +def get_real_hatsuon(text): + for regex, replacement in _real_hatsuon: + text = re.sub(regex, replacement, text) + return text + + +def japanese_to_ipa(text): + text = japanese_to_romaji_with_accent(text).replace("...", "…") + text = get_real_sokuon(text) + text = get_real_hatsuon(text) + for regex, replacement in _romaji_to_ipa2: + text = re.sub(regex, replacement, text) + return text diff --git a/models/tts/debatts/utils/g2p/korean.py b/models/tts/debatts/utils/g2p/korean.py new file mode 100644 index 00000000..60ba2e13 --- /dev/null +++ b/models/tts/debatts/utils/g2p/korean.py @@ -0,0 +1,167 @@ +"""https://github.com/bootphon/phonemizer""" + +import re + +# from g2pkk import G2p +# from jamo import hangul_to_jamo + +english_dictionary = { + "KOREA": "코리아", + "IDOL": "아이돌", + "IT": "아이티", + "IQ": "아이큐", + "UP": "업", + "DOWN": "다운", + "PC": "피씨", + "CCTV": "씨씨티비", + "SNS": "에스엔에스", + "AI": "에이아이", + "CEO": "씨이오", + "A": "에이", + "B": "비", + "C": "씨", + "D": "디", + "E": "이", + "F": "에프", + "G": "지", + "H": "에이치", + "I": "아이", + "J": "제이", + "K": "케이", + "L": "엘", + "M": "엠", + "N": "엔", + "O": "오", + "P": "피", + "Q": "큐", + "R": "알", + "S": "에스", + "T": "티", + "U": "유", + "V": "브이", + "W": "더블유", + "X": "엑스", + "Y": "와이", + "Z": "제트", +} + +# List of (jamo, ipa) pairs: (need to update) +_jamo_to_ipa = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("ㅏ", "ɐ"), + ("ㅑ", "jɐ"), + ("ㅓ", "ʌ"), + ("ㅕ", "jʌ"), + ("ㅗ", "o"), + ("ㅛ", "jo"), + ("ᅮ", "u"), + ("ㅠ", "ju"), + ("ᅳ", "ɯ"), + ("ㅣ", "i"), + ("ㅔ", "e"), + ("ㅐ", "ɛ"), + ("ㅖ", "je"), + ("ㅒ", "jɛ"), # lost + ("ㅚ", "we"), + ("ㅟ", "wi"), + ("ㅢ", "ɯj"), + ("ㅘ", "wɐ"), + ("ㅙ", "wɛ"), # lost + ("ㅝ", "wʌ"), + ("ㅞ", "wɛ"), # lost + ("ㄱ", "q"), # 'ɡ' or 'k' + ("ㄴ", "n"), + ("ㄷ", "t"), # d + ("ㄹ", "ɫ"), # 'ᄅ' is 'r', 'ᆯ' is 'ɫ' + ("ㅁ", "m"), + ("ㅂ", "p"), + ("ㅅ", "s"), # 'ᄉ'is 't', 'ᆺ'is 's' + ("ㅇ", "ŋ"), # 'ᄋ' is None, 'ᆼ' is 'ŋ' + ("ㅈ", "tɕ"), + ("ㅊ", "tɕʰ"), # tʃh + ("ㅋ", "kʰ"), # kh + ("ㅌ", "tʰ"), # th + ("ㅍ", "pʰ"), # ph + ("ㅎ", "h"), + ("ㄲ", "k*"), # q + ("ㄸ", "t*"), # t + ("ㅃ", "p*"), # p + ("ㅆ", "s*"), # 'ᄊ' is 's', 'ᆻ' is 't' + ("ㅉ", "tɕ*"), # tɕ ? + ] +] + +_special_map = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("ʃ", "ɕ"), + ("tɕh", "tɕʰ"), + ("kh", "kʰ"), + ("th", "tʰ"), + ("ph", "pʰ"), + ] +] + + +def normalize(text): + text = text.strip() + text = re.sub( + "[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text + ) + text = normalize_english(text) + text = text.lower() + return text + + +def normalize_english(text): + def fn(m): + word = m.group() + if word in english_dictionary: + return english_dictionary.get(word) + return word + + text = re.sub("([A-Za-z]+)", fn, text) + return text + + +# Convert jamo to IPA +def jamo_to_ipa(text): + res = "" + for t in text: + for regex, replacement in _jamo_to_ipa: + t = re.sub(regex, replacement, t) + res += t + return res + + +# special map +def special_map(text): + for regex, replacement in _special_map: + text = re.sub(regex, replacement, text) + return text + + +def korean_to_ipa(text): + text = normalize(text) + + # espeak-ng + from phonemizer import phonemize + from phonemizer.separator import Separator + + ipa = phonemize( + text, + language="ko", + backend="espeak", + separator=Separator(phone=None, word=" ", syllable="|"), + strip=True, + preserve_punctuation=True, + njobs=4, + ) + ipa = special_map(ipa) + # # hangul charactier + # g2p = G2p() + # text = g2p(text) + # text = list(hangul_to_jamo(text)) # '하늘' --> ['ᄒ', 'ᅡ', 'ᄂ', 'ᅳ', 'ᆯ'] + # ipa = jamo_to_ipa(text) + return ipa diff --git a/models/tts/debatts/utils/g2p/mandarin.py b/models/tts/debatts/utils/g2p/mandarin.py new file mode 100644 index 00000000..78eef537 --- /dev/null +++ b/models/tts/debatts/utils/g2p/mandarin.py @@ -0,0 +1,270 @@ +"""from https://github.com/Plachtaa/VALL-E-X/g2p""" + +import re +import jieba +import cn2an + +# List of (Latin alphabet, bopomofo) pairs: +_latin_to_bopomofo = [ + (re.compile("%s" % x[0], re.IGNORECASE), x[1]) + for x in [ + ("a", "ㄟˉ"), + ("b", "ㄅㄧˋ"), + ("c", "ㄙㄧˉ"), + ("d", "ㄉㄧˋ"), + ("e", "ㄧˋ"), + ("f", "ㄝˊㄈㄨˋ"), + ("g", "ㄐㄧˋ"), + ("h", "ㄝˇㄑㄩˋ"), + ("i", "ㄞˋ"), + ("j", "ㄐㄟˋ"), + ("k", "ㄎㄟˋ"), + ("l", "ㄝˊㄛˋ"), + ("m", "ㄝˊㄇㄨˋ"), + ("n", "ㄣˉ"), + ("o", "ㄡˉ"), + ("p", "ㄆㄧˉ"), + ("q", "ㄎㄧㄡˉ"), + ("r", "ㄚˋ"), + ("s", "ㄝˊㄙˋ"), + ("t", "ㄊㄧˋ"), + ("u", "ㄧㄡˉ"), + ("v", "ㄨㄧˉ"), + ("w", "ㄉㄚˋㄅㄨˋㄌㄧㄡˋ"), + ("x", "ㄝˉㄎㄨˋㄙˋ"), + ("y", "ㄨㄞˋ"), + ("z", "ㄗㄟˋ"), + ] +] + +# List of (bopomofo, romaji) pairs: +_bopomofo_to_romaji = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("ㄅㄛ", "p⁼wo"), + ("ㄆㄛ", "pʰwo"), + ("ㄇㄛ", "mwo"), + ("ㄈㄛ", "fwo"), + ("ㄅ", "p⁼"), + ("ㄆ", "pʰ"), + ("ㄇ", "m"), + ("ㄈ", "f"), + ("ㄉ", "t⁼"), + ("ㄊ", "tʰ"), + ("ㄋ", "n"), + ("ㄌ", "l"), + ("ㄍ", "k⁼"), + ("ㄎ", "kʰ"), + ("ㄏ", "h"), + ("ㄐ", "ʧ⁼"), + ("ㄑ", "ʧʰ"), + ("ㄒ", "ʃ"), + ("ㄓ", "ʦ`⁼"), + ("ㄔ", "ʦ`ʰ"), + ("ㄕ", "s`"), + ("ㄖ", "ɹ`"), + ("ㄗ", "ʦ⁼"), + ("ㄘ", "ʦʰ"), + ("ㄙ", "s"), + ("ㄚ", "a"), + ("ㄛ", "o"), + ("ㄜ", "ə"), + ("ㄝ", "e"), + ("ㄞ", "ai"), + ("ㄟ", "ei"), + ("ㄠ", "au"), + ("ㄡ", "ou"), + ("ㄧㄢ", "yeNN"), + ("ㄢ", "aNN"), + ("ㄧㄣ", "iNN"), + ("ㄣ", "əNN"), + ("ㄤ", "aNg"), + ("ㄧㄥ", "iNg"), + ("ㄨㄥ", "uNg"), + ("ㄩㄥ", "yuNg"), + ("ㄥ", "əNg"), + ("ㄦ", "əɻ"), + ("ㄧ", "i"), + ("ㄨ", "u"), + ("ㄩ", "ɥ"), + ("ˉ", "→"), + ("ˊ", "↑"), + ("ˇ", "↓↑"), + ("ˋ", "↓"), + ("˙", ""), + (",", ","), + ("。", "."), + ("!", "!"), + ("?", "?"), + ("—", "-"), + ] +] + +# List of (romaji, ipa) pairs: +_romaji_to_ipa = [ + (re.compile("%s" % x[0], re.IGNORECASE), x[1]) + for x in [ + ("ʃy", "ʃ"), + ("ʧʰy", "ʧʰ"), + ("ʧ⁼y", "ʧ⁼"), + ("NN", "n"), + ("Ng", "ŋ"), + ("y", "j"), + ("h", "x"), + ] +] + +# List of (bopomofo, ipa) pairs: +_bopomofo_to_ipa = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("ㄅㄛ", "p⁼wo"), + ("ㄆㄛ", "pʰwo"), + ("ㄇㄛ", "mwo"), + ("ㄈㄛ", "fwo"), + ("ㄧㄢ", "jɛn"), + ("ㄩㄢ", "ɥæn"), + ("ㄧㄣ", "in"), + ("ㄩㄣ", "ɥn"), + ("ㄧㄥ", "iŋ"), + ("ㄨㄥ", "ʊŋ"), + ("ㄩㄥ", "jʊŋ"), + # Add + ("ㄧㄚ", "ia"), + ("ㄧㄝ", "iɛ"), + ("ㄧㄠ", "iɑʊ"), + ("ㄧㄡ", "ioʊ"), + ("ㄧㄤ", "iɑŋ"), + ("ㄨㄚ", "ua"), + ("ㄨㄛ", "uo"), + ("ㄨㄞ", "uaɪ"), + ("ㄨㄟ", "ueɪ"), + ("ㄨㄢ", "uan"), + ("ㄨㄣ", "uən"), + ("ㄨㄤ", "uɑŋ"), + ("ㄩㄝ", "ɥɛ"), + # End + ("ㄅ", "p⁼"), + ("ㄆ", "pʰ"), + ("ㄇ", "m"), + ("ㄈ", "f"), + ("ㄉ", "t⁼"), + ("ㄊ", "tʰ"), + ("ㄋ", "n"), + ("ㄌ", "l"), + ("ㄍ", "k⁼"), + ("ㄎ", "kʰ"), + ("ㄏ", "x"), + ("ㄐ", "tʃ⁼"), + ("ㄑ", "tʃʰ"), + ("ㄒ", "ʃ"), + ("ㄓ", "ts`⁼"), + ("ㄔ", "ts`ʰ"), + ("ㄕ", "s`"), + ("ㄖ", "ɹ`"), + ("ㄗ", "ts⁼"), + ("ㄘ", "tsʰ"), + ("ㄙ", "s"), + ("ㄚ", "a"), + ("ㄛ", "o"), + ("ㄜ", "ə"), + ("ㄝ", "ɛ"), + ("ㄞ", "aɪ"), + ("ㄟ", "eɪ"), + ("ㄠ", "ɑʊ"), + ("ㄡ", "oʊ"), + ("ㄢ", "an"), + ("ㄣ", "ən"), + ("ㄤ", "ɑŋ"), + ("ㄥ", "əŋ"), + ("ㄦ", "əɻ"), + ("ㄧ", "i"), + ("ㄨ", "u"), + ("ㄩ", "ɥ"), + ("ˉ", "→"), + ("ˊ", "↑"), + ("ˇ", "↓↑"), + ("ˋ", "↓"), + ("˙", ""), + (",", ","), + ("。", "."), + ("!", "!"), + ("?", "?"), + ("—", "-"), + ] +] + + +# Convert numbers to Chinese pronunciation +def number_to_chinese(text): + numbers = re.findall(r"\d+(?:\.?\d+)?", text) + for number in numbers: + text = text.replace(number, cn2an.an2cn(number), 1) + return text + + +# Word Segmentation, and convert Chinese pronunciation to pinyin (bopomofo) +def chinese_to_bopomofo(text): + from pypinyin import lazy_pinyin, BOPOMOFO + + text = text.replace("、", ",").replace(";", ",").replace(":", ",") + words = jieba.lcut(text, cut_all=False) + text = "" + for word in words: + bopomofos = lazy_pinyin(word, BOPOMOFO) + if not re.search("[\u4e00-\u9fff]", word): + text += word + continue + for i in range(len(bopomofos)): + bopomofos[i] = re.sub(r"([\u3105-\u3129])$", r"\1ˉ", bopomofos[i]) + if text != "": + text += " " + text += "".join(bopomofos) + return text + + +# Convert latin pronunciation to pinyin (bopomofo) +def latin_to_bopomofo(text): + for regex, replacement in _latin_to_bopomofo: + text = re.sub(regex, replacement, text) + return text + + +# Convert pinyin (bopomofo) to Romaji (not used) +def bopomofo_to_romaji(text): + for regex, replacement in _bopomofo_to_romaji: + text = re.sub(regex, replacement, text) + return text + + +# Convert pinyin (bopomofo) to IPA +def bopomofo_to_ipa(text): + for regex, replacement in _bopomofo_to_ipa: + text = re.sub(regex, replacement, text) + return text + + +# Convert Chinese to Romaji (not used) +def chinese_to_romaji(text): + text = number_to_chinese(text) + text = chinese_to_bopomofo(text) + text = latin_to_bopomofo(text) + text = bopomofo_to_romaji(text) + text = re.sub("i([aoe])", r"y\1", text) + text = re.sub("u([aoəe])", r"w\1", text) + text = re.sub("([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ`\2", text).replace("ɻ", "ɹ`") + text = re.sub("([ʦs][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text) + return text + + +# Convert Chinese to IPA +def chinese_to_ipa(text): + text = number_to_chinese(text) + text = chinese_to_bopomofo(text) + text = latin_to_bopomofo(text) + text = bopomofo_to_ipa(text) + text = re.sub("i([aoe])", r"j\1", text) + text = re.sub("u([aoəe])", r"w\1", text) + text = re.sub("([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ`\2", text).replace("ɻ", "ɹ`") + text = re.sub("([s][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text) + return text diff --git a/models/tts/debatts/utils/g2p_liwei/__init__.py b/models/tts/debatts/utils/g2p_liwei/__init__.py new file mode 100644 index 00000000..fa08d7a8 --- /dev/null +++ b/models/tts/debatts/utils/g2p_liwei/__init__.py @@ -0,0 +1,66 @@ +from utils.g2p_liwei import cleaners +from tokenizers import Tokenizer +from utils.g2p_liwei.text_tokenizers import TextTokenizer +import json +import re + +class PhonemeBpeTokenizer: + + def __init__(self, vacab_path="/mntcephfs/lab_data/lijiaqi/Speech/utils/g2p_liwei/vacab.json"): + self.lang2backend = { + 'zh': "cmn", + 'ja': "ja", + "en": "en-us", + "fr": "fr-fr", + "ko": "ko", + "de": "de", + } + self.text_tokenizers = {} + self.int_text_tokenizers() + # TODO + vacab_path="/mntcephfs/lab_data/lijiaqi/Speech/utils/g2p_liwei/vacab.json" + with open(vacab_path, 'rb') as f: + json_data = f.read() + data = json.loads(json_data) + self.vocab = data['vocab'] + + def int_text_tokenizers(self): + for key, value in self.lang2backend.items(): + self.text_tokenizers[key] = TextTokenizer(language=value) + + def tokenize(self, text, language): + + # 1. convert text to phoneme + phonemes = self._clean_text(text, language, ['cjekfd_cleaners']) + # print('clean text: ', phonemes) + + # 2. tokenize phonemes + phoneme_tokens = self.phoneme2token(phonemes) + # print('encode: ', phoneme_tokens) + + # # 3. decode tokens [optional] + # decoded_text = self.tokenizer.decode(phoneme_tokens) + # print('decoded: ', decoded_text) + + return phonemes, phoneme_tokens + + def _clean_text(self, text, language, cleaner_names): + + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception('Unknown cleaner: %s' % name) + text = cleaner(text, language, self.text_tokenizers) + return text + + def phoneme2token(self, phonemes): + # 使用的是国际音标,可以将音素转化成token。实际上输入的phone id也是将音频先asr成文本再转成token的,使用的是同一套vocab体系 + tokens = [] + if isinstance(phonemes, list): + for phone in phonemes: + phonemes_split = phone.split("|") + tokens.append([self.vocab[p] for p in phonemes_split if p in self.vocab]) + else: + phonemes_split = phonemes.split("|") + tokens = [self.vocab[p] for p in phonemes_split if p in self.vocab] + return tokens \ No newline at end of file diff --git a/models/tts/debatts/utils/g2p_liwei/cleaners.py b/models/tts/debatts/utils/g2p_liwei/cleaners.py new file mode 100644 index 00000000..ee99d3be --- /dev/null +++ b/models/tts/debatts/utils/g2p_liwei/cleaners.py @@ -0,0 +1,25 @@ +import re +from utils.g2p_liwei.japanese import japanese_to_ipa +from utils.g2p_liwei.mandarin import chinese_to_ipa +from utils.g2p_liwei.english import english_to_ipa +from utils.g2p_liwei.french import french_to_ipa +from utils.g2p_liwei.korean import korean_to_ipa +from utils.g2p_liwei.german import german_to_ipa + +def cjekfd_cleaners(text, language, text_tokenizers): + + if language == 'zh': + return chinese_to_ipa(text, text_tokenizers['zh']) + elif language == 'ja': + return japanese_to_ipa(text, text_tokenizers['ja']) + elif language == 'en': + return english_to_ipa(text, text_tokenizers['en']) + elif language == 'fr': + return french_to_ipa(text, text_tokenizers['fr']) + elif language == 'ko': + return korean_to_ipa(text, text_tokenizers['ko']) + elif language == 'de': + return german_to_ipa(text, text_tokenizers['de']) + else: + raise Exception('Unknown language: %s' % language) + return None diff --git a/models/tts/debatts/utils/g2p_liwei/english.py b/models/tts/debatts/utils/g2p_liwei/english.py new file mode 100644 index 00000000..5951814e --- /dev/null +++ b/models/tts/debatts/utils/g2p_liwei/english.py @@ -0,0 +1,166 @@ +import re +from unidecode import unidecode +import inflect +''' + Text clean time +''' +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_percent_number_re = re.compile(r'([0-9\.\,]*[0-9]+%)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_fraction_re = re.compile(r'([0-9]+)/([0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\b' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), + ('etc', 'et cetera'), + ('btw', 'by the way'), +]] + +_special_map = [ + ('t|ɹ', 'tɹ'), + ('d|ɹ', 'dɹ'), + ('t|s', 'ts'), + ('d|z', 'dz'), + ('ɪ|ɹ', 'ɪɹ'), + ('ɐ', 'ɚ'), + ('ᵻ', 'ɪ'), + ('əl', 'l'), + ('x', 'k'), + ('ɬ', 'l'), + ('ʔ', 't'), + ('n̩', 'n'), + ('oː|ɹ', 'oːɹ') +] + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + +def _remove_commas(m): + return m.group(1).replace(',', '') + + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + +def _expand_percent(m): + return m.group(1).replace('%', ' percent ') + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + +def fraction_to_words(numerator, denominator): + if numerator == 1 and denominator == 2: + return "one half" + if numerator == 1 and denominator == 4: + return "one quarter" + if denominator == 2: + return _inflect.number_to_words(numerator) + " halves" + if denominator == 4: + return _inflect.number_to_words(numerator) + " quarters" + return _inflect.number_to_words(numerator) + " " + _inflect.ordinal(_inflect.number_to_words(denominator)) + +def _expand_fraction(m): + numerator = int(m.group(1)) + denominator = int(m.group(2)) + return fraction_to_words(numerator, denominator) + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return ' two thousand ' + elif num > 2000 and num < 2010: + return ' two thousand ' + _inflect.number_to_words(num % 100) + ' ' + elif num % 100 == 0: + return ' ' + _inflect.number_to_words(num // 100) + ' hundred ' + else: + return ' ' + _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') + ' ' + else: + return ' ' + _inflect.number_to_words(num, andword='') + ' ' + +# Normalize numbers pronunciation +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_fraction_re, _expand_fraction, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_percent_number_re, _expand_percent, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text + +def _english_to_ipa(text): + # text = unidecode(text).lower() + text = expand_abbreviations(text) + text = normalize_numbers(text) + return text + +# special map +def special_map(text): + for regex, replacement in _special_map: + regex = regex.replace("|", "\|") + while re.search(r'(^|[_|]){}([_|]|$)'.format(regex), text): + text = re.sub(r'(^|[_|]){}([_|]|$)'.format(regex), r'\1{}\2'.format(replacement), text) + # text = re.sub(r'([,.!?])', r'|\1', text) + return text + +# Add some special operation +def english_to_ipa(text, text_tokenizer): + if type(text) == str: + text = _english_to_ipa(text) + else: + text = [_english_to_ipa(t) for t in text] + phonemes = text_tokenizer(text) + if type(text) == str: + return special_map(phonemes) + else: + result_ph = [] + for phone in phonemes: + result_ph.append(special_map(phone)) + return result_ph \ No newline at end of file diff --git a/models/tts/debatts/utils/g2p_liwei/french.py b/models/tts/debatts/utils/g2p_liwei/french.py new file mode 100644 index 00000000..b1337860 --- /dev/null +++ b/models/tts/debatts/utils/g2p_liwei/french.py @@ -0,0 +1,164 @@ +'''https://github.com/bootphon/phonemizer''' +import re +from phonemizer import phonemize +from phonemizer.separator import Separator + +# List of (regular expression, replacement) pairs for abbreviations in french: +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("M", "monsieur"), + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ("N.B", "nota bene"), + ("M", "monsieur"), + ("p.c.q", "parce que"), + ("Pr", "professeur"), + ("qqch", "quelque chose"), + ("rdv", "rendez-vous"), + ("max", "maximum"), + ("min", "minimum"), + ("no", "numéro"), + ("adr", "adresse"), + ("dr", "docteur"), + ("st", "saint"), + ("co", "companie"), + ("jr", "junior"), + ("sgt", "sergent"), + ("capt", "capitain"), + ("col", "colonel"), + ("av", "avenue"), + ("av. J.-C", "avant Jésus-Christ"), + ("apr. J.-C", "après Jésus-Christ"), + ("art", "article"), + ("boul", "boulevard"), + ("c.-à-d", "c’est-à-dire"), + ("etc", "et cetera"), + ("ex", "exemple"), + ("excl", "exclusivement"), + ("boul", "boulevard"), + ] +] + [ + (re.compile("\\b%s" % x[0]), x[1]) + for x in [ + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ] +] + +rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": ".", + "…": ".", + "$": ".", + "“": "", + "”": "", + "‘": "", + "’": "", + "(": "", + ")": "", + "(": "", + ")": "", + "《": "", + "》": "", + "【": "", + "】": "", + "[": "", + "]": "", + "—": "", + "~": "-", + "~": "-", + "「": "", + "」": "", + "¿" : "", + "¡" : "" +} + +_special_map = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('ø', 'ɸ'), + ('ː', ':'), + ('j', 'jˈ'), # To avoid incorrect connect + ('n', 'ˈn'), # To avoid incorrect connect + ('w', 'wˈ'), # To avoid incorrect connect + ('ã', 'a~'), + ('ɑ̃', 'ɑ~'), + ('ɔ̃', 'ɔ~'), + ('ɛ̃', 'ɛ~'), + ('œ̃', 'œ~'), +]] + +def collapse_whitespace(text): + # Regular expression matching whitespace: + _whitespace_re = re.compile(r"\s+") + return re.sub(_whitespace_re, " ", text).strip() + +def remove_punctuation_at_begin(text): + return re.sub(r'^[,.!?]+', '', text) + +def remove_aux_symbols(text): + text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) + return text + +def replace_symbols(text): + text = text.replace(";", ",") + text = text.replace("-", " ") + text = text.replace(":", ",") + text = text.replace("&", " et ") + return text + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + +def replace_punctuation(text): + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + return replaced_text + +def text_normalize(text): + text = expand_abbreviations(text) + text = replace_punctuation(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) + text = remove_punctuation_at_begin(text) + text = collapse_whitespace(text) + text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text) + return text + +# special map +def special_map(text): + for regex, replacement in _special_map: + text = re.sub(regex, replacement, text) + return text + +def french_to_ipa(text): + text = text_normalize(text) + + ipa = phonemize(text.strip(), + language="fr-fr", + backend="espeak", + separator=Separator(phone=None, word=' ', syllable='|'), + strip=True, + preserve_punctuation=True, + njobs=4) + + # remove "(en)" and "(fr)" tag + ipa = ipa.replace("(en)", "").replace("(fr)", "") + + ipa = special_map(ipa) + + return ipa + \ No newline at end of file diff --git a/models/tts/debatts/utils/g2p_liwei/g2p_liwei.py b/models/tts/debatts/utils/g2p_liwei/g2p_liwei.py new file mode 100644 index 00000000..70cfb054 --- /dev/null +++ b/models/tts/debatts/utils/g2p_liwei/g2p_liwei.py @@ -0,0 +1,7 @@ +from utils.g2p_liwei import PhonemeBpeTokenizer +import tqdm + +text_tokenizer = PhonemeBpeTokenizer() + +def liwei_g2p(text, language): + return text_tokenizer.tokenize(text=text, language=language) \ No newline at end of file diff --git a/models/tts/debatts/utils/g2p_liwei/german.py b/models/tts/debatts/utils/g2p_liwei/german.py new file mode 100644 index 00000000..ef54e4ef --- /dev/null +++ b/models/tts/debatts/utils/g2p_liwei/german.py @@ -0,0 +1,108 @@ +'''https://github.com/bootphon/phonemizer''' +import re +from phonemizer import phonemize +from phonemizer.separator import Separator + +rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": ".", + "…": ".", + "$": ".", + "“": "", + "”": "", + "‘": "", + "’": "", + "(": "", + ")": "", + "(": "", + ")": "", + "《": "", + "》": "", + "【": "", + "】": "", + "[": "", + "]": "", + "—": "", + "~": "-", + "~": "-", + "「": "", + "」": "", + "¿" : "", + "¡" : "" +} + +_special_map = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('ø', 'ɸ'), + ('ː', ':'), + ('ɜ', 'ʒ'), + ('ɑ̃', 'ɑ~'), + ('j', 'jˈ'), # To avoid incorrect connect + ('n', 'ˈn'), # To avoid incorrect connect + ('t', 'tˈ'), # To avoid incorrect connect + ('ŋ', 'ˈŋ'), # To avoid incorrect connect + ('ɪ', 'ˈɪ'), # To avoid incorrect connect +]] + +def collapse_whitespace(text): + # Regular expression matching whitespace: + _whitespace_re = re.compile(r"\s+") + return re.sub(_whitespace_re, " ", text).strip() + +def remove_punctuation_at_begin(text): + return re.sub(r'^[,.!?]+', '', text) + +def remove_aux_symbols(text): + text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) + return text + +def replace_symbols(text): + text = text.replace(";", ",") + text = text.replace("-", " ") + text = text.replace(":", ",") + return text + +def replace_punctuation(text): + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + return replaced_text + +def text_normalize(text): + text = replace_punctuation(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) + text = remove_punctuation_at_begin(text) + text = collapse_whitespace(text) + text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text) + return text + +# special map +def special_map(text): + for regex, replacement in _special_map: + text = re.sub(regex, replacement, text) + return text + +def german_to_ipa(text): + text = text_normalize(text) + + ipa = phonemize(text.strip(), + language="de", + backend="espeak", + separator=Separator(phone=None, word=' ', syllable='|'), + strip=True, + preserve_punctuation=True, + njobs=4) + + # remove "(en)" and "(fr)" tag + ipa = ipa.replace("(en)", "").replace("(de)", "") + + ipa = special_map(ipa) + + return ipa \ No newline at end of file diff --git a/models/tts/debatts/utils/g2p_liwei/japanese.py b/models/tts/debatts/utils/g2p_liwei/japanese.py new file mode 100644 index 00000000..b33dd502 --- /dev/null +++ b/models/tts/debatts/utils/g2p_liwei/japanese.py @@ -0,0 +1,154 @@ +"""from https://github.com/Plachtaa/VALL-E-X/g2p""" +import re +from unidecode import unidecode + +''' + Text clean time +''' + +# Regular expression matching Japanese without punctuation marks: +_japanese_characters = re.compile( + r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') + +# Regular expression matching non-Japanese characters or punctuation marks: +_japanese_marks = re.compile( + r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') + +# List of (symbol, Japanese) pairs for marks: +_symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('%', 'パーセント') +]] + +# List of (romaji, ipa2) pairs for marks: +_romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('u', 'ɯ'), + ('ʧ', 'tʃ'), + ('j', 'dʑ'), + ('y', 'j'), + ('ni', 'n^i'), + ('nj', 'n^'), + ('hi', 'çi'), + ('hj', 'ç'), + ('f', 'ɸ'), + ('I', 'i*'), + ('U', 'ɯ*'), + ('r', 'ɾ') +]] + +# List of (consonant, sokuon) pairs: +_real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ + (r'Q([↑↓]*[kg])', r'k#\1'), + (r'Q([↑↓]*[tdjʧ])', r't#\1'), + (r'Q([↑↓]*[sʃ])', r's\1'), + (r'Q([↑↓]*[pb])', r'p#\1') +]] + +# List of (consonant, hatsuon) pairs: +_real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ + (r'N([↑↓]*[pbm])', r'm\1'), + (r'N([↑↓]*[ʧʥj])', r'n^\1'), + (r'N([↑↓]*[tdn])', r'n\1'), + (r'N([↑↓]*[kg])', r'ŋ\1') +]] + + +def symbols_to_japanese(text): + for regex, replacement in _symbols_to_japanese: + text = re.sub(regex, replacement, text) + return text + + +def japanese_to_romaji_with_accent(text): + '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' + import pyopenjtalk + text = symbols_to_japanese(text) + sentences = re.split(_japanese_marks, text) + marks = re.findall(_japanese_marks, text) + text = '' + for i, sentence in enumerate(sentences): + if re.match(_japanese_characters, sentence): + if text != '': + text += ' ' + labels = pyopenjtalk.extract_fullcontext(sentence) + for n, label in enumerate(labels): + phoneme = re.search(r'\-([^\+]*)\+', label).group(1) + if phoneme not in ['sil', 'pau']: + text += phoneme.replace('ch', 'ʧ').replace('sh', + 'ʃ').replace('cl', 'Q') + else: + continue + # n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) + a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) + a2 = int(re.search(r"\+(\d+)\+", label).group(1)) + a3 = int(re.search(r"\+(\d+)/", label).group(1)) + if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']: + a2_next = -1 + else: + a2_next = int( + re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) + # Accent phrase boundary + if a3 == 1 and a2_next == 1: + text += ' ' + # Falling + elif a1 == 0 and a2_next == a2 + 1: + text += '↓' + # Rising + elif a2 == 1 and a2_next == 2: + text += '↑' + if i < len(marks): + text += unidecode(marks[i]).replace(' ', '') + return text + + +def get_real_sokuon(text): + for regex, replacement in _real_sokuon: + text = re.sub(regex, replacement, text) + return text + + +def get_real_hatsuon(text): + for regex, replacement in _real_hatsuon: + text = re.sub(regex, replacement, text) + return text + +def japanese_to_ipa(text): + text = japanese_to_romaji_with_accent(text).replace('...', '…') + text = get_real_sokuon(text) + text = get_real_hatsuon(text) + for regex, replacement in _romaji_to_ipa2: + text = re.sub(regex, replacement, text) + return text + +''' + Phoneme merge time +''' +def _connect_tone(phoneme_tokens, vocab): + + tone_list = ["→", "↑", "↓↑", "↓"] + tone_token = [] + last_single_token = 0 + base = 0 + pattern = r"\[[^\[\]]*\]" # Exclude "[" and "]" + for tone, idx in vocab.items(): + if re.match(pattern, tone): + base = idx + 1 + if tone in tone_list: + tone_token.append(idx) + last_single_token = idx + + pre_token = None + cur_token = None + res_token = [] + for t in phoneme_tokens: + cur_token = t + if t in tone_token: + cur_token = last_single_token + (pre_token - base) * len(tone_list) + tone_token.index(t) + 1 + res_token.pop() + res_token.append(cur_token) + pre_token = t + + return res_token + +def japanese_merge_phoneme(phoneme_tokens, vocab): + phoneme_tokens = _connect_tone(phoneme_tokens, vocab) + return phoneme_tokens \ No newline at end of file diff --git a/models/tts/debatts/utils/g2p_liwei/korean.py b/models/tts/debatts/utils/g2p_liwei/korean.py new file mode 100644 index 00000000..ebf9ecf5 --- /dev/null +++ b/models/tts/debatts/utils/g2p_liwei/korean.py @@ -0,0 +1,149 @@ +'''https://github.com/bootphon/phonemizer''' +import re +# from g2pkk import G2p +# from jamo import hangul_to_jamo + +english_dictionary = { + "KOREA": "코리아", + "IDOL": "아이돌", + "IT": "아이티", + "IQ": "아이큐", + "UP": "업", + "DOWN": "다운", + "PC": "피씨", + "CCTV": "씨씨티비", + "SNS": "에스엔에스", + "AI": "에이아이", + "CEO": "씨이오", + "A": "에이", + "B": "비", + "C": "씨", + "D": "디", + "E": "이", + "F": "에프", + "G": "지", + "H": "에이치", + "I": "아이", + "J": "제이", + "K": "케이", + "L": "엘", + "M": "엠", + "N": "엔", + "O": "오", + "P": "피", + "Q": "큐", + "R": "알", + "S": "에스", + "T": "티", + "U": "유", + "V": "브이", + "W": "더블유", + "X": "엑스", + "Y": "와이", + "Z": "제트", +} + +# List of (jamo, ipa) pairs: (need to update) +_jamo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('ㅏ', 'ɐ'), + ('ㅑ', 'jɐ'), + ('ㅓ', 'ʌ'), + ('ㅕ', 'jʌ'), + ('ㅗ', 'o'), + ('ㅛ', 'jo'), + ('ᅮ', 'u'), + ('ㅠ', 'ju'), + ('ᅳ', 'ɯ'), + ('ㅣ', 'i'), + ('ㅔ', 'e'), + ('ㅐ', 'ɛ'), + ('ㅖ', 'je'), + ('ㅒ', 'jɛ'), # lost + ('ㅚ', 'we'), + ('ㅟ', 'wi'), + ('ㅢ', 'ɯj'), + ('ㅘ', 'wɐ'), + ('ㅙ', 'wɛ'), # lost + ('ㅝ', 'wʌ'), + ('ㅞ', 'wɛ'), # lost + ('ㄱ', 'q'), # 'ɡ' or 'k' + ('ㄴ', 'n'), + ('ㄷ', 't'), # d + ('ㄹ', 'ɫ'), # 'ᄅ' is 'r', 'ᆯ' is 'ɫ' + ('ㅁ', 'm'), + ('ㅂ', 'p'), + ('ㅅ', 's'), # 'ᄉ'is 't', 'ᆺ'is 's' + ('ㅇ', 'ŋ'), # 'ᄋ' is None, 'ᆼ' is 'ŋ' + ('ㅈ', 'tɕ'), + ('ㅊ', 'tɕʰ'), # tʃh + ('ㅋ', 'kʰ'), # kh + ('ㅌ', 'tʰ'), # th + ('ㅍ', 'pʰ'), # ph + ('ㅎ', 'h'), + ('ㄲ', 'k*'), # q + ('ㄸ', 't*'), # t + ('ㅃ', 'p*'), # p + ('ㅆ', 's*'), # 'ᄊ' is 's', 'ᆻ' is 't' + ('ㅉ', 'tɕ*'), # tɕ ? +]] + +_special_map = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('ʃ', 'ɕ'), + ('tɕh', 'tɕʰ'), + ('kh', 'kʰ'), + ('th', 'tʰ'), + ('ph', 'pʰ'), +]] + +def normalize(text): + text = text.strip() + text = re.sub("[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text) + text = normalize_english(text) + text = text.lower() + return text + +def normalize_english(text): + def fn(m): + word = m.group() + if word in english_dictionary: + return english_dictionary.get(word) + return word + + text = re.sub("([A-Za-z]+)", fn, text) + return text + +# Convert jamo to IPA +def jamo_to_ipa(text): + res = "" + for t in text: + for regex, replacement in _jamo_to_ipa: + t = re.sub(regex, replacement, t) + res += t + return res + +# special map +def special_map(text): + for regex, replacement in _special_map: + text = re.sub(regex, replacement, text) + return text + +def korean_to_ipa(text): + text = normalize(text) + + # espeak-ng + from phonemizer import phonemize + from phonemizer.separator import Separator + ipa = phonemize(text, + language="ko", + backend="espeak", + separator=Separator(phone=None, word=' ', syllable='|'), + strip=True, + preserve_punctuation=True, + njobs=4) + ipa = special_map(ipa) + # # hangul charactier + # g2p = G2p() + # text = g2p(text) + # text = list(hangul_to_jamo(text)) # '하늘' --> ['ᄒ', 'ᅡ', 'ᄂ', 'ᅳ', 'ᆯ'] + # ipa = jamo_to_ipa(text) + return ipa diff --git a/models/tts/debatts/utils/g2p_liwei/mandarin.py b/models/tts/debatts/utils/g2p_liwei/mandarin.py new file mode 100644 index 00000000..075cddf7 --- /dev/null +++ b/models/tts/debatts/utils/g2p_liwei/mandarin.py @@ -0,0 +1,191 @@ +import re +import jieba +import cn2an + +''' + Text clean time +''' +# List of (Latin alphabet, bopomofo) pairs: +_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ + ('a', 'ㄟˉ'), + ('b', 'ㄅㄧˋ'), + ('c', 'ㄙㄧˉ'), + ('d', 'ㄉㄧˋ'), + ('e', 'ㄧˋ'), + ('f', 'ㄝˊㄈㄨˋ'), + ('g', 'ㄐㄧˋ'), + ('h', 'ㄝˇㄑㄩˋ'), + ('i', 'ㄞˋ'), + ('j', 'ㄐㄟˋ'), + ('k', 'ㄎㄟˋ'), + ('l', 'ㄝˊㄛˋ'), + ('m', 'ㄝˊㄇㄨˋ'), + ('n', 'ㄣˉ'), + ('o', 'ㄡˉ'), + ('p', 'ㄆㄧˉ'), + ('q', 'ㄎㄧㄡˉ'), + ('r', 'ㄚˋ'), + ('s', 'ㄝˊㄙˋ'), + ('t', 'ㄊㄧˋ'), + ('u', 'ㄧㄡˉ'), + ('v', 'ㄨㄧˉ'), + ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), + ('x', 'ㄝˉㄎㄨˋㄙˋ'), + ('y', 'ㄨㄞˋ'), + ('z', 'ㄗㄟˋ') +]] + +# List of (bopomofo, ipa) pairs: +_bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('ㄅㄛ', 'p⁼wo'), + ('ㄆㄛ', 'pʰwo'), + ('ㄇㄛ', 'mwo'), + ('ㄈㄛ', 'fwo'), + ('ㄧㄢ', '|jɛn'), + ('ㄩㄢ', '|ɥæn'), + ('ㄧㄣ', '|in'), + ('ㄩㄣ', '|ɥn'), + ('ㄧㄥ', '|iŋ'), + ('ㄨㄥ', '|ʊŋ'), + ('ㄩㄥ', '|jʊŋ'), + # Add + ('ㄧㄚ', '|ia'), + ('ㄧㄝ', '|iɛ'), + ('ㄧㄠ', '|iɑʊ'), + ('ㄧㄡ', '|ioʊ'), + ('ㄧㄤ', '|iɑŋ'), + ('ㄨㄚ', '|ua'), + ('ㄨㄛ', '|uo'), + ('ㄨㄞ', '|uaɪ'), + ('ㄨㄟ', '|ueɪ'), + ('ㄨㄢ', '|uan'), + ('ㄨㄣ', '|uən'), + ('ㄨㄤ', '|uɑŋ'), + ('ㄩㄝ', '|ɥɛ'), + # End + ('ㄅ', 'p⁼'), + ('ㄆ', 'pʰ'), + ('ㄇ', 'm'), + ('ㄈ', 'f'), + ('ㄉ', 't⁼'), + ('ㄊ', 'tʰ'), + ('ㄋ', 'n'), + ('ㄌ', 'l'), + ('ㄍ', 'k⁼'), + ('ㄎ', 'kʰ'), + ('ㄏ', 'x'), + ('ㄐ', 'tʃ⁼'), + ('ㄑ', 'tʃʰ'), + ('ㄒ', 'ʃ'), + ('ㄓ', 'ts`⁼'), + ('ㄔ', 'ts`ʰ'), + ('ㄕ', 's`'), + ('ㄖ', 'ɹ`'), + ('ㄗ', 'ts⁼'), + ('ㄘ', 'tsʰ'), + ('ㄙ', '|s'), + ('ㄚ', '|a'), + ('ㄛ', '|o'), + ('ㄜ', '|ə'), + ('ㄝ', '|ɛ'), + ('ㄞ', '|aɪ'), + ('ㄟ', '|eɪ'), + ('ㄠ', '|ɑʊ'), + ('ㄡ', '|oʊ'), + ('ㄢ', '|an'), + ('ㄣ', '|ən'), + ('ㄤ', '|ɑŋ'), + ('ㄥ', '|əŋ'), + ('ㄦ', 'əɹ'), + ('ㄧ', '|i'), + ('ㄨ', '|u'), + ('ㄩ', '|ɥ'), + ('ˉ', '→|'), + ('ˊ', '↑|'), + ('ˇ', '↓↑|'), + ('ˋ', '↓|'), + ('˙', '|'), +]] + +# Convert numbers to Chinese pronunciation +def number_to_chinese(text): + # numbers = re.findall(r'\d+(?:\.?\d+)?', text) + # for number in numbers: + # text = text.replace(number, cn2an.an2cn(number), 1) + text = cn2an.transform(text, "an2cn") + return text + +def normalization(text): + text = text.replace(",", ",") + text = text.replace("。", ".") + text = text.replace("!", "!") + text = text.replace("?", "?") + text = text.replace(";", ";") + text = text.replace(":", ":") + text = text.replace("、", ",") + text = text.replace("‘", "'") + text = text.replace("’", "'") + text = text.replace("⋯", "…") + text = text.replace("···", "…") + text = text.replace("・・・", "…") + text = text.replace("...", "…") + text = re.sub(r"\s+", "", text) + text = re.sub(r'[^\u4e00-\u9fff\s_,\.\?!;:\'…]', '', text) + text = re.sub(r'\s*([,\.\?!;:\'…])\s*', r'\1', text) + return text + +# Word Segmentation, and convert Chinese pronunciation to pinyin (bopomofo) +def chinese_to_bopomofo(text): + from pypinyin import lazy_pinyin, BOPOMOFO + words = jieba.lcut(text, cut_all=False) + text = '' + for word in words: + bopomofos = lazy_pinyin(word, BOPOMOFO) + if not re.search('[\u4e00-\u9fff]', word): + text += word + continue + for i in range(len(bopomofos)): + bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i]) + if text != '': + text += '|' + text += '|'.join(bopomofos) + return text + +# Convert latin pronunciation to pinyin (bopomofo) +def latin_to_bopomofo(text): + for regex, replacement in _latin_to_bopomofo: + text = re.sub(regex, replacement, text) + return text + +# Convert pinyin (bopomofo) to IPA +def bopomofo_to_ipa(text): + for regex, replacement in _bopomofo_to_ipa: + text = re.sub(regex, replacement, text) + return text + +def _chinese_to_ipa(text): + text = number_to_chinese(text.strip()) + text = normalization(text) + # print("Normalized text: ", text) + text = chinese_to_bopomofo(text) + text = latin_to_bopomofo(text) + text = bopomofo_to_ipa(text) + text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', + r'\1ɹ\2', text) + text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) + text = re.sub(r'^\||[^\w\s_,\.\?!;:\'…\|→↓↑⁼ʰ`]', '', text) + text = re.sub(r'([,\.\?!;:\'…])', r'|\1|', text) + text = re.sub(r'\|+', '|', text) + text = text.rstrip('|') + return text + +# Convert Chinese to IPA +def chinese_to_ipa(text, text_tokenizer): + # phonemes = text_tokenizer(text.strip()) + if type(text) == str: + return _chinese_to_ipa(text) + else: + result_ph = [] + for t in text: + result_ph.append(_chinese_to_ipa(t)) + return result_ph diff --git a/models/tts/debatts/utils/g2p_liwei/text_tokenizers.py b/models/tts/debatts/utils/g2p_liwei/text_tokenizers.py new file mode 100644 index 00000000..943632df --- /dev/null +++ b/models/tts/debatts/utils/g2p_liwei/text_tokenizers.py @@ -0,0 +1,80 @@ +import re +import os +from typing import List, Pattern, Union +from phonemizer.utils import list2str, str2list +from phonemizer.backend import EspeakBackend +from phonemizer.backend.espeak.language_switch import LanguageSwitch +from phonemizer.backend.espeak.words_mismatch import WordMismatch +from phonemizer.punctuation import Punctuation +from phonemizer.separator import Separator + + + +class TextTokenizer: + """Phonemize Text.""" + + def __init__( + self, + language="en-us", + backend="espeak", + separator=Separator(word="|_|", syllable="-", phone="|"), + preserve_punctuation=True, + with_stress: bool = False, + tie: Union[bool, str] = False, + language_switch: LanguageSwitch = "remove-flags", + words_mismatch: WordMismatch = "ignore", + ) -> None: + self.preserve_punctuation_marks = ",.?!;:'…" + self.backend = EspeakBackend( + language, + punctuation_marks=self.preserve_punctuation_marks, + preserve_punctuation=preserve_punctuation, + with_stress=with_stress, + tie=tie, + language_switch=language_switch, + words_mismatch=words_mismatch, + ) + + self.separator = separator + + # convert chinese punctuation to english punctuation + def convert_chinese_punctuation(self, text: str) -> str: + text = text.replace(",", ",") + text = text.replace("。", ".") + text = text.replace("!", "!") + text = text.replace("?", "?") + text = text.replace(";", ";") + text = text.replace(":", ":") + text = text.replace("、", ",") + text = text.replace("‘", "'") + text = text.replace("’", "'") + text = text.replace("⋯", "…") + text = text.replace("···", "…") + text = text.replace("・・・", "…") + text = text.replace("...", "…") + return text + + def __call__(self, text, strip=True) -> List[str]: + + text_type = type(text) + normalized_text = [] + for line in str2list(text): + line = self.convert_chinese_punctuation(line.strip()) + line = re.sub(r'[^\w\s_,\.\?!;:\'…]', '', line) + line = re.sub(r'\s*([,\.\?!;:\'…])\s*', r'\1', line) + line = re.sub(r"\s+", " ", line) + normalized_text.append(line) + # print("Normalized test: ", normalized_text[0]) + phonemized = self.backend.phonemize( + normalized_text, separator=self.separator, strip=strip, njobs=1 + ) + if text_type == str: + phonemized = re.sub(r'([,\.\?!;:\'…])', r'|\1|', list2str(phonemized)) + phonemized = re.sub(r'\|+', '|', phonemized) + phonemized = phonemized.rstrip('|') + else: + for i in range(len(phonemized)): + phonemized[i] = re.sub(r'([,\.\?!;:\'…])', r'|\1|', phonemized[i]) + phonemized[i] = re.sub(r'\|+', '|', phonemized[i]) + phonemized[i] = phonemized[i].rstrip('|') + return phonemized \ No newline at end of file diff --git a/models/tts/debatts/utils/g2p_liwei/vacab.json b/models/tts/debatts/utils/g2p_liwei/vacab.json new file mode 100644 index 00000000..25762bfc --- /dev/null +++ b/models/tts/debatts/utils/g2p_liwei/vacab.json @@ -0,0 +1,372 @@ +{ + "vocab": { + ",": 0, + ".": 1, + "?": 2, + "!": 3, + "_": 4, + "iː": 5, + "ɪ": 6, + "ɜː": 7, + "ɚ": 8, + "oːɹ": 9, + "ɔː": 10, + "ɔːɹ": 11, + "ɑː": 12, + "uː": 13, + "ʊ": 14, + "ɑːɹ": 15, + "ʌ": 16, + "ɛ": 17, + "æ": 18, + "eɪ": 19, + "aɪ": 20, + "ɔɪ": 21, + "aʊ": 22, + "oʊ": 23, + "ɪɹ": 24, + "ɛɹ": 25, + "ʊɹ": 26, + "p": 27, + "b": 28, + "t": 29, + "d": 30, + "k": 31, + "ɡ": 32, + "f": 33, + "v": 34, + "θ": 35, + "ð": 36, + "s": 37, + "z": 38, + "ʃ": 39, + "ʒ": 40, + "h": 41, + "tʃ": 42, + "dʒ": 43, + "m": 44, + "n": 45, + "ŋ": 46, + "j": 47, + "w": 48, + "ɹ": 49, + "l": 50, + "tɹ": 51, + "dɹ": 52, + "ts": 53, + "dz": 54, + "i": 55, + "ɔ": 56, + "ə": 57, + "ɾ": 58, + "iə": 59, + "r": 60, + "u": 61, + "oː": 62, + "ɛː": 63, + "ɪː": 64, + "aɪə": 65, + "aɪɚ": 66, + "ɑ̃": 67, + "ç": 68, + "ɔ̃": 69, + "ææ": 70, + "ɐɐ": 71, + "ɡʲ": 72, + "nʲ": 73, + "iːː": 74, + + "p⁼": 75, + "pʰ": 76, + "t⁼": 77, + "tʰ": 78, + "k⁼": 79, + "kʰ": 80, + "x": 81, + "tʃ⁼": 82, + "tʃʰ": 83, + "ts`⁼": 84, + "ts`ʰ": 85, + "s`": 86, + "ɹ`": 87, + "ts⁼": 88, + "tsʰ": 89, + "p⁼wo": 90, + "p⁼wo→": 91, + "p⁼wo↑": 92, + "p⁼wo↓↑": 93, + "p⁼wo↓": 94, + "pʰwo": 95, + "pʰwo→": 96, + "pʰwo↑": 97, + "pʰwo↓↑": 98, + "pʰwo↓": 99, + "mwo": 100, + "mwo→": 101, + "mwo↑": 102, + "mwo↓↑": 103, + "mwo↓": 104, + "fwo": 105, + "fwo→": 106, + "fwo↑": 107, + "fwo↓↑": 108, + "fwo↓": 109, + "jɛn": 110, + "jɛn→": 111, + "jɛn↑": 112, + "jɛn↓↑": 113, + "jɛn↓": 114, + "ɥæn": 115, + "ɥæn→": 116, + "ɥæn↑": 117, + "ɥæn↓↑": 118, + "ɥæn↓": 119, + "in": 120, + "in→": 121, + "in↑": 122, + "in↓↑": 123, + "in↓": 124, + "ɥn": 125, + "ɥn→": 126, + "ɥn↑": 127, + "ɥn↓↑": 128, + "ɥn↓": 129, + "iŋ": 130, + "iŋ→": 131, + "iŋ↑": 132, + "iŋ↓↑": 133, + "iŋ↓": 134, + "ʊŋ": 135, + "ʊŋ→": 136, + "ʊŋ↑": 137, + "ʊŋ↓↑": 138, + "ʊŋ↓": 139, + "jʊŋ": 140, + "jʊŋ→": 141, + "jʊŋ↑": 142, + "jʊŋ↓↑": 143, + "jʊŋ↓": 144, + "ia": 145, + "ia→": 146, + "ia↑": 147, + "ia↓↑": 148, + "ia↓": 149, + "iɛ": 150, + "iɛ→": 151, + "iɛ↑": 152, + "iɛ↓↑": 153, + "iɛ↓": 154, + "iɑʊ": 155, + "iɑʊ→": 156, + "iɑʊ↑": 157, + "iɑʊ↓↑": 158, + "iɑʊ↓": 159, + "ioʊ": 160, + "ioʊ→": 161, + "ioʊ↑": 162, + "ioʊ↓↑": 163, + "ioʊ↓": 164, + "iɑŋ": 165, + "iɑŋ→": 166, + "iɑŋ↑": 167, + "iɑŋ↓↑": 168, + "iɑŋ↓": 169, + "ua": 170, + "ua→": 171, + "ua↑": 172, + "ua↓↑": 173, + "ua↓": 174, + "uo": 175, + "uo→": 176, + "uo↑": 177, + "uo↓↑": 178, + "uo↓": 179, + "uaɪ": 180, + "uaɪ→": 181, + "uaɪ↑": 182, + "uaɪ↓↑": 183, + "uaɪ↓": 184, + "ueɪ": 185, + "ueɪ→": 186, + "ueɪ↑": 187, + "ueɪ↓↑": 188, + "ueɪ↓": 189, + "uan": 190, + "uan→": 191, + "uan↑": 192, + "uan↓↑": 193, + "uan↓": 194, + "uən": 195, + "uən→": 196, + "uən↑": 197, + "uən↓↑": 198, + "uən↓": 199, + "uɑŋ": 200, + "uɑŋ→": 201, + "uɑŋ↑": 202, + "uɑŋ↓↑": 203, + "uɑŋ↓": 204, + "ɥɛ": 205, + "ɥɛ→": 206, + "ɥɛ↑": 207, + "ɥɛ↓↑": 208, + "ɥɛ↓": 209, + "a": 210, + "a→": 211, + "a↑": 212, + "a↓↑": 213, + "a↓": 214, + "o": 215, + "o→": 216, + "o↑": 217, + "o↓↑": 218, + "o↓": 219, + "ə→": 220, + "ə↑": 221, + "ə↓↑": 222, + "ə↓": 223, + "ɛ→": 224, + "ɛ↑": 225, + "ɛ↓↑": 226, + "ɛ↓": 227, + "aɪ→": 228, + "aɪ↑": 229, + "aɪ↓↑": 230, + "aɪ↓": 231, + "eɪ→": 232, + "eɪ↑": 233, + "eɪ↓↑": 234, + "eɪ↓": 235, + "ɑʊ": 236, + "ɑʊ→": 237, + "ɑʊ↑": 238, + "ɑʊ↓↑": 239, + "ɑʊ↓": 240, + "oʊ→": 241, + "oʊ↑": 242, + "oʊ↓↑": 243, + "oʊ↓": 244, + "an": 245, + "an→": 246, + "an↑": 247, + "an↓↑": 248, + "an↓": 249, + "ən": 250, + "ən→": 251, + "ən↑": 252, + "ən↓↑": 253, + "ən↓": 254, + "ɑŋ": 255, + "ɑŋ→": 256, + "ɑŋ↑": 257, + "ɑŋ↓↑": 258, + "ɑŋ↓": 259, + "əŋ": 260, + "əŋ→": 261, + "əŋ↑": 262, + "əŋ↓↑": 263, + "əŋ↓": 264, + "əɹ": 265, + "əɹ→": 266, + "əɹ↑": 267, + "əɹ↓↑": 268, + "əɹ↓": 269, + "i→": 270, + "i↑": 271, + "i↓↑": 272, + "i↓": 273, + "u→": 274, + "u↑": 275, + "u↓↑": 276, + "u↓": 277, + "ɥ": 278, + "ɥ→": 279, + "ɥ↑": 280, + "ɥ↓↑": 281, + "ɥ↓": 282, + "ts`⁼ɹ": 283, + "ts`⁼ɹ→": 284, + "ts`⁼ɹ↑": 285, + "ts`⁼ɹ↓↑": 286, + "ts`⁼ɹ↓": 287, + "ts`ʰɹ": 288, + "ts`ʰɹ→": 289, + "ts`ʰɹ↑": 290, + "ts`ʰɹ↓↑": 291, + "ts`ʰɹ↓": 292, + "s`ɹ": 293, + "s`ɹ→": 294, + "s`ɹ↑": 295, + "s`ɹ↓↑": 296, + "s`ɹ↓": 297, + "ɹ`ɹ": 298, + "ɹ`ɹ→": 299, + "ɹ`ɹ↑": 300, + "ɹ`ɹ↓↑": 301, + "ɹ`ɹ↓": 302, + "ts⁼ɹ": 303, + "ts⁼ɹ→": 304, + "ts⁼ɹ↑": 305, + "ts⁼ɹ↓↑": 306, + "ts⁼ɹ↓": 307, + "tsʰɹ": 308, + "tsʰɹ→": 309, + "tsʰɹ↑": 310, + "tsʰɹ↓↑": 311, + "tsʰɹ↓": 312, + "sɹ": 313, + "sɹ→": 314, + "sɹ↑": 315, + "sɹ↓↑": 316, + "sɹ↓": 317, + + "ɯ": 318, + "e": 319, + "aː": 320, + "ɯː": 321, + "eː": 322, + "ç": 323, + "ɸ": 324, + "ɰᵝ": 325, + "ɴ": 326, + "g": 327, + "dʑ": 328, + "q": 329, + "ː": 330, + "bj": 331, + "tɕ": 332, + "dej": 333, + "tej": 334, + "gj": 335, + "gɯ": 336, + "çj": 337, + "kj": 338, + "kɯ": 339, + "mj": 340, + "nj": 341, + "pj": 342, + "ɾj": 343, + "ɕ": 344, + "tsɯ": 345, + + "ɐ": 346, + "ɑ": 347, + "ɒ": 348, + "ɜ": 349, + "ɫ": 350, + "ʑ": 351, + "ʲ": 352, + + "y": 353, + "ø": 354, + "œ": 355, + "ʁ": 356, + "̃": 357, + "ɲ": 358, + + ":": 359, + ";": 360, + "'": 361, + "…": 362 + } +} \ No newline at end of file diff --git a/models/tts/debatts/utils/hparam.py b/models/tts/debatts/utils/hparam.py new file mode 100644 index 00000000..c5dd35c6 --- /dev/null +++ b/models/tts/debatts/utils/hparam.py @@ -0,0 +1,659 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py pylint: disable=line-too-long +"""Hyperparameter values.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import numbers +import re +import six + +# Define the regular expression for parsing a single clause of the input +# (delimited by commas). A legal clause looks like: +# []? = +# where is either a single token or [] enclosed list of tokens. +# For example: "var[1] = a" or "x = [1,2,3]" +PARAM_RE = re.compile( + r""" + (?P[a-zA-Z][\w\.]*) # variable name: "var" or "x" + (\[\s*(?P\d+)\s*\])? # (optional) index: "1" or None + \s*=\s* + ((?P[^,\[]*) # single value: "a" or None + | + \[(?P[^\]]*)\]) # list of values: None or "1,2,3" + ($|,\s*)""", + re.VERBOSE, +) + + +def _parse_fail(name, var_type, value, values): + """Helper function for raising a value error for bad assignment.""" + raise ValueError( + "Could not parse hparam '%s' of type '%s' with value '%s' in %s" + % (name, var_type.__name__, value, values) + ) + + +def _reuse_fail(name, values): + """Helper function for raising a value error for reuse of name.""" + raise ValueError("Multiple assignments to variable '%s' in %s" % (name, values)) + + +def _process_scalar_value(name, parse_fn, var_type, m_dict, values, results_dictionary): + """Update results_dictionary with a scalar value. + + Used to update the results_dictionary to be returned by parse_values when + encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) + + Mutates results_dictionary. + + Args: + name: Name of variable in assignment ("s" or "arr"). + parse_fn: Function for parsing the actual value. + var_type: Type of named variable. + m_dict: Dictionary constructed from regex parsing. + m_dict['val']: RHS value (scalar) + m_dict['index']: List index value (or None) + values: Full expression being parsed + results_dictionary: The dictionary being updated for return by the parsing + function. + + Raises: + ValueError: If the name has already been used. + """ + try: + parsed_value = parse_fn(m_dict["val"]) + except ValueError: + _parse_fail(name, var_type, m_dict["val"], values) + + # If no index is provided + if not m_dict["index"]: + if name in results_dictionary: + _reuse_fail(name, values) + results_dictionary[name] = parsed_value + else: + if name in results_dictionary: + # The name has already been used as a scalar, then it + # will be in this dictionary and map to a non-dictionary. + if not isinstance(results_dictionary.get(name), dict): + _reuse_fail(name, values) + else: + results_dictionary[name] = {} + + index = int(m_dict["index"]) + # Make sure the index position hasn't already been assigned a value. + if index in results_dictionary[name]: + _reuse_fail("{}[{}]".format(name, index), values) + results_dictionary[name][index] = parsed_value + + +def _process_list_value(name, parse_fn, var_type, m_dict, values, results_dictionary): + """Update results_dictionary from a list of values. + + Used to update results_dictionary to be returned by parse_values when + encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) + + Mutates results_dictionary. + + Args: + name: Name of variable in assignment ("arr"). + parse_fn: Function for parsing individual values. + var_type: Type of named variable. + m_dict: Dictionary constructed from regex parsing. + m_dict['val']: RHS value (scalar) + values: Full expression being parsed + results_dictionary: The dictionary being updated for return by the parsing + function. + + Raises: + ValueError: If the name has an index or the values cannot be parsed. + """ + if m_dict["index"] is not None: + raise ValueError("Assignment of a list to a list index.") + elements = filter(None, re.split("[ ,]", m_dict["vals"])) + # Make sure the name hasn't already been assigned a value + if name in results_dictionary: + raise _reuse_fail(name, values) + try: + results_dictionary[name] = [parse_fn(e) for e in elements] + except ValueError: + _parse_fail(name, var_type, m_dict["vals"], values) + + +def _cast_to_type_if_compatible(name, param_type, value): + """Cast hparam to the provided type, if compatible. + + Args: + name: Name of the hparam to be cast. + param_type: The type of the hparam. + value: The value to be cast, if compatible. + + Returns: + The result of casting `value` to `param_type`. + + Raises: + ValueError: If the type of `value` is not compatible with param_type. + * If `param_type` is a string type, but `value` is not. + * If `param_type` is a boolean, but `value` is not, or vice versa. + * If `param_type` is an integer type, but `value` is not. + * If `param_type` is a float type, but `value` is not a numeric type. + """ + fail_msg = "Could not cast hparam '%s' of type '%s' from value %r" % ( + name, + param_type, + value, + ) + + # Some callers use None, for which we can't do any casting/checking. :( + if issubclass(param_type, type(None)): + return value + + # Avoid converting a non-string type to a string. + if issubclass(param_type, (six.string_types, six.binary_type)) and not isinstance( + value, (six.string_types, six.binary_type) + ): + raise ValueError(fail_msg) + + # Avoid converting a number or string type to a boolean or vice versa. + if issubclass(param_type, bool) != isinstance(value, bool): + raise ValueError(fail_msg) + + # Avoid converting float to an integer (the reverse is fine). + if issubclass(param_type, numbers.Integral) and not isinstance( + value, numbers.Integral + ): + raise ValueError(fail_msg) + + # Avoid converting a non-numeric type to a numeric type. + if issubclass(param_type, numbers.Number) and not isinstance(value, numbers.Number): + raise ValueError(fail_msg) + + return param_type(value) + + +def parse_values(values, type_map, ignore_unknown=False): + """Parses hyperparameter values from a string into a python map. + + `values` is a string containing comma-separated `name=value` pairs. + For each pair, the value of the hyperparameter named `name` is set to + `value`. + + If a hyperparameter name appears multiple times in `values`, a ValueError + is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). + + If a hyperparameter name in both an index assignment and scalar assignment, + a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). + + The hyperparameter name may contain '.' symbols, which will result in an + attribute name that is only accessible through the getattr and setattr + functions. (And must be first explicit added through add_hparam.) + + WARNING: Use of '.' in your variable names is allowed, but is not well + supported and not recommended. + + The `value` in `name=value` must follows the syntax according to the + type of the parameter: + + * Scalar integer: A Python-parsable integer point value. E.g.: 1, + 100, -12. + * Scalar float: A Python-parsable floating point value. E.g.: 1.0, + -.54e89. + * Boolean: Either true or false. + * Scalar string: A non-empty sequence of characters, excluding comma, + spaces, and square brackets. E.g.: foo, bar_1. + * List: A comma separated list of scalar values of the parameter type + enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. + + When index assignment is used, the corresponding type_map key should be the + list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not + "arr[1]"). + + Args: + values: String. Comma separated list of `name=value` pairs where + 'value' must follow the syntax described above. + type_map: A dictionary mapping hyperparameter names to types. Note every + parameter name in values must be a key in type_map. The values must + conform to the types indicated, where a value V is said to conform to a + type T if either V has type T, or V is a list of elements of type T. + Hence, for a multidimensional parameter 'x' taking float values, + 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. + ignore_unknown: Bool. Whether values that are missing a type in type_map + should be ignored. If set to True, a ValueError will not be raised for + unknown hyperparameter type. + + Returns: + A python map mapping each name to either: + * A scalar value. + * A list of scalar values. + * A dictionary mapping index numbers to scalar values. + (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") + + Raises: + ValueError: If there is a problem with input. + * If `values` cannot be parsed. + * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). + * If the same rvalue is assigned two different values (e.g. 'a=1,a=2', + 'a[1]=1,a[1]=2', or 'a=1,a=[1]') + """ + results_dictionary = {} + pos = 0 + while pos < len(values): + m = PARAM_RE.match(values, pos) + if not m: + raise ValueError("Malformed hyperparameter value: %s" % values[pos:]) + # Check that there is a comma between parameters and move past it. + pos = m.end() + # Parse the values. + m_dict = m.groupdict() + name = m_dict["name"] + if name not in type_map: + if ignore_unknown: + continue + raise ValueError("Unknown hyperparameter type for %s" % name) + type_ = type_map[name] + + # Set up correct parsing function (depending on whether type_ is a bool) + if type_ == bool: + + def parse_bool(value): + if value in ["true", "True"]: + return True + elif value in ["false", "False"]: + return False + else: + try: + return bool(int(value)) + except ValueError: + _parse_fail(name, type_, value, values) + + parse = parse_bool + else: + parse = type_ + + # If a singe value is provided + if m_dict["val"] is not None: + _process_scalar_value( + name, parse, type_, m_dict, values, results_dictionary + ) + + # If the assigned value is a list: + elif m_dict["vals"] is not None: + _process_list_value(name, parse, type_, m_dict, values, results_dictionary) + + else: # Not assigned a list or value + _parse_fail(name, type_, "", values) + + return results_dictionary + + +class HParams(object): + """Class to hold a set of hyperparameters as name-value pairs. + + A `HParams` object holds hyperparameters used to build and train a model, + such as the number of hidden units in a neural net layer or the learning rate + to use when training. + + You first create a `HParams` object by specifying the names and values of the + hyperparameters. + + To make them easily accessible the parameter names are added as direct + attributes of the class. A typical usage is as follows: + + ```python + # Create a HParams object specifying names and values of the model + # hyperparameters: + hparams = HParams(learning_rate=0.1, num_hidden_units=100) + + # The hyperparameter are available as attributes of the HParams object: + hparams.learning_rate ==> 0.1 + hparams.num_hidden_units ==> 100 + ``` + + Hyperparameters have type, which is inferred from the type of their value + passed at construction type. The currently supported types are: integer, + float, boolean, string, and list of integer, float, boolean, or string. + + You can override hyperparameter values by calling the + [`parse()`](#HParams.parse) method, passing a string of comma separated + `name=value` pairs. This is intended to make it possible to override + any hyperparameter values from a single command-line flag to which + the user passes 'hyper-param=value' pairs. It avoids having to define + one flag for each hyperparameter. + + The syntax expected for each value depends on the type of the parameter. + See `parse()` for a description of the syntax. + + Example: + + ```python + # Define a command line flag to pass name=value pairs. + # For example using argparse: + import argparse + parser = argparse.ArgumentParser(description='Train my model.') + parser.add_argument('--hparams', type=str, + help='Comma separated list of "name=value" pairs.') + args = parser.parse_args() + ... + def my_program(): + # Create a HParams object specifying the names and values of the + # model hyperparameters: + hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100, + activations=['relu', 'tanh']) + + # Override hyperparameters values by parsing the command line + hparams.parse(args.hparams) + + # If the user passed `--hparams=learning_rate=0.3` on the command line + # then 'hparams' has the following attributes: + hparams.learning_rate ==> 0.3 + hparams.num_hidden_units ==> 100 + hparams.activations ==> ['relu', 'tanh'] + + # If the hyperparameters are in json format use parse_json: + hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') + ``` + """ + + _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. + + def __init__(self, model_structure=None, **kwargs): + """Create an instance of `HParams` from keyword arguments. + + The keyword arguments specify name-values pairs for the hyperparameters. + The parameter types are inferred from the type of the values passed. + + The parameter names are added as attributes of `HParams` object, so they + can be accessed directly with the dot notation `hparams._name_`. + + Example: + + ```python + # Define 3 hyperparameters: 'learning_rate' is a float parameter, + # 'num_hidden_units' an integer parameter, and 'activation' a string + # parameter. + hparams = tf.HParams( + learning_rate=0.1, num_hidden_units=100, activation='relu') + + hparams.activation ==> 'relu' + ``` + + Note that a few names are reserved and cannot be used as hyperparameter + names. If you use one of the reserved name the constructor raises a + `ValueError`. + + Args: + model_structure: An instance of ModelStructure, defining the feature + crosses to be used in the Trial. + **kwargs: Key-value pairs where the key is the hyperparameter name and + the value is the value for the parameter. + + Raises: + ValueError: If both `hparam_def` and initialization values are provided, + or if one of the arguments is invalid. + + """ + # Register the hyperparameters and their type in _hparam_types. + # This simplifies the implementation of parse(). + # _hparam_types maps the parameter name to a tuple (type, bool). + # The type value is the type of the parameter for scalar hyperparameters, + # or the type of the list elements for multidimensional hyperparameters. + # The bool value is True if the value is a list, False otherwise. + self._hparam_types = {} + self._model_structure = model_structure + for name, value in six.iteritems(kwargs): + self.add_hparam(name, value) + + def add_hparam(self, name, value): + """Adds {name, value} pair to hyperparameters. + + Args: + name: Name of the hyperparameter. + value: Value of the hyperparameter. Can be one of the following types: + int, float, string, int list, float list, or string list. + + Raises: + ValueError: if one of the arguments is invalid. + """ + # Keys in kwargs are unique, but 'name' could the name of a pre-existing + # attribute of this object. In that case we refuse to use it as a + # hyperparameter name. + if getattr(self, name, None) is not None: + raise ValueError("Hyperparameter name is reserved: %s" % name) + if isinstance(value, (list, tuple)): + if not value: + raise ValueError( + "Multi-valued hyperparameters cannot be empty: %s" % name + ) + self._hparam_types[name] = (type(value[0]), True) + else: + self._hparam_types[name] = (type(value), False) + setattr(self, name, value) + + def set_hparam(self, name, value): + """Set the value of an existing hyperparameter. + + This function verifies that the type of the value matches the type of the + existing hyperparameter. + + Args: + name: Name of the hyperparameter. + value: New value of the hyperparameter. + + Raises: + KeyError: If the hyperparameter doesn't exist. + ValueError: If there is a type mismatch. + """ + param_type, is_list = self._hparam_types[name] + if isinstance(value, list): + if not is_list: + raise ValueError( + "Must not pass a list for single-valued parameter: %s" % name + ) + setattr( + self, + name, + [_cast_to_type_if_compatible(name, param_type, v) for v in value], + ) + else: + if is_list: + raise ValueError( + "Must pass a list for multi-valued parameter: %s." % name + ) + setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) + + def del_hparam(self, name): + """Removes the hyperparameter with key 'name'. + + Does nothing if it isn't present. + + Args: + name: Name of the hyperparameter. + """ + if hasattr(self, name): + delattr(self, name) + del self._hparam_types[name] + + def parse(self, values): + """Override existing hyperparameter values, parsing new values from a string. + + See parse_values for more detail on the allowed format for values. + + Args: + values: String. Comma separated list of `name=value` pairs where 'value' + must follow the syntax described above. + + Returns: + The `HParams` instance. + + Raises: + ValueError: If `values` cannot be parsed or a hyperparameter in `values` + doesn't exist. + """ + type_map = {} + for name, t in self._hparam_types.items(): + param_type, _ = t + type_map[name] = param_type + + values_map = parse_values(values, type_map) + return self.override_from_dict(values_map) + + def override_from_dict(self, values_dict): + """Override existing hyperparameter values, parsing new values from a dictionary. + + Args: + values_dict: Dictionary of name:value pairs. + + Returns: + The `HParams` instance. + + Raises: + KeyError: If a hyperparameter in `values_dict` doesn't exist. + ValueError: If `values_dict` cannot be parsed. + """ + for name, value in values_dict.items(): + self.set_hparam(name, value) + return self + + def set_model_structure(self, model_structure): + self._model_structure = model_structure + + def get_model_structure(self): + return self._model_structure + + def to_json(self, indent=None, separators=None, sort_keys=False): + """Serializes the hyperparameters into JSON. + + Args: + indent: If a non-negative integer, JSON array elements and object members + will be pretty-printed with that indent level. An indent level of 0, or + negative, will only insert newlines. `None` (the default) selects the + most compact representation. + separators: Optional `(item_separator, key_separator)` tuple. Default is + `(', ', ': ')`. + sort_keys: If `True`, the output dictionaries will be sorted by key. + + Returns: + A JSON string. + """ + + def remove_callables(x): + """Omit callable elements from input with arbitrary nesting.""" + if isinstance(x, dict): + return { + k: remove_callables(v) + for k, v in six.iteritems(x) + if not callable(v) + } + elif isinstance(x, list): + return [remove_callables(i) for i in x if not callable(i)] + return x + + return json.dumps( + remove_callables(self.values()), + indent=indent, + separators=separators, + sort_keys=sort_keys, + ) + + def parse_json(self, values_json): + """Override existing hyperparameter values, parsing new values from a json object. + + Args: + values_json: String containing a json object of name:value pairs. + + Returns: + The `HParams` instance. + + Raises: + KeyError: If a hyperparameter in `values_json` doesn't exist. + ValueError: If `values_json` cannot be parsed. + """ + values_map = json.loads(values_json) + return self.override_from_dict(values_map) + + def values(self): + """Return the hyperparameter values as a Python dictionary. + + Returns: + A dictionary with hyperparameter names as keys. The values are the + hyperparameter values. + """ + return {n: getattr(self, n) for n in self._hparam_types.keys()} + + def get(self, key, default=None): + """Returns the value of `key` if it exists, else `default`.""" + if key in self._hparam_types: + # Ensure that default is compatible with the parameter type. + if default is not None: + param_type, is_param_list = self._hparam_types[key] + type_str = "list<%s>" % param_type if is_param_list else str(param_type) + fail_msg = ( + "Hparam '%s' of type '%s' is incompatible with " + "default=%s" % (key, type_str, default) + ) + + is_default_list = isinstance(default, list) + if is_param_list != is_default_list: + raise ValueError(fail_msg) + + try: + if is_default_list: + for value in default: + _cast_to_type_if_compatible(key, param_type, value) + else: + _cast_to_type_if_compatible(key, param_type, default) + except ValueError as e: + raise ValueError("%s. %s" % (fail_msg, e)) + + return getattr(self, key) + + return default + + def __contains__(self, key): + return key in self._hparam_types + + def __str__(self): + return str(sorted(self.values().items())) + + def __repr__(self): + return "%s(%s)" % (type(self).__name__, self.__str__()) + + @staticmethod + def _get_kind_name(param_type, is_list): + """Returns the field name given parameter type and is_list. + + Args: + param_type: Data type of the hparam. + is_list: Whether this is a list. + + Returns: + A string representation of the field name. + + Raises: + ValueError: If parameter type is not recognized. + """ + if issubclass(param_type, bool): + # This check must happen before issubclass(param_type, six.integer_types), + # since Python considers bool to be a subclass of int. + typename = "bool" + elif issubclass(param_type, six.integer_types): + # Setting 'int' and 'long' types to be 'int64' to ensure the type is + # compatible with both Python2 and Python3. + typename = "int64" + elif issubclass(param_type, (six.string_types, six.binary_type)): + # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is + # compatible with both Python2 and Python3. + typename = "bytes" + elif issubclass(param_type, float): + typename = "float" + else: + raise ValueError("Unsupported parameter type: %s" % str(param_type)) + + suffix = "list" if is_list else "value" + return "_".join([typename, suffix]) diff --git a/models/tts/debatts/utils/hubert.py b/models/tts/debatts/utils/hubert.py new file mode 100644 index 00000000..84b509fb --- /dev/null +++ b/models/tts/debatts/utils/hubert.py @@ -0,0 +1,155 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from https://github.com/svc-develop-team/so-vits-svc/blob/4.0/preprocess_hubert_f0.py + +import os +import librosa +import torch +import numpy as np +from fairseq import checkpoint_utils +from tqdm import tqdm +import torch + + +def load_hubert_model(hps): + # Load model + ckpt_path = hps.hubert_file + print("Load Hubert Model...") + + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [ckpt_path], + suffix="", + ) + model = models[0] + model.eval() + + if torch.cuda.is_available(): + model = model.cuda() + + return model + + +def get_hubert_content(hmodel, wav_16k_tensor): + feats = wav_16k_tensor + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.to(wav_16k_tensor.device), + "padding_mask": padding_mask.to(wav_16k_tensor.device), + "output_layer": 9, # layer 9 + } + with torch.no_grad(): + logits = hmodel.extract_features(**inputs) + feats = hmodel.final_proj(logits[0]).squeeze(0) + + return feats + + +def content_vector_encoder(model, audio_path, default_sampling_rate=16000): + """ + # content vector default sr: 16000 + """ + + wav16k, sr = librosa.load(audio_path, sr=default_sampling_rate) + device = next(model.parameters()).device + wav16k = torch.from_numpy(wav16k).to(device) + + # (1, 256, frame_len) + content_feature = get_hubert_content(model, wav_16k_tensor=wav16k) + + return content_feature.cpu().detach().numpy() + + +def repeat_expand_2d(content, target_len): + """ + content : [hubert_dim(256), src_len] + target: [hubert_dim(256), target_len] + """ + src_len = content.shape[-1] + target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to( + content.device + ) + temp = torch.arange(src_len + 1) * target_len / src_len + current_pos = 0 + for i in range(target_len): + if i < temp[current_pos + 1]: + target[:, i] = content[:, current_pos] + else: + current_pos += 1 + target[:, i] = content[:, current_pos] + + return target + + +def get_mapped_features(raw_content_features, mapping_features): + """ + Content Vector: frameshift = 20ms, hop_size = 480 in 24k + + Now it's only used for mapping to bigvgan's mels (sr = 24k, hop_size = 256, frameshift ~= 10.7 ms) + """ + source_hop = 480 + target_hop = 256 + + factor = np.gcd(source_hop, target_hop) + source_hop //= factor + target_hop //= factor + print( + "Mapping source's {} frames => target's {} frames".format( + target_hop, source_hop + ) + ) + + results = [] + for index, mapping_feat in enumerate(tqdm(mapping_features)): + # mappping_feat: (mels_frame_len, n_mels) + target_len = len(mapping_feat) + + # (source_len, 256) + raw_feats = raw_content_features[index][0].cpu().numpy().T + source_len, width = raw_feats.shape + + # const ~= target_len * target_hop + const = source_len * source_hop // target_hop * target_hop + + # (source_len * source_hop, dim) + up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0) + # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) + down_sampling_feats = np.average( + up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 + ) + + err = abs(target_len - len(down_sampling_feats)) + if err > 3: + print("index:", index) + print("mels:", mapping_feat.shape) + print("raw content vector:", raw_feats.shape) + print("up_sampling:", up_sampling_feats.shape) + print("down_sampling_feats:", down_sampling_feats.shape) + exit() + if len(down_sampling_feats) < target_len: + # (1, dim) -> (err, dim) + end = down_sampling_feats[-1][None, :].repeat(err, axis=0) + down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) + + # (target_len, dim) + feats = down_sampling_feats[:target_len] + results.append(feats) + + return results + + +def extract_hubert_features_of_dataset(datasets, model, out_dir): + for utt in tqdm(datasets): + uid = utt["Uid"] + audio_path = utt["Path"] + + content_vector_feature = content_vector_encoder(model, audio_path) # (T, 256) + + save_path = os.path.join(out_dir, uid + ".npy") + np.save(save_path, content_vector_feature) diff --git a/models/tts/debatts/utils/io.py b/models/tts/debatts/utils/io.py new file mode 100644 index 00000000..a93e75c6 --- /dev/null +++ b/models/tts/debatts/utils/io.py @@ -0,0 +1,182 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import numpy as np +import torch +import torchaudio + + +def save_feature(process_dir, feature_dir, item, feature, overrides=True): + """Save features to path + + Args: + process_dir (str): directory to store features + feature_dir (_type_): directory to store one type of features (mel, energy, ...) + item (str): uid + feature (tensor): feature tensor + overrides (bool, optional): whether to override existing files. Defaults to True. + """ + process_dir = os.path.join(process_dir, feature_dir) + os.makedirs(process_dir, exist_ok=True) + out_path = os.path.join(process_dir, item + ".npy") + + if os.path.exists(out_path): + if overrides: + np.save(out_path, feature) + else: + np.save(out_path, feature) + + +def save_txt(process_dir, feature_dir, item, feature, overrides=True): + process_dir = os.path.join(process_dir, feature_dir) + os.makedirs(process_dir, exist_ok=True) + out_path = os.path.join(process_dir, item + ".txt") + + if os.path.exists(out_path): + if overrides: + f = open(out_path, "w") + f.writelines(feature) + f.close() + else: + f = open(out_path, "w") + f.writelines(feature) + f.close() + + +def save_audio(path, waveform, fs, add_silence=False, turn_up=False, volume_peak=0.9): + """Save audio to path with processing (turn up volume, add silence) + Args: + path (str): path to save audio + waveform (numpy array): waveform to save + fs (int): sampling rate + add_silence (bool, optional): whether to add silence to beginning and end. Defaults to False. + turn_up (bool, optional): whether to turn up volume. Defaults to False. + volume_peak (float, optional): volume peak. Defaults to 0.9. + """ + if turn_up: + # continue to turn up to volume_peak + ratio = volume_peak / max(waveform.max(), abs(waveform.min())) + waveform = waveform * ratio + + if add_silence: + silence_len = fs // 20 + silence = np.zeros((silence_len,), dtype=waveform.dtype) + result = np.concatenate([silence, waveform, silence]) + waveform = result + + waveform = torch.as_tensor(waveform, dtype=torch.float32, device="cpu") + if len(waveform.size()) == 1: + waveform = waveform[None, :] + elif waveform.size(0) != 1: + # Stereo to mono + waveform = torch.mean(waveform, dim=0, keepdim=True) + torchaudio.save(path, waveform, fs, encoding="PCM_S", bits_per_sample=16) + + +def save_torch_audio(process_dir, feature_dir, item, wav_torch, fs, overrides=True): + """Save torch audio to path without processing + Args: + process_dir (str): directory to store features + feature_dir (_type_): directory to store one type of features (mel, energy, ...) + item (str): uid + wav_torch (tensor): feature tensor + fs (int): sampling rate + overrides (bool, optional): whether to override existing files. Defaults to True. + """ + if wav_torch.shape != 2: + wav_torch = wav_torch.unsqueeze(0) + + process_dir = os.path.join(process_dir, feature_dir) + os.makedirs(process_dir, exist_ok=True) + out_path = os.path.join(process_dir, item + ".wav") + + torchaudio.save(out_path, wav_torch, fs) + + +async def async_load_audio(path, sample_rate: int = 24000): + r""" + Args: + path: The source loading path. + sample_rate: The target sample rate, will automatically resample if necessary. + + Returns: + waveform: The waveform object. Should be [1 x sequence_len]. + """ + + async def use_torchaudio_load(path): + return torchaudio.load(path) + + waveform, sr = await use_torchaudio_load(path) + waveform = torch.mean(waveform, dim=0, keepdim=True) + + if sr != sample_rate: + waveform = torchaudio.functional.resample(waveform, sr, sample_rate) + + if torch.any(torch.isnan(waveform) or torch.isinf(waveform)): + raise ValueError("NaN or Inf found in waveform.") + return waveform + + +async def async_save_audio( + path, + waveform, + sample_rate: int = 24000, + add_silence: bool = False, + volume_peak: float = 0.9, +): + r""" + Args: + path: The target saving path. + waveform: The waveform object. Should be [n_channel x sequence_len]. + sample_rate: Sample rate. + add_silence: If ``true``, concat 0.05s silence to beginning and end. + volume_peak: Turn up volume for larger number, vice versa. + """ + + async def use_torchaudio_save(path, waveform, sample_rate): + torchaudio.save( + path, waveform, sample_rate, encoding="PCM_S", bits_per_sample=16 + ) + + waveform = torch.as_tensor(waveform, device="cpu", dtype=torch.float32) + shape = waveform.size()[:-1] + + ratio = abs(volume_peak) / max(waveform.max(), abs(waveform.min())) + waveform = waveform * ratio + + if add_silence: + silence_len = sample_rate // 20 + silence = torch.zeros((*shape, silence_len), dtype=waveform.type()) + waveform = torch.concatenate((silence, waveform, silence), dim=-1) + + if waveform.dim() == 1: + waveform = waveform[None] + + await use_torchaudio_save(path, waveform, sample_rate) + + +def load_mel_extrema(cfg, dataset_name, split): + dataset_dir = os.path.join( + cfg.OUTPUT_PATH, + "preprocess/{}_version".format(cfg.data.process_version), + dataset_name, + ) + + min_file = os.path.join( + dataset_dir, + "mel_min_max", + split.split("_")[-1], + "mel_min.npy", + ) + max_file = os.path.join( + dataset_dir, + "mel_min_max", + split.split("_")[-1], + "mel_max.npy", + ) + mel_min = np.load(min_file) + mel_max = np.load(max_file) + return mel_min, mel_max diff --git a/models/tts/debatts/utils/io_optim.py b/models/tts/debatts/utils/io_optim.py new file mode 100644 index 00000000..e9afaa06 --- /dev/null +++ b/models/tts/debatts/utils/io_optim.py @@ -0,0 +1,123 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torchaudio +import json +import os +import numpy as np +import librosa +import whisper +from torch.nn.utils.rnn import pad_sequence + + +class TorchaudioDataset(torch.utils.data.Dataset): + def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): + """ + Args: + cfg: config + dataset: dataset name + + """ + assert isinstance(dataset, str) + + self.sr = sr + self.cfg = cfg + + if metadata is None: + self.train_metadata_path = os.path.join( + cfg.preprocess.processed_dir, dataset, cfg.preprocess.train_file + ) + self.valid_metadata_path = os.path.join( + cfg.preprocess.processed_dir, dataset, cfg.preprocess.valid_file + ) + self.metadata = self.get_metadata() + else: + self.metadata = metadata + + if accelerator is not None: + self.device = accelerator.device + elif torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + def get_metadata(self): + metadata = [] + with open(self.train_metadata_path, "r", encoding="utf-8") as t: + metadata.extend(json.load(t)) + with open(self.valid_metadata_path, "r", encoding="utf-8") as v: + metadata.extend(json.load(v)) + return metadata + + def __len__(self): + return len(self.metadata) + + def __getitem__(self, index): + utt_info = self.metadata[index] + wav_path = utt_info["Path"] + + wav, sr = torchaudio.load(wav_path) + + # resample + if sr != self.sr: + wav = torchaudio.functional.resample(wav, sr, self.sr) + # downmixing + if wav.shape[0] > 1: + wav = torch.mean(wav, dim=0, keepdim=True) + assert wav.shape[0] == 1 + wav = wav.squeeze(0) + # record the length of wav without padding + length = wav.shape[0] + # wav: (T) + return utt_info, wav, length + + +class LibrosaDataset(TorchaudioDataset): + def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): + super().__init__(cfg, dataset, sr, accelerator, metadata) + + def __getitem__(self, index): + utt_info = self.metadata[index] + wav_path = utt_info["Path"] + + wav, _ = librosa.load(wav_path, sr=self.sr) + # wav: (T) + wav = torch.from_numpy(wav) + + # record the length of wav without padding + length = wav.shape[0] + return utt_info, wav, length + + +class FFmpegDataset(TorchaudioDataset): + def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): + super().__init__(cfg, dataset, sr, accelerator, metadata) + + def __getitem__(self, index): + utt_info = self.metadata[index] + wav_path = utt_info["Path"] + + # wav: (T,) + wav = whisper.load_audio(wav_path, sr=16000) # sr = 16000 + # convert to torch tensor + wav = torch.from_numpy(wav) + # record the length of wav without padding + length = wav.shape[0] + + return utt_info, wav, length + + +def collate_batch(batch_list): + """ + Args: + batch_list: list of (metadata, wav, length) + """ + metadata = [item[0] for item in batch_list] + # wavs: (B, T) + wavs = pad_sequence([item[1] for item in batch_list], batch_first=True) + lens = [item[2] for item in batch_list] + + return metadata, wavs, lens diff --git a/models/tts/debatts/utils/logger.py b/models/tts/debatts/utils/logger.py new file mode 100644 index 00000000..4d31167d --- /dev/null +++ b/models/tts/debatts/utils/logger.py @@ -0,0 +1,43 @@ +import logging +import time +import os + + +def init_logger(name): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + # Add file handler to save logs to a file + log_date = time.strftime("%Y-%m-%d", time.localtime()) + log_time = time.strftime("%H-%M-%S", time.localtime()) + + os.makedirs(f"logs/{log_date}", exist_ok=True) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + fh = logging.FileHandler(f"logs/{log_date}/{log_time}.log") + fh.setFormatter(formatter) + logger.addHandler(fh) + + # 创建一个自定义的日志格式器,将特定级别的日志设置为红色 + class ColorFormatter(logging.Formatter): + def format(self, record): + if record.levelno >= logging.ERROR: + record.msg = "\033[1;31m" + str(record.msg) + "\033[0m" + elif record.levelno >= logging.WARNING: + record.msg = "\033[1;33m" + str(record.msg) + "\033[0m" + elif record.levelno >= logging.INFO: + record.msg = "\033[1;34m" + str(record.msg) + "\033[0m" + elif record.levelno >= logging.DEBUG: + record.msg = "\033[1;32m" + str(record.msg) + "\033[0m" + return super().format(record) + + color_formatter = ColorFormatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ch = logging.StreamHandler() + ch.setFormatter(color_formatter) + logger.addHandler(ch) + + return logger diff --git a/models/tts/debatts/utils/mel.py b/models/tts/debatts/utils/mel.py new file mode 100644 index 00000000..3894b73c --- /dev/null +++ b/models/tts/debatts/utils/mel.py @@ -0,0 +1,280 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from librosa.filters import mel as librosa_mel_fn + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + # Min value: ln(1e-5) = -11.5129 + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def extract_linear_features(y, cfg, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + # complex tensor as default, then use view_as_real for future pytorch compatibility + spec = torch.stft( + y, + cfg.n_fft, + hop_length=cfg.hop_size, + win_length=cfg.win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + spec = torch.squeeze(spec, 0) + return spec + + +def mel_spectrogram_torch(y, cfg, center=False): + """ + TODO: to merge this funtion with the extract_mel_features below + """ + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + if cfg.fmax not in mel_basis: + mel = librosa_mel_fn( + sr=cfg.sample_rate, + n_fft=cfg.n_fft, + n_mels=cfg.n_mel, + fmin=cfg.fmin, + fmax=cfg.fmax, + ) + mel_basis[str(cfg.fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) + hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + cfg.n_fft, + hop_length=cfg.hop_size, + win_length=cfg.win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +mel_basis = {} +hann_window = {} + + +def extract_mel_features( + y, + cfg, + center=False, +): + """Extract mel features + + Args: + y (tensor): audio data in tensor + cfg (dict): configuration in cfg.preprocess + center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False. + + Returns: + tensor: a tensor containing the mel feature calculated based on STFT result + """ + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + if cfg.fmax not in mel_basis: + mel = librosa_mel_fn( + sr=cfg.sample_rate, + n_fft=cfg.n_fft, + n_mels=cfg.n_mel, + fmin=cfg.fmin, + fmax=cfg.fmax, + ) + mel_basis[str(cfg.fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) + hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + # complex tensor as default, then use view_as_real for future pytorch compatibility + spec = torch.stft( + y, + cfg.n_fft, + hop_length=cfg.hop_size, + win_length=cfg.win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + return spec.squeeze(0) + + +def extract_mel_features_tts( + y, + cfg, + center=False, + taco=False, + _stft=None, +): + """Extract mel features + + Args: + y (tensor): audio data in tensor + cfg (dict): configuration in cfg.preprocess + center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False. + taco: use tacotron mel + + Returns: + tensor: a tensor containing the mel feature calculated based on STFT result + """ + if not taco: + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + if cfg.fmax not in mel_basis: + mel = librosa_mel_fn( + sr=cfg.sample_rate, + n_fft=cfg.n_fft, + n_mels=cfg.n_mel, + fmin=cfg.fmin, + fmax=cfg.fmax, + ) + mel_basis[str(cfg.fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) + hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + # complex tensor as default, then use view_as_real for future pytorch compatibility + spec = torch.stft( + y, + cfg.n_fft, + hop_length=cfg.hop_size, + win_length=cfg.win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + else: + audio = torch.clip(y, -1, 1) + audio = torch.autograd.Variable(audio, requires_grad=False) + spec, energy = _stft.mel_spectrogram(audio) + + return spec.squeeze(0) + + +def amplitude_phase_spectrum(y, cfg): + hann_window = torch.hann_window(cfg.win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + stft_spec = torch.stft( + y, + cfg.n_fft, + hop_length=cfg.hop_size, + win_length=cfg.win_size, + window=hann_window, + center=False, + return_complex=True, + ) + + stft_spec = torch.view_as_real(stft_spec) + if stft_spec.size()[0] == 1: + stft_spec = stft_spec.squeeze(0) + + if len(list(stft_spec.size())) == 4: + rea = stft_spec[:, :, :, 0] # [batch_size, n_fft//2+1, frames] + imag = stft_spec[:, :, :, 1] # [batch_size, n_fft//2+1, frames] + else: + rea = stft_spec[:, :, 0] # [n_fft//2+1, frames] + imag = stft_spec[:, :, 1] # [n_fft//2+1, frames] + + log_amplitude = torch.log( + torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5 + ) # [n_fft//2+1, frames] + phase = torch.atan2(imag, rea) # [n_fft//2+1, frames] + + return log_amplitude, phase, rea, imag diff --git a/models/tts/debatts/utils/mert.py b/models/tts/debatts/utils/mert.py new file mode 100644 index 00000000..4181429f --- /dev/null +++ b/models/tts/debatts/utils/mert.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from https://huggingface.co/m-a-p/MERT-v1-330M + +import torch +from tqdm import tqdm +import numpy as np + +from transformers import Wav2Vec2FeatureExtractor +from transformers import AutoModel +import torchaudio +import torchaudio.transforms as T +from sklearn.preprocessing import StandardScaler + + +def mert_encoder(model, processor, audio_path, hps): + """ + # mert default sr: 24000 + """ + with torch.no_grad(): + resample_rate = processor.sampling_rate + device = next(model.parameters()).device + + input_audio, sampling_rate = torchaudio.load(audio_path) + input_audio = input_audio.squeeze() + + if sampling_rate != resample_rate: + resampler = T.Resample(sampling_rate, resample_rate) + input_audio = resampler(input_audio) + + inputs = processor( + input_audio, sampling_rate=resample_rate, return_tensors="pt" + ).to( + device + ) # {input_values: tensor, attention_mask: tensor} + + outputs = model(**inputs, output_hidden_states=True) # list: len is 25 + + # [25 layer, Time steps, 1024 feature_dim] + # all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() + # mert_features.append(all_layer_hidden_states) + + feature = outputs.hidden_states[ + hps.mert_feature_layer + ].squeeze() # [1, frame len, 1024] -> [frame len, 1024] + + return feature.cpu().detach().numpy() + + +def mert_features_normalization(raw_mert_features): + normalized_mert_features = list() + + mert_features = np.array(raw_mert_features) + scaler = StandardScaler().fit(mert_features) + for raw_mert_feature in raw_mert_feature: + normalized_mert_feature = scaler.transform(raw_mert_feature) + normalized_mert_features.append(normalized_mert_feature) + return normalized_mert_features + + +def get_mapped_mert_features(raw_mert_features, mapping_features, fast_mapping=True): + source_hop = 320 + target_hop = 256 + + factor = np.gcd(source_hop, target_hop) + source_hop //= factor + target_hop //= factor + print( + "Mapping source's {} frames => target's {} frames".format( + target_hop, source_hop + ) + ) + + mert_features = [] + for index, mapping_feat in enumerate(tqdm(mapping_features)): + # mapping_feat: (mels_frame_len, n_mels) + target_len = mapping_feat.shape[0] + + # (frame_len, 1024) + raw_feats = raw_mert_features[index].cpu().numpy() + source_len, width = raw_feats.shape + + # const ~= target_len * target_hop + const = source_len * source_hop // target_hop * target_hop + + # (source_len * source_hop, dim) + up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0) + # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) + down_sampling_feats = np.average( + up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 + ) + + err = abs(target_len - len(down_sampling_feats)) + if err > 3: + print("index:", index) + print("mels:", mapping_feat.shape) + print("raw mert vector:", raw_feats.shape) + print("up_sampling:", up_sampling_feats.shape) + print("const:", const) + print("down_sampling_feats:", down_sampling_feats.shape) + exit() + if len(down_sampling_feats) < target_len: + # (1, dim) -> (err, dim) + end = down_sampling_feats[-1][None, :].repeat(err, axis=0) + down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) + + # (target_len, dim) + feats = down_sampling_feats[:target_len] + mert_features.append(feats) + + return mert_features + + +def load_mert_model(hps): + print("Loading MERT Model: ", hps.mert_model) + + # Load model + model_name = hps.mert_model + model = AutoModel.from_pretrained(model_name, trust_remote_code=True) + + if torch.cuda.is_available(): + model = model.cuda() + + # model = model.eval() + + preprocessor = Wav2Vec2FeatureExtractor.from_pretrained( + model_name, trust_remote_code=True + ) + return model, preprocessor + + +# loading the corresponding preprocessor config +# def load_preprocessor (model_name="m-a-p/MERT-v1-330M"): +# print('load_preprocessor...') +# preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(model_name,trust_remote_code=True) +# return preprocessor diff --git a/models/tts/debatts/utils/mfa_prepare.py b/models/tts/debatts/utils/mfa_prepare.py new file mode 100644 index 00000000..b79ba862 --- /dev/null +++ b/models/tts/debatts/utils/mfa_prepare.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" This code is modified from https://montreal-forced-aligner.readthedocs.io/en/latest/user_guide/performance.html""" + +import os +import subprocess +from multiprocessing import Pool +from tqdm import tqdm +import torchaudio +from pathlib import Path + + +def remove_empty_dirs(path): + """remove empty directories in a given path""" + # Check if the given path is a directory + if not os.path.isdir(path): + print(f"{path} is not a directory") + return + + # Walk through all directories and subdirectories + for root, dirs, _ in os.walk(path, topdown=False): + for dir in dirs: + dir_path = os.path.join(root, dir) + # Check if the directory is empty + if not os.listdir(dir_path): + os.rmdir(dir_path) # "Removed empty directory + + +def process_single_wav_file(task): + """process a single wav file""" + wav_file, output_dir = task + speaker_id, book_name, filename = Path(wav_file).parts[-3:] + + output_book_dir = Path(output_dir, speaker_id) + output_book_dir.mkdir(parents=True, exist_ok=True) + new_filename = f"{speaker_id}_{book_name}_{filename}" + + new_wav_file = Path(output_book_dir, new_filename) + command = [ + "ffmpeg", + "-nostdin", + "-hide_banner", + "-loglevel", + "error", + "-nostats", + "-i", + wav_file, + "-acodec", + "pcm_s16le", + "-ar", + "16000", + new_wav_file, + ] + subprocess.check_call( + command + ) # Run the command to convert the file to 16kHz and 16-bit PCM + os.remove(wav_file) + + +def process_wav_files(wav_files, output_dir, n_process): + """process wav files in parallel""" + tasks = [(wav_file, output_dir) for wav_file in wav_files] + print(f"Processing {len(tasks)} files") + with Pool(processes=n_process) as pool: + for _ in tqdm( + pool.imap_unordered(process_single_wav_file, tasks), total=len(tasks) + ): + pass + print("Removing empty directories...") + remove_empty_dirs(output_dir) + print("Done!") + + +def get_wav_files(dataset_path): + """get all wav files in the dataset""" + wav_files = [] + for speaker_id in os.listdir(dataset_path): + speaker_dir = os.path.join(dataset_path, speaker_id) + if not os.path.isdir(speaker_dir): + continue + for book_name in os.listdir(speaker_dir): + book_dir = os.path.join(speaker_dir, book_name) + if not os.path.isdir(book_dir): + continue + for file in os.listdir(book_dir): + if file.endswith(".wav"): + wav_files.append(os.path.join(book_dir, file)) + print("Found {} wav files".format(len(wav_files))) + return wav_files + + +def filter_wav_files_by_length(wav_files, max_len_sec=15): + """filter wav files by length""" + print("original wav files: {}".format(len(wav_files))) + filtered_wav_files = [] + for audio_file in wav_files: + metadata = torchaudio.info(str(audio_file)) + audio_length = metadata.num_frames / metadata.sample_rate + if audio_length <= max_len_sec: + filtered_wav_files.append(audio_file) + else: + os.remove(audio_file) + print("filtered wav files: {}".format(len(filtered_wav_files))) + return filtered_wav_files + + +if __name__ == "__main__": + dataset_path = "/path/to/output/directory" + n_process = 16 + max_len_sec = 15 + wav_files = get_wav_files(dataset_path) + filtered_wav_files = filter_wav_files_by_length(wav_files, max_len_sec) + process_wav_files(filtered_wav_files, dataset_path, n_process) diff --git a/models/tts/debatts/utils/model_summary.py b/models/tts/debatts/utils/model_summary.py new file mode 100644 index 00000000..ec72b0d1 --- /dev/null +++ b/models/tts/debatts/utils/model_summary.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import humanfriendly +import numpy as np +import torch + + +def get_human_readable_count(number: int) -> str: + """Return human_readable_count + + Originated from: + https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py + + Abbreviates an integer number with K, M, B, T for thousands, millions, + billions and trillions, respectively. + Examples: + >>> get_human_readable_count(123) + '123 ' + >>> get_human_readable_count(1234) # (one thousand) + '1 K' + >>> get_human_readable_count(2e6) # (two million) + '2 M' + >>> get_human_readable_count(3e9) # (three billion) + '3 B' + >>> get_human_readable_count(4e12) # (four trillion) + '4 T' + >>> get_human_readable_count(5e15) # (more than trillion) + '5,000 T' + Args: + number: a positive integer number + Return: + A string formatted according to the pattern described above. + """ + assert number >= 0 + labels = [" ", "K", "M", "B", "T"] + num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) + num_groups = int(np.ceil(num_digits / 3)) + num_groups = min(num_groups, len(labels)) + shift = -3 * (num_groups - 1) + number = number * (10**shift) + index = num_groups - 1 + return f"{number:.2f} {labels[index]}" + + +def to_bytes(dtype) -> int: + return int(str(dtype)[-2:]) // 8 + + +def model_summary(model: torch.nn.Module) -> str: + message = "Model structure:\n" + message += str(model) + tot_params = sum(p.numel() for p in model.parameters()) + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params) + tot_params = get_human_readable_count(tot_params) + num_params = get_human_readable_count(num_params) + message += "\n\nModel summary:\n" + message += f" Class Name: {model.__class__.__name__}\n" + message += f" Total Number of model parameters: {tot_params}\n" + message += ( + f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n" + ) + num_bytes = humanfriendly.format_size( + sum( + p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad + ) + ) + message += f" Size: {num_bytes}\n" + dtype = next(iter(model.parameters())).dtype + message += f" Type: {dtype}" + return message diff --git a/models/tts/debatts/utils/prompt_preparer.py b/models/tts/debatts/utils/prompt_preparer.py new file mode 100644 index 00000000..945a5e24 --- /dev/null +++ b/models/tts/debatts/utils/prompt_preparer.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +class PromptPreparer: + def prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): + if self.prefix_mode == 0: + y_emb, prefix_len = self._handle_prefix_mode_0(y, codes, nar_stage) + elif self.prefix_mode == 1: + y_emb, prefix_len = self._handle_prefix_mode_1(y, y_lens, codes, nar_stage) + elif self.prefix_mode in [2, 4]: + y_emb, prefix_len = self._handle_prefix_mode_2_4( + y, y_lens, codes, nar_stage, y_prompts_codes + ) + else: + raise ValueError("Invalid prefix mode") + + return y_emb, prefix_len + + def _handle_prefix_mode_0(self, y, codes, nar_stage): + prefix_len = 0 + y_emb = self.nar_audio_embeddings[0](y) + for j in range(1, nar_stage): + y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) + return y_emb, 0 + + def _handle_prefix_mode_1(self, y, y_lens, codes, nar_stage): + int_low = (0.25 * y_lens.min()).type(torch.int64).item() + prefix_len = torch.randint(int_low, int_low * 2, size=()).item() + prefix_len = min(prefix_len, 225) + + y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) + y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) + for j in range(1, self.num_quantizers): + y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j]) + if j < nar_stage: + y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j]) + y_emb = torch.concat([y_prompts, y_emb], axis=1) + return y_emb, prefix_len + + def _handle_prefix_mode_2_4(self, y, y_lens, codes, nar_stage, y_prompts_codes): + if self.prefix_mode == 2: + prefix_len = min(225, int(0.25 * y_lens.min().item())) + + y_prompts_codes = [] + for b in range(codes.shape[0]): + start = self.rng.randint(0, y_lens[b].item() - prefix_len) + y_prompts_codes.append( + torch.clone(codes[b, start : start + prefix_len]) + ) + codes[b, start : start + prefix_len, nar_stage] = self.audio_token_num + y_prompts_codes = torch.stack(y_prompts_codes, dim=0) + else: + prefix_len = y_prompts_codes.shape[1] + + y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) + y_emb = self.nar_audio_embeddings[0](y) + for j in range(1, self.num_quantizers): + y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j]) + if j < nar_stage: + y_emb += self.nar_audio_embeddings[j](codes[..., j]) + y_emb = torch.concat([y_prompts, y_emb], axis=1) + + return y_emb, prefix_len diff --git a/models/tts/debatts/utils/ssim.py b/models/tts/debatts/utils/ssim.py new file mode 100644 index 00000000..a0b95007 --- /dev/null +++ b/models/tts/debatts/utils/ssim.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from https://github.com/Po-Hsun-Su/pytorch-ssim + +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp + + +def gaussian(window_size, sigma): + gauss = torch.Tensor( + [ + exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) + for x in range(window_size) + ] + ) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable( + _2D_window.expand(channel, 1, window_size, window_size).contiguous() + ) + return window + + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = ( + F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + ) + sigma2_sq = ( + F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + ) + sigma12 = ( + F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) + - mu1_mu2 + ) + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, fake, real, bias=6.0): + fake = fake[:, None, :, :] + bias # [B, 1, T, n_mels] + real = real[:, None, :, :] + bias # [B, 1, T, n_mels] + self.window = self.window.to(dtype=fake.dtype, device=fake.device) + loss = 1 - _ssim( + fake, real, self.window, self.window_size, self.channel, self.size_average + ) + return loss diff --git a/models/tts/debatts/utils/stft.py b/models/tts/debatts/utils/stft.py new file mode 100644 index 00000000..bcec4c84 --- /dev/null +++ b/models/tts/debatts/utils/stft.py @@ -0,0 +1,278 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import numpy as np +from scipy.signal import get_window +from librosa.util import pad_center, tiny +from librosa.filters import mel as librosa_mel_fn + +import torch +import numpy as np +import librosa.util as librosa_util +from scipy.signal import get_window + + +def window_sumsquare( + window, + n_frames, + hop_length, + win_length, + n_fft, + dtype=np.float32, + norm=None, +): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] + return x + + +def griffin_lim(magnitudes, stft_fn, n_iters=30): + """ + PARAMS + ------ + magnitudes: spectrogram magnitudes + stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods + """ + + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) + angles = angles.astype(np.float32) + angles = torch.autograd.Variable(torch.from_numpy(angles)) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + + for i in range(n_iters): + _, angles = stft_fn.transform(signal) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + return signal + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + + def __init__(self, filter_length, hop_length, win_length, window="hann"): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) + + if window is not None: + assert filter_length >= win_length + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode="reflect", + ) + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data.cuda(), + torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(), + stride=self.hop_length, + padding=0, + ).cpu() + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + torch.autograd.Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0] + ) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False + ) + window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] + inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class TacotronSTFT(torch.nn.Module): + def __init__( + self, + filter_length, + hop_length, + win_length, + n_mel_channels, + sampling_rate, + mel_fmin, + mel_fmax, + ): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + mel_basis = librosa_mel_fn( + sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + + def spectral_normalize(self, magnitudes): + output = dynamic_range_compression(magnitudes) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert torch.min(y.data) >= -1 + assert torch.max(y.data) <= 1 + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output) + energy = torch.norm(magnitudes, dim=1) + + return mel_output, energy diff --git a/models/tts/debatts/utils/symbol_table.py b/models/tts/debatts/utils/symbol_table.py new file mode 100644 index 00000000..a0e736fe --- /dev/null +++ b/models/tts/debatts/utils/symbol_table.py @@ -0,0 +1,317 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from +# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/utils/symbol_table.py + +from dataclasses import dataclass +from dataclasses import field +from typing import Dict +from typing import Generic +from typing import List +from typing import Optional +from typing import TypeVar +from typing import Union + +Symbol = TypeVar("Symbol") + + +@dataclass(repr=False) +class SymbolTable(Generic[Symbol]): + """SymbolTable that maps symbol IDs, found on the FSA arcs to + actual objects. These objects can be arbitrary Python objects + that can serve as keys in a dictionary (i.e. they need to be + hashable and immutable). + + The SymbolTable can only be read to/written from disk if the + symbols are strings. + """ + + _id2sym: Dict[int, Symbol] = field(default_factory=dict) + """Map an integer to a symbol. + """ + + _sym2id: Dict[Symbol, int] = field(default_factory=dict) + """Map a symbol to an integer. + """ + + _next_available_id: int = 1 + """A helper internal field that helps adding new symbols + to the table efficiently. + """ + + eps: Symbol = "" + """Null symbol, always mapped to index 0. + """ + + def __post_init__(self): + assert all(self._sym2id[sym] == idx for idx, sym in self._id2sym.items()) + assert all(self._id2sym[idx] == sym for sym, idx in self._sym2id.items()) + assert 0 not in self._id2sym or self._id2sym[0] == self.eps + + self._next_available_id = max(self._id2sym, default=0) + 1 + self._id2sym.setdefault(0, self.eps) + self._sym2id.setdefault(self.eps, 0) + + @staticmethod + def from_str(s: str) -> "SymbolTable": + """Build a symbol table from a string. + + The string consists of lines. Every line has two fields separated + by space(s), tab(s) or both. The first field is the symbol and the + second the integer id of the symbol. + + Args: + s: + The input string with the format described above. + Returns: + An instance of :class:`SymbolTable`. + """ + id2sym: Dict[int, str] = dict() + sym2id: Dict[str, int] = dict() + + for line in s.split("\n"): + fields = line.split() + if len(fields) == 0: + continue # skip empty lines + assert ( + len(fields) == 2 + ), f"Expect a line with 2 fields. Given: {len(fields)}" + sym, idx = fields[0], int(fields[1]) + assert sym not in sym2id, f"Duplicated symbol {sym}" + assert idx not in id2sym, f"Duplicated id {idx}" + id2sym[idx] = sym + sym2id[sym] = idx + + eps = id2sym.get(0, "") + + return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps) + + @staticmethod + def from_file(filename: str) -> "SymbolTable": + """Build a symbol table from file. + + Every line in the symbol table file has two fields separated by + space(s), tab(s) or both. The following is an example file: + + .. code-block:: + + 0 + a 1 + b 2 + c 3 + + Args: + filename: + Name of the symbol table file. Its format is documented above. + + Returns: + An instance of :class:`SymbolTable`. + + """ + with open(filename, "r", encoding="utf-8") as f: + return SymbolTable.from_str(f.read().strip()) + + def to_str(self) -> str: + """ + Returns: + Return a string representation of this object. You can pass + it to the method ``from_str`` to recreate an identical object. + """ + s = "" + for idx, symbol in sorted(self._id2sym.items()): + s += f"{symbol} {idx}\n" + return s + + def to_file(self, filename: str): + """Serialize the SymbolTable to a file. + + Every line in the symbol table file has two fields separated by + space(s), tab(s) or both. The following is an example file: + + .. code-block:: + + 0 + a 1 + b 2 + c 3 + + Args: + filename: + Name of the symbol table file. Its format is documented above. + """ + with open(filename, "w") as f: + for idx, symbol in sorted(self._id2sym.items()): + print(symbol, idx, file=f) + + def add(self, symbol: Symbol, index: Optional[int] = None) -> int: + """Add a new symbol to the SymbolTable. + + Args: + symbol: + The symbol to be added. + index: + Optional int id to which the symbol should be assigned. + If it is not available, a ValueError will be raised. + + Returns: + The int id to which the symbol has been assigned. + """ + # Already in the table? Return its ID. + if symbol in self._sym2id: + return self._sym2id[symbol] + # Specific ID not provided - use next available. + if index is None: + index = self._next_available_id + # Specific ID provided but not available. + if index in self._id2sym: + raise ValueError( + f"Cannot assign id '{index}' to '{symbol}' - " + f"already occupied by {self._id2sym[index]}" + ) + self._sym2id[symbol] = index + self._id2sym[index] = symbol + + # Update next available ID if needed + if self._next_available_id <= index: + self._next_available_id = index + 1 + + return index + + def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]: + """Get a symbol for an id or get an id for a symbol + + Args: + k: + If it is an id, it tries to find the symbol corresponding + to the id; if it is a symbol, it tries to find the id + corresponding to the symbol. + + Returns: + An id or a symbol depending on the given `k`. + """ + if isinstance(k, int): + return self._id2sym[k] + else: + return self._sym2id[k] + + def merge(self, other: "SymbolTable") -> "SymbolTable": + """Create a union of two SymbolTables. + Raises an AssertionError if the same IDs are occupied by + different symbols. + + Args: + other: + A symbol table to merge with ``self``. + + Returns: + A new symbol table. + """ + self._check_compatible(other) + return SymbolTable( + _id2sym={**self._id2sym, **other._id2sym}, + _sym2id={**self._sym2id, **other._sym2id}, + eps=self.eps, + ) + + def _check_compatible(self, other: "SymbolTable") -> None: + # Epsilon compatibility + assert self.eps == other.eps, ( + f"Mismatched epsilon symbol: " f"{self.eps} != {other.eps}" + ) + # IDs compatibility + common_ids = set(self._id2sym).intersection(other._id2sym) + for idx in common_ids: + assert self[idx] == other[idx], ( + f"ID conflict for id: {idx}, " + f'self[idx] = "{self[idx]}", ' + f'other[idx] = "{other[idx]}"' + ) + # Symbols compatibility + common_symbols = set(self._sym2id).intersection(other._sym2id) + for sym in common_symbols: + assert self[sym] == other[sym], ( + f"ID conflict for id: {sym}, " + f'self[sym] = "{self[sym]}", ' + f'other[sym] = "{other[sym]}"' + ) + + def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]: + return self.get(item) + + def __contains__(self, item: Union[int, Symbol]) -> bool: + if isinstance(item, int): + return item in self._id2sym + else: + return item in self._sym2id + + def __len__(self) -> int: + return len(self._id2sym) + + def __eq__(self, other: "SymbolTable") -> bool: + if len(self) != len(other): + return False + + for s in self.symbols: + if self[s] != other[s]: + return False + + return True + + @property + def ids(self) -> List[int]: + """Returns a list of integer IDs corresponding to the symbols.""" + ans = list(self._id2sym.keys()) + ans.sort() + return ans + + @property + def symbols(self) -> List[Symbol]: + """Returns a list of symbols (e.g., strings) corresponding to + the integer IDs. + """ + ans = list(self._sym2id.keys()) + ans.sort() + return ans + + +class TextToken: + def __init__( + self, + text_tokens: List[str], + add_eos: bool = True, + add_bos: bool = True, + pad_symbol: str = "", + bos_symbol: str = "", + eos_symbol: str = "", + ): + self.pad_symbol = pad_symbol + self.add_eos = add_eos + self.add_bos = add_bos + self.bos_symbol = bos_symbol + self.eos_symbol = eos_symbol + + unique_tokens = [pad_symbol] + if add_bos: + unique_tokens.append(bos_symbol) + if add_eos: + unique_tokens.append(eos_symbol) + unique_tokens.extend(sorted(text_tokens)) + + self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} + self.idx2token = unique_tokens + + def get_token_id_seq(self, text): + tokens_seq = [p for p in text] + seq = ( + ([self.bos_symbol] if self.add_bos else []) + + tokens_seq + + ([self.eos_symbol] if self.add_eos else []) + ) + + token_ids = [self.token2idx[token] for token in seq] + token_lens = len(tokens_seq) + self.add_eos + self.add_bos + + return token_ids, token_lens diff --git a/models/tts/debatts/utils/tokenizer.py b/models/tts/debatts/utils/tokenizer.py new file mode 100644 index 00000000..7eeef586 --- /dev/null +++ b/models/tts/debatts/utils/tokenizer.py @@ -0,0 +1,150 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from +# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/data/tokenizer.py + +import re +from typing import Any, Dict, List, Optional, Pattern, Union + +import torch +import torchaudio +from encodec import EncodecModel +from encodec.utils import convert_audio + + +class AudioTokenizer: + """EnCodec audio tokenizer for encoding and decoding audio. + + Attributes: + device: The device on which the codec model is loaded. + codec: The pretrained EnCodec model. + sample_rate: Sample rate of the model. + channels: Number of audio channels in the model. + """ + + def __init__(self, device: Any = None) -> None: + model = EncodecModel.encodec_model_24khz() + model.set_target_bandwidth(6.0) + remove_encodec_weight_norm(model) + + if not device: + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda:0") + + self._device = device + + self.codec = model.to(device) + self.sample_rate = model.sample_rate + self.channels = model.channels + + @property + def device(self): + return self._device + + def encode(self, wav: torch.Tensor) -> torch.Tensor: + """Encode the audio waveform. + + Args: + wav: A tensor representing the audio waveform. + + Returns: + A tensor representing the encoded audio. + """ + return self.codec.encode(wav.to(self.device)) + + def decode(self, frames: torch.Tensor) -> torch.Tensor: + """Decode the encoded audio frames. + + Args: + frames: A tensor representing the encoded audio frames. + + Returns: + A tensor representing the decoded audio waveform. + """ + return self.codec.decode(frames) + + +def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str): + """ + Tokenize the audio waveform using the given AudioTokenizer. + + Args: + tokenizer: An instance of AudioTokenizer. + audio_path: Path to the audio file. + + Returns: + A tensor of encoded frames from the audio. + + Raises: + FileNotFoundError: If the audio file is not found. + RuntimeError: If there's an error processing the audio data. + """ + # try: + # Load and preprocess the audio waveform + wav, sr = torchaudio.load(audio_path) + wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) + wav = wav.unsqueeze(0) + + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = tokenizer.encode(wav) + return encoded_frames + + # except FileNotFoundError: + # raise FileNotFoundError(f"Audio file not found at {audio_path}") + # except Exception as e: + # raise RuntimeError(f"Error processing audio data: {e}") + + +def remove_encodec_weight_norm(model): + from encodec.modules import SConv1d + from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock + from torch.nn.utils import remove_weight_norm + + encoder = model.encoder.model + for key in encoder._modules: + if isinstance(encoder._modules[key], SEANetResnetBlock): + remove_weight_norm(encoder._modules[key].shortcut.conv.conv) + block_modules = encoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(encoder._modules[key], SConv1d): + remove_weight_norm(encoder._modules[key].conv.conv) + + decoder = model.decoder.model + for key in decoder._modules: + if isinstance(decoder._modules[key], SEANetResnetBlock): + remove_weight_norm(decoder._modules[key].shortcut.conv.conv) + block_modules = decoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(decoder._modules[key], SConvTranspose1d): + remove_weight_norm(decoder._modules[key].convtr.convtr) + elif isinstance(decoder._modules[key], SConv1d): + remove_weight_norm(decoder._modules[key].conv.conv) + + +def extract_encodec_token(wav_path): + model = EncodecModel.encodec_model_24khz() + model.set_target_bandwidth(6.0) + + wav, sr = torchaudio.load(wav_path) + wav = convert_audio(wav, sr, model.sample_rate, model.channels) + wav = wav.unsqueeze(0) + if torch.cuda.is_available(): + model = model.cuda() + wav = wav.cuda() + with torch.no_grad(): + encoded_frames = model.encode(wav) + codes_ = torch.cat( + [encoded[0] for encoded in encoded_frames], dim=-1 + ) # [B, n_q, T] + codes = codes_.cpu().numpy()[0, :, :].T # [T, 8] + + return codes diff --git a/models/tts/debatts/utils/tool.py b/models/tts/debatts/utils/tool.py new file mode 100644 index 00000000..8af5e9ba --- /dev/null +++ b/models/tts/debatts/utils/tool.py @@ -0,0 +1,84 @@ +import json +import os +import io +import scipy +import os +import shutil +from pydub import AudioSegment +import soundfile as sf + + +def load_cfg(cfg_path): + if not os.path.exists("config.json"): + raise FileNotFoundError( + "config.json not found. Please: copy, config, and rename `config.json.example` to `config.json`" + ) + with open(cfg_path, "r") as f: + cfg = json.load(f) + return cfg + + +def write_wav(path, sr, x): + """numpy array to WAV""" + sf.write(path, x, sr) + + +def write_mp3(path, sr, x): + """numpy array to MP3""" + wav_io = io.BytesIO() + scipy.io.wavfile.write(wav_io, sr, x) + wav_io.seek(0) + sound = AudioSegment.from_wav(wav_io) + with open(path, "wb") as af: + sound.export( + af, + format="mp3", + codec="mp3", + bitrate="160000", + ) + + +# 读取文件夹内所有音频文件 +def get_audio_files(folder_path): + audio_files = [] + for root, _, files in os.walk(folder_path): + if "_processed" in root: + continue + for file in files: + if ".temp" in file: + continue + if file.endswith((".mp3", ".wav", ".flac", ".m4a")): + audio_files.append(os.path.join(root, file)) + return audio_files + + +def get_specific_files(folder_path, ext): + audio_files = [] + for root, _, files in os.walk(folder_path): + if "_processed" in root: + continue + for file in files: + if ".temp" in file: + continue + if file.endswith(ext): + audio_files.append(os.path.join(root, file)) + return audio_files + + +def move_vocals(src_directory): + # 遍历根目录下的所有文件和文件夹 + for root, _, files in os.walk(src_directory): + for file in files: + # 检查文件名是否为'vocals.mp3' + if file == "vocals.mp3": + # 构建源文件的完整路径 + src_path = os.path.join(root, file) + # 获取父级目录的名称 + parent_dir_name = os.path.basename(root) + # 构建目标文件的完整路径 + dest_path = os.path.join(src_directory, parent_dir_name + ".mp3") + # 复制文件 + shutil.copy(src_path, dest_path) + + # 删除源文件夹 + shutil.rmtree(src_directory + "/htdemucs") diff --git a/models/tts/debatts/utils/topk_sampling.py b/models/tts/debatts/utils/topk_sampling.py new file mode 100644 index 00000000..236a0f93 --- /dev/null +++ b/models/tts/debatts/utils/topk_sampling.py @@ -0,0 +1,72 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + + +def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (vocabulary size) + top_k >0: keep only top k tokens with highest probability (top-k filtering). + top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + + Basic outline taken from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + assert logits.dim() == 2 # [BATCH_SIZE, VOCAB_SIZE] + top_k = min(top_k, logits.size(-1)) # Safety check + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k, dim=1)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # Replace logits to be removed with -inf in the sorted_logits + sorted_logits[sorted_indices_to_remove] = filter_value + # Then reverse the sorting process by mapping back sorted_logits to their original position + logits = torch.gather(sorted_logits, 1, sorted_indices.argsort(-1)) + + # pred_token = torch.multinomial(F.softmax(logits, -1), 1) # [BATCH_SIZE, 1] + return logits + + +def topk_sampling(logits, top_k=50, top_p=1.0, temperature=1.0): + """ + Perform top-k and top-p sampling on logits. + + Args: + logits (torch.Tensor): The logits to sample from. + top_k (int, optional): The number of highest probability tokens to keep for top-k filtering. + Must be a positive integer. Defaults to 50. + top_p (float, optional): The cumulative probability threshold for nucleus sampling. + Must be between 0 and 1. Defaults to 1.0. + temperature (float, optional): The scaling factor to adjust the logits distribution. + Must be strictly positive. Defaults to 1.0. + + Returns: + torch.Tensor: The sampled token. + """ + + # Adjust logits using temperature + if temperature != 1.0: + logits = logits / temperature + + # Top-p/top-k filtering + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + + # Sample from the filtered distribution + token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + return token diff --git a/models/tts/debatts/utils/trainer_utils.py b/models/tts/debatts/utils/trainer_utils.py new file mode 100644 index 00000000..e5d9ad79 --- /dev/null +++ b/models/tts/debatts/utils/trainer_utils.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def check_nan(logger, loss, y_pred, y_gt): + if torch.any(torch.isnan(loss)): + logger.info("out has nan: ", torch.any(torch.isnan(y_pred))) + logger.info("y_gt has nan: ", torch.any(torch.isnan(y_gt))) + logger.info("out: ", y_pred) + logger.info("y_gt: ", y_gt) + logger.info("loss = {:.4f}\n".format(loss.item())) + exit() diff --git a/models/tts/debatts/utils/util.py b/models/tts/debatts/utils/util.py new file mode 100644 index 00000000..b7eaf1aa --- /dev/null +++ b/models/tts/debatts/utils/util.py @@ -0,0 +1,687 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import collections +import glob +import os +import random +import time +import argparse +from collections import OrderedDict + +import json5 +import numpy as np +import glob +from torch.nn import functional as F + + +try: + from ruamel.yaml import YAML as yaml +except: + from ruamel_yaml import YAML as yaml + +import torch + +from utils.hparam import HParams +import logging +from logging import handlers + + +def str2bool(v): + """Used in argparse.ArgumentParser.add_argument to indicate + that a type is a bool type and user can enter + + - yes, true, t, y, 1, to represent True + - no, false, f, n, 0, to represent False + + See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa + """ + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + +def find_checkpoint_of_mapper(mapper_ckpt_dir): + mapper_ckpts = glob.glob(os.path.join(mapper_ckpt_dir, "ckpts/*.pt")) + + # Select the max steps + mapper_ckpts.sort() + mapper_weights_file = mapper_ckpts[-1] + return mapper_weights_file + + +def pad_f0_to_tensors(f0s, batched=None): + # Initialize + tensors = [] + + if batched == None: + # Get the max frame for padding + size = -1 + for f0 in f0s: + size = max(size, f0.shape[-1]) + + tensor = torch.zeros(len(f0s), size) + + for i, f0 in enumerate(f0s): + tensor[i, : f0.shape[-1]] = f0[:] + + tensors.append(tensor) + else: + start = 0 + while start + batched - 1 < len(f0s): + end = start + batched - 1 + + # Get the max frame for padding + size = -1 + for i in range(start, end + 1): + size = max(size, f0s[i].shape[-1]) + + tensor = torch.zeros(batched, size) + + for i in range(start, end + 1): + tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:] + + tensors.append(tensor) + + start = start + batched + + if start != len(f0s): + end = len(f0s) + + # Get the max frame for padding + size = -1 + for i in range(start, end): + size = max(size, f0s[i].shape[-1]) + + tensor = torch.zeros(len(f0s) - start, size) + + for i in range(start, end): + tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:] + + tensors.append(tensor) + + return tensors + + +def pad_mels_to_tensors(mels, batched=None): + """ + Args: + mels: A list of mel-specs + Returns: + tensors: A list of tensors containing the batched mel-specs + mel_frames: A list of tensors containing the frames of the original mel-specs + """ + # Initialize + tensors = [] + mel_frames = [] + + # Split mel-specs into batches to avoid cuda memory exceed + if batched == None: + # Get the max frame for padding + size = -1 + for mel in mels: + size = max(size, mel.shape[-1]) + + tensor = torch.zeros(len(mels), mels[0].shape[0], size) + mel_frame = torch.zeros(len(mels), dtype=torch.int32) + + for i, mel in enumerate(mels): + tensor[i, :, : mel.shape[-1]] = mel[:] + mel_frame[i] = mel.shape[-1] + + tensors.append(tensor) + mel_frames.append(mel_frame) + else: + start = 0 + while start + batched - 1 < len(mels): + end = start + batched - 1 + + # Get the max frame for padding + size = -1 + for i in range(start, end + 1): + size = max(size, mels[i].shape[-1]) + + tensor = torch.zeros(batched, mels[0].shape[0], size) + mel_frame = torch.zeros(batched, dtype=torch.int32) + + for i in range(start, end + 1): + tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:] + mel_frame[i - start] = mels[i].shape[-1] + + tensors.append(tensor) + mel_frames.append(mel_frame) + + start = start + batched + + if start != len(mels): + end = len(mels) + + # Get the max frame for padding + size = -1 + for i in range(start, end): + size = max(size, mels[i].shape[-1]) + + tensor = torch.zeros(len(mels) - start, mels[0].shape[0], size) + mel_frame = torch.zeros(len(mels) - start, dtype=torch.int32) + + for i in range(start, end): + tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:] + mel_frame[i - start] = mels[i].shape[-1] + + tensors.append(tensor) + mel_frames.append(mel_frame) + + return tensors, mel_frames + + +def load_model_config(args): + """Load model configurations (in args.json under checkpoint directory) + + Args: + args (ArgumentParser): arguments to run bins/preprocess.py + + Returns: + dict: dictionary that stores model configurations + """ + if args.checkpoint_dir is None: + assert args.checkpoint_file is not None + checkpoint_dir = os.path.split(args.checkpoint_file)[0] + else: + checkpoint_dir = args.checkpoint_dir + config_path = os.path.join(checkpoint_dir, "args.json") + print("config_path: ", config_path) + + config = load_config(config_path) + return config + + +def remove_and_create(dir): + if os.path.exists(dir): + os.system("rm -r {}".format(dir)) + os.makedirs(dir, exist_ok=True) + + +def has_existed(path, warning=False): + if not warning: + return os.path.exists(path) + + if os.path.exists(path): + answer = input( + "The path {} has existed. \nInput 'y' (or hit Enter) to skip it, and input 'n' to re-write it [y/n]\n".format( + path + ) + ) + if not answer == "n": + return True + + return False + + +def remove_older_ckpt(saved_model_name, checkpoint_dir, max_to_keep=5): + if os.path.exists(os.path.join(checkpoint_dir, "checkpoint")): + with open(os.path.join(checkpoint_dir, "checkpoint"), "r") as f: + ckpts = [x.strip() for x in f.readlines()] + else: + ckpts = [] + ckpts.append(saved_model_name) + for item in ckpts[:-max_to_keep]: + if os.path.exists(os.path.join(checkpoint_dir, item)): + os.remove(os.path.join(checkpoint_dir, item)) + with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as f: + for item in ckpts[-max_to_keep:]: + f.write("{}\n".format(item)) + + +def set_all_random_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def save_checkpoint( + args, + generator, + g_optimizer, + step, + discriminator=None, + d_optimizer=None, + max_to_keep=5, +): + saved_model_name = "model.ckpt-{}.pt".format(step) + checkpoint_path = os.path.join(args.checkpoint_dir, saved_model_name) + + if discriminator and d_optimizer: + torch.save( + { + "generator": generator.state_dict(), + "discriminator": discriminator.state_dict(), + "g_optimizer": g_optimizer.state_dict(), + "d_optimizer": d_optimizer.state_dict(), + "global_step": step, + }, + checkpoint_path, + ) + else: + torch.save( + { + "generator": generator.state_dict(), + "g_optimizer": g_optimizer.state_dict(), + "global_step": step, + }, + checkpoint_path, + ) + + print("Saved checkpoint: {}".format(checkpoint_path)) + + if os.path.exists(os.path.join(args.checkpoint_dir, "checkpoint")): + with open(os.path.join(args.checkpoint_dir, "checkpoint"), "r") as f: + ckpts = [x.strip() for x in f.readlines()] + else: + ckpts = [] + ckpts.append(saved_model_name) + for item in ckpts[:-max_to_keep]: + if os.path.exists(os.path.join(args.checkpoint_dir, item)): + os.remove(os.path.join(args.checkpoint_dir, item)) + with open(os.path.join(args.checkpoint_dir, "checkpoint"), "w") as f: + for item in ckpts[-max_to_keep:]: + f.write("{}\n".format(item)) + + +def attempt_to_restore( + generator, g_optimizer, checkpoint_dir, discriminator=None, d_optimizer=None +): + checkpoint_list = os.path.join(checkpoint_dir, "checkpoint") + if os.path.exists(checkpoint_list): + checkpoint_filename = open(checkpoint_list).readlines()[-1].strip() + checkpoint_path = os.path.join(checkpoint_dir, "{}".format(checkpoint_filename)) + print("Restore from {}".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + if generator: + if not list(generator.state_dict().keys())[0].startswith("module."): + raw_dict = checkpoint["generator"] + clean_dict = OrderedDict() + for k, v in raw_dict.items(): + if k.startswith("module."): + clean_dict[k[7:]] = v + else: + clean_dict[k] = v + generator.load_state_dict(clean_dict) + else: + generator.load_state_dict(checkpoint["generator"]) + if g_optimizer: + g_optimizer.load_state_dict(checkpoint["g_optimizer"]) + global_step = 100000 + if discriminator and "discriminator" in checkpoint.keys(): + discriminator.load_state_dict(checkpoint["discriminator"]) + global_step = checkpoint["global_step"] + print("restore discriminator") + if d_optimizer and "d_optimizer" in checkpoint.keys(): + d_optimizer.load_state_dict(checkpoint["d_optimizer"]) + print("restore d_optimizer...") + else: + global_step = 0 + return global_step + + +class ExponentialMovingAverage(object): + def __init__(self, decay): + self.decay = decay + self.shadow = {} + + def register(self, name, val): + self.shadow[name] = val.clone() + + def update(self, name, x): + assert name in self.shadow + update_delta = self.shadow[name] - x + self.shadow[name] -= (1.0 - self.decay) * update_delta + + +def apply_moving_average(model, ema): + for name, param in model.named_parameters(): + if name in ema.shadow: + ema.update(name, param.data) + + +def register_model_to_ema(model, ema): + for name, param in model.named_parameters(): + if param.requires_grad: + ema.register(name, param.data) + + +class YParams(HParams): + def __init__(self, yaml_file): + if not os.path.exists(yaml_file): + raise IOError("yaml file: {} is not existed".format(yaml_file)) + super().__init__() + self.d = collections.OrderedDict() + with open(yaml_file) as fp: + for _, v in yaml().load(fp).items(): + for k1, v1 in v.items(): + try: + if self.get(k1): + self.set_hparam(k1, v1) + else: + self.add_hparam(k1, v1) + self.d[k1] = v1 + except Exception: + import traceback + + print(traceback.format_exc()) + + # @property + def get_elements(self): + return self.d.items() + + +def override_config(base_config, new_config): + """Update new configurations in the original dict with the new dict + + Args: + base_config (dict): original dict to be overridden + new_config (dict): dict with new configurations + + Returns: + dict: updated configuration dict + """ + for k, v in new_config.items(): + if type(v) == dict: + if k not in base_config.keys(): + base_config[k] = {} + base_config[k] = override_config(base_config[k], v) + else: + base_config[k] = v + return base_config + + +def get_lowercase_keys_config(cfg): + """Change all keys in cfg to lower case + + Args: + cfg (dict): dictionary that stores configurations + + Returns: + dict: dictionary that stores configurations + """ + updated_cfg = dict() + for k, v in cfg.items(): + if type(v) == dict: + v = get_lowercase_keys_config(v) + updated_cfg[k.lower()] = v + return updated_cfg + + +def _load_config(config_fn, lowercase=False): + """Load configurations into a dictionary + + Args: + config_fn (str): path to configuration file + lowercase (bool, optional): whether changing keys to lower case. Defaults to False. + + Returns: + dict: dictionary that stores configurations + """ + with open(config_fn, "r") as f: + data = f.read() + config_ = json5.loads(data) + if "base_config" in config_: + # load configurations from new path + p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"]) + p_config_ = _load_config(p_config_path) + config_ = override_config(p_config_, config_) + if lowercase: + # change keys in config_ to lower case + config_ = get_lowercase_keys_config(config_) + return config_ + + +def load_config(config_fn, lowercase=False): + """Load configurations into a dictionary + + Args: + config_fn (str): path to configuration file + lowercase (bool, optional): _description_. Defaults to False. + + Returns: + JsonHParams: an object that stores configurations + """ + config_ = _load_config(config_fn, lowercase=lowercase) + # create an JsonHParams object with configuration dict + cfg = JsonHParams(**config_) + return cfg + + +def save_config(save_path, cfg): + """Save configurations into a json file + + Args: + save_path (str): path to save configurations + cfg (dict): dictionary that stores configurations + """ + with open(save_path, "w") as f: + json5.dump( + cfg, f, ensure_ascii=False, indent=4, quote_keys=True, sort_keys=True + ) + + +class JsonHParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = JsonHParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + +class ValueWindow: + def __init__(self, window_size=100): + self._window_size = window_size + self._values = [] + + def append(self, x): + self._values = self._values[-(self._window_size - 1) :] + [x] + + @property + def sum(self): + return sum(self._values) + + @property + def count(self): + return len(self._values) + + @property + def average(self): + return self.sum / max(1, self.count) + + def reset(self): + self._values = [] + + +class Logger(object): + def __init__( + self, + filename, + level="info", + when="D", + backCount=10, + fmt="%(asctime)s : %(message)s", + ): + self.level_relations = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "crit": logging.CRITICAL, + } + if level == "debug": + fmt = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s" + self.logger = logging.getLogger(filename) + format_str = logging.Formatter(fmt) + self.logger.setLevel(self.level_relations.get(level)) + sh = logging.StreamHandler() + sh.setFormatter(format_str) + th = handlers.TimedRotatingFileHandler( + filename=filename, when=when, backupCount=backCount, encoding="utf-8" + ) + th.setFormatter(format_str) + self.logger.addHandler(sh) + self.logger.addHandler(th) + self.logger.info( + "==========================New Starting Here==============================" + ) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm + + +def get_current_time(): + pass + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """ + Args: + lengths: + A 1-D tensor containing sentence lengths. + max_len: + The length of masks. + Returns: + Return a 2-D bool tensor, where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + + >>> lengths = torch.tensor([1, 3, 2, 5]) + >>> make_pad_mask(lengths) + tensor([[False, True, True, True, True], + [False, False, False, True, True], + [False, False, True, True, True], + [False, False, False, False, False]]) + """ + assert lengths.ndim == 1, lengths.ndim + max_len = max(max_len, lengths.max()) + n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) + + return expaned_lengths >= lengths.unsqueeze(-1) diff --git a/models/tts/debatts/utils/whisper_transcription.py b/models/tts/debatts/utils/whisper_transcription.py new file mode 100644 index 00000000..98126987 --- /dev/null +++ b/models/tts/debatts/utils/whisper_transcription.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import pathlib +import string +import time +from multiprocessing import Pool, Value, Lock +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor +import torch +import whisper + +processed_files_count = Value("i", 0) # count of processed files +lock = Lock() # lock for the count + + +def preprocess_text(text): + """Preprocess text after ASR""" + return text.lower().translate(str.maketrans("", "", string.punctuation)) + + +def transcribe_audio(model, processor, audio_file, device): + """Transcribe audio file""" + audio = whisper.load_audio(audio_file) # load from path + audio = whisper.pad_or_trim(audio) # default 30 seconds + inputs = whisper.log_mel_spectrogram(audio).to( + device=device + ) # convert to spectrogram + inputs = inputs.unsqueeze(0).type(torch.cuda.HalfTensor) # add batch dimension + + outputs = model.generate( + inputs=inputs, max_new_tokens=128 + ) # generate transcription + transcription = processor.batch_decode(outputs, skip_special_tokens=True)[ + 0 + ] # decode + transcription_processed = preprocess_text(transcription) # preprocess + return transcription_processed + + +def write_transcription(audio_file, transcription): + """Write transcription to txt file""" + txt_file = audio_file.with_suffix(".txt") + with open(txt_file, "w") as file: + file.write(transcription) + + +def init_whisper(model_id, device): + """Initialize whisper model and processor""" + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + print(f"Loading model {model_id}") # model_id = "distil-whisper/distil-large-v2" + distil_model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=False + ) + distil_model = distil_model.to(device) + processor = AutoProcessor.from_pretrained(model_id) + return distil_model, processor + + +def asr_wav_files(file_list, gpu_id, total_files, model_id): + """Transcribe wav files in a list""" + device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu" + whisper_model, processor = init_whisper(model_id, device) + print(f"Processing on {device} starts") + start_time = time.time() + for audio_file in file_list: + try: + transcription = transcribe_audio( + whisper_model, processor, audio_file, device + ) + write_transcription(audio_file, transcription) + with lock: + processed_files_count.value += 1 + if processed_files_count.value % 5 == 0: + current_time = time.time() + avg_time_per_file = (current_time - start_time) / ( + processed_files_count.value + ) + remaining_files = total_files - processed_files_count.value + estimated_time_remaining = avg_time_per_file * remaining_files + remaining_time_formatted = time.strftime( + "%H:%M:%S", time.gmtime(estimated_time_remaining) + ) + print( + f"Processed {processed_files_count.value}/{total_files} files, time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}, Estimated time remaining: {remaining_time_formatted}" + ) + except Exception as e: + print(f"Error processing file {audio_file}: {e}") + + +def asr_main(input_dir, num_gpus, model_id): + """Transcribe wav files in a directory""" + num_processes = min(num_gpus, os.cpu_count()) + print(f"Using {num_processes} GPUs for transcription") + wav_files = list(pathlib.Path(input_dir).rglob("*.wav")) + total_files = len(wav_files) + print(f"Found {total_files} wav files in {input_dir}") + files_per_process = len(wav_files) // num_processes + print(f"Processing {files_per_process} files per process") + with Pool(num_processes) as p: + p.starmap( + asr_wav_files, + [ + ( + wav_files[i * files_per_process : (i + 1) * files_per_process], + i % num_gpus, + total_files, + model_id, + ) + for i in range(num_processes) + ], + ) + print("Done!") + + +if __name__ == "__main__": + input_dir = "/path/to/output/directory" + num_gpus = 2 + model_id = "distil-whisper/distil-large-v2" + asr_main(input_dir, num_gpus, model_id) diff --git a/models/tts/debatts/utils/world.py b/models/tts/debatts/utils/world.py new file mode 100644 index 00000000..ce5f61bd --- /dev/null +++ b/models/tts/debatts/utils/world.py @@ -0,0 +1,92 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# 1. Extract WORLD features including F0, AP, SP +# 2. Transform between SP and MCEP +import torchaudio +import pyworld as pw +import numpy as np +import torch +import diffsptk +import os +from tqdm import tqdm +import pickle +import torchaudio + + +def get_mcep_params(fs): + """Hyperparameters of transformation between SP and MCEP + + Reference: + https://github.com/CSTR-Edinburgh/merlin/blob/master/misc/scripts/vocoder/world_v2/copy_synthesis.sh + + """ + if fs in [44100, 48000]: + fft_size = 2048 + alpha = 0.77 + if fs in [16000]: + fft_size = 1024 + alpha = 0.58 + return fft_size, alpha + + +def extract_world_features(waveform, frameshift=10): + # waveform: (1, seq) + # x: (seq,) + x = np.array(waveform, dtype=np.double) + + _f0, t = pw.dio(x, fs, frame_period=frameshift) # raw pitch extractor + f0 = pw.stonemask(x, _f0, t, fs) # pitch refinement + sp = pw.cheaptrick(x, f0, t, fs) # extract smoothed spectrogram + ap = pw.d4c(x, f0, t, fs) # extract aperiodicity + + return f0, sp, ap, fs + + +def sp2mcep(x, mcsize, fs): + fft_size, alpha = get_mcep_params(fs) + x = torch.as_tensor(x, dtype=torch.float) + + tmp = diffsptk.ScalarOperation("SquareRoot")(x) + tmp = diffsptk.ScalarOperation("Multiplication", 32768.0)(tmp) + mgc = diffsptk.MelCepstralAnalysis( + cep_order=mcsize - 1, fft_length=fft_size, alpha=alpha, n_iter=1 + )(tmp) + return mgc.numpy() + + +def mcep2sp(x, mcsize, fs): + fft_size, alpha = get_mcep_params(fs) + x = torch.as_tensor(x, dtype=torch.float) + + tmp = diffsptk.MelGeneralizedCepstrumToSpectrum( + alpha=alpha, + cep_order=mcsize - 1, + fft_length=fft_size, + )(x) + tmp = diffsptk.ScalarOperation("Division", 32768.0)(tmp) + sp = diffsptk.ScalarOperation("Power", 2)(tmp) + return sp.double().numpy() + + +def f0_statistics(f0_features, path): + print("\nF0 statistics...") + + total_f0 = [] + for f0 in tqdm(f0_features): + total_f0 += [f for f in f0 if f != 0] + + mean = sum(total_f0) / len(total_f0) + print("Min = {}, Max = {}, Mean = {}".format(min(total_f0), max(total_f0), mean)) + + with open(path, "wb") as f: + pickle.dump([mean, total_f0], f) + + +def world_synthesis(f0, sp, ap, fs, frameshift): + y = pw.synthesize( + f0, sp, ap, fs, frame_period=frameshift + ) # synthesize an utterance using the parameters + return y From 42c47802dfc51b51ddff37fe5c2e62f75c45a0df Mon Sep 17 00:00:00 2001 From: hehaorui Date: Mon, 28 Oct 2024 16:57:32 +0800 Subject: [PATCH 2/8] debatts code modified name --- ..._semantic_repcodec_8192_1q_1layer_24k.json | 112 +++++++++++++++++ ...ama_new_semantic_repcodec_8192_1q_24k.json | 113 ++++++++++++++++++ ...c_repcodec_8192_1q_large_101k_fix_new.json | 0 models/tts/debatts/t2s_sft_dataset_new.py | 6 +- .../debatts/try_inference_small_samples.py | 89 ++++---------- .../utils/{g2p_liwei => g2p_new}/__init__.py | 8 +- .../utils/{g2p_liwei => g2p_new}/cleaners.py | 12 +- .../utils/{g2p_liwei => g2p_new}/english.py | 0 .../utils/{g2p_liwei => g2p_new}/french.py | 0 .../g2p_liwei.py => g2p_new/g2p_new.py} | 4 +- .../utils/{g2p_liwei => g2p_new}/german.py | 0 .../utils/{g2p_liwei => g2p_new}/japanese.py | 0 .../utils/{g2p_liwei => g2p_new}/korean.py | 0 .../utils/{g2p_liwei => g2p_new}/mandarin.py | 0 .../{g2p_liwei => g2p_new}/text_tokenizers.py | 0 .../utils/{g2p_liwei => g2p_new}/vacab.json | 0 16 files changed, 261 insertions(+), 83 deletions(-) create mode 100644 models/tts/debatts/s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json create mode 100644 models/tts/debatts/s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json rename models/tts/debatts/{ => t2s_egs}/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json (100%) rename models/tts/debatts/utils/{g2p_liwei => g2p_new}/__init__.py (87%) rename models/tts/debatts/utils/{g2p_liwei => g2p_new}/cleaners.py (68%) rename models/tts/debatts/utils/{g2p_liwei => g2p_new}/english.py (100%) rename models/tts/debatts/utils/{g2p_liwei => g2p_new}/french.py (100%) rename models/tts/debatts/utils/{g2p_liwei/g2p_liwei.py => g2p_new/g2p_new.py} (59%) rename models/tts/debatts/utils/{g2p_liwei => g2p_new}/german.py (100%) rename models/tts/debatts/utils/{g2p_liwei => g2p_new}/japanese.py (100%) rename models/tts/debatts/utils/{g2p_liwei => g2p_new}/korean.py (100%) rename models/tts/debatts/utils/{g2p_liwei => g2p_new}/mandarin.py (100%) rename models/tts/debatts/utils/{g2p_liwei => g2p_new}/text_tokenizers.py (100%) rename models/tts/debatts/utils/{g2p_liwei => g2p_new}/vacab.json (100%) diff --git a/models/tts/debatts/s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json b/models/tts/debatts/s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json new file mode 100644 index 00000000..ab4c9083 --- /dev/null +++ b/models/tts/debatts/s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json @@ -0,0 +1,112 @@ +{ + "model_type": "SoundStorm", + "dataset": ["emilia"], + "preprocess": { + "hop_size": 480, + "sample_rate": 24000, + "processed_dir": "", + "valid_file": "valid.json", + "train_file": "train.json", + "use_phone_cond": false + }, + "model": { + "soundstorm": { + "num_quantizer": 1, + "hidden_size": 1024, + "num_layers": 16, + "num_heads": 16, + "codebook_size": 1024, + "cfg_scale": 0.15, + "mask_layer_schedule": "linear", + "use_cond_code": true, + "cond_codebook_size": 8192, + "cond_dim":1024, + "use_llama_style": true, + "use_phone_cond": false, + "use_pretrained_model": false + }, + "kmeans": { + "type": "repcodec", + "stat_mean_var_path": "./stat_ckpt/emilia_wav2vec2bert_stats_10k.pt In HuggingFace", + "repcodec": { + "codebook_size": 8192, + "hidden_size": 1024, + "codebook_dim": 8, + "vocos_dim": 384, + "vocos_intermediate_dim": 2048, + "vocos_num_layers": 12 + }, + "pretrained_path": "./semantic_codec/emilia_50k_8192_norm_8d_86k_steps_model.safetensors In HuggingFace" + }, + "codec": { + "encoder": { + "d_model": 96, + "up_ratios": [3, 4, 5, 8], + "out_channels": 256, + "use_tanh": false, + "pretrained_path": "./acoustic_codec/emilia_50k_model.safetensors In HuggingFace" + }, + "decoder": { + "in_channel": 256, + "upsample_initial_channel": 1536, + "up_ratios": [8, 5, 4, 3], + "num_quantizers": 12, + "codebook_size": 1024, + "codebook_dim": 8, + "quantizer_type": "fvq", + "quantizer_dropout": 0.5, + "commitment": 0.25, + "codebook_loss_weight": 1.0, + "use_l2_normlize": true, + "codebook_type": "euclidean", + "kmeans_init": false, + "kmeans_iters": 10, + "decay": 0.8, + "eps": 0.5, + "threshold_ema_dead_code": 2, + "weight_init": false, + "use_vocos": true, + "vocos_dim": 512, + "vocos_intermediate_dim": 4096, + "vocos_num_layers": 30, + "n_fft": 1920, + "hop_size": 480, + "padding": "same", + "pretrained_path": "./acoustic_codec/emilia_50k_model_1.safetensors In HuggingFace" + } + } + }, + "log_dir": "", + "train": { + "max_epoch": 0, + "use_dynamic_batchsize": true, + "max_tokens": 2000000, + "max_sentences": 20, + "lr_warmup_steps": 32000, + "lr_scheduler": "inverse_sqrt", + "num_train_steps": 800000, + "adam": { + "lr": 1e-4 + }, + "ddp": false, + "random_seed": 114, + "batch_size": 10, + "epochs": 5000, + "max_steps": 1000000, + "total_training_steps": 800000, + "save_summary_steps": 500, + "save_checkpoints_steps": 1000, + "valid_interval": 2000, + "keep_checkpoint_max": 100, + "gradient_accumulation_step": 1, + "tracker": ["tensorboard"], + "save_checkpoint_stride": [1], + "keep_last": [10], + "run_eval": [true], + "dataloader": { + "num_worker": 16, + "pin_memory": true + }, + "use_emilia_dataset": true + } +} \ No newline at end of file diff --git a/models/tts/debatts/s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json b/models/tts/debatts/s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json new file mode 100644 index 00000000..7711811b --- /dev/null +++ b/models/tts/debatts/s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json @@ -0,0 +1,113 @@ +{ + "model_type": "SoundStorm", + "dataset": ["emilia"], + "preprocess": { + "hop_size": 480, + "sample_rate": 24000, + "processed_dir": "", + "valid_file": "valid.json", + "train_file": "train.json", + "use_phone_cond": false + }, + "model": { + "soundstorm": { + "num_quantizer": 12, + "hidden_size": 1024, + "num_layers": 16, + "num_heads": 16, + "codebook_size": 1024, + "cfg_scale": 0.15, + "mask_layer_schedule": "linear", + "use_cond_code": true, + "cond_codebook_size": 8192, + "cond_dim":1024, + "use_llama_style": true, + "use_phone_cond": false, + "use_pretrained_model": false, + "predict_layer_1": false + }, + "kmeans": { + "type": "repcodec", + "stat_mean_var_path": "./stat_ckpt/emilia_wav2vec2bert_stats_10k.pt In HuggingFace", + "repcodec": { + "codebook_size": 8192, + "hidden_size": 1024, + "codebook_dim": 8, + "vocos_dim": 384, + "vocos_intermediate_dim": 2048, + "vocos_num_layers": 12 + }, + "pretrained_path": "./semantic_codec/emilia_50k_8192_norm_8d_86k_steps_model.safetensors In HuggingFace" + }, + "codec": { + "encoder": { + "d_model": 96, + "up_ratios": [3, 4, 5, 8], + "out_channels": 256, + "use_tanh": false, + "pretrained_path": "./acoustic_codec/emilia_50k_model.safetensors In HuggingFace" + }, + "decoder": { + "in_channel": 256, + "upsample_initial_channel": 1536, + "up_ratios": [8, 5, 4, 3], + "num_quantizers": 12, + "codebook_size": 1024, + "codebook_dim": 8, + "quantizer_type": "fvq", + "quantizer_dropout": 0.5, + "commitment": 0.25, + "codebook_loss_weight": 1.0, + "use_l2_normlize": true, + "codebook_type": "euclidean", + "kmeans_init": false, + "kmeans_iters": 10, + "decay": 0.8, + "eps": 0.5, + "threshold_ema_dead_code": 2, + "weight_init": false, + "use_vocos": true, + "vocos_dim": 512, + "vocos_intermediate_dim": 4096, + "vocos_num_layers": 30, + "n_fft": 1920, + "hop_size": 480, + "padding": "same", + "pretrained_path": "./acoustic_codec/emilia_50k_model_1.safetensors In HuggingFace" + } + } + }, + "log_dir": "", + "train": { + "max_epoch": 0, + "use_dynamic_batchsize": true, + "max_tokens": 2000000, + "max_sentences": 20, + "lr_warmup_steps": 32000, + "lr_scheduler": "inverse_sqrt", + "num_train_steps": 800000, + "adam": { + "lr": 1e-4 + }, + "ddp": false, + "random_seed": 114, + "batch_size": 10, + "epochs": 5000, + "max_steps": 1000000, + "total_training_steps": 800000, + "save_summary_steps": 500, + "save_checkpoints_steps": 1000, + "valid_interval": 2000, + "keep_checkpoint_max": 100, + "gradient_accumulation_step": 1, + "tracker": ["tensorboard"], + "save_checkpoint_stride": [1], + "keep_last": [10], + "run_eval": [true], + "dataloader": { + "num_worker": 16, + "pin_memory": true + }, + "use_emilia_dataset": true + } +} \ No newline at end of file diff --git a/models/tts/debatts/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json b/models/tts/debatts/t2s_egs/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json similarity index 100% rename from models/tts/debatts/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json rename to models/tts/debatts/t2s_egs/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json diff --git a/models/tts/debatts/t2s_sft_dataset_new.py b/models/tts/debatts/t2s_sft_dataset_new.py index 0e104491..e705fb18 100644 --- a/models/tts/debatts/t2s_sft_dataset_new.py +++ b/models/tts/debatts/t2s_sft_dataset_new.py @@ -27,7 +27,7 @@ import sys sys.path.append('./models/tts/debatts') from utils.g2p_new.g2p import phonemizer_g2p -from utils.g2p_liwei.g2p_liwei import liwei_g2p +from utils.g2p_new.g2p_new import new_g2p from torch.nn.utils.rnn import pad_sequence device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -98,8 +98,8 @@ def __init__( "./w2v-bert-2" ) - def liwei_g2p(self, text, language): - return liwei_g2p(text, language) + def new_g2p(self, text, language): + return new_g2p(text, language) def __len__(self): return self.wav_paths.__len__() diff --git a/models/tts/debatts/try_inference_small_samples.py b/models/tts/debatts/try_inference_small_samples.py index 0e21965f..23e6ee2b 100644 --- a/models/tts/debatts/try_inference_small_samples.py +++ b/models/tts/debatts/try_inference_small_samples.py @@ -7,7 +7,7 @@ import os os.chdir('./models/tts/debatts') sys.path.append('./models/tts/debatts') -from utils.g2p_liwei.g2p_liwei import liwei_g2p +from utils.g2p_new.g2p_new import new_g2p from transformers import Wav2Vec2Model from cgitb import text @@ -42,11 +42,10 @@ from models.tts.text2semantic.t2s_model import T2SLlama from models.tts.text2semantic.t2s_model_new import T2SLlama_new -from utils.g2p_liwei.g2p_liwei import liwei_g2p from models.tts.text2semantic.t2s_sft_dataset_new import DownsampleWithMask -def liwei_g2p_(text, language): - return liwei_g2p(text, language) +def new_g2p_(text, language): + return new_g2p(text, language) def build_t2s_model_new(cfg, device): t2s_model = T2SLlama_new(phone_vocab_size=1024, @@ -89,7 +88,7 @@ def build_kmeans_model(cfg, device): return kmeans_model def build_semantic_model(cfg, device): - semantic_model = Wav2Vec2BertModel.from_pretrained("/mntcephfs/lab_data/lijiaqi/debate/gluster-tts/w2v-bert-2") + semantic_model = Wav2Vec2BertModel.from_pretrained("./w2v-bert-2") semantic_model.eval() semantic_model.to(device) @@ -152,22 +151,22 @@ def extract_features(speech, processor): return input_features, attention_mask @torch.no_grad() -def text2semantic(prompt0_speech, prompt0_text, prompt_speech, prompt_text, prompt_language, target_text, target_language, use_prompt_text=True, temp=1.0, top_k=1000, top_p=0.85, infer_mode = "ori"): +def text2semantic(prompt0_speech, prompt0_text, prompt_speech, prompt_text, prompt_language, target_text, target_language, use_prompt_text=True, temp=1.0, top_k=1000, top_p=0.85, infer_mode = "new"): if use_prompt_text: if infer_mode == "new" and prompt0_speech is not None and prompt0_speech.any(): - prompt0_phone_id = liwei_g2p_(prompt0_text, prompt_language)[1] + prompt0_phone_id = new_g2p_(prompt0_text, prompt_language)[1] prompt0_phone_id = torch.tensor(prompt0_phone_id, dtype=torch.long).to(device) - prompt_phone_id = liwei_g2p_(prompt_text, prompt_language)[1] + prompt_phone_id = new_g2p_(prompt_text, prompt_language)[1] prompt_phone_id = torch.tensor(prompt_phone_id, dtype=torch.long).to(device) - target_phone_id = liwei_g2p_(target_text, target_language)[1] + target_phone_id = new_g2p_(target_text, target_language)[1] target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device) phone_id = torch.cat([prompt_phone_id, torch.LongTensor([4]).to(device), target_phone_id]) else: - target_phone_id = liwei_g2p_(target_text, target_language)[1] + target_phone_id = new_g2p_(target_text, target_language)[1] target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device) phone_id = target_phone_id @@ -185,32 +184,17 @@ def text2semantic(prompt0_speech, prompt0_text, prompt_speech, prompt_text, prom semantic_code_prompt0 = extract_semantic_code(semantic_mean, semantic_std, input_fetures_prompt0, attention_mask_prompt0) if use_prompt_text: - if infer_mode =="ori": - predict_semantic = t2s_model.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :], temperature=temp, top_k=top_k, top_p=top_p) - elif infer_mode == "tune": - predict_semantic = t2s_model_tune.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :], temperature=temp, top_k=top_k, top_p=top_p) - elif infer_mode == "new": + if infer_mode == "new": predict_semantic = t2s_model_new.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :], prompt0_ids=semantic_code_prompt0[:, :], temperature=temp, top_k=top_k, top_p=top_p) else: - if infer_mode == "ori": - predict_semantic = t2s_model.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :1], temperature=temp, top_k=top_k, top_p=top_p) - elif infer_mode == "tune": - predict_semantic = t2s_model_tune.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :1], temperature=temp, top_k=top_k, top_p=top_p) - elif infer_mode == "new": + if infer_mode == "new": predict_semantic = t2s_model_new.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :1], prompt0_ids=semantic_code_prompt0[:, :1], temperature=temp, top_k=top_k, top_p=top_p) combine_semantic_code = torch.cat([semantic_code[:,:], predict_semantic], dim=-1) prompt_semantic_code = semantic_code - # max_com_semantic_value = torch.max(combine_semantic_code).item() - # max_prompt_semantic_value = torch.max(prompt_semantic_code).item() - - # print(f"Max token value in com semantic: {max_com_semantic_value}, shape is {combine_semantic_code.shape}") - # print(f"Max token value in prompt semantic: {max_prompt_semantic_value}, shape is {prompt_semantic_code.shape}") - # print(f"combine semantic_code of t2s new is {combine_semantic_code}, shape is {combine_semantic_code.shape}") - return combine_semantic_code, prompt_semantic_code @torch.no_grad() @@ -250,8 +234,8 @@ def semantic2acoustic(combine_semantic_code, acoustic_code): return combine_audio, recovered_audio device = torch.device("cuda:0") -cfg_soundstorm_1layer = load_config("./egs/tts/SoundStorm/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json") -cfg_soundstorm_full = load_config("./models/tts/debatts/egs/tts/SoundStorm/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json") +cfg_soundstorm_1layer = load_config("./s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json") +cfg_soundstorm_full = load_config("./s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json") soundstorm_1layer = build_soundstorm(cfg_soundstorm_1layer, device) soundstorm_full = build_soundstorm(cfg_soundstorm_full, device) @@ -270,20 +254,9 @@ def semantic2acoustic(combine_semantic_code, acoustic_code): safetensors.torch.load_model(soundstorm_1layer, soundstorm_1layer_path) safetensors.torch.load_model(soundstorm_full, soundstorm_full_path) -t2s_cfg = load_config("./exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json") -t2s_model = build_t2s_model(t2s_cfg, device) -t2s_model_ckpt_path = "/mntcephfs/lab_data/lijiaqi/debate/gluster-tts/ckpt/t2s/t2s_625ksteps_model.safetensors" -safetensors.torch.load_model(t2s_model, t2s_model_ckpt_path) -print(t2s_model.bos_target_id, t2s_model.eos_target_id, t2s_model.bos_phone_id, t2s_model.eos_phone_id, t2s_model.pad_token_id) - -t2s_cfg = load_config("./egs/tts/Text2Semantic/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json") -t2s_model_tune = build_t2s_model(t2s_cfg, device) -t2s_model_tune_ckpt_path = "/mntcephfs/data/wuzhizheng/debate_/ckpt_ori_tune/epoch-0021_step-0005000_loss-4.354165/model.safetensors" -safetensors.torch.load_model(t2s_model_tune, t2s_model_tune_ckpt_path) - -t2s_cfg = load_config("./egs/tts/Text2Semantic/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json") +t2s_cfg = load_config("./t2s_egs/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json") t2s_model_new = build_t2s_model_new(t2s_cfg, device) -t2s_model_new_ckpt_path = "./s2a_model/model.safetensors" # 1900(02), 1906 +t2s_model_new_ckpt_path = "./t2s_model/model.safetensors" safetensors.torch.load_model(t2s_model_new, t2s_model_new_ckpt_path) from funasr import AutoModel @@ -361,7 +334,7 @@ def generate_text_data(wav_file): return wav_file, txt, wav_file -def infer(speech_path, prompt_text, target_wav_path, target_text, target_language='zh', speech_path_prompt0=None, prompt0_text=None, temperature=0.2, top_k=20, top_p=0.9, concat_prompt=False, infer_mode="ori", idx = 0, epoch=0, spk_prompt_type=""): +def infer(speech_path, prompt_text, target_wav_path, target_text, target_language='zh', speech_path_prompt0=None, prompt0_text=None, temperature=0.2, top_k=20, top_p=0.9, concat_prompt=False, infer_mode="new", idx = 0, epoch=0, spk_prompt_type=""): if idx != 0: save_dir = os.path.join("The Path to Store Generated Speech", f"{infer_mode}/{spk_prompt_type}") if not os.path.exists(save_dir): @@ -376,22 +349,20 @@ def infer(speech_path, prompt_text, target_wav_path, target_text, target_languag if os.path.exists(save_path): return save_path - print(f"HERE COMES INFER!!! {infer_mode}") - print(f"IN INFER PROMPT text is {prompt_text}") - print(f"IN INFER Target text is {target_text}") + # print(f"HERE COMES INFER!!! {infer_mode}") + # print(f"IN INFER PROMPT text is {prompt_text}") + # print(f"IN INFER Target text is {target_text}") speech_16k = librosa.load(speech_path, sr=16000)[0] speech = librosa.load(speech_path, sr=cfg_soundstorm_1layer.preprocess.sample_rate)[0] if infer_mode == "new": speech_16k_prompt0 = librosa.load(speech_path_prompt0, sr=16000)[0] speech_prompt0 = librosa.load(speech_path_prompt0, sr=cfg_soundstorm_1layer.preprocess.sample_rate)[0] - # combine_semantic_code, _ = text2semantic_new(speech_16k_prompt0, prompt0_text, speech_16k, prompt_text, target_language, target_text, target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode=infer_mode) combine_semantic_code, _ = text2semantic(prompt0_speech=speech_16k_prompt0, prompt0_text=prompt0_text, prompt_speech=speech_16k, prompt_text=prompt_text, prompt_language=target_language, target_text=target_text, target_language=target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode = infer_mode) else: combine_semantic_code, _ = text2semantic(prompt0_speech=None, prompt0_text=None, prompt_speech=speech_16k, prompt_text=prompt_text, prompt_language = target_language, target_text=target_text, target_language=target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode=infer_mode) acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device)) - # print(acoustic_code.shape) combine_audio, recovered_audio = semantic2acoustic(combine_semantic_code, acoustic_code) @@ -401,7 +372,7 @@ def infer(speech_path, prompt_text, target_wav_path, target_text, target_languag sf.write(save_path, combine_audio, samplerate=cfg_soundstorm_1layer.preprocess.sample_rate) return save_path -def infer_small(speech_path, prompt_text, target_text, target_language='zh', speech_path_prompt0=None, prompt0_text=None, temperature=0.2, top_k=20, top_p=0.9, concat_prompt=False, infer_mode="ori", save_path=None): +def infer_small(speech_path, prompt_text, target_text, target_language='zh', speech_path_prompt0=None, prompt0_text=None, temperature=0.2, top_k=20, top_p=0.9, concat_prompt=False, infer_mode="new", save_path=None): if os.path.exists(save_path): return save_path @@ -427,24 +398,6 @@ def infer_small(speech_path, prompt_text, target_text, target_language='zh', spe sf.write(save_path, combine_audio, samplerate=cfg_soundstorm_1layer.preprocess.sample_rate) return save_path -def get_prompt0_wav_path(wav_path): - base_path = wav_path.split('chenci')[0].rstrip('_') - json_path = f"{base_path}.json" - - try: - with open(json_path, 'r', encoding='utf-8') as file: - data = json.load(file) - prompt0_wav_path = data['prompt0_wav_path'] - basename = os.path.basename(prompt0_wav_path) - prompt0_wav_path = os.path.join("Debatts-Data Test Directory", basename) - return prompt0_wav_path - except FileNotFoundError: - print(f"File not found: {json_path}") - except KeyError: - print(f"Cannot find 'prompt0_wav_path' in json") - except json.JSONDecodeError as e: - print(f"Error in reading json: {e}") - ##################################### EVALUATION ################################################################ from funasr import AutoModel import torch.nn.functional as F @@ -454,7 +407,7 @@ def get_prompt0_wav_path(wav_path): from models.tts.soundstorm.try_inference_new import evaluation_new from models.tts.soundstorm.try_inference_new import extract_emotion_similarity -prompt0_wav_path = "./models/tts/debatts/speech_examples/87_SPEAKER01_2_part03_213.wav" +prompt0_wav_path = "./debatts/speech_examples/87_SPEAKER01_2_part03_213.wav" prompt0_text = generate_text_data(prompt0_wav_path)[1] spk_prompt_wav_path = "The Speaker Identity Path" diff --git a/models/tts/debatts/utils/g2p_liwei/__init__.py b/models/tts/debatts/utils/g2p_new/__init__.py similarity index 87% rename from models/tts/debatts/utils/g2p_liwei/__init__.py rename to models/tts/debatts/utils/g2p_new/__init__.py index fa08d7a8..67e6fded 100644 --- a/models/tts/debatts/utils/g2p_liwei/__init__.py +++ b/models/tts/debatts/utils/g2p_new/__init__.py @@ -1,12 +1,12 @@ -from utils.g2p_liwei import cleaners +from utils.g2p_new import cleaners from tokenizers import Tokenizer -from utils.g2p_liwei.text_tokenizers import TextTokenizer +from utils.g2p_new.text_tokenizers import TextTokenizer import json import re class PhonemeBpeTokenizer: - def __init__(self, vacab_path="/mntcephfs/lab_data/lijiaqi/Speech/utils/g2p_liwei/vacab.json"): + def __init__(self, vacab_path="./utils/g2p_new/vacab.json"): self.lang2backend = { 'zh': "cmn", 'ja': "ja", @@ -18,7 +18,7 @@ def __init__(self, vacab_path="/mntcephfs/lab_data/lijiaqi/Speech/utils/g2p_liwe self.text_tokenizers = {} self.int_text_tokenizers() # TODO - vacab_path="/mntcephfs/lab_data/lijiaqi/Speech/utils/g2p_liwei/vacab.json" + vacab_path="/mntcephfs/lab_data/lijiaqi/Speech/utils/g2p_new/vacab.json" with open(vacab_path, 'rb') as f: json_data = f.read() data = json.loads(json_data) diff --git a/models/tts/debatts/utils/g2p_liwei/cleaners.py b/models/tts/debatts/utils/g2p_new/cleaners.py similarity index 68% rename from models/tts/debatts/utils/g2p_liwei/cleaners.py rename to models/tts/debatts/utils/g2p_new/cleaners.py index ee99d3be..58755726 100644 --- a/models/tts/debatts/utils/g2p_liwei/cleaners.py +++ b/models/tts/debatts/utils/g2p_new/cleaners.py @@ -1,10 +1,10 @@ import re -from utils.g2p_liwei.japanese import japanese_to_ipa -from utils.g2p_liwei.mandarin import chinese_to_ipa -from utils.g2p_liwei.english import english_to_ipa -from utils.g2p_liwei.french import french_to_ipa -from utils.g2p_liwei.korean import korean_to_ipa -from utils.g2p_liwei.german import german_to_ipa +from utils.g2p_new.japanese import japanese_to_ipa +from utils.g2p_new.mandarin import chinese_to_ipa +from utils.g2p_new.english import english_to_ipa +from utils.g2p_new.french import french_to_ipa +from utils.g2p_new.korean import korean_to_ipa +from utils.g2p_new.german import german_to_ipa def cjekfd_cleaners(text, language, text_tokenizers): diff --git a/models/tts/debatts/utils/g2p_liwei/english.py b/models/tts/debatts/utils/g2p_new/english.py similarity index 100% rename from models/tts/debatts/utils/g2p_liwei/english.py rename to models/tts/debatts/utils/g2p_new/english.py diff --git a/models/tts/debatts/utils/g2p_liwei/french.py b/models/tts/debatts/utils/g2p_new/french.py similarity index 100% rename from models/tts/debatts/utils/g2p_liwei/french.py rename to models/tts/debatts/utils/g2p_new/french.py diff --git a/models/tts/debatts/utils/g2p_liwei/g2p_liwei.py b/models/tts/debatts/utils/g2p_new/g2p_new.py similarity index 59% rename from models/tts/debatts/utils/g2p_liwei/g2p_liwei.py rename to models/tts/debatts/utils/g2p_new/g2p_new.py index 70cfb054..332bec80 100644 --- a/models/tts/debatts/utils/g2p_liwei/g2p_liwei.py +++ b/models/tts/debatts/utils/g2p_new/g2p_new.py @@ -1,7 +1,7 @@ -from utils.g2p_liwei import PhonemeBpeTokenizer +from utils.g2p_new import PhonemeBpeTokenizer import tqdm text_tokenizer = PhonemeBpeTokenizer() -def liwei_g2p(text, language): +def new_g2p(text, language): return text_tokenizer.tokenize(text=text, language=language) \ No newline at end of file diff --git a/models/tts/debatts/utils/g2p_liwei/german.py b/models/tts/debatts/utils/g2p_new/german.py similarity index 100% rename from models/tts/debatts/utils/g2p_liwei/german.py rename to models/tts/debatts/utils/g2p_new/german.py diff --git a/models/tts/debatts/utils/g2p_liwei/japanese.py b/models/tts/debatts/utils/g2p_new/japanese.py similarity index 100% rename from models/tts/debatts/utils/g2p_liwei/japanese.py rename to models/tts/debatts/utils/g2p_new/japanese.py diff --git a/models/tts/debatts/utils/g2p_liwei/korean.py b/models/tts/debatts/utils/g2p_new/korean.py similarity index 100% rename from models/tts/debatts/utils/g2p_liwei/korean.py rename to models/tts/debatts/utils/g2p_new/korean.py diff --git a/models/tts/debatts/utils/g2p_liwei/mandarin.py b/models/tts/debatts/utils/g2p_new/mandarin.py similarity index 100% rename from models/tts/debatts/utils/g2p_liwei/mandarin.py rename to models/tts/debatts/utils/g2p_new/mandarin.py diff --git a/models/tts/debatts/utils/g2p_liwei/text_tokenizers.py b/models/tts/debatts/utils/g2p_new/text_tokenizers.py similarity index 100% rename from models/tts/debatts/utils/g2p_liwei/text_tokenizers.py rename to models/tts/debatts/utils/g2p_new/text_tokenizers.py diff --git a/models/tts/debatts/utils/g2p_liwei/vacab.json b/models/tts/debatts/utils/g2p_new/vacab.json similarity index 100% rename from models/tts/debatts/utils/g2p_liwei/vacab.json rename to models/tts/debatts/utils/g2p_new/vacab.json From 1bf752361e795de8b24925c233ef55b175cdbe5f Mon Sep 17 00:00:00 2001 From: hehaorui Date: Mon, 28 Oct 2024 17:16:24 +0800 Subject: [PATCH 3/8] debatts code modified name --- models/tts/debatts/try_inference_small_samples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/tts/debatts/try_inference_small_samples.py b/models/tts/debatts/try_inference_small_samples.py index 23e6ee2b..d4193810 100644 --- a/models/tts/debatts/try_inference_small_samples.py +++ b/models/tts/debatts/try_inference_small_samples.py @@ -36,7 +36,7 @@ from tqdm import tqdm from transformers import SeamlessM4TFeatureExtractor -processor = SeamlessM4TFeatureExtractor.from_pretrained("/mntcephfs/lab_data/lijiaqi/debate/gluster-tts/w2v-bert-2") +processor = SeamlessM4TFeatureExtractor.from_pretrained("./ckpt/w2v-bert-2") from transformers import AutoProcessor, AutoModel From 5a95bc71c7d65f484bc56ace443283a0d8d6230d Mon Sep 17 00:00:00 2001 From: hehaorui Date: Mon, 28 Oct 2024 17:27:26 +0800 Subject: [PATCH 4/8] debatts code modified name --- models/tts/debatts/utils/audio.py | 74 --- models/tts/debatts/utils/audio_slicer.py | 476 ------------------ models/tts/debatts/utils/cut_by_vad.py | 105 ---- models/tts/debatts/utils/data_utils.py | 2 +- models/tts/debatts/utils/distribution.py | 2 +- models/tts/debatts/utils/dsp.py | 2 +- models/tts/debatts/utils/duration.py | 2 +- models/tts/debatts/utils/f0.py | 275 ---------- models/tts/debatts/utils/hparam.py | 2 +- models/tts/debatts/utils/hubert.py | 155 ------ models/tts/debatts/utils/io.py | 182 ------- models/tts/debatts/utils/io_optim.py | 123 ----- models/tts/debatts/utils/logger.py | 5 + models/tts/debatts/utils/mel.py | 280 ----------- models/tts/debatts/utils/mert.py | 139 ----- models/tts/debatts/utils/mfa_prepare.py | 116 ----- models/tts/debatts/utils/model_summary.py | 74 --- models/tts/debatts/utils/prompt_preparer.py | 68 --- models/tts/debatts/utils/ssim.py | 2 +- models/tts/debatts/utils/stft.py | 2 +- models/tts/debatts/utils/symbol_table.py | 2 +- models/tts/debatts/utils/tokenizer.py | 2 +- models/tts/debatts/utils/tool.py | 5 + models/tts/debatts/utils/topk_sampling.py | 2 +- models/tts/debatts/utils/trainer_utils.py | 2 +- models/tts/debatts/utils/util.py | 2 +- .../debatts/utils/whisper_transcription.py | 122 ----- models/tts/debatts/utils/world.py | 92 ---- 28 files changed, 22 insertions(+), 2293 deletions(-) delete mode 100644 models/tts/debatts/utils/audio.py delete mode 100644 models/tts/debatts/utils/audio_slicer.py delete mode 100644 models/tts/debatts/utils/cut_by_vad.py delete mode 100644 models/tts/debatts/utils/f0.py delete mode 100644 models/tts/debatts/utils/hubert.py delete mode 100644 models/tts/debatts/utils/io.py delete mode 100644 models/tts/debatts/utils/io_optim.py delete mode 100644 models/tts/debatts/utils/mel.py delete mode 100644 models/tts/debatts/utils/mert.py delete mode 100644 models/tts/debatts/utils/mfa_prepare.py delete mode 100644 models/tts/debatts/utils/model_summary.py delete mode 100644 models/tts/debatts/utils/prompt_preparer.py delete mode 100644 models/tts/debatts/utils/whisper_transcription.py delete mode 100644 models/tts/debatts/utils/world.py diff --git a/models/tts/debatts/utils/audio.py b/models/tts/debatts/utils/audio.py deleted file mode 100644 index 374d5091..00000000 --- a/models/tts/debatts/utils/audio.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import numpy as np -from numpy import linalg as LA -import librosa -import soundfile as sf -import librosa.filters - - -def load_audio_torch(wave_file, fs): - """Load audio data into torch tensor - - Args: - wave_file (str): path to wave file - fs (int): sample rate - - Returns: - audio (tensor): audio data in tensor - fs (int): sample rate - """ - - audio, sample_rate = librosa.load(wave_file, sr=fs, mono=True) - # audio: (T,) - assert len(audio) > 2 - - # Check the audio type (for soundfile loading backbone) - float, 8bit or 16bit - if np.issubdtype(audio.dtype, np.integer): - max_mag = -np.iinfo(audio.dtype).min - else: - max_mag = max(np.amax(audio), -np.amin(audio)) - max_mag = ( - (2**31) + 1 - if max_mag > (2**15) - else ((2**15) + 1 if max_mag > 1.01 else 1.0) - ) - - # Normalize the audio - audio = torch.FloatTensor(audio.astype(np.float32)) / max_mag - - if (torch.isnan(audio) | torch.isinf(audio)).any(): - return [], sample_rate or fs or 48000 - - # Resample the audio to our target samplerate - if fs is not None and fs != sample_rate: - audio = torch.from_numpy( - librosa.core.resample(audio.numpy(), orig_sr=sample_rate, target_sr=fs) - ) - sample_rate = fs - - return audio, fs - - -def _stft(y, cfg): - return librosa.stft( - y=y, n_fft=cfg.n_fft, hop_length=cfg.hop_size, win_length=cfg.win_size - ) - - -def energy(wav, cfg): - D = _stft(wav, cfg) - magnitudes = np.abs(D).T # [F, T] - return LA.norm(magnitudes, axis=1) - - -def get_energy_from_tacotron(audio, _stft): - audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) - audio = torch.autograd.Variable(audio, requires_grad=False) - mel, energy = _stft.mel_spectrogram(audio) - energy = torch.squeeze(energy, 0).numpy().astype(np.float32) - return mel, energy diff --git a/models/tts/debatts/utils/audio_slicer.py b/models/tts/debatts/utils/audio_slicer.py deleted file mode 100644 index 28474596..00000000 --- a/models/tts/debatts/utils/audio_slicer.py +++ /dev/null @@ -1,476 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import os -import json -import numpy as np -from tqdm import tqdm -import torch -import torchaudio - -from utils.io import save_audio -from utils.audio import load_audio_torch - - -# This function is obtained from librosa. -def get_rms( - y, - *, - frame_length=2048, - hop_length=512, - pad_mode="constant", -): - padding = (int(frame_length // 2), int(frame_length // 2)) - y = np.pad(y, padding, mode=pad_mode) - - axis = -1 - # put our new within-frame axis at the end for now - out_strides = y.strides + tuple([y.strides[axis]]) - # Reduce the shape on the framing axis - x_shape_trimmed = list(y.shape) - x_shape_trimmed[axis] -= frame_length - 1 - out_shape = tuple(x_shape_trimmed) + tuple([frame_length]) - xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides) - if axis < 0: - target_axis = axis - 1 - else: - target_axis = axis + 1 - xw = np.moveaxis(xw, -1, target_axis) - # Downsample along the target axis - slices = [slice(None)] * xw.ndim - slices[axis] = slice(0, None, hop_length) - x = xw[tuple(slices)] - - # Calculate power - power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True) - - return np.sqrt(power) - - -class Slicer: - """ - Copy from: https://github.com/openvpi/audio-slicer/blob/main/slicer2.py - """ - - def __init__( - self, - sr: int, - threshold: float = -40.0, - min_length: int = 5000, - min_interval: int = 300, - hop_size: int = 10, - max_sil_kept: int = 5000, - ): - if not min_length >= min_interval >= hop_size: - raise ValueError( - "The following condition must be satisfied: min_length >= min_interval >= hop_size" - ) - if not max_sil_kept >= hop_size: - raise ValueError( - "The following condition must be satisfied: max_sil_kept >= hop_size" - ) - min_interval = sr * min_interval / 1000 - self.threshold = 10 ** (threshold / 20.0) - self.hop_size = round(sr * hop_size / 1000) - self.win_size = min(round(min_interval), 4 * self.hop_size) - self.min_length = round(sr * min_length / 1000 / self.hop_size) - self.min_interval = round(min_interval / self.hop_size) - self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) - - def _apply_slice(self, waveform, begin, end): - begin = begin * self.hop_size - if len(waveform.shape) > 1: - end = min(waveform.shape[1], end * self.hop_size) - return waveform[:, begin:end], begin, end - else: - end = min(waveform.shape[0], end * self.hop_size) - return waveform[begin:end], begin, end - - # @timeit - def slice(self, waveform, return_chunks_positions=False): - if len(waveform.shape) > 1: - # (#channle, wave_len) -> (wave_len) - samples = waveform.mean(axis=0) - else: - samples = waveform - if samples.shape[0] <= self.min_length: - return [waveform] - rms_list = get_rms( - y=samples, frame_length=self.win_size, hop_length=self.hop_size - ).squeeze(0) - sil_tags = [] - silence_start = None - clip_start = 0 - for i, rms in enumerate(rms_list): - # Keep looping while frame is silent. - if rms < self.threshold: - # Record start of silent frames. - if silence_start is None: - silence_start = i - continue - # Keep looping while frame is not silent and silence start has not been recorded. - if silence_start is None: - continue - # Clear recorded silence start if interval is not enough or clip is too short - is_leading_silence = silence_start == 0 and i > self.max_sil_kept - need_slice_middle = ( - i - silence_start >= self.min_interval - and i - clip_start >= self.min_length - ) - if not is_leading_silence and not need_slice_middle: - silence_start = None - continue - # Need slicing. Record the range of silent frames to be removed. - if i - silence_start <= self.max_sil_kept: - pos = rms_list[silence_start : i + 1].argmin() + silence_start - if silence_start == 0: - sil_tags.append((0, pos)) - else: - sil_tags.append((pos, pos)) - clip_start = pos - elif i - silence_start <= self.max_sil_kept * 2: - pos = rms_list[ - i - self.max_sil_kept : silence_start + self.max_sil_kept + 1 - ].argmin() - pos += i - self.max_sil_kept - pos_l = ( - rms_list[ - silence_start : silence_start + self.max_sil_kept + 1 - ].argmin() - + silence_start - ) - pos_r = ( - rms_list[i - self.max_sil_kept : i + 1].argmin() - + i - - self.max_sil_kept - ) - if silence_start == 0: - sil_tags.append((0, pos_r)) - clip_start = pos_r - else: - sil_tags.append((min(pos_l, pos), max(pos_r, pos))) - clip_start = max(pos_r, pos) - else: - pos_l = ( - rms_list[ - silence_start : silence_start + self.max_sil_kept + 1 - ].argmin() - + silence_start - ) - pos_r = ( - rms_list[i - self.max_sil_kept : i + 1].argmin() - + i - - self.max_sil_kept - ) - if silence_start == 0: - sil_tags.append((0, pos_r)) - else: - sil_tags.append((pos_l, pos_r)) - clip_start = pos_r - silence_start = None - # Deal with trailing silence. - total_frames = rms_list.shape[0] - if ( - silence_start is not None - and total_frames - silence_start >= self.min_interval - ): - silence_end = min(total_frames, silence_start + self.max_sil_kept) - pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start - sil_tags.append((pos, total_frames + 1)) - # Apply and return slices. - if len(sil_tags) == 0: - return [waveform] - else: - chunks = [] - chunks_pos_of_waveform = [] - - if sil_tags[0][0] > 0: - chunk, begin, end = self._apply_slice(waveform, 0, sil_tags[0][0]) - chunks.append(chunk) - chunks_pos_of_waveform.append((begin, end)) - - for i in range(len(sil_tags) - 1): - chunk, begin, end = self._apply_slice( - waveform, sil_tags[i][1], sil_tags[i + 1][0] - ) - chunks.append(chunk) - chunks_pos_of_waveform.append((begin, end)) - - if sil_tags[-1][1] < total_frames: - chunk, begin, end = self._apply_slice( - waveform, sil_tags[-1][1], total_frames - ) - chunks.append(chunk) - chunks_pos_of_waveform.append((begin, end)) - - return ( - chunks - if not return_chunks_positions - else ( - chunks, - chunks_pos_of_waveform, - ) - ) - - -def split_utterances_from_audio( - wav_file, - output_dir, - max_duration_of_utterance=10.0, - min_interval=300, - db_threshold=-40, -): - """ - Split a long audio into utterances accoring to the silence (VAD). - - max_duration_of_utterance (second): - The maximum duration of every utterance (seconds) - min_interval (millisecond): - The smaller min_interval is, the more sliced audio clips this script is likely to generate. - """ - print("File:", wav_file.split("/")[-1]) - waveform, fs = torchaudio.load(wav_file) - - slicer = Slicer(sr=fs, min_interval=min_interval, threshold=db_threshold) - chunks, positions = slicer.slice(waveform, return_chunks_positions=True) - - durations = [(end - begin) / fs for begin, end in positions] - print( - "Slicer's min silence part is {}ms, min and max duration of sliced utterances is {}s and {}s".format( - min_interval, min(durations), max(durations) - ) - ) - - res_chunks, res_positions = [], [] - for i, chunk in enumerate(chunks): - if len(chunk.shape) == 1: - chunk = chunk[None, :] - - begin, end = positions[i] - assert end - begin == chunk.shape[-1] - - max_wav_len = max_duration_of_utterance * fs - if chunk.shape[-1] <= max_wav_len: - res_chunks.append(chunk) - res_positions.append(positions[i]) - else: - # TODO: to reserve overlapping and conduct fade-in, fade-out - - # Get segments number - number = 2 - while chunk.shape[-1] // number >= max_wav_len: - number += 1 - seg_len = chunk.shape[-1] // number - - # Split - for num in range(number): - s = seg_len * num - t = min(s + seg_len, chunk.shape[-1]) - - seg_begin = begin + s - seg_end = begin + t - - res_chunks.append(chunk[:, s:t]) - res_positions.append((seg_begin, seg_end)) - - # Save utterances - os.makedirs(output_dir, exist_ok=True) - res = {"fs": int(fs)} - for i, chunk in enumerate(res_chunks): - filename = "{:04d}.wav".format(i) - res[filename] = [int(p) for p in res_positions[i]] - save_audio(os.path.join(output_dir, filename), chunk, fs) - - # Save positions - with open(os.path.join(output_dir, "positions.json"), "w") as f: - json.dump(res, f, indent=4, ensure_ascii=False) - return res - - -def is_silence( - wavform, - fs, - threshold=-40.0, - min_interval=300, - hop_size=10, - min_length=5000, -): - """ - Detect whether the given wavform is a silence - - wavform: (T, ) - """ - threshold = 10 ** (threshold / 20.0) - - hop_size = round(fs * hop_size / 1000) - win_size = min(round(min_interval), 4 * hop_size) - min_length = round(fs * min_length / 1000 / hop_size) - - if wavform.shape[0] <= min_length: - return True - - # (#Frame,) - rms_array = get_rms(y=wavform, frame_length=win_size, hop_length=hop_size).squeeze( - 0 - ) - return (rms_array < threshold).all() - - -def split_audio( - wav_file, target_sr, output_dir, max_duration_of_segment=10.0, overlap_duration=1.0 -): - """ - Split a long audio into segments. - - target_sr: - The target sampling rate to save the segments. - max_duration_of_utterance (second): - The maximum duration of every utterance (second) - overlap_duraion: - Each segment has "overlap duration" (second) overlap with its previous and next segment - """ - # (#channel, T) -> (T,) - waveform, fs = torchaudio.load(wav_file) - waveform = torchaudio.functional.resample( - waveform, orig_freq=fs, new_freq=target_sr - ) - waveform = torch.mean(waveform, dim=0) - - # waveform, _ = load_audio_torch(wav_file, target_sr) - assert len(waveform.shape) == 1 - - assert overlap_duration < max_duration_of_segment - length = int(max_duration_of_segment * target_sr) - stride = int((max_duration_of_segment - overlap_duration) * target_sr) - chunks = [] - for i in range(0, len(waveform), stride): - # (length,) - chunks.append(waveform[i : i + length]) - if i + length >= len(waveform): - break - - # Save segments - os.makedirs(output_dir, exist_ok=True) - results = [] - for i, chunk in enumerate(chunks): - uid = "{:04d}".format(i) - filename = os.path.join(output_dir, "{}.wav".format(uid)) - results.append( - {"Uid": uid, "Path": filename, "Duration": len(chunk) / target_sr} - ) - save_audio( - filename, - chunk, - target_sr, - turn_up=not is_silence(chunk, target_sr), - add_silence=False, - ) - - return results - - -def merge_segments_torchaudio(wav_files, fs, output_path, overlap_duration=1.0): - """Merge the given wav_files (may have overlaps) into a long audio - - fs: - The sampling rate of the wav files. - output_path: - The output path to save the merged audio. - overlap_duration (float, optional): - Each segment has "overlap duration" (second) overlap with its previous and next segment. Defaults to 1.0. - """ - - waveforms = [] - for file in wav_files: - # (T,) - waveform, _ = load_audio_torch(file, fs) - waveforms.append(waveform) - - if len(waveforms) == 1: - save_audio(output_path, waveforms[0], fs, add_silence=False, turn_up=False) - return - - overlap_len = int(overlap_duration * fs) - fade_out = torchaudio.transforms.Fade(fade_out_len=overlap_len) - fade_in = torchaudio.transforms.Fade(fade_in_len=overlap_len) - fade_in_and_out = torchaudio.transforms.Fade(fade_out_len=overlap_len) - - segments_lens = [len(wav) for wav in waveforms] - merged_waveform_len = sum(segments_lens) - overlap_len * (len(waveforms) - 1) - merged_waveform = torch.zeros(merged_waveform_len) - - start = 0 - for index, wav in enumerate( - tqdm(waveforms, desc="Merge for {}".format(output_path)) - ): - wav_len = len(wav) - - if index == 0: - wav = fade_out(wav) - elif index == len(waveforms) - 1: - wav = fade_in(wav) - else: - wav = fade_in_and_out(wav) - - merged_waveform[start : start + wav_len] = wav - start += wav_len - overlap_len - - save_audio(output_path, merged_waveform, fs, add_silence=False, turn_up=True) - - -def merge_segments_encodec(wav_files, fs, output_path, overlap_duration=1.0): - """Merge the given wav_files (may have overlaps) into a long audio - - fs: - The sampling rate of the wav files. - output_path: - The output path to save the merged audio. - overlap_duration (float, optional): - Each segment has "overlap duration" (second) overlap with its previous and next segment. Defaults to 1.0. - """ - - waveforms = [] - for file in wav_files: - # (T,) - waveform, _ = load_audio_torch(file, fs) - waveforms.append(waveform) - - if len(waveforms) == 1: - save_audio(output_path, waveforms[0], fs, add_silence=False, turn_up=False) - return - - device = waveforms[0].device - dtype = waveforms[0].dtype - shape = waveforms[0].shape[:-1] - - overlap_len = int(overlap_duration * fs) - segments_lens = [len(wav) for wav in waveforms] - merged_waveform_len = sum(segments_lens) - overlap_len * (len(waveforms) - 1) - - sum_weight = torch.zeros(merged_waveform_len, device=device, dtype=dtype) - out = torch.zeros(*shape, merged_waveform_len, device=device, dtype=dtype) - offset = 0 - - for frame in waveforms: - frame_length = frame.size(-1) - t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=torch.float32)[ - 1:-1 - ] - weight = 0.5 - (t - 0.5).abs() - weighted_frame = frame * weight - - cur = out[..., offset : offset + frame_length] - cur += weighted_frame[..., : cur.size(-1)] - out[..., offset : offset + frame_length] = cur - - cur = sum_weight[offset : offset + frame_length] - cur += weight[..., : cur.size(-1)] - sum_weight[offset : offset + frame_length] = cur - - offset += frame_length - overlap_len - - assert sum_weight.min() > 0 - merged_waveform = out / sum_weight - save_audio(output_path, merged_waveform, fs, add_silence=False, turn_up=True) diff --git a/models/tts/debatts/utils/cut_by_vad.py b/models/tts/debatts/utils/cut_by_vad.py deleted file mode 100644 index 0d41a4a1..00000000 --- a/models/tts/debatts/utils/cut_by_vad.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -""" This code is modified from https://github.com/facebookresearch/libri-light/blob/main/data_preparation/cut_by_vad.py""" -import pathlib -import soundfile as sf -import numpy as np -import json -import multiprocessing -import tqdm - - -def save(seq, fname, index, extension): - """save audio sequences to file""" - output = np.hstack(seq) - file_name = fname.parent / (fname.stem + f"_{index:04}{extension}") - fname.parent.mkdir(exist_ok=True, parents=True) - sf.write(file_name, output, samplerate=16000) - - -def cut_sequence(path, vad, path_out, target_len_sec, out_extension): - """cut audio sequences based on VAD""" - data, samplerate = sf.read(path) - - assert len(data.shape) == 1 - assert samplerate == 16000 - - to_stitch = [] - length_accumulated = 0.0 - - i = 0 - # Iterate over VAD segments - for start, end in vad: - start_index = int(start * samplerate) - end_index = int(end * samplerate) - slice = data[start_index:end_index] - - # Save slices that exceed the target length or if there's already accumulated audio - if ( - length_accumulated + (end - start) > target_len_sec - and length_accumulated > 0 - ): - save(to_stitch, path_out, i, out_extension) - to_stitch = [] - i += 1 - length_accumulated = 0 - - # Add the current slice to the list to be stitched - to_stitch.append(slice) - length_accumulated += end - start - - # Save any remaining slices - if to_stitch: - save(to_stitch, path_out, i, out_extension) - - -def cut_book(task): - """process each book in the dataset""" - path_book, root_out, target_len_sec, extension = task - - speaker = pathlib.Path(path_book.parent.name) - - for i, meta_file_path in enumerate(path_book.glob("*.json")): - with open(meta_file_path, "r") as f: - meta = json.loads(f.read()) - book_id = meta["book_meta"]["id"] - vad = meta["voice_activity"] - - sound_file = meta_file_path.parent / (meta_file_path.stem + ".flac") - - path_out = root_out / speaker / book_id / (meta_file_path.stem) - cut_sequence(sound_file, vad, path_out, target_len_sec, extension) - - -def cut_segments( - input_dir, output_dir, target_len_sec=30, n_process=32, out_extension=".wav" -): - """Main function to cut segments from audio files""" - - pathlib.Path(output_dir).mkdir(exist_ok=True, parents=True) - list_dir = pathlib.Path(input_dir).glob("*/*") - list_dir = [x for x in list_dir if x.is_dir()] - - print(f"{len(list_dir)} directories detected") - print(f"Launching {n_process} processes") - - # Create tasks for multiprocessing - tasks = [ - (path_book, output_dir, target_len_sec, out_extension) for path_book in list_dir - ] - - # Process tasks in parallel using multiprocessing - with multiprocessing.Pool(processes=n_process) as pool: - for _ in tqdm.tqdm(pool.imap_unordered(cut_book, tasks), total=len(tasks)): - pass - - -if __name__ == "__main__": - input_dir = "/path/to/input_dir" - output_dir = "/path/to/output_dir" - target_len_sec = 10 - n_process = 16 - cut_segments(input_dir, output_dir, target_len_sec, n_process) diff --git a/models/tts/debatts/utils/data_utils.py b/models/tts/debatts/utils/data_utils.py index 8c0bc2ff..15f240d9 100644 --- a/models/tts/debatts/utils/data_utils.py +++ b/models/tts/debatts/utils/data_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Amphion. +# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/distribution.py b/models/tts/debatts/utils/distribution.py index de3000e9..6718693b 100644 --- a/models/tts/debatts/utils/distribution.py +++ b/models/tts/debatts/utils/distribution.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Amphion. +# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/dsp.py b/models/tts/debatts/utils/dsp.py index 18f9466f..bc68606f 100644 --- a/models/tts/debatts/utils/dsp.py +++ b/models/tts/debatts/utils/dsp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Amphion. +# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/duration.py b/models/tts/debatts/utils/duration.py index c9544b40..35ad3624 100644 --- a/models/tts/debatts/utils/duration.py +++ b/models/tts/debatts/utils/duration.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Amphion. +# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/f0.py b/models/tts/debatts/utils/f0.py deleted file mode 100644 index 169b1403..00000000 --- a/models/tts/debatts/utils/f0.py +++ /dev/null @@ -1,275 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import librosa -import numpy as np -import torch -import parselmouth -import torchcrepe -import pyworld as pw - - -def f0_to_coarse(f0, pitch_bin, f0_min, f0_max): - """ - Convert f0 (Hz) to pitch (mel scale), and then quantize the mel-scale pitch to the - range from [1, 2, 3, ..., pitch_bin-1] - - Reference: https://en.wikipedia.org/wiki/Mel_scale - - Args: - f0 (array or Tensor): Hz - pitch_bin (int): the vocabulary size - f0_min (int): the minimum f0 (Hz) - f0_max (int): the maximum f0 (Hz) - - Returns: - quantized f0 (array or Tensor) - """ - f0_mel_min = 1127 * np.log(1 + f0_min / 700) - f0_mel_max = 1127 * np.log(1 + f0_max / 700) - - is_torch = isinstance(f0, torch.Tensor) - f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) - f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (pitch_bin - 2) / ( - f0_mel_max - f0_mel_min - ) + 1 - - f0_mel[f0_mel <= 1] = 1 - f0_mel[f0_mel > pitch_bin - 1] = pitch_bin - 1 - f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int32) - assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, ( - f0_coarse.max(), - f0_coarse.min(), - ) - return f0_coarse - - -def interpolate(f0): - """Interpolate the unvoiced part. Thus the f0 can be passed to a subtractive synthesizer. - Args: - f0: A numpy array of shape (seq_len,) - Returns: - f0: Interpolated f0 of shape (seq_len,) - uv: Unvoiced part of shape (seq_len,) - """ - uv = f0 == 0 - if len(f0[~uv]) > 0: - # interpolate the unvoiced f0 - f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) - uv = uv.astype("float") - uv = np.min(np.array([uv[:-2], uv[1:-1], uv[2:]]), axis=0) - uv = np.pad(uv, (1, 1)) - return f0, uv - - -def get_log_f0(f0): - f0[np.where(f0 == 0)] = 1 - log_f0 = np.log(f0) - return log_f0 - - -def get_f0_features_using_pyin(audio, cfg): - """Using pyin to extract the f0 feature. - Args: - audio - fs - win_length - hop_length - f0_min - f0_max - Returns: - f0: numpy array of shape (frame_len,) - """ - f0, voiced_flag, voiced_probs = librosa.pyin( - y=audio, - fmin=cfg.f0_min, - fmax=cfg.f0_max, - sr=cfg.sample_rate, - win_length=cfg.win_size, - hop_length=cfg.hop_size, - ) - # Set nan to 0 - f0[voiced_flag == False] = 0 - return f0 - - -def get_f0_features_using_parselmouth(audio, cfg, speed=1): - """Using parselmouth to extract the f0 feature. - Args: - audio - mel_len - hop_length - fs - f0_min - f0_max - speed(default=1) - Returns: - f0: numpy array of shape (frame_len,) - pitch_coarse: numpy array of shape (frame_len,) - """ - hop_size = int(np.round(cfg.hop_size * speed)) - - # Calculate the time step for pitch extraction - time_step = hop_size / cfg.sample_rate * 1000 - - f0 = ( - parselmouth.Sound(audio, cfg.sample_rate) - .to_pitch_ac( - time_step=time_step / 1000, - voicing_threshold=0.6, - pitch_floor=cfg.f0_min, - pitch_ceiling=cfg.f0_max, - ) - .selected_array["frequency"] - ) - return f0 - - -def get_f0_features_using_dio(audio, cfg): - """Using dio to extract the f0 feature. - Args: - audio - mel_len - fs - hop_length - f0_min - f0_max - Returns: - f0: numpy array of shape (frame_len,) - """ - # Get the raw f0 - _f0, t = pw.dio( - audio.astype("double"), - cfg.sample_rate, - f0_floor=cfg.f0_min, - f0_ceil=cfg.f0_max, - channels_in_octave=2, - frame_period=(1000 * cfg.hop_size / cfg.sample_rate), - ) - # Get the f0 - f0 = pw.stonemask(audio.astype("double"), _f0, t, cfg.sample_rate) - return f0 - - -def get_f0_features_using_harvest(audio, mel_len, fs, hop_length, f0_min, f0_max): - """Using harvest to extract the f0 feature. - Args: - audio - mel_len - fs - hop_length - f0_min - f0_max - Returns: - f0: numpy array of shape (frame_len,) - """ - f0, _ = pw.harvest( - audio.astype("double"), - fs, - f0_floor=f0_min, - f0_ceil=f0_max, - frame_period=(1000 * hop_length / fs), - ) - f0 = f0.astype("float")[:mel_len] - return f0 - - -def get_f0_features_using_crepe( - audio, mel_len, fs, hop_length, hop_length_new, f0_min, f0_max, threshold=0.3 -): - """Using torchcrepe to extract the f0 feature. - Args: - audio - mel_len - fs - hop_length - hop_length_new - f0_min - f0_max - threshold(default=0.3) - Returns: - f0: numpy array of shape (frame_len,) - """ - # Currently, crepe only supports 16khz audio - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - audio_16k = librosa.resample(audio, orig_sr=fs, target_sr=16000) - audio_16k_torch = torch.FloatTensor(audio_16k).unsqueeze(0).to(device) - - # Get the raw pitch - f0, pd = torchcrepe.predict( - audio_16k_torch, - 16000, - hop_length_new, - f0_min, - f0_max, - pad=True, - model="full", - batch_size=1024, - device=device, - return_periodicity=True, - ) - - # Filter, de-silence, set up threshold for unvoiced part - pd = torchcrepe.filter.median(pd, 3) - pd = torchcrepe.threshold.Silence(-60.0)(pd, audio_16k_torch, 16000, hop_length_new) - f0 = torchcrepe.threshold.At(threshold)(f0, pd) - f0 = torchcrepe.filter.mean(f0, 3) - - # Convert unvoiced part to 0hz - f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0) - - # Interpolate f0 - nzindex = torch.nonzero(f0[0]).squeeze() - f0 = torch.index_select(f0[0], dim=0, index=nzindex).cpu().numpy() - time_org = 0.005 * nzindex.cpu().numpy() - time_frame = np.arange(mel_len) * hop_length / fs - f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) - return f0 - - -def get_f0(audio, cfg, use_interpolate=False, return_uv=False): - if cfg.pitch_extractor == "dio": - f0 = get_f0_features_using_dio(audio, cfg) - elif cfg.pitch_extractor == "pyin": - f0 = get_f0_features_using_pyin(audio, cfg) - elif cfg.pitch_extractor == "parselmouth": - f0 = get_f0_features_using_parselmouth(audio, cfg) - - if use_interpolate: - f0, uv = interpolate(f0) - else: - uv = f0 == 0 - - if return_uv: - return f0, uv - - return f0 - - -def get_cents(f0_hz): - """ - F_{cent} = 1200 * log2 (F/440) - - Reference: - APSIPA'17, Perceptual Evaluation of Singing Quality - """ - voiced_f0 = f0_hz[f0_hz != 0] - return 1200 * np.log2(voiced_f0 / 440) - - -def get_pitch_derivatives(f0_hz): - """ - f0_hz: (,T) - """ - f0_cent = get_cents(f0_hz) - return f0_cent[1:] - f0_cent[:-1] - - -def get_pitch_sub_median(f0_hz): - """ - f0_hz: (,T) - """ - f0_cent = get_cents(f0_hz) - return f0_cent - np.median(f0_cent) diff --git a/models/tts/debatts/utils/hparam.py b/models/tts/debatts/utils/hparam.py index c5dd35c6..551792a1 100644 --- a/models/tts/debatts/utils/hparam.py +++ b/models/tts/debatts/utils/hparam.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Amphion. +# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/hubert.py b/models/tts/debatts/utils/hubert.py deleted file mode 100644 index 84b509fb..00000000 --- a/models/tts/debatts/utils/hubert.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# This code is modified from https://github.com/svc-develop-team/so-vits-svc/blob/4.0/preprocess_hubert_f0.py - -import os -import librosa -import torch -import numpy as np -from fairseq import checkpoint_utils -from tqdm import tqdm -import torch - - -def load_hubert_model(hps): - # Load model - ckpt_path = hps.hubert_file - print("Load Hubert Model...") - - models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( - [ckpt_path], - suffix="", - ) - model = models[0] - model.eval() - - if torch.cuda.is_available(): - model = model.cuda() - - return model - - -def get_hubert_content(hmodel, wav_16k_tensor): - feats = wav_16k_tensor - if feats.dim() == 2: # double channels - feats = feats.mean(-1) - assert feats.dim() == 1, feats.dim() - feats = feats.view(1, -1) - padding_mask = torch.BoolTensor(feats.shape).fill_(False) - inputs = { - "source": feats.to(wav_16k_tensor.device), - "padding_mask": padding_mask.to(wav_16k_tensor.device), - "output_layer": 9, # layer 9 - } - with torch.no_grad(): - logits = hmodel.extract_features(**inputs) - feats = hmodel.final_proj(logits[0]).squeeze(0) - - return feats - - -def content_vector_encoder(model, audio_path, default_sampling_rate=16000): - """ - # content vector default sr: 16000 - """ - - wav16k, sr = librosa.load(audio_path, sr=default_sampling_rate) - device = next(model.parameters()).device - wav16k = torch.from_numpy(wav16k).to(device) - - # (1, 256, frame_len) - content_feature = get_hubert_content(model, wav_16k_tensor=wav16k) - - return content_feature.cpu().detach().numpy() - - -def repeat_expand_2d(content, target_len): - """ - content : [hubert_dim(256), src_len] - target: [hubert_dim(256), target_len] - """ - src_len = content.shape[-1] - target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to( - content.device - ) - temp = torch.arange(src_len + 1) * target_len / src_len - current_pos = 0 - for i in range(target_len): - if i < temp[current_pos + 1]: - target[:, i] = content[:, current_pos] - else: - current_pos += 1 - target[:, i] = content[:, current_pos] - - return target - - -def get_mapped_features(raw_content_features, mapping_features): - """ - Content Vector: frameshift = 20ms, hop_size = 480 in 24k - - Now it's only used for mapping to bigvgan's mels (sr = 24k, hop_size = 256, frameshift ~= 10.7 ms) - """ - source_hop = 480 - target_hop = 256 - - factor = np.gcd(source_hop, target_hop) - source_hop //= factor - target_hop //= factor - print( - "Mapping source's {} frames => target's {} frames".format( - target_hop, source_hop - ) - ) - - results = [] - for index, mapping_feat in enumerate(tqdm(mapping_features)): - # mappping_feat: (mels_frame_len, n_mels) - target_len = len(mapping_feat) - - # (source_len, 256) - raw_feats = raw_content_features[index][0].cpu().numpy().T - source_len, width = raw_feats.shape - - # const ~= target_len * target_hop - const = source_len * source_hop // target_hop * target_hop - - # (source_len * source_hop, dim) - up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0) - # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) - down_sampling_feats = np.average( - up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 - ) - - err = abs(target_len - len(down_sampling_feats)) - if err > 3: - print("index:", index) - print("mels:", mapping_feat.shape) - print("raw content vector:", raw_feats.shape) - print("up_sampling:", up_sampling_feats.shape) - print("down_sampling_feats:", down_sampling_feats.shape) - exit() - if len(down_sampling_feats) < target_len: - # (1, dim) -> (err, dim) - end = down_sampling_feats[-1][None, :].repeat(err, axis=0) - down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) - - # (target_len, dim) - feats = down_sampling_feats[:target_len] - results.append(feats) - - return results - - -def extract_hubert_features_of_dataset(datasets, model, out_dir): - for utt in tqdm(datasets): - uid = utt["Uid"] - audio_path = utt["Path"] - - content_vector_feature = content_vector_encoder(model, audio_path) # (T, 256) - - save_path = os.path.join(out_dir, uid + ".npy") - np.save(save_path, content_vector_feature) diff --git a/models/tts/debatts/utils/io.py b/models/tts/debatts/utils/io.py deleted file mode 100644 index a93e75c6..00000000 --- a/models/tts/debatts/utils/io.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import os -import numpy as np -import torch -import torchaudio - - -def save_feature(process_dir, feature_dir, item, feature, overrides=True): - """Save features to path - - Args: - process_dir (str): directory to store features - feature_dir (_type_): directory to store one type of features (mel, energy, ...) - item (str): uid - feature (tensor): feature tensor - overrides (bool, optional): whether to override existing files. Defaults to True. - """ - process_dir = os.path.join(process_dir, feature_dir) - os.makedirs(process_dir, exist_ok=True) - out_path = os.path.join(process_dir, item + ".npy") - - if os.path.exists(out_path): - if overrides: - np.save(out_path, feature) - else: - np.save(out_path, feature) - - -def save_txt(process_dir, feature_dir, item, feature, overrides=True): - process_dir = os.path.join(process_dir, feature_dir) - os.makedirs(process_dir, exist_ok=True) - out_path = os.path.join(process_dir, item + ".txt") - - if os.path.exists(out_path): - if overrides: - f = open(out_path, "w") - f.writelines(feature) - f.close() - else: - f = open(out_path, "w") - f.writelines(feature) - f.close() - - -def save_audio(path, waveform, fs, add_silence=False, turn_up=False, volume_peak=0.9): - """Save audio to path with processing (turn up volume, add silence) - Args: - path (str): path to save audio - waveform (numpy array): waveform to save - fs (int): sampling rate - add_silence (bool, optional): whether to add silence to beginning and end. Defaults to False. - turn_up (bool, optional): whether to turn up volume. Defaults to False. - volume_peak (float, optional): volume peak. Defaults to 0.9. - """ - if turn_up: - # continue to turn up to volume_peak - ratio = volume_peak / max(waveform.max(), abs(waveform.min())) - waveform = waveform * ratio - - if add_silence: - silence_len = fs // 20 - silence = np.zeros((silence_len,), dtype=waveform.dtype) - result = np.concatenate([silence, waveform, silence]) - waveform = result - - waveform = torch.as_tensor(waveform, dtype=torch.float32, device="cpu") - if len(waveform.size()) == 1: - waveform = waveform[None, :] - elif waveform.size(0) != 1: - # Stereo to mono - waveform = torch.mean(waveform, dim=0, keepdim=True) - torchaudio.save(path, waveform, fs, encoding="PCM_S", bits_per_sample=16) - - -def save_torch_audio(process_dir, feature_dir, item, wav_torch, fs, overrides=True): - """Save torch audio to path without processing - Args: - process_dir (str): directory to store features - feature_dir (_type_): directory to store one type of features (mel, energy, ...) - item (str): uid - wav_torch (tensor): feature tensor - fs (int): sampling rate - overrides (bool, optional): whether to override existing files. Defaults to True. - """ - if wav_torch.shape != 2: - wav_torch = wav_torch.unsqueeze(0) - - process_dir = os.path.join(process_dir, feature_dir) - os.makedirs(process_dir, exist_ok=True) - out_path = os.path.join(process_dir, item + ".wav") - - torchaudio.save(out_path, wav_torch, fs) - - -async def async_load_audio(path, sample_rate: int = 24000): - r""" - Args: - path: The source loading path. - sample_rate: The target sample rate, will automatically resample if necessary. - - Returns: - waveform: The waveform object. Should be [1 x sequence_len]. - """ - - async def use_torchaudio_load(path): - return torchaudio.load(path) - - waveform, sr = await use_torchaudio_load(path) - waveform = torch.mean(waveform, dim=0, keepdim=True) - - if sr != sample_rate: - waveform = torchaudio.functional.resample(waveform, sr, sample_rate) - - if torch.any(torch.isnan(waveform) or torch.isinf(waveform)): - raise ValueError("NaN or Inf found in waveform.") - return waveform - - -async def async_save_audio( - path, - waveform, - sample_rate: int = 24000, - add_silence: bool = False, - volume_peak: float = 0.9, -): - r""" - Args: - path: The target saving path. - waveform: The waveform object. Should be [n_channel x sequence_len]. - sample_rate: Sample rate. - add_silence: If ``true``, concat 0.05s silence to beginning and end. - volume_peak: Turn up volume for larger number, vice versa. - """ - - async def use_torchaudio_save(path, waveform, sample_rate): - torchaudio.save( - path, waveform, sample_rate, encoding="PCM_S", bits_per_sample=16 - ) - - waveform = torch.as_tensor(waveform, device="cpu", dtype=torch.float32) - shape = waveform.size()[:-1] - - ratio = abs(volume_peak) / max(waveform.max(), abs(waveform.min())) - waveform = waveform * ratio - - if add_silence: - silence_len = sample_rate // 20 - silence = torch.zeros((*shape, silence_len), dtype=waveform.type()) - waveform = torch.concatenate((silence, waveform, silence), dim=-1) - - if waveform.dim() == 1: - waveform = waveform[None] - - await use_torchaudio_save(path, waveform, sample_rate) - - -def load_mel_extrema(cfg, dataset_name, split): - dataset_dir = os.path.join( - cfg.OUTPUT_PATH, - "preprocess/{}_version".format(cfg.data.process_version), - dataset_name, - ) - - min_file = os.path.join( - dataset_dir, - "mel_min_max", - split.split("_")[-1], - "mel_min.npy", - ) - max_file = os.path.join( - dataset_dir, - "mel_min_max", - split.split("_")[-1], - "mel_max.npy", - ) - mel_min = np.load(min_file) - mel_max = np.load(max_file) - return mel_min, mel_max diff --git a/models/tts/debatts/utils/io_optim.py b/models/tts/debatts/utils/io_optim.py deleted file mode 100644 index e9afaa06..00000000 --- a/models/tts/debatts/utils/io_optim.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torchaudio -import json -import os -import numpy as np -import librosa -import whisper -from torch.nn.utils.rnn import pad_sequence - - -class TorchaudioDataset(torch.utils.data.Dataset): - def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): - """ - Args: - cfg: config - dataset: dataset name - - """ - assert isinstance(dataset, str) - - self.sr = sr - self.cfg = cfg - - if metadata is None: - self.train_metadata_path = os.path.join( - cfg.preprocess.processed_dir, dataset, cfg.preprocess.train_file - ) - self.valid_metadata_path = os.path.join( - cfg.preprocess.processed_dir, dataset, cfg.preprocess.valid_file - ) - self.metadata = self.get_metadata() - else: - self.metadata = metadata - - if accelerator is not None: - self.device = accelerator.device - elif torch.cuda.is_available(): - self.device = torch.device("cuda") - else: - self.device = torch.device("cpu") - - def get_metadata(self): - metadata = [] - with open(self.train_metadata_path, "r", encoding="utf-8") as t: - metadata.extend(json.load(t)) - with open(self.valid_metadata_path, "r", encoding="utf-8") as v: - metadata.extend(json.load(v)) - return metadata - - def __len__(self): - return len(self.metadata) - - def __getitem__(self, index): - utt_info = self.metadata[index] - wav_path = utt_info["Path"] - - wav, sr = torchaudio.load(wav_path) - - # resample - if sr != self.sr: - wav = torchaudio.functional.resample(wav, sr, self.sr) - # downmixing - if wav.shape[0] > 1: - wav = torch.mean(wav, dim=0, keepdim=True) - assert wav.shape[0] == 1 - wav = wav.squeeze(0) - # record the length of wav without padding - length = wav.shape[0] - # wav: (T) - return utt_info, wav, length - - -class LibrosaDataset(TorchaudioDataset): - def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): - super().__init__(cfg, dataset, sr, accelerator, metadata) - - def __getitem__(self, index): - utt_info = self.metadata[index] - wav_path = utt_info["Path"] - - wav, _ = librosa.load(wav_path, sr=self.sr) - # wav: (T) - wav = torch.from_numpy(wav) - - # record the length of wav without padding - length = wav.shape[0] - return utt_info, wav, length - - -class FFmpegDataset(TorchaudioDataset): - def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): - super().__init__(cfg, dataset, sr, accelerator, metadata) - - def __getitem__(self, index): - utt_info = self.metadata[index] - wav_path = utt_info["Path"] - - # wav: (T,) - wav = whisper.load_audio(wav_path, sr=16000) # sr = 16000 - # convert to torch tensor - wav = torch.from_numpy(wav) - # record the length of wav without padding - length = wav.shape[0] - - return utt_info, wav, length - - -def collate_batch(batch_list): - """ - Args: - batch_list: list of (metadata, wav, length) - """ - metadata = [item[0] for item in batch_list] - # wavs: (B, T) - wavs = pad_sequence([item[1] for item in batch_list], batch_first=True) - lens = [item[2] for item in batch_list] - - return metadata, wavs, lens diff --git a/models/tts/debatts/utils/logger.py b/models/tts/debatts/utils/logger.py index 4d31167d..df811ebc 100644 --- a/models/tts/debatts/utils/logger.py +++ b/models/tts/debatts/utils/logger.py @@ -1,3 +1,8 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import logging import time import os diff --git a/models/tts/debatts/utils/mel.py b/models/tts/debatts/utils/mel.py deleted file mode 100644 index 3894b73c..00000000 --- a/models/tts/debatts/utils/mel.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from librosa.filters import mel as librosa_mel_fn - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - # Min value: ln(1e-5) = -11.5129 - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def spectral_normalize_torch(magnitudes): - output = dynamic_range_compression_torch(magnitudes) - return output - - -def extract_linear_features(y, cfg, center=False): - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) - - global hann_window - hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - # complex tensor as default, then use view_as_real for future pytorch compatibility - spec = torch.stft( - y, - cfg.n_fft, - hop_length=cfg.hop_size, - win_length=cfg.win_size, - window=hann_window[str(y.device)], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - spec = torch.view_as_real(spec) - spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) - spec = torch.squeeze(spec, 0) - return spec - - -def mel_spectrogram_torch(y, cfg, center=False): - """ - TODO: to merge this funtion with the extract_mel_features below - """ - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) - - global mel_basis, hann_window - if cfg.fmax not in mel_basis: - mel = librosa_mel_fn( - sr=cfg.sample_rate, - n_fft=cfg.n_fft, - n_mels=cfg.n_mel, - fmin=cfg.fmin, - fmax=cfg.fmax, - ) - mel_basis[str(cfg.fmax) + "_" + str(y.device)] = ( - torch.from_numpy(mel).float().to(y.device) - ) - hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.stft( - y, - cfg.n_fft, - hop_length=cfg.hop_size, - win_length=cfg.win_size, - window=hann_window[str(y.device)], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - - spec = torch.view_as_real(spec) - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - - spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) - spec = spectral_normalize_torch(spec) - - return spec - - -mel_basis = {} -hann_window = {} - - -def extract_mel_features( - y, - cfg, - center=False, -): - """Extract mel features - - Args: - y (tensor): audio data in tensor - cfg (dict): configuration in cfg.preprocess - center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False. - - Returns: - tensor: a tensor containing the mel feature calculated based on STFT result - """ - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) - - global mel_basis, hann_window - if cfg.fmax not in mel_basis: - mel = librosa_mel_fn( - sr=cfg.sample_rate, - n_fft=cfg.n_fft, - n_mels=cfg.n_mel, - fmin=cfg.fmin, - fmax=cfg.fmax, - ) - mel_basis[str(cfg.fmax) + "_" + str(y.device)] = ( - torch.from_numpy(mel).float().to(y.device) - ) - hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - # complex tensor as default, then use view_as_real for future pytorch compatibility - spec = torch.stft( - y, - cfg.n_fft, - hop_length=cfg.hop_size, - win_length=cfg.win_size, - window=hann_window[str(y.device)], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - spec = torch.view_as_real(spec) - spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) - - spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) - spec = spectral_normalize_torch(spec) - return spec.squeeze(0) - - -def extract_mel_features_tts( - y, - cfg, - center=False, - taco=False, - _stft=None, -): - """Extract mel features - - Args: - y (tensor): audio data in tensor - cfg (dict): configuration in cfg.preprocess - center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False. - taco: use tacotron mel - - Returns: - tensor: a tensor containing the mel feature calculated based on STFT result - """ - if not taco: - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) - - global mel_basis, hann_window - if cfg.fmax not in mel_basis: - mel = librosa_mel_fn( - sr=cfg.sample_rate, - n_fft=cfg.n_fft, - n_mels=cfg.n_mel, - fmin=cfg.fmin, - fmax=cfg.fmax, - ) - mel_basis[str(cfg.fmax) + "_" + str(y.device)] = ( - torch.from_numpy(mel).float().to(y.device) - ) - hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - # complex tensor as default, then use view_as_real for future pytorch compatibility - spec = torch.stft( - y, - cfg.n_fft, - hop_length=cfg.hop_size, - win_length=cfg.win_size, - window=hann_window[str(y.device)], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - spec = torch.view_as_real(spec) - spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) - - spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) - spec = spectral_normalize_torch(spec) - else: - audio = torch.clip(y, -1, 1) - audio = torch.autograd.Variable(audio, requires_grad=False) - spec, energy = _stft.mel_spectrogram(audio) - - return spec.squeeze(0) - - -def amplitude_phase_spectrum(y, cfg): - hann_window = torch.hann_window(cfg.win_size).to(y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - stft_spec = torch.stft( - y, - cfg.n_fft, - hop_length=cfg.hop_size, - win_length=cfg.win_size, - window=hann_window, - center=False, - return_complex=True, - ) - - stft_spec = torch.view_as_real(stft_spec) - if stft_spec.size()[0] == 1: - stft_spec = stft_spec.squeeze(0) - - if len(list(stft_spec.size())) == 4: - rea = stft_spec[:, :, :, 0] # [batch_size, n_fft//2+1, frames] - imag = stft_spec[:, :, :, 1] # [batch_size, n_fft//2+1, frames] - else: - rea = stft_spec[:, :, 0] # [n_fft//2+1, frames] - imag = stft_spec[:, :, 1] # [n_fft//2+1, frames] - - log_amplitude = torch.log( - torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5 - ) # [n_fft//2+1, frames] - phase = torch.atan2(imag, rea) # [n_fft//2+1, frames] - - return log_amplitude, phase, rea, imag diff --git a/models/tts/debatts/utils/mert.py b/models/tts/debatts/utils/mert.py deleted file mode 100644 index 4181429f..00000000 --- a/models/tts/debatts/utils/mert.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# This code is modified from https://huggingface.co/m-a-p/MERT-v1-330M - -import torch -from tqdm import tqdm -import numpy as np - -from transformers import Wav2Vec2FeatureExtractor -from transformers import AutoModel -import torchaudio -import torchaudio.transforms as T -from sklearn.preprocessing import StandardScaler - - -def mert_encoder(model, processor, audio_path, hps): - """ - # mert default sr: 24000 - """ - with torch.no_grad(): - resample_rate = processor.sampling_rate - device = next(model.parameters()).device - - input_audio, sampling_rate = torchaudio.load(audio_path) - input_audio = input_audio.squeeze() - - if sampling_rate != resample_rate: - resampler = T.Resample(sampling_rate, resample_rate) - input_audio = resampler(input_audio) - - inputs = processor( - input_audio, sampling_rate=resample_rate, return_tensors="pt" - ).to( - device - ) # {input_values: tensor, attention_mask: tensor} - - outputs = model(**inputs, output_hidden_states=True) # list: len is 25 - - # [25 layer, Time steps, 1024 feature_dim] - # all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() - # mert_features.append(all_layer_hidden_states) - - feature = outputs.hidden_states[ - hps.mert_feature_layer - ].squeeze() # [1, frame len, 1024] -> [frame len, 1024] - - return feature.cpu().detach().numpy() - - -def mert_features_normalization(raw_mert_features): - normalized_mert_features = list() - - mert_features = np.array(raw_mert_features) - scaler = StandardScaler().fit(mert_features) - for raw_mert_feature in raw_mert_feature: - normalized_mert_feature = scaler.transform(raw_mert_feature) - normalized_mert_features.append(normalized_mert_feature) - return normalized_mert_features - - -def get_mapped_mert_features(raw_mert_features, mapping_features, fast_mapping=True): - source_hop = 320 - target_hop = 256 - - factor = np.gcd(source_hop, target_hop) - source_hop //= factor - target_hop //= factor - print( - "Mapping source's {} frames => target's {} frames".format( - target_hop, source_hop - ) - ) - - mert_features = [] - for index, mapping_feat in enumerate(tqdm(mapping_features)): - # mapping_feat: (mels_frame_len, n_mels) - target_len = mapping_feat.shape[0] - - # (frame_len, 1024) - raw_feats = raw_mert_features[index].cpu().numpy() - source_len, width = raw_feats.shape - - # const ~= target_len * target_hop - const = source_len * source_hop // target_hop * target_hop - - # (source_len * source_hop, dim) - up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0) - # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) - down_sampling_feats = np.average( - up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 - ) - - err = abs(target_len - len(down_sampling_feats)) - if err > 3: - print("index:", index) - print("mels:", mapping_feat.shape) - print("raw mert vector:", raw_feats.shape) - print("up_sampling:", up_sampling_feats.shape) - print("const:", const) - print("down_sampling_feats:", down_sampling_feats.shape) - exit() - if len(down_sampling_feats) < target_len: - # (1, dim) -> (err, dim) - end = down_sampling_feats[-1][None, :].repeat(err, axis=0) - down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) - - # (target_len, dim) - feats = down_sampling_feats[:target_len] - mert_features.append(feats) - - return mert_features - - -def load_mert_model(hps): - print("Loading MERT Model: ", hps.mert_model) - - # Load model - model_name = hps.mert_model - model = AutoModel.from_pretrained(model_name, trust_remote_code=True) - - if torch.cuda.is_available(): - model = model.cuda() - - # model = model.eval() - - preprocessor = Wav2Vec2FeatureExtractor.from_pretrained( - model_name, trust_remote_code=True - ) - return model, preprocessor - - -# loading the corresponding preprocessor config -# def load_preprocessor (model_name="m-a-p/MERT-v1-330M"): -# print('load_preprocessor...') -# preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(model_name,trust_remote_code=True) -# return preprocessor diff --git a/models/tts/debatts/utils/mfa_prepare.py b/models/tts/debatts/utils/mfa_prepare.py deleted file mode 100644 index b79ba862..00000000 --- a/models/tts/debatts/utils/mfa_prepare.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -""" This code is modified from https://montreal-forced-aligner.readthedocs.io/en/latest/user_guide/performance.html""" - -import os -import subprocess -from multiprocessing import Pool -from tqdm import tqdm -import torchaudio -from pathlib import Path - - -def remove_empty_dirs(path): - """remove empty directories in a given path""" - # Check if the given path is a directory - if not os.path.isdir(path): - print(f"{path} is not a directory") - return - - # Walk through all directories and subdirectories - for root, dirs, _ in os.walk(path, topdown=False): - for dir in dirs: - dir_path = os.path.join(root, dir) - # Check if the directory is empty - if not os.listdir(dir_path): - os.rmdir(dir_path) # "Removed empty directory - - -def process_single_wav_file(task): - """process a single wav file""" - wav_file, output_dir = task - speaker_id, book_name, filename = Path(wav_file).parts[-3:] - - output_book_dir = Path(output_dir, speaker_id) - output_book_dir.mkdir(parents=True, exist_ok=True) - new_filename = f"{speaker_id}_{book_name}_{filename}" - - new_wav_file = Path(output_book_dir, new_filename) - command = [ - "ffmpeg", - "-nostdin", - "-hide_banner", - "-loglevel", - "error", - "-nostats", - "-i", - wav_file, - "-acodec", - "pcm_s16le", - "-ar", - "16000", - new_wav_file, - ] - subprocess.check_call( - command - ) # Run the command to convert the file to 16kHz and 16-bit PCM - os.remove(wav_file) - - -def process_wav_files(wav_files, output_dir, n_process): - """process wav files in parallel""" - tasks = [(wav_file, output_dir) for wav_file in wav_files] - print(f"Processing {len(tasks)} files") - with Pool(processes=n_process) as pool: - for _ in tqdm( - pool.imap_unordered(process_single_wav_file, tasks), total=len(tasks) - ): - pass - print("Removing empty directories...") - remove_empty_dirs(output_dir) - print("Done!") - - -def get_wav_files(dataset_path): - """get all wav files in the dataset""" - wav_files = [] - for speaker_id in os.listdir(dataset_path): - speaker_dir = os.path.join(dataset_path, speaker_id) - if not os.path.isdir(speaker_dir): - continue - for book_name in os.listdir(speaker_dir): - book_dir = os.path.join(speaker_dir, book_name) - if not os.path.isdir(book_dir): - continue - for file in os.listdir(book_dir): - if file.endswith(".wav"): - wav_files.append(os.path.join(book_dir, file)) - print("Found {} wav files".format(len(wav_files))) - return wav_files - - -def filter_wav_files_by_length(wav_files, max_len_sec=15): - """filter wav files by length""" - print("original wav files: {}".format(len(wav_files))) - filtered_wav_files = [] - for audio_file in wav_files: - metadata = torchaudio.info(str(audio_file)) - audio_length = metadata.num_frames / metadata.sample_rate - if audio_length <= max_len_sec: - filtered_wav_files.append(audio_file) - else: - os.remove(audio_file) - print("filtered wav files: {}".format(len(filtered_wav_files))) - return filtered_wav_files - - -if __name__ == "__main__": - dataset_path = "/path/to/output/directory" - n_process = 16 - max_len_sec = 15 - wav_files = get_wav_files(dataset_path) - filtered_wav_files = filter_wav_files_by_length(wav_files, max_len_sec) - process_wav_files(filtered_wav_files, dataset_path, n_process) diff --git a/models/tts/debatts/utils/model_summary.py b/models/tts/debatts/utils/model_summary.py deleted file mode 100644 index ec72b0d1..00000000 --- a/models/tts/debatts/utils/model_summary.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import humanfriendly -import numpy as np -import torch - - -def get_human_readable_count(number: int) -> str: - """Return human_readable_count - - Originated from: - https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py - - Abbreviates an integer number with K, M, B, T for thousands, millions, - billions and trillions, respectively. - Examples: - >>> get_human_readable_count(123) - '123 ' - >>> get_human_readable_count(1234) # (one thousand) - '1 K' - >>> get_human_readable_count(2e6) # (two million) - '2 M' - >>> get_human_readable_count(3e9) # (three billion) - '3 B' - >>> get_human_readable_count(4e12) # (four trillion) - '4 T' - >>> get_human_readable_count(5e15) # (more than trillion) - '5,000 T' - Args: - number: a positive integer number - Return: - A string formatted according to the pattern described above. - """ - assert number >= 0 - labels = [" ", "K", "M", "B", "T"] - num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) - num_groups = int(np.ceil(num_digits / 3)) - num_groups = min(num_groups, len(labels)) - shift = -3 * (num_groups - 1) - number = number * (10**shift) - index = num_groups - 1 - return f"{number:.2f} {labels[index]}" - - -def to_bytes(dtype) -> int: - return int(str(dtype)[-2:]) // 8 - - -def model_summary(model: torch.nn.Module) -> str: - message = "Model structure:\n" - message += str(model) - tot_params = sum(p.numel() for p in model.parameters()) - num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params) - tot_params = get_human_readable_count(tot_params) - num_params = get_human_readable_count(num_params) - message += "\n\nModel summary:\n" - message += f" Class Name: {model.__class__.__name__}\n" - message += f" Total Number of model parameters: {tot_params}\n" - message += ( - f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n" - ) - num_bytes = humanfriendly.format_size( - sum( - p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad - ) - ) - message += f" Size: {num_bytes}\n" - dtype = next(iter(model.parameters())).dtype - message += f" Type: {dtype}" - return message diff --git a/models/tts/debatts/utils/prompt_preparer.py b/models/tts/debatts/utils/prompt_preparer.py deleted file mode 100644 index 945a5e24..00000000 --- a/models/tts/debatts/utils/prompt_preparer.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch - - -class PromptPreparer: - def prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): - if self.prefix_mode == 0: - y_emb, prefix_len = self._handle_prefix_mode_0(y, codes, nar_stage) - elif self.prefix_mode == 1: - y_emb, prefix_len = self._handle_prefix_mode_1(y, y_lens, codes, nar_stage) - elif self.prefix_mode in [2, 4]: - y_emb, prefix_len = self._handle_prefix_mode_2_4( - y, y_lens, codes, nar_stage, y_prompts_codes - ) - else: - raise ValueError("Invalid prefix mode") - - return y_emb, prefix_len - - def _handle_prefix_mode_0(self, y, codes, nar_stage): - prefix_len = 0 - y_emb = self.nar_audio_embeddings[0](y) - for j in range(1, nar_stage): - y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) - return y_emb, 0 - - def _handle_prefix_mode_1(self, y, y_lens, codes, nar_stage): - int_low = (0.25 * y_lens.min()).type(torch.int64).item() - prefix_len = torch.randint(int_low, int_low * 2, size=()).item() - prefix_len = min(prefix_len, 225) - - y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) - y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) - for j in range(1, self.num_quantizers): - y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j]) - if j < nar_stage: - y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j]) - y_emb = torch.concat([y_prompts, y_emb], axis=1) - return y_emb, prefix_len - - def _handle_prefix_mode_2_4(self, y, y_lens, codes, nar_stage, y_prompts_codes): - if self.prefix_mode == 2: - prefix_len = min(225, int(0.25 * y_lens.min().item())) - - y_prompts_codes = [] - for b in range(codes.shape[0]): - start = self.rng.randint(0, y_lens[b].item() - prefix_len) - y_prompts_codes.append( - torch.clone(codes[b, start : start + prefix_len]) - ) - codes[b, start : start + prefix_len, nar_stage] = self.audio_token_num - y_prompts_codes = torch.stack(y_prompts_codes, dim=0) - else: - prefix_len = y_prompts_codes.shape[1] - - y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) - y_emb = self.nar_audio_embeddings[0](y) - for j in range(1, self.num_quantizers): - y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j]) - if j < nar_stage: - y_emb += self.nar_audio_embeddings[j](codes[..., j]) - y_emb = torch.concat([y_prompts, y_emb], axis=1) - - return y_emb, prefix_len diff --git a/models/tts/debatts/utils/ssim.py b/models/tts/debatts/utils/ssim.py index a0b95007..904aa976 100644 --- a/models/tts/debatts/utils/ssim.py +++ b/models/tts/debatts/utils/ssim.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Amphion. +# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/stft.py b/models/tts/debatts/utils/stft.py index bcec4c84..ce74eeec 100644 --- a/models/tts/debatts/utils/stft.py +++ b/models/tts/debatts/utils/stft.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Amphion. +# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/symbol_table.py b/models/tts/debatts/utils/symbol_table.py index a0e736fe..ef68b94c 100644 --- a/models/tts/debatts/utils/symbol_table.py +++ b/models/tts/debatts/utils/symbol_table.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Amphion. +# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/tokenizer.py b/models/tts/debatts/utils/tokenizer.py index 7eeef586..1414dc2e 100644 --- a/models/tts/debatts/utils/tokenizer.py +++ b/models/tts/debatts/utils/tokenizer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Amphion. +# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/tool.py b/models/tts/debatts/utils/tool.py index 8af5e9ba..ac680ea4 100644 --- a/models/tts/debatts/utils/tool.py +++ b/models/tts/debatts/utils/tool.py @@ -1,3 +1,8 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import json import os import io diff --git a/models/tts/debatts/utils/topk_sampling.py b/models/tts/debatts/utils/topk_sampling.py index 236a0f93..0a15fd33 100644 --- a/models/tts/debatts/utils/topk_sampling.py +++ b/models/tts/debatts/utils/topk_sampling.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Amphion. +# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/trainer_utils.py b/models/tts/debatts/utils/trainer_utils.py index e5d9ad79..f6269631 100644 --- a/models/tts/debatts/utils/trainer_utils.py +++ b/models/tts/debatts/utils/trainer_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Amphion. +# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/util.py b/models/tts/debatts/utils/util.py index b7eaf1aa..b497ed2a 100644 --- a/models/tts/debatts/utils/util.py +++ b/models/tts/debatts/utils/util.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Amphion. +# Copyright (c) 2024 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/whisper_transcription.py b/models/tts/debatts/utils/whisper_transcription.py deleted file mode 100644 index 98126987..00000000 --- a/models/tts/debatts/utils/whisper_transcription.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import os -import pathlib -import string -import time -from multiprocessing import Pool, Value, Lock -from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor -import torch -import whisper - -processed_files_count = Value("i", 0) # count of processed files -lock = Lock() # lock for the count - - -def preprocess_text(text): - """Preprocess text after ASR""" - return text.lower().translate(str.maketrans("", "", string.punctuation)) - - -def transcribe_audio(model, processor, audio_file, device): - """Transcribe audio file""" - audio = whisper.load_audio(audio_file) # load from path - audio = whisper.pad_or_trim(audio) # default 30 seconds - inputs = whisper.log_mel_spectrogram(audio).to( - device=device - ) # convert to spectrogram - inputs = inputs.unsqueeze(0).type(torch.cuda.HalfTensor) # add batch dimension - - outputs = model.generate( - inputs=inputs, max_new_tokens=128 - ) # generate transcription - transcription = processor.batch_decode(outputs, skip_special_tokens=True)[ - 0 - ] # decode - transcription_processed = preprocess_text(transcription) # preprocess - return transcription_processed - - -def write_transcription(audio_file, transcription): - """Write transcription to txt file""" - txt_file = audio_file.with_suffix(".txt") - with open(txt_file, "w") as file: - file.write(transcription) - - -def init_whisper(model_id, device): - """Initialize whisper model and processor""" - torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 - print(f"Loading model {model_id}") # model_id = "distil-whisper/distil-large-v2" - distil_model = AutoModelForSpeechSeq2Seq.from_pretrained( - model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=False - ) - distil_model = distil_model.to(device) - processor = AutoProcessor.from_pretrained(model_id) - return distil_model, processor - - -def asr_wav_files(file_list, gpu_id, total_files, model_id): - """Transcribe wav files in a list""" - device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu" - whisper_model, processor = init_whisper(model_id, device) - print(f"Processing on {device} starts") - start_time = time.time() - for audio_file in file_list: - try: - transcription = transcribe_audio( - whisper_model, processor, audio_file, device - ) - write_transcription(audio_file, transcription) - with lock: - processed_files_count.value += 1 - if processed_files_count.value % 5 == 0: - current_time = time.time() - avg_time_per_file = (current_time - start_time) / ( - processed_files_count.value - ) - remaining_files = total_files - processed_files_count.value - estimated_time_remaining = avg_time_per_file * remaining_files - remaining_time_formatted = time.strftime( - "%H:%M:%S", time.gmtime(estimated_time_remaining) - ) - print( - f"Processed {processed_files_count.value}/{total_files} files, time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}, Estimated time remaining: {remaining_time_formatted}" - ) - except Exception as e: - print(f"Error processing file {audio_file}: {e}") - - -def asr_main(input_dir, num_gpus, model_id): - """Transcribe wav files in a directory""" - num_processes = min(num_gpus, os.cpu_count()) - print(f"Using {num_processes} GPUs for transcription") - wav_files = list(pathlib.Path(input_dir).rglob("*.wav")) - total_files = len(wav_files) - print(f"Found {total_files} wav files in {input_dir}") - files_per_process = len(wav_files) // num_processes - print(f"Processing {files_per_process} files per process") - with Pool(num_processes) as p: - p.starmap( - asr_wav_files, - [ - ( - wav_files[i * files_per_process : (i + 1) * files_per_process], - i % num_gpus, - total_files, - model_id, - ) - for i in range(num_processes) - ], - ) - print("Done!") - - -if __name__ == "__main__": - input_dir = "/path/to/output/directory" - num_gpus = 2 - model_id = "distil-whisper/distil-large-v2" - asr_main(input_dir, num_gpus, model_id) diff --git a/models/tts/debatts/utils/world.py b/models/tts/debatts/utils/world.py deleted file mode 100644 index ce5f61bd..00000000 --- a/models/tts/debatts/utils/world.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2023 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# 1. Extract WORLD features including F0, AP, SP -# 2. Transform between SP and MCEP -import torchaudio -import pyworld as pw -import numpy as np -import torch -import diffsptk -import os -from tqdm import tqdm -import pickle -import torchaudio - - -def get_mcep_params(fs): - """Hyperparameters of transformation between SP and MCEP - - Reference: - https://github.com/CSTR-Edinburgh/merlin/blob/master/misc/scripts/vocoder/world_v2/copy_synthesis.sh - - """ - if fs in [44100, 48000]: - fft_size = 2048 - alpha = 0.77 - if fs in [16000]: - fft_size = 1024 - alpha = 0.58 - return fft_size, alpha - - -def extract_world_features(waveform, frameshift=10): - # waveform: (1, seq) - # x: (seq,) - x = np.array(waveform, dtype=np.double) - - _f0, t = pw.dio(x, fs, frame_period=frameshift) # raw pitch extractor - f0 = pw.stonemask(x, _f0, t, fs) # pitch refinement - sp = pw.cheaptrick(x, f0, t, fs) # extract smoothed spectrogram - ap = pw.d4c(x, f0, t, fs) # extract aperiodicity - - return f0, sp, ap, fs - - -def sp2mcep(x, mcsize, fs): - fft_size, alpha = get_mcep_params(fs) - x = torch.as_tensor(x, dtype=torch.float) - - tmp = diffsptk.ScalarOperation("SquareRoot")(x) - tmp = diffsptk.ScalarOperation("Multiplication", 32768.0)(tmp) - mgc = diffsptk.MelCepstralAnalysis( - cep_order=mcsize - 1, fft_length=fft_size, alpha=alpha, n_iter=1 - )(tmp) - return mgc.numpy() - - -def mcep2sp(x, mcsize, fs): - fft_size, alpha = get_mcep_params(fs) - x = torch.as_tensor(x, dtype=torch.float) - - tmp = diffsptk.MelGeneralizedCepstrumToSpectrum( - alpha=alpha, - cep_order=mcsize - 1, - fft_length=fft_size, - )(x) - tmp = diffsptk.ScalarOperation("Division", 32768.0)(tmp) - sp = diffsptk.ScalarOperation("Power", 2)(tmp) - return sp.double().numpy() - - -def f0_statistics(f0_features, path): - print("\nF0 statistics...") - - total_f0 = [] - for f0 in tqdm(f0_features): - total_f0 += [f for f in f0 if f != 0] - - mean = sum(total_f0) / len(total_f0) - print("Min = {}, Max = {}, Mean = {}".format(min(total_f0), max(total_f0), mean)) - - with open(path, "wb") as f: - pickle.dump([mean, total_f0], f) - - -def world_synthesis(f0, sp, ap, fs, frameshift): - y = pw.synthesize( - f0, sp, ap, fs, frame_period=frameshift - ) # synthesize an utterance using the parameters - return y From 167321aead7dd4b9d293bbf74c7618fa1cdb9322 Mon Sep 17 00:00:00 2001 From: hehaorui Date: Mon, 28 Oct 2024 17:43:40 +0800 Subject: [PATCH 5/8] debatts code modified by black --- models/tts/debatts/t2s_model_new.py | 192 +++++--- models/tts/debatts/t2s_sft_dataset_new.py | 131 ++++-- .../debatts/try_inference_small_samples.py | 412 +++++++++++++----- models/tts/debatts/utils/g2p_new/__init__.py | 101 ++--- models/tts/debatts/utils/g2p_new/cleaners.py | 27 +- models/tts/debatts/utils/g2p_new/english.py | 163 ++++--- models/tts/debatts/utils/g2p_new/french.py | 66 +-- models/tts/debatts/utils/g2p_new/g2p_new.py | 3 +- models/tts/debatts/utils/g2p_new/german.py | 64 +-- models/tts/debatts/utils/g2p_new/japanese.py | 127 +++--- models/tts/debatts/utils/g2p_new/korean.py | 136 +++--- models/tts/debatts/utils/g2p_new/mandarin.py | 243 ++++++----- .../debatts/utils/g2p_new/text_tokenizers.py | 21 +- models/tts/debatts/utils/topk_sampling.py | 26 +- 14 files changed, 1081 insertions(+), 631 deletions(-) diff --git a/models/tts/debatts/t2s_model_new.py b/models/tts/debatts/t2s_model_new.py index ab20e468..f00ecaf7 100644 --- a/models/tts/debatts/t2s_model_new.py +++ b/models/tts/debatts/t2s_model_new.py @@ -14,9 +14,11 @@ import torch.nn as nn import tqdm from einops import rearrange -os.chdir('./models/tts/debatts') + +os.chdir("./models/tts/debatts") import sys -sys.path.append('./models/tts/debatts') + +sys.path.append("./models/tts/debatts") from utils.topk_sampling import top_k_top_p_filtering import pickle @@ -113,7 +115,6 @@ def __init__( else eos_prompt0_id ) - self.config = LlamaConfig( vocab_size=phone_vocab_size + target_vocab_size + 20, hidden_size=hidden_size, @@ -124,7 +125,7 @@ def __init__( bos_token_id=bos_target_id, eos_token_id=eos_target_id, bos_prompt0_id=bos_prompt0_id, - eos_prompt0_id=eos_prompt0_id + eos_prompt0_id=eos_prompt0_id, ) self.phone_vocab_size = phone_vocab_size self.target_vocab_size = target_vocab_size @@ -145,15 +146,24 @@ def __init__( torch.nn.init.normal_(self.lang_emb.weight, mean=0.0, std=0.02) def forward( - self, prompt0_ids, prompt0_mask, phone_ids, phone_mask, target_ids, target_mask, lang_id=None, + self, + prompt0_ids, + prompt0_mask, + phone_ids, + phone_mask, + target_ids, + target_mask, + lang_id=None, ): - prompt0_ids, prompt0_mask, prompt0_label, prompt0_lang_mask = self.add_phone_eos_bos_label( - prompt0_ids, - prompt0_mask, - self.eos_prompt0_id, - self.bos_prompt0_id, - self.pad_token_id, - label="prompt0_id" + prompt0_ids, prompt0_mask, prompt0_label, prompt0_lang_mask = ( + self.add_phone_eos_bos_label( + prompt0_ids, + prompt0_mask, + self.eos_prompt0_id, + self.bos_prompt0_id, + self.pad_token_id, + label="prompt0_id", + ) ) phone_ids, phone_mask, phone_label, lang_mask = self.add_phone_eos_bos_label( phone_ids, @@ -161,7 +171,7 @@ def forward( self.eos_phone_id, self.bos_phone_id, self.pad_token_id, - label="phone_id" + label="phone_id", ) target_ids, target_mask, target_label = self.add_target_eos_bos_label( target_ids, @@ -178,9 +188,15 @@ def forward( # lang_id: (B,); lang_mask: (B, T) if self.use_lang_emb: - lang_embedding = self.lang_emb(lang_id).unsqueeze(1) # (B, d) -> (B, 1, d) - lang_embedding = lang_embedding * torch.cat([prompt0_lang_mask, lang_mask, torch.zeros_like(target_mask)], dim=-1).unsqueeze(-1) # (B, T, d) - input_token_embedding = self.model.model.embed_tokens(input_token_ids) # (B, T, d) + lang_embedding = self.lang_emb(lang_id).unsqueeze(1) # (B, d) -> (B, 1, d) + lang_embedding = lang_embedding * torch.cat( + [prompt0_lang_mask, lang_mask, torch.zeros_like(target_mask)], dim=-1 + ).unsqueeze( + -1 + ) # (B, T, d) + input_token_embedding = self.model.model.embed_tokens( + input_token_ids + ) # (B, T, d) inputs_embeds = input_token_embedding + lang_embedding out = self.model( @@ -233,12 +249,16 @@ def add_phone_eos_bos_label( """ phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad( 1 - phone_mask, (0, 1), value=1 - ) # make pad token eos token, add eos token at the end - phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask - phone_ids = phone_ids * phone_mask + pad_token_id * (1 - phone_mask) # restore pad token ids - phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token - phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask - phone_label = -100 * torch.ones_like(phone_ids) # loss for entire phone is not computed (passed to llama) + ) # make pad token eos token, add eos token at the end + phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask + phone_ids = phone_ids * phone_mask + pad_token_id * ( + 1 - phone_mask + ) # restore pad token ids + phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token + phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask + phone_label = -100 * torch.ones_like( + phone_ids + ) # loss for entire phone is not computed (passed to llama) return phone_ids, phone_mask, phone_label, lang_mask def add_target_eos_bos_label( @@ -254,9 +274,11 @@ def add_target_eos_bos_label( target_ids = target_ids * target_mask + pad_token_id * (1 - target_mask) target_ids = F.pad(target_ids, (1, 0), value=target_bos_id) target_mask = F.pad(target_mask, (1, 0), value=1) - target_label = target_ids * target_mask + (-100) * (1 - target_mask) # loss for target is computed on unmasked tokens + target_label = target_ids * target_mask + (-100) * ( + 1 - target_mask + ) # loss for target is computed on unmasked tokens return target_ids, target_mask, target_label - + def add_phone_middle_label( self, prompt0_ids, prompt0_mask, eos_prompt0_id, pad_token_id ): @@ -267,19 +289,28 @@ def add_phone_middle_label( prompt0_ids = F.pad(prompt0_ids, (0, 1), value=0) + eos_prompt0_id * F.pad( 1 - prompt0_mask, (0, 1), value=1 ) # Add eos_prompt0_id at the positions transitioning to padding - prompt0_mask = F.pad(prompt0_mask, (1, 0), value=1) # Pad the mask for the new eos_prompt0_id - prompt0_ids = prompt0_ids * prompt0_mask + pad_token_id * (1 - prompt0_mask) # Restore pad tokens - prompt0_ids = F.pad(prompt0_ids, (1, 0), value=eos_prompt0_id) # Add eos_prompt0_id at the beginning - prompt0_mask = F.pad(prompt0_mask, (1, 0), value=1) # Adjust the mask for the added eos_prompt0_id - prompt0_label = prompt0_ids * prompt0_mask + (-100) * (1 - prompt0_mask) # Set up labels for loss computation + prompt0_mask = F.pad( + prompt0_mask, (1, 0), value=1 + ) # Pad the mask for the new eos_prompt0_id + prompt0_ids = prompt0_ids * prompt0_mask + pad_token_id * ( + 1 - prompt0_mask + ) # Restore pad tokens + prompt0_ids = F.pad( + prompt0_ids, (1, 0), value=eos_prompt0_id + ) # Add eos_prompt0_id at the beginning + prompt0_mask = F.pad( + prompt0_mask, (1, 0), value=1 + ) # Adjust the mask for the added eos_prompt0_id + prompt0_label = prompt0_ids * prompt0_mask + (-100) * ( + 1 - prompt0_mask + ) # Set up labels for loss computation return prompt0_ids, prompt0_mask, prompt0_label - @torch.no_grad() def sample_hf( self, - phone_ids, # the phones of prompt and target should be concatenated together。在实际使用中,phone_ids是文本的token输入 + phone_ids, # the phones of prompt and target should be concatenated together。在实际使用中,phone_ids是文本的token输入 prompt_ids, prompt0_ids=None, max_length=100000, @@ -287,12 +318,12 @@ def sample_hf( top_k=30, top_p=0.7, repeat_penalty=3.5, - lang_ids=None + lang_ids=None, ): if prompt0_ids is not None: phone_mask = torch.ones_like(phone_ids) prompt_mask = torch.ones_like(prompt_ids) - + prompt_mask_prompt0 = torch.ones_like(prompt0_ids) # downsample = DownsampleWithMask(downsample_factor=2) @@ -304,7 +335,7 @@ def sample_hf( self.eos_phone_id, self.bos_phone_id, self.pad_token_id, - label="phone_id" + label="phone_id", ) prompt_ids, _, _ = self.add_target_eos_bos_label( prompt_ids, @@ -313,8 +344,8 @@ def sample_hf( self.bos_target_id, self.pad_token_id, ) - prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode - + prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode + prompt0_ids, _, _ = self.add_target_eos_bos_label( prompt0_ids, prompt_mask_prompt0, @@ -322,20 +353,37 @@ def sample_hf( self.bos_prompt0_id, self.pad_token_id, ) - + input_token_ids = torch.cat([prompt0_ids, phone_ids, prompt_ids], dim=-1) input_length = input_token_ids.shape[1] if lang_ids != None and self.use_lang_emb: lang_ids = F.pad(F.pad(lang_ids, (1, 0), value=0), (0, 1), value=0) - input_token_embedding = self.model.model.embed_tokens(input_token_ids) # (B, T, d) + input_token_embedding = self.model.model.embed_tokens( + input_token_ids + ) # (B, T, d) # lang_ids: [1,1,1,1,1,1,2,2,2,2] which means ['en','en','en','en','en','en','zh','zh','zh','zh'] lang_mask = torch.ones_like(phone_ids) - lang_mask[:,0] = 0 - lang_mask[:,-1] = 0 - lang_embedding = torch.cat([self.lang_emb(lang_ids), self.lang_emb(lang_ids), torch.zeros(lang_ids.shape[0], input_token_ids.shape[1] - lang_ids.shape[1], self.hidden_size).to(input_token_ids.device)], dim=1) * torch.cat([lang_mask, torch.zeros_like(prompt_ids)], dim=-1).unsqueeze(-1) - + lang_mask[:, 0] = 0 + lang_mask[:, -1] = 0 + lang_embedding = torch.cat( + [ + self.lang_emb(lang_ids), + self.lang_emb(lang_ids), + torch.zeros( + lang_ids.shape[0], + input_token_ids.shape[1] - lang_ids.shape[1], + self.hidden_size, + ).to(input_token_ids.device), + ], + dim=1, + ) * torch.cat( + [lang_mask, torch.zeros_like(prompt_ids)], dim=-1 + ).unsqueeze( + -1 + ) + inputs_embeds = input_token_embedding + lang_embedding # if prosody_features is not None: @@ -356,11 +404,11 @@ def sample_hf( repetition_penalty=repeat_penalty, min_new_tokens=50, ) - gen_tokens = generated_ids[:,:-1] + gen_tokens = generated_ids[:, :-1] else: input_token_embedding = self.model.model.embed_tokens(input_token_ids) - + generated_ids = self.model.generate( input_token_ids, do_sample=True, @@ -376,7 +424,7 @@ def sample_hf( gen_tokens = generated_ids[:, input_length:-1] return gen_tokens - + else: phone_mask = torch.ones_like(phone_ids) prompt_mask = torch.ones_like(prompt_ids) @@ -386,7 +434,7 @@ def sample_hf( self.eos_phone_id, self.bos_phone_id, self.pad_token_id, - label="phone_ids" + label="phone_ids", ) prompt_ids, _, _ = self.add_target_eos_bos_label( prompt_ids, @@ -395,22 +443,37 @@ def sample_hf( self.bos_target_id, self.pad_token_id, ) - prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode - + prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode + input_token_ids = torch.cat([phone_ids, prompt_ids], dim=-1) input_length = input_token_ids.shape[1] if lang_ids != None and self.use_lang_emb: lang_ids = F.pad(F.pad(lang_ids, (1, 0), value=0), (0, 1), value=0) # token to vector - input_token_embedding = self.model.model.embed_tokens(input_token_ids) # (B, T, d) + input_token_embedding = self.model.model.embed_tokens( + input_token_ids + ) # (B, T, d) # lang_ids: [1,1,1,1,1,1,2,2,2,2] which means ['en','en','en','en','en','en','zh','zh','zh','zh'] lang_mask = torch.ones_like(phone_ids) - lang_mask[:,0] = 0 - lang_mask[:,-1] = 0 - lang_embedding = torch.cat([self.lang_emb(lang_ids), torch.zeros(lang_ids.shape[0], input_token_ids.shape[1] - lang_ids.shape[1], self.hidden_size).to(input_token_ids.device)], dim=1) * torch.cat([lang_mask, torch.zeros_like(prompt_ids)], dim=-1).unsqueeze(-1) - - + lang_mask[:, 0] = 0 + lang_mask[:, -1] = 0 + lang_embedding = torch.cat( + [ + self.lang_emb(lang_ids), + torch.zeros( + lang_ids.shape[0], + input_token_ids.shape[1] - lang_ids.shape[1], + self.hidden_size, + ).to(input_token_ids.device), + ], + dim=1, + ) * torch.cat( + [lang_mask, torch.zeros_like(prompt_ids)], dim=-1 + ).unsqueeze( + -1 + ) + inputs_embeds = input_token_embedding + lang_embedding generated_ids = self.model.generate( @@ -427,13 +490,13 @@ def sample_hf( min_new_tokens=50, ) # assert generated_ids.size(1) > input_length, f"Generated tokens length {generated_ids.size(1)} is less than input length {input_length}, generated ids is {generated_ids}" - gen_tokens = generated_ids[:,:-1] + gen_tokens = generated_ids[:, :-1] else: input_token_embedding = self.model.model.embed_tokens(input_token_ids) # if prosody_features is not None: - # + # # prosody_features = prosody_features.unsqueeze(1).expand(-1, input_token_embedding.size(1), -1) # inputs_embeds = input_token_embedding + prosody_features # generated_ids = self.model.generate( @@ -448,7 +511,7 @@ def sample_hf( top_k=top_k, top_p=top_p, repetition_penalty=repeat_penalty, - min_new_tokens=50 + min_new_tokens=50, ) # assert generated_ids.size(1) > input_length, f"Generated tokens length {generated_ids.size(1)} is less than input length {input_length}, generated ids is {generated_ids}" @@ -464,15 +527,20 @@ def __init__(self, downsample_factor=2): def forward(self, x, mask): # x shape: (batch_size, seq_len) # mask shape: (batch_size, seq_len) - + x = x.float() x = x.unsqueeze(1) # add channel dimension: (batch_size, 1, seq_len) - x = F.avg_pool1d(x, kernel_size=self.downsample_factor, stride=self.downsample_factor) - x = x.squeeze(1) # remove channel dimension: (batch_size, seq_len // downsample_factor) + x = F.avg_pool1d( + x, kernel_size=self.downsample_factor, stride=self.downsample_factor + ) + x = x.squeeze( + 1 + ) # remove channel dimension: (batch_size, seq_len // downsample_factor) x = x.long() # average pooling mask = mask.float() # convert mask to float for pooling mask = mask.unsqueeze(1) # add channel dimension: (batch_size, 1, seq_len) - mask = F.avg_pool1d(mask, kernel_size=self.downsample_factor, stride=self.downsample_factor) - \ No newline at end of file + mask = F.avg_pool1d( + mask, kernel_size=self.downsample_factor, stride=self.downsample_factor + ) diff --git a/models/tts/debatts/t2s_sft_dataset_new.py b/models/tts/debatts/t2s_sft_dataset_new.py index e705fb18..11328e7c 100644 --- a/models/tts/debatts/t2s_sft_dataset_new.py +++ b/models/tts/debatts/t2s_sft_dataset_new.py @@ -23,9 +23,11 @@ from pathlib import Path from transformers import SeamlessM4TFeatureExtractor from transformers import Wav2Vec2BertModel -os.chdir('./models/tts/debatts') + +os.chdir("./models/tts/debatts") import sys -sys.path.append('./models/tts/debatts') + +sys.path.append("./models/tts/debatts") from utils.g2p_new.g2p import phonemizer_g2p from utils.g2p_new.g2p_new import new_g2p from torch.nn.utils.rnn import pad_sequence @@ -44,6 +46,7 @@ def filter(self, record): return False return True + filter = WarningFilter() logging.getLogger("phonemizer").addFilter(filter) logging.getLogger("qcloud_cos.cos_client").addFilter(filter) @@ -51,13 +54,14 @@ def filter(self, record): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + class T2SDataset(torch.utils.data.Dataset): def __init__( self, cfg=None, ): self.cfg = cfg - + self.meta_info_path = "Debatts-Data Summary Json" with open(self.meta_info_path, "r") as f: self.meta_info_data = json.load(f) @@ -82,7 +86,7 @@ def __init__( self.wav_path_index2spkid.append(info["speaker_id"]) self.wav_path_index2phoneid.append(info["phone_id"]) self.index2num_frames.append(info["duration"] * 50 + len(info["phone_id"])) - lang_id = self.lang2id[info['language']] + lang_id = self.lang2id[info["language"]] self.index2lang.append(lang_id) # self.index2num_frames.append(info["duration"] * self.cfg.preprocess.sample_rate) @@ -94,9 +98,7 @@ def __init__( ) ) - self.processor = SeamlessM4TFeatureExtractor.from_pretrained( - "./w2v-bert-2" - ) + self.processor = SeamlessM4TFeatureExtractor.from_pretrained("./w2v-bert-2") def new_g2p(self, text, language): return new_g2p(text, language) @@ -105,12 +107,23 @@ def __len__(self): return self.wav_paths.__len__() def get_num_frames(self, index): - return self.wav_path_index2duration[index] * 50 + self.wav_path_index2phonelen[index] + return ( + self.wav_path_index2duration[index] * 50 + + self.wav_path_index2phonelen[index] + ) def __getitem__(self, idx): wav_path = self.wav_paths[idx] speech, sr = librosa.load(wav_path, sr=self.cfg.preprocess.sample_rate) - speech = np.pad(speech, (0, self.cfg.preprocess.hop_size - len(speech) % self.cfg.preprocess.hop_size), mode="constant") + speech = np.pad( + speech, + ( + 0, + self.cfg.preprocess.hop_size + - len(speech) % self.cfg.preprocess.hop_size, + ), + mode="constant", + ) # resample the speech to 16k for feature extraction if self.cfg.preprocess.sample_rate != 16000: speech_16k = librosa.resample( @@ -124,8 +137,18 @@ def __getitem__(self, idx): attention_mask = inputs["attention_mask"][0] prompt0_wav_path = self.prompt0_paths[idx] # Get prompt0 path - speech_prompt0, sr_prompt0 = librosa.load(prompt0_wav_path, sr=self.cfg.preprocess.sample_rate) - speech_prompt0 = np.pad(speech_prompt0, (0, self.cfg.preprocess.hop_size - len(speech_prompt0) % self.cfg.preprocess.hop_size), mode="constant") + speech_prompt0, sr_prompt0 = librosa.load( + prompt0_wav_path, sr=self.cfg.preprocess.sample_rate + ) + speech_prompt0 = np.pad( + speech_prompt0, + ( + 0, + self.cfg.preprocess.hop_size + - len(speech_prompt0) % self.cfg.preprocess.hop_size, + ), + mode="constant", + ) # resample the speech to 16k for feature extraction if self.cfg.preprocess.sample_rate != 16000: speech_16k_prompt0 = librosa.resample( @@ -133,7 +156,7 @@ def __getitem__(self, idx): ) else: speech_16k_prompt0 = speech_prompt0 - + inputs_prompt0 = self.processor(speech_16k_prompt0, sampling_rate=16000) input_features_prompt0 = inputs_prompt0["input_features"][0] @@ -157,8 +180,8 @@ def __getitem__(self, idx): spk_id = self.wav_path_index2spkid[idx] - single_feature.update({"spk_id": spk_id}) - single_feature.update({"lang_id": lang_id}) + single_feature.update({"spk_id": spk_id}) + single_feature.update({"lang_id": lang_id}) single_feature.update({"phone_id": phone_id}) single_feature.update({"phone_mask": phone_mask}) @@ -170,12 +193,11 @@ def __getitem__(self, idx): "mask": mask, "input_features_prompt0": input_features_prompt0, "attention_mask_prompt0": attention_mask_prompt0, - "mask_prompt0":mask_prompt0 + "mask_prompt0": mask_prompt0, } ) return single_feature - class T2SCollator(object): @@ -188,36 +210,70 @@ def __call__(self, batch): for key in batch[0].keys(): if "input_features" in key: packed_batch_features[key] = pad_sequence( - [utt[key].float() if isinstance(utt[key], torch.Tensor) else torch.tensor(utt[key]).float() for utt in batch], - batch_first=True + [ + ( + utt[key].float() + if isinstance(utt[key], torch.Tensor) + else torch.tensor(utt[key]).float() + ) + for utt in batch + ], + batch_first=True, ) if "attention_mask" in key: packed_batch_features[key] = pad_sequence( - [utt[key].float() if isinstance(utt[key], torch.Tensor) else torch.tensor(utt[key]).float() for utt in batch], - batch_first=True + [ + ( + utt[key].float() + if isinstance(utt[key], torch.Tensor) + else torch.tensor(utt[key]).float() + ) + for utt in batch + ], + batch_first=True, ) if "mask" in key: packed_batch_features[key] = pad_sequence( - [utt[key].long() if isinstance(utt[key], torch.Tensor) else torch.tensor(utt[key]).long() for utt in batch], - batch_first=True + [ + ( + utt[key].long() + if isinstance(utt[key], torch.Tensor) + else torch.tensor(utt[key]).long() + ) + for utt in batch + ], + batch_first=True, ) if "semantic_code" in key: packed_batch_features[key] = pad_sequence( - [utt[key].float() if isinstance(utt[key], torch.Tensor) else torch.tensor(utt[key]).float() for utt in batch], - batch_first=True + [ + ( + utt[key].float() + if isinstance(utt[key], torch.Tensor) + else torch.tensor(utt[key]).float() + ) + for utt in batch + ], + batch_first=True, ) if key == "phone_id": packed_batch_features[key] = pad_sequence( - [utt[key].long() for utt in batch], batch_first=True, padding_value=1023, # phone vocab size is 1024 + [utt[key].long() for utt in batch], + batch_first=True, + padding_value=1023, # phone vocab size is 1024 ) if key == "phone_mask": packed_batch_features[key] = pad_sequence( [torch.tensor(utt[key]).long() for utt in batch], batch_first=True ) if key == "lang_id": - packed_batch_features[key] = torch.tensor([utt[key] for utt in batch]).long() + packed_batch_features[key] = torch.tensor( + [utt[key] for utt in batch] + ).long() if key == "spk_id": - packed_batch_features[key] = torch.tensor([utt[key] for utt in batch]).long() + packed_batch_features[key] = torch.tensor( + [utt[key] for utt in batch] + ).long() if key == "spk_emb_input_features": packed_batch_features[key] = pad_sequence( [torch.tensor(utt[key]).float() for utt in batch], batch_first=True @@ -232,7 +288,6 @@ def __call__(self, batch): return packed_batch_features - class DownsampleWithMask(nn.Module): def __init__(self, downsample_factor=2): super(DownsampleWithMask, self).__init__() @@ -250,24 +305,32 @@ def forward(self, x, mask): x = x.float() x = x.permute(1, 0) # to (feature_dim, timestep) x = x.unsqueeze(1) # add channel dimension: (timestep, 1, feature_dim) - + if x.size(-1) < self.downsample_factor: raise ValueError("Input size must be larger than downsample factor") # print(f"################## x size before {x.shape}################################") x = F.avg_pool1d(x, kernel_size=self.downsample_factor) - x = x.squeeze(1) # remove channel dimension: (timestep, feature_dim // downsample_factor) + x = x.squeeze( + 1 + ) # remove channel dimension: (timestep, feature_dim // downsample_factor) x = x.long() x = x.permute(1, 0) # to (feature_dim, timestep) mask = mask.float() # convert mask to float for pooling - mask = mask.unsqueeze(0).unsqueeze(0) # add channel dimension: (timestep, 1, feature_dim) - + mask = mask.unsqueeze(0).unsqueeze( + 0 + ) # add channel dimension: (timestep, 1, feature_dim) + if mask.size(-1) < self.downsample_factor: raise ValueError("Mask size must be larger than downsample factor") - mask = F.avg_pool1d(mask, kernel_size=self.downsample_factor, stride=self.downsample_factor) - mask = mask.squeeze(0).squeeze(0) # remove channel dimension: (timestep, feature_dim // downsample_factor) + mask = F.avg_pool1d( + mask, kernel_size=self.downsample_factor, stride=self.downsample_factor + ) + mask = mask.squeeze(0).squeeze( + 0 + ) # remove channel dimension: (timestep, feature_dim // downsample_factor) mask = (mask >= 0.5).long() # if average > 0.5 --> 1, else 0 return x, mask diff --git a/models/tts/debatts/try_inference_small_samples.py b/models/tts/debatts/try_inference_small_samples.py index d4193810..6b8a845c 100644 --- a/models/tts/debatts/try_inference_small_samples.py +++ b/models/tts/debatts/try_inference_small_samples.py @@ -5,8 +5,9 @@ import sys import os -os.chdir('./models/tts/debatts') -sys.path.append('./models/tts/debatts') + +os.chdir("./models/tts/debatts") +sys.path.append("./models/tts/debatts") from utils.g2p_new.g2p_new import new_g2p from transformers import Wav2Vec2Model @@ -36,6 +37,7 @@ from tqdm import tqdm from transformers import SeamlessM4TFeatureExtractor + processor = SeamlessM4TFeatureExtractor.from_pretrained("./ckpt/w2v-bert-2") from transformers import AutoProcessor, AutoModel @@ -43,34 +45,40 @@ from models.tts.text2semantic.t2s_model import T2SLlama from models.tts.text2semantic.t2s_model_new import T2SLlama_new from models.tts.text2semantic.t2s_sft_dataset_new import DownsampleWithMask - + + def new_g2p_(text, language): return new_g2p(text, language) + def build_t2s_model_new(cfg, device): - t2s_model = T2SLlama_new(phone_vocab_size=1024, - target_vocab_size=8192, - hidden_size=2048, - intermediate_size=8192, - pad_token_id=9216, - bos_target_id=9217, - eos_target_id=9218, - bos_phone_id=9219, - eos_phone_id=9220, - bos_prompt0_id=9221, - eos_prompt0_id=9222, - use_lang_emb=False) + t2s_model = T2SLlama_new( + phone_vocab_size=1024, + target_vocab_size=8192, + hidden_size=2048, + intermediate_size=8192, + pad_token_id=9216, + bos_target_id=9217, + eos_target_id=9218, + bos_phone_id=9219, + eos_phone_id=9220, + bos_prompt0_id=9221, + eos_prompt0_id=9222, + use_lang_emb=False, + ) t2s_model.eval() t2s_model.to(device) t2s_model.half() return t2s_model + def build_soundstorm(cfg, device): soundstorm_model = SoundStorm(cfg=cfg.model.soundstorm) soundstorm_model.eval() soundstorm_model.to(device) return soundstorm_model + def build_kmeans_model(cfg, device): if cfg.model.kmeans.type == "kmeans": kmeans_model = KMeans(cfg=cfg.model.kmeans.kmeans) @@ -79,7 +87,7 @@ def build_kmeans_model(cfg, device): elif cfg.model.kmeans.type == "repcodec": kmeans_model = RepCodec(cfg=cfg.model.kmeans.repcodec) kmeans_model.eval() - pretrained_path =cfg.model.kmeans.pretrained_path + pretrained_path = cfg.model.kmeans.pretrained_path if ".bin" in pretrained_path: kmeans_model.load_state_dict(torch.load(pretrained_path)) elif ".safetensors" in pretrained_path: @@ -87,6 +95,7 @@ def build_kmeans_model(cfg, device): kmeans_model.to(device) return kmeans_model + def build_semantic_model(cfg, device): semantic_model = Wav2Vec2BertModel.from_pretrained("./w2v-bert-2") semantic_model.eval() @@ -102,6 +111,7 @@ def build_semantic_model(cfg, device): return semantic_model, semantic_mean, semantic_std + def build_codec_model(cfg, device): codec_encoder = CodecEncoder(cfg=cfg.model.codec.encoder) codec_decoder = CodecDecoder(cfg=cfg.model.codec.decoder) @@ -113,14 +123,19 @@ def build_codec_model(cfg, device): torch.load(cfg.model.codec.decoder.pretrained_path) ) else: - accelerate.load_checkpoint_and_dispatch(codec_encoder, cfg.model.codec.encoder.pretrained_path) - accelerate.load_checkpoint_and_dispatch(codec_decoder, cfg.model.codec.decoder.pretrained_path) + accelerate.load_checkpoint_and_dispatch( + codec_encoder, cfg.model.codec.encoder.pretrained_path + ) + accelerate.load_checkpoint_and_dispatch( + codec_decoder, cfg.model.codec.decoder.pretrained_path + ) codec_encoder.eval() codec_decoder.eval() codec_encoder.to(device) codec_decoder.to(device) return codec_encoder, codec_decoder + @torch.no_grad() def extract_acoustic_code(speech): vq_emb = codec_encoder(speech.unsqueeze(1)) @@ -130,6 +145,7 @@ def extract_acoustic_code(speech): ) # (num_quantizer, T, C) -> (T, C, num_quantizer) return acoustic_code + @torch.no_grad() def extract_semantic_code(semantic_mean, semantic_std, input_features, attention_mask): vq_emb = semantic_model( @@ -143,6 +159,7 @@ def extract_semantic_code(semantic_mean, semantic_std, input_features, attention semantic_code, _ = kmeans_model.quantize(feat) # (B, T) return semantic_code + @torch.no_grad() def extract_features(speech, processor): inputs = processor(speech, sampling_rate=16000, return_tensors="pt") @@ -150,53 +167,90 @@ def extract_features(speech, processor): attention_mask = inputs["attention_mask"][0] return input_features, attention_mask + @torch.no_grad() -def text2semantic(prompt0_speech, prompt0_text, prompt_speech, prompt_text, prompt_language, target_text, target_language, use_prompt_text=True, temp=1.0, top_k=1000, top_p=0.85, infer_mode = "new"): +def text2semantic( + prompt0_speech, + prompt0_text, + prompt_speech, + prompt_text, + prompt_language, + target_text, + target_language, + use_prompt_text=True, + temp=1.0, + top_k=1000, + top_p=0.85, + infer_mode="new", +): if use_prompt_text: if infer_mode == "new" and prompt0_speech is not None and prompt0_speech.any(): prompt0_phone_id = new_g2p_(prompt0_text, prompt_language)[1] - prompt0_phone_id = torch.tensor(prompt0_phone_id, dtype=torch.long).to(device) + prompt0_phone_id = torch.tensor(prompt0_phone_id, dtype=torch.long).to( + device + ) prompt_phone_id = new_g2p_(prompt_text, prompt_language)[1] prompt_phone_id = torch.tensor(prompt_phone_id, dtype=torch.long).to(device) target_phone_id = new_g2p_(target_text, target_language)[1] - target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device) + target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device) - phone_id = torch.cat([prompt_phone_id, torch.LongTensor([4]).to(device), target_phone_id]) + phone_id = torch.cat( + [prompt_phone_id, torch.LongTensor([4]).to(device), target_phone_id] + ) else: target_phone_id = new_g2p_(target_text, target_language)[1] target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device) phone_id = target_phone_id - + input_fetures, attention_mask = extract_features(prompt_speech, processor) input_fetures = input_fetures.unsqueeze(0).to(device) attention_mask = attention_mask.unsqueeze(0).to(device) - semantic_code = extract_semantic_code(semantic_mean, semantic_std, input_fetures, attention_mask) - + semantic_code = extract_semantic_code( + semantic_mean, semantic_std, input_fetures, attention_mask + ) if infer_mode == "new": - input_fetures_prompt0, attention_mask_prompt0 = extract_features(prompt0_speech, processor) + input_fetures_prompt0, attention_mask_prompt0 = extract_features( + prompt0_speech, processor + ) input_fetures_prompt0 = input_fetures_prompt0.unsqueeze(0).to(device) attention_mask_prompt0 = attention_mask_prompt0.unsqueeze(0).to(device) attention_mask_prompt0 = attention_mask_prompt0.float() - semantic_code_prompt0 = extract_semantic_code(semantic_mean, semantic_std, input_fetures_prompt0, attention_mask_prompt0) - + semantic_code_prompt0 = extract_semantic_code( + semantic_mean, semantic_std, input_fetures_prompt0, attention_mask_prompt0 + ) + if use_prompt_text: if infer_mode == "new": - predict_semantic = t2s_model_new.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :], prompt0_ids=semantic_code_prompt0[:, :], temperature=temp, top_k=top_k, top_p=top_p) + predict_semantic = t2s_model_new.sample_hf( + phone_ids=phone_id.unsqueeze(0), + prompt_ids=semantic_code[:, :], + prompt0_ids=semantic_code_prompt0[:, :], + temperature=temp, + top_k=top_k, + top_p=top_p, + ) else: if infer_mode == "new": - predict_semantic = t2s_model_new.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :1], prompt0_ids=semantic_code_prompt0[:, :1], temperature=temp, top_k=top_k, top_p=top_p) - - - combine_semantic_code = torch.cat([semantic_code[:,:], predict_semantic], dim=-1) + predict_semantic = t2s_model_new.sample_hf( + phone_ids=phone_id.unsqueeze(0), + prompt_ids=semantic_code[:, :1], + prompt0_ids=semantic_code_prompt0[:, :1], + temperature=temp, + top_k=top_k, + top_p=top_p, + ) + + combine_semantic_code = torch.cat([semantic_code[:, :], predict_semantic], dim=-1) prompt_semantic_code = semantic_code - + return combine_semantic_code, prompt_semantic_code + @torch.no_grad() def semantic2acoustic(combine_semantic_code, acoustic_code): @@ -205,27 +259,44 @@ def semantic2acoustic(combine_semantic_code, acoustic_code): if soundstorm_1layer.cond_code_layers == 1: cond = soundstorm_1layer.cond_emb(semantic_code) else: - cond = soundstorm_1layer.cond_emb[0](semantic_code[0,:,:]) + cond = soundstorm_1layer.cond_emb[0](semantic_code[0, :, :]) for i in range(1, soundstorm_1layer.cond_code_layers): - cond += soundstorm_1layer.cond_emb[i](semantic_code[i,:,:]) - cond = cond / math.sqrt(soundstorm_1layer.cond_code_layers) - - prompt = acoustic_code[:,:,:] - predict_1layer = soundstorm_1layer.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=[40], cfg=1.0, rescale_cfg=1.0) + cond += soundstorm_1layer.cond_emb[i](semantic_code[i, :, :]) + cond = cond / math.sqrt(soundstorm_1layer.cond_code_layers) + + prompt = acoustic_code[:, :, :] + predict_1layer = soundstorm_1layer.reverse_diffusion( + cond=cond, + prompt=prompt, + temp=1.5, + filter_thres=0.98, + n_timesteps=[40], + cfg=1.0, + rescale_cfg=1.0, + ) if soundstorm_full.cond_code_layers == 1: cond = soundstorm_full.cond_emb(semantic_code) else: - cond = soundstorm_full.cond_emb[0](semantic_code[0,:,:]) + cond = soundstorm_full.cond_emb[0](semantic_code[0, :, :]) for i in range(1, soundstorm_full.cond_code_layers): - cond += soundstorm_full.cond_emb[i](semantic_code[i,:,:]) - cond = cond / math.sqrt(soundstorm_full.cond_code_layers) - - prompt = acoustic_code[:,:,:] - predict_full = soundstorm_full.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=[40,16,10,10,10,10,10,10,10,10,10,10], cfg=1.0, rescale_cfg=1.0, gt_code=predict_1layer) - vq_emb = codec_decoder.vq2emb(predict_full.permute(2,0,1), n_quantizers=12) + cond += soundstorm_full.cond_emb[i](semantic_code[i, :, :]) + cond = cond / math.sqrt(soundstorm_full.cond_code_layers) + + prompt = acoustic_code[:, :, :] + predict_full = soundstorm_full.reverse_diffusion( + cond=cond, + prompt=prompt, + temp=1.5, + filter_thres=0.98, + n_timesteps=[40, 16, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10], + cfg=1.0, + rescale_cfg=1.0, + gt_code=predict_1layer, + ) + vq_emb = codec_decoder.vq2emb(predict_full.permute(2, 0, 1), n_quantizers=12) recovered_audio = codec_decoder(vq_emb) - prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2,0,1), n_quantizers=12) + prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2, 0, 1), n_quantizers=12) recovered_prompt_audio = codec_decoder(prompt_vq_emb) recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy() recovered_audio = recovered_audio[0][0].cpu().numpy() @@ -233,19 +304,28 @@ def semantic2acoustic(combine_semantic_code, acoustic_code): return combine_audio, recovered_audio + device = torch.device("cuda:0") -cfg_soundstorm_1layer = load_config("./s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json") -cfg_soundstorm_full = load_config("./s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json") +cfg_soundstorm_1layer = load_config( + "./s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json" +) +cfg_soundstorm_full = load_config( + "./s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json" +) soundstorm_1layer = build_soundstorm(cfg_soundstorm_1layer, device) soundstorm_full = build_soundstorm(cfg_soundstorm_full, device) -semantic_model, semantic_mean, semantic_std = build_semantic_model(cfg_soundstorm_full, device) +semantic_model, semantic_mean, semantic_std = build_semantic_model( + cfg_soundstorm_full, device +) kmeans_model = build_kmeans_model(cfg_soundstorm_full, device) codec_encoder, codec_decoder = build_codec_model(cfg_soundstorm_full, device) -semantic_model, semantic_mean, semantic_std = build_semantic_model(cfg_soundstorm_full, device) +semantic_model, semantic_mean, semantic_std = build_semantic_model( + cfg_soundstorm_full, device +) kmeans_model = build_kmeans_model(cfg_soundstorm_full, device) @@ -254,40 +334,53 @@ def semantic2acoustic(combine_semantic_code, acoustic_code): safetensors.torch.load_model(soundstorm_1layer, soundstorm_1layer_path) safetensors.torch.load_model(soundstorm_full, soundstorm_full_path) -t2s_cfg = load_config("./t2s_egs/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json") +t2s_cfg = load_config( + "./t2s_egs/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json" +) t2s_model_new = build_t2s_model_new(t2s_cfg, device) t2s_model_new_ckpt_path = "./t2s_model/model.safetensors" safetensors.torch.load_model(t2s_model_new, t2s_model_new_ckpt_path) from funasr import AutoModel + print("Loading ASR model...") -asr_model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", vad_kwargs={"max_single_segment_time": 60000}, punc_model="ct-punc", device="cuda:0") +asr_model = AutoModel( + model="paraformer-zh", + vad_model="fsmn-vad", + vad_kwargs={"max_single_segment_time": 60000}, + punc_model="ct-punc", + device="cuda:0", +) + def adjust_punctuation(text): """ - Adjust the punctuation so that the comma is followed - by a space and the rest of the punctuation uses the + Adjust the punctuation so that the comma is followed + by a space and the rest of the punctuation uses the full Angle symbol. """ - text = text.replace(',', ', ') + text = text.replace(",", ", ") punct_mapping = { - '。': '。', - '?': '?', - '!': '!', - ':': ':', - ';': ';', - '“': '“', - '”': '”', - '‘': '‘', - '’': '’' + "。": "。", + "?": "?", + "!": "!", + ":": ":", + ";": ";", + "“": "“", + "”": "”", + "‘": "‘", + "’": "’", } for punct, full_punct in punct_mapping.items(): text = text.replace(punct, full_punct) return text + import random import zhconv + + def generate_text_data(wav_file): idx = random.randint(0, 7000) speech = librosa.load(wav_file, sr=16000)[0] @@ -295,15 +388,15 @@ def generate_text_data(wav_file): txt_json_param_path = wav_file.replace(".wav", "_asr_param.json") if os.path.exists(txt_json_path): - with open(txt_json_path, 'r', encoding='utf-8') as file: + with open(txt_json_path, "r", encoding="utf-8") as file: json_data = json.load(file) - + if "text" in json_data: txt = json_data["text"] txt = adjust_punctuation(txt) elif os.path.exists(txt_json_param_path): - with open(txt_json_param_path, 'r', encoding='utf-8') as file: + with open(txt_json_param_path, "r", encoding="utf-8") as file: json_data = json.load(file) if "text" in json_data: txt = json_data["text"] @@ -312,92 +405,199 @@ def generate_text_data(wav_file): else: res = asr_model.generate(input=wav_file, batch_size_s=300) txt = res[0]["text"] - txt = zhconv.convert(txt, 'zh-cn') + txt = zhconv.convert(txt, "zh-cn") txt = adjust_punctuation(txt) json_data["text"] = txt - with open(txt_json_path, 'w', encoding='utf-8') as file: + with open(txt_json_path, "w", encoding="utf-8") as file: json.dump(json_data, file, ensure_ascii=False, indent=4) # If no JSON file is found, generate new text and save it to a new JSON file else: res = asr_model.generate(input=wav_file, batch_size_s=300) txt = res[0]["text"] - txt = zhconv.convert(txt, 'zh-cn') + txt = zhconv.convert(txt, "zh-cn") txt = adjust_punctuation(txt) # txt = re.sub(" ", "", txt) - + json_data = {"text": txt} - with open(txt_json_path, 'w', encoding='utf-8') as file: + with open(txt_json_path, "w", encoding="utf-8") as file: json.dump(json_data, file, ensure_ascii=False, indent=4) return wav_file, txt, wav_file -def infer(speech_path, prompt_text, target_wav_path, target_text, target_language='zh', speech_path_prompt0=None, prompt0_text=None, temperature=0.2, top_k=20, top_p=0.9, concat_prompt=False, infer_mode="new", idx = 0, epoch=0, spk_prompt_type=""): +def infer( + speech_path, + prompt_text, + target_wav_path, + target_text, + target_language="zh", + speech_path_prompt0=None, + prompt0_text=None, + temperature=0.2, + top_k=20, + top_p=0.9, + concat_prompt=False, + infer_mode="new", + idx=0, + epoch=0, + spk_prompt_type="", +): if idx != 0: - save_dir = os.path.join("The Path to Store Generated Speech", f"{infer_mode}/{spk_prompt_type}") + save_dir = os.path.join( + "The Path to Store Generated Speech", f"{infer_mode}/{spk_prompt_type}" + ) if not os.path.exists(save_dir): os.mkdir(save_dir) - save_path = os.path.join(save_dir, f"{os.path.splitext(os.path.basename(target_wav_path))[0]}_infer_{infer_mode}_{idx}_epoch_{epoch}_{spk_prompt_type}.wav") + save_path = os.path.join( + save_dir, + f"{os.path.splitext(os.path.basename(target_wav_path))[0]}_infer_{infer_mode}_{idx}_epoch_{epoch}_{spk_prompt_type}.wav", + ) else: - save_dir = os.path.join("The Path to Store Generated Speech", f"{infer_mode}/{spk_prompt_type}") + save_dir = os.path.join( + "The Path to Store Generated Speech", f"{infer_mode}/{spk_prompt_type}" + ) if not os.path.exists(save_dir): os.mkdir(save_dir) - save_path = os.path.join(save_dir, f"{os.path.splitext(os.path.basename(target_wav_path))[0]}_infer_{infer_mode}_epoch_{epoch}_{spk_prompt_type}.wav") - + save_path = os.path.join( + save_dir, + f"{os.path.splitext(os.path.basename(target_wav_path))[0]}_infer_{infer_mode}_epoch_{epoch}_{spk_prompt_type}.wav", + ) + if os.path.exists(save_path): return save_path - + # print(f"HERE COMES INFER!!! {infer_mode}") # print(f"IN INFER PROMPT text is {prompt_text}") # print(f"IN INFER Target text is {target_text}") speech_16k = librosa.load(speech_path, sr=16000)[0] - speech = librosa.load(speech_path, sr=cfg_soundstorm_1layer.preprocess.sample_rate)[0] + speech = librosa.load(speech_path, sr=cfg_soundstorm_1layer.preprocess.sample_rate)[ + 0 + ] if infer_mode == "new": speech_16k_prompt0 = librosa.load(speech_path_prompt0, sr=16000)[0] - speech_prompt0 = librosa.load(speech_path_prompt0, sr=cfg_soundstorm_1layer.preprocess.sample_rate)[0] - combine_semantic_code, _ = text2semantic(prompt0_speech=speech_16k_prompt0, prompt0_text=prompt0_text, prompt_speech=speech_16k, prompt_text=prompt_text, prompt_language=target_language, target_text=target_text, target_language=target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode = infer_mode) + speech_prompt0 = librosa.load( + speech_path_prompt0, sr=cfg_soundstorm_1layer.preprocess.sample_rate + )[0] + combine_semantic_code, _ = text2semantic( + prompt0_speech=speech_16k_prompt0, + prompt0_text=prompt0_text, + prompt_speech=speech_16k, + prompt_text=prompt_text, + prompt_language=target_language, + target_text=target_text, + target_language=target_language, + temp=temperature, + top_k=top_k, + top_p=top_p, + infer_mode=infer_mode, + ) else: - combine_semantic_code, _ = text2semantic(prompt0_speech=None, prompt0_text=None, prompt_speech=speech_16k, prompt_text=prompt_text, prompt_language = target_language, target_text=target_text, target_language=target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode=infer_mode) + combine_semantic_code, _ = text2semantic( + prompt0_speech=None, + prompt0_text=None, + prompt_speech=speech_16k, + prompt_text=prompt_text, + prompt_language=target_language, + target_text=target_text, + target_language=target_language, + temp=temperature, + top_k=top_k, + top_p=top_p, + infer_mode=infer_mode, + ) acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device)) - combine_audio, recovered_audio = semantic2acoustic(combine_semantic_code, acoustic_code) - - + combine_audio, recovered_audio = semantic2acoustic( + combine_semantic_code, acoustic_code + ) + if not concat_prompt: - combine_audio = combine_audio[speech.shape[-1]:] + combine_audio = combine_audio[speech.shape[-1] :] # sf.write(os.path.join(save_path, "{}.wav".format(uid)), recovered_audio, samplerate=cfg_soundstorm_1layer.preprocess.sample_rate) - sf.write(save_path, combine_audio, samplerate=cfg_soundstorm_1layer.preprocess.sample_rate) + sf.write( + save_path, + combine_audio, + samplerate=cfg_soundstorm_1layer.preprocess.sample_rate, + ) return save_path -def infer_small(speech_path, prompt_text, target_text, target_language='zh', speech_path_prompt0=None, prompt0_text=None, temperature=0.2, top_k=20, top_p=0.9, concat_prompt=False, infer_mode="new", save_path=None): + +def infer_small( + speech_path, + prompt_text, + target_text, + target_language="zh", + speech_path_prompt0=None, + prompt0_text=None, + temperature=0.2, + top_k=20, + top_p=0.9, + concat_prompt=False, + infer_mode="new", + save_path=None, +): if os.path.exists(save_path): return save_path - + speech_16k = librosa.load(speech_path, sr=16000)[0] - speech = librosa.load(speech_path, sr=cfg_soundstorm_1layer.preprocess.sample_rate)[0] + speech = librosa.load(speech_path, sr=cfg_soundstorm_1layer.preprocess.sample_rate)[ + 0 + ] if infer_mode == "new": speech_16k_prompt0 = librosa.load(speech_path_prompt0, sr=16000)[0] - speech_prompt0 = librosa.load(speech_path_prompt0, sr=cfg_soundstorm_1layer.preprocess.sample_rate)[0] + speech_prompt0 = librosa.load( + speech_path_prompt0, sr=cfg_soundstorm_1layer.preprocess.sample_rate + )[0] # combine_semantic_code, _ = text2semantic_new(speech_16k_prompt0, prompt0_text, speech_16k, prompt_text, target_language, target_text, target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode=infer_mode) - combine_semantic_code, _ = text2semantic(prompt0_speech=speech_16k_prompt0, prompt0_text=prompt0_text, prompt_speech=speech_16k, prompt_text=prompt_text, prompt_language=target_language, target_text=target_text, target_language=target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode = infer_mode) + combine_semantic_code, _ = text2semantic( + prompt0_speech=speech_16k_prompt0, + prompt0_text=prompt0_text, + prompt_speech=speech_16k, + prompt_text=prompt_text, + prompt_language=target_language, + target_text=target_text, + target_language=target_language, + temp=temperature, + top_k=top_k, + top_p=top_p, + infer_mode=infer_mode, + ) else: - combine_semantic_code, _ = text2semantic(prompt0_speech=None, prompt0_text=None, prompt_speech=speech_16k, prompt_text=prompt_text, prompt_language = target_language, target_text=target_text, target_language=target_language, temp=temperature, top_k=top_k, top_p=top_p, infer_mode=infer_mode) + combine_semantic_code, _ = text2semantic( + prompt0_speech=None, + prompt0_text=None, + prompt_speech=speech_16k, + prompt_text=prompt_text, + prompt_language=target_language, + target_text=target_text, + target_language=target_language, + temp=temperature, + top_k=top_k, + top_p=top_p, + infer_mode=infer_mode, + ) acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device)) - combine_audio, recovered_audio = semantic2acoustic(combine_semantic_code, acoustic_code) - - + combine_audio, recovered_audio = semantic2acoustic( + combine_semantic_code, acoustic_code + ) + if not concat_prompt: - combine_audio = combine_audio[speech.shape[-1]:] + combine_audio = combine_audio[speech.shape[-1] :] # sf.write(os.path.join(save_path, "{}.wav".format(uid)), recovered_audio, samplerate=cfg_soundstorm_1layer.preprocess.sample_rate) - sf.write(save_path, combine_audio, samplerate=cfg_soundstorm_1layer.preprocess.sample_rate) + sf.write( + save_path, + combine_audio, + samplerate=cfg_soundstorm_1layer.preprocess.sample_rate, + ) return save_path + ##################################### EVALUATION ################################################################ from funasr import AutoModel import torch.nn.functional as F @@ -416,4 +616,12 @@ def infer_small(speech_path, prompt_text, target_text, target_language='zh', spe save_path_dir = "The Path to Save Generated Speech" wav_filename = "The Filename of Generated Speech" save_path = os.path.join(save_path_infer_dir, wav_filename) -save_path = infer_small(speech_path=spk_prompt_wav_path, prompt_text = spk_prompt_text, target_text=target_text, speech_path_prompt0 = prompt0_wav_path, prompt0_text = prompt0_text, infer_mode = "new", save_path = save_path) +save_path = infer_small( + speech_path=spk_prompt_wav_path, + prompt_text=spk_prompt_text, + target_text=target_text, + speech_path_prompt0=prompt0_wav_path, + prompt0_text=prompt0_text, + infer_mode="new", + save_path=save_path, +) diff --git a/models/tts/debatts/utils/g2p_new/__init__.py b/models/tts/debatts/utils/g2p_new/__init__.py index 67e6fded..80b61797 100644 --- a/models/tts/debatts/utils/g2p_new/__init__.py +++ b/models/tts/debatts/utils/g2p_new/__init__.py @@ -4,63 +4,66 @@ import json import re + class PhonemeBpeTokenizer: - def __init__(self, vacab_path="./utils/g2p_new/vacab.json"): - self.lang2backend = { - 'zh': "cmn", - 'ja': "ja", - "en": "en-us", - "fr": "fr-fr", - "ko": "ko", - "de": "de", - } - self.text_tokenizers = {} - self.int_text_tokenizers() - # TODO - vacab_path="/mntcephfs/lab_data/lijiaqi/Speech/utils/g2p_new/vacab.json" - with open(vacab_path, 'rb') as f: - json_data = f.read() - data = json.loads(json_data) - self.vocab = data['vocab'] + def __init__(self, vacab_path="./utils/g2p_new/vacab.json"): + self.lang2backend = { + "zh": "cmn", + "ja": "ja", + "en": "en-us", + "fr": "fr-fr", + "ko": "ko", + "de": "de", + } + self.text_tokenizers = {} + self.int_text_tokenizers() + # TODO + vacab_path = "/mntcephfs/lab_data/lijiaqi/Speech/utils/g2p_new/vacab.json" + with open(vacab_path, "rb") as f: + json_data = f.read() + data = json.loads(json_data) + self.vocab = data["vocab"] - def int_text_tokenizers(self): - for key, value in self.lang2backend.items(): - self.text_tokenizers[key] = TextTokenizer(language=value) + def int_text_tokenizers(self): + for key, value in self.lang2backend.items(): + self.text_tokenizers[key] = TextTokenizer(language=value) - def tokenize(self, text, language): + def tokenize(self, text, language): - # 1. convert text to phoneme - phonemes = self._clean_text(text, language, ['cjekfd_cleaners']) - # print('clean text: ', phonemes) + # 1. convert text to phoneme + phonemes = self._clean_text(text, language, ["cjekfd_cleaners"]) + # print('clean text: ', phonemes) - # 2. tokenize phonemes - phoneme_tokens = self.phoneme2token(phonemes) - # print('encode: ', phoneme_tokens) + # 2. tokenize phonemes + phoneme_tokens = self.phoneme2token(phonemes) + # print('encode: ', phoneme_tokens) - # # 3. decode tokens [optional] - # decoded_text = self.tokenizer.decode(phoneme_tokens) - # print('decoded: ', decoded_text) + # # 3. decode tokens [optional] + # decoded_text = self.tokenizer.decode(phoneme_tokens) + # print('decoded: ', decoded_text) - return phonemes, phoneme_tokens + return phonemes, phoneme_tokens - def _clean_text(self, text, language, cleaner_names): + def _clean_text(self, text, language, cleaner_names): - for name in cleaner_names: - cleaner = getattr(cleaners, name) - if not cleaner: - raise Exception('Unknown cleaner: %s' % name) - text = cleaner(text, language, self.text_tokenizers) - return text + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception("Unknown cleaner: %s" % name) + text = cleaner(text, language, self.text_tokenizers) + return text - def phoneme2token(self, phonemes): - # 使用的是国际音标,可以将音素转化成token。实际上输入的phone id也是将音频先asr成文本再转成token的,使用的是同一套vocab体系 - tokens = [] - if isinstance(phonemes, list): - for phone in phonemes: - phonemes_split = phone.split("|") - tokens.append([self.vocab[p] for p in phonemes_split if p in self.vocab]) - else: - phonemes_split = phonemes.split("|") - tokens = [self.vocab[p] for p in phonemes_split if p in self.vocab] - return tokens \ No newline at end of file + def phoneme2token(self, phonemes): + # 使用的是国际音标,可以将音素转化成token。实际上输入的phone id也是将音频先asr成文本再转成token的,使用的是同一套vocab体系 + tokens = [] + if isinstance(phonemes, list): + for phone in phonemes: + phonemes_split = phone.split("|") + tokens.append( + [self.vocab[p] for p in phonemes_split if p in self.vocab] + ) + else: + phonemes_split = phonemes.split("|") + tokens = [self.vocab[p] for p in phonemes_split if p in self.vocab] + return tokens diff --git a/models/tts/debatts/utils/g2p_new/cleaners.py b/models/tts/debatts/utils/g2p_new/cleaners.py index 58755726..4a9509a4 100644 --- a/models/tts/debatts/utils/g2p_new/cleaners.py +++ b/models/tts/debatts/utils/g2p_new/cleaners.py @@ -6,20 +6,21 @@ from utils.g2p_new.korean import korean_to_ipa from utils.g2p_new.german import german_to_ipa + def cjekfd_cleaners(text, language, text_tokenizers): - if language == 'zh': - return chinese_to_ipa(text, text_tokenizers['zh']) - elif language == 'ja': - return japanese_to_ipa(text, text_tokenizers['ja']) - elif language == 'en': - return english_to_ipa(text, text_tokenizers['en']) - elif language == 'fr': - return french_to_ipa(text, text_tokenizers['fr']) - elif language == 'ko': - return korean_to_ipa(text, text_tokenizers['ko']) - elif language == 'de': - return german_to_ipa(text, text_tokenizers['de']) + if language == "zh": + return chinese_to_ipa(text, text_tokenizers["zh"]) + elif language == "ja": + return japanese_to_ipa(text, text_tokenizers["ja"]) + elif language == "en": + return english_to_ipa(text, text_tokenizers["en"]) + elif language == "fr": + return french_to_ipa(text, text_tokenizers["fr"]) + elif language == "ko": + return korean_to_ipa(text, text_tokenizers["ko"]) + elif language == "de": + return german_to_ipa(text, text_tokenizers["de"]) else: - raise Exception('Unknown language: %s' % language) + raise Exception("Unknown language: %s" % language) return None diff --git a/models/tts/debatts/utils/g2p_new/english.py b/models/tts/debatts/utils/g2p_new/english.py index 5951814e..64442c4c 100644 --- a/models/tts/debatts/utils/g2p_new/english.py +++ b/models/tts/debatts/utils/g2p_new/english.py @@ -1,94 +1,102 @@ import re from unidecode import unidecode import inflect -''' + +""" Text clean time -''' +""" _inflect = inflect.engine() -_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') -_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') -_percent_number_re = re.compile(r'([0-9\.\,]*[0-9]+%)') -_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') -_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') -_fraction_re = re.compile(r'([0-9]+)/([0-9]+)') -_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') -_number_re = re.compile(r'[0-9]+') +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_fraction_re = re.compile(r"([0-9]+)/([0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") # List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [(re.compile('\\b%s\\b' % x[0], re.IGNORECASE), x[1]) for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), - ('etc', 'et cetera'), - ('btw', 'by the way'), -]] +_abbreviations = [ + (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ("etc", "et cetera"), + ("btw", "by the way"), + ] +] _special_map = [ - ('t|ɹ', 'tɹ'), - ('d|ɹ', 'dɹ'), - ('t|s', 'ts'), - ('d|z', 'dz'), - ('ɪ|ɹ', 'ɪɹ'), - ('ɐ', 'ɚ'), - ('ᵻ', 'ɪ'), - ('əl', 'l'), - ('x', 'k'), - ('ɬ', 'l'), - ('ʔ', 't'), - ('n̩', 'n'), - ('oː|ɹ', 'oːɹ') + ("t|ɹ", "tɹ"), + ("d|ɹ", "dɹ"), + ("t|s", "ts"), + ("d|z", "dz"), + ("ɪ|ɹ", "ɪɹ"), + ("ɐ", "ɚ"), + ("ᵻ", "ɪ"), + ("əl", "l"), + ("x", "k"), + ("ɬ", "l"), + ("ʔ", "t"), + ("n̩", "n"), + ("oː|ɹ", "oːɹ"), ] + def expand_abbreviations(text): for regex, replacement in _abbreviations: text = re.sub(regex, replacement, text) return text + def _remove_commas(m): - return m.group(1).replace(',', '') + return m.group(1).replace(",", "") def _expand_decimal_point(m): - return m.group(1).replace('.', ' point ') + return m.group(1).replace(".", " point ") + def _expand_percent(m): - return m.group(1).replace('%', ' percent ') + return m.group(1).replace("%", " percent ") def _expand_dollars(m): match = m.group(1) - parts = match.split('.') + parts = match.split(".") if len(parts) > 2: - return match + ' dollars' # Unexpected format + return match + " dollars" # Unexpected format dollars = int(parts[0]) if parts[0] else 0 cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 if dollars and cents: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) elif dollars: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - return '%s %s' % (dollars, dollar_unit) + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) elif cents: - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s' % (cents, cent_unit) + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) else: - return 'zero dollars' + return "zero dollars" + def fraction_to_words(numerator, denominator): if numerator == 1 and denominator == 2: @@ -99,34 +107,48 @@ def fraction_to_words(numerator, denominator): return _inflect.number_to_words(numerator) + " halves" if denominator == 4: return _inflect.number_to_words(numerator) + " quarters" - return _inflect.number_to_words(numerator) + " " + _inflect.ordinal(_inflect.number_to_words(denominator)) + return ( + _inflect.number_to_words(numerator) + + " " + + _inflect.ordinal(_inflect.number_to_words(denominator)) + ) + def _expand_fraction(m): numerator = int(m.group(1)) denominator = int(m.group(2)) return fraction_to_words(numerator, denominator) + def _expand_ordinal(m): return _inflect.number_to_words(m.group(0)) + def _expand_number(m): num = int(m.group(0)) if num > 1000 and num < 3000: if num == 2000: - return ' two thousand ' + return " two thousand " elif num > 2000 and num < 2010: - return ' two thousand ' + _inflect.number_to_words(num % 100) + ' ' + return " two thousand " + _inflect.number_to_words(num % 100) + " " elif num % 100 == 0: - return ' ' + _inflect.number_to_words(num // 100) + ' hundred ' + return " " + _inflect.number_to_words(num // 100) + " hundred " else: - return ' ' + _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') + ' ' + return ( + " " + + _inflect.number_to_words(num, andword="", zero="oh", group=2).replace( + ", ", " " + ) + + " " + ) else: - return ' ' + _inflect.number_to_words(num, andword='') + ' ' + return " " + _inflect.number_to_words(num, andword="") + " " + # Normalize numbers pronunciation def normalize_numbers(text): text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_pounds_re, r"\1 pounds", text) text = re.sub(_dollars_re, _expand_dollars, text) text = re.sub(_fraction_re, _expand_fraction, text) text = re.sub(_decimal_number_re, _expand_decimal_point, text) @@ -135,21 +157,26 @@ def normalize_numbers(text): text = re.sub(_number_re, _expand_number, text) return text + def _english_to_ipa(text): # text = unidecode(text).lower() text = expand_abbreviations(text) text = normalize_numbers(text) return text + # special map def special_map(text): for regex, replacement in _special_map: regex = regex.replace("|", "\|") - while re.search(r'(^|[_|]){}([_|]|$)'.format(regex), text): - text = re.sub(r'(^|[_|]){}([_|]|$)'.format(regex), r'\1{}\2'.format(replacement), text) + while re.search(r"(^|[_|]){}([_|]|$)".format(regex), text): + text = re.sub( + r"(^|[_|]){}([_|]|$)".format(regex), r"\1{}\2".format(replacement), text + ) # text = re.sub(r'([,.!?])', r'|\1', text) return text + # Add some special operation def english_to_ipa(text, text_tokenizer): if type(text) == str: @@ -163,4 +190,4 @@ def english_to_ipa(text, text_tokenizer): result_ph = [] for phone in phonemes: result_ph.append(special_map(phone)) - return result_ph \ No newline at end of file + return result_ph diff --git a/models/tts/debatts/utils/g2p_new/french.py b/models/tts/debatts/utils/g2p_new/french.py index b1337860..9c059e08 100644 --- a/models/tts/debatts/utils/g2p_new/french.py +++ b/models/tts/debatts/utils/g2p_new/french.py @@ -1,4 +1,5 @@ -'''https://github.com/bootphon/phonemizer''' +"""https://github.com/bootphon/phonemizer""" + import re from phonemizer import phonemize from phonemizer.separator import Separator @@ -82,35 +83,42 @@ "~": "-", "「": "", "」": "", - "¿" : "", - "¡" : "" + "¿": "", + "¡": "", } -_special_map = [(re.compile('%s' % x[0]), x[1]) for x in [ - ('ø', 'ɸ'), - ('ː', ':'), - ('j', 'jˈ'), # To avoid incorrect connect - ('n', 'ˈn'), # To avoid incorrect connect - ('w', 'wˈ'), # To avoid incorrect connect - ('ã', 'a~'), - ('ɑ̃', 'ɑ~'), - ('ɔ̃', 'ɔ~'), - ('ɛ̃', 'ɛ~'), - ('œ̃', 'œ~'), -]] +_special_map = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("ø", "ɸ"), + ("ː", ":"), + ("j", "jˈ"), # To avoid incorrect connect + ("n", "ˈn"), # To avoid incorrect connect + ("w", "wˈ"), # To avoid incorrect connect + ("ã", "a~"), + ("ɑ̃", "ɑ~"), + ("ɔ̃", "ɔ~"), + ("ɛ̃", "ɛ~"), + ("œ̃", "œ~"), + ] +] + def collapse_whitespace(text): # Regular expression matching whitespace: _whitespace_re = re.compile(r"\s+") return re.sub(_whitespace_re, " ", text).strip() + def remove_punctuation_at_begin(text): - return re.sub(r'^[,.!?]+', '', text) + return re.sub(r"^[,.!?]+", "", text) + def remove_aux_symbols(text): text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) return text + def replace_symbols(text): text = text.replace(";", ",") text = text.replace("-", " ") @@ -118,16 +126,19 @@ def replace_symbols(text): text = text.replace("&", " et ") return text + def expand_abbreviations(text): for regex, replacement in _abbreviations: text = re.sub(regex, replacement, text) return text + def replace_punctuation(text): pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) return replaced_text + def text_normalize(text): text = expand_abbreviations(text) text = replace_punctuation(text) @@ -135,30 +146,33 @@ def text_normalize(text): text = remove_aux_symbols(text) text = remove_punctuation_at_begin(text) text = collapse_whitespace(text) - text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text) + text = re.sub(r"([^\.,!\?\-…])$", r"\1.", text) return text + # special map def special_map(text): for regex, replacement in _special_map: text = re.sub(regex, replacement, text) return text + def french_to_ipa(text): text = text_normalize(text) - ipa = phonemize(text.strip(), - language="fr-fr", - backend="espeak", - separator=Separator(phone=None, word=' ', syllable='|'), - strip=True, - preserve_punctuation=True, - njobs=4) - + ipa = phonemize( + text.strip(), + language="fr-fr", + backend="espeak", + separator=Separator(phone=None, word=" ", syllable="|"), + strip=True, + preserve_punctuation=True, + njobs=4, + ) + # remove "(en)" and "(fr)" tag ipa = ipa.replace("(en)", "").replace("(fr)", "") ipa = special_map(ipa) return ipa - \ No newline at end of file diff --git a/models/tts/debatts/utils/g2p_new/g2p_new.py b/models/tts/debatts/utils/g2p_new/g2p_new.py index 332bec80..eaa6ebbd 100644 --- a/models/tts/debatts/utils/g2p_new/g2p_new.py +++ b/models/tts/debatts/utils/g2p_new/g2p_new.py @@ -3,5 +3,6 @@ text_tokenizer = PhonemeBpeTokenizer() + def new_g2p(text, language): - return text_tokenizer.tokenize(text=text, language=language) \ No newline at end of file + return text_tokenizer.tokenize(text=text, language=language) diff --git a/models/tts/debatts/utils/g2p_new/german.py b/models/tts/debatts/utils/g2p_new/german.py index ef54e4ef..3f9259a3 100644 --- a/models/tts/debatts/utils/g2p_new/german.py +++ b/models/tts/debatts/utils/g2p_new/german.py @@ -1,4 +1,5 @@ -'''https://github.com/bootphon/phonemizer''' +"""https://github.com/bootphon/phonemizer""" + import re from phonemizer import phonemize from phonemizer.separator import Separator @@ -35,74 +36,87 @@ "~": "-", "「": "", "」": "", - "¿" : "", - "¡" : "" + "¿": "", + "¡": "", } -_special_map = [(re.compile('%s' % x[0]), x[1]) for x in [ - ('ø', 'ɸ'), - ('ː', ':'), - ('ɜ', 'ʒ'), - ('ɑ̃', 'ɑ~'), - ('j', 'jˈ'), # To avoid incorrect connect - ('n', 'ˈn'), # To avoid incorrect connect - ('t', 'tˈ'), # To avoid incorrect connect - ('ŋ', 'ˈŋ'), # To avoid incorrect connect - ('ɪ', 'ˈɪ'), # To avoid incorrect connect -]] +_special_map = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("ø", "ɸ"), + ("ː", ":"), + ("ɜ", "ʒ"), + ("ɑ̃", "ɑ~"), + ("j", "jˈ"), # To avoid incorrect connect + ("n", "ˈn"), # To avoid incorrect connect + ("t", "tˈ"), # To avoid incorrect connect + ("ŋ", "ˈŋ"), # To avoid incorrect connect + ("ɪ", "ˈɪ"), # To avoid incorrect connect + ] +] + def collapse_whitespace(text): # Regular expression matching whitespace: _whitespace_re = re.compile(r"\s+") return re.sub(_whitespace_re, " ", text).strip() + def remove_punctuation_at_begin(text): - return re.sub(r'^[,.!?]+', '', text) + return re.sub(r"^[,.!?]+", "", text) + def remove_aux_symbols(text): text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) return text + def replace_symbols(text): text = text.replace(";", ",") text = text.replace("-", " ") text = text.replace(":", ",") return text + def replace_punctuation(text): pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) return replaced_text + def text_normalize(text): text = replace_punctuation(text) text = replace_symbols(text) text = remove_aux_symbols(text) text = remove_punctuation_at_begin(text) text = collapse_whitespace(text) - text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text) + text = re.sub(r"([^\.,!\?\-…])$", r"\1.", text) return text + # special map def special_map(text): for regex, replacement in _special_map: text = re.sub(regex, replacement, text) return text + def german_to_ipa(text): text = text_normalize(text) - ipa = phonemize(text.strip(), - language="de", - backend="espeak", - separator=Separator(phone=None, word=' ', syllable='|'), - strip=True, - preserve_punctuation=True, - njobs=4) - + ipa = phonemize( + text.strip(), + language="de", + backend="espeak", + separator=Separator(phone=None, word=" ", syllable="|"), + strip=True, + preserve_punctuation=True, + njobs=4, + ) + # remove "(en)" and "(fr)" tag ipa = ipa.replace("(en)", "").replace("(de)", "") ipa = special_map(ipa) - return ipa \ No newline at end of file + return ipa diff --git a/models/tts/debatts/utils/g2p_new/japanese.py b/models/tts/debatts/utils/g2p_new/japanese.py index b33dd502..ed1dc2e3 100644 --- a/models/tts/debatts/utils/g2p_new/japanese.py +++ b/models/tts/debatts/utils/g2p_new/japanese.py @@ -1,55 +1,65 @@ """from https://github.com/Plachtaa/VALL-E-X/g2p""" + import re from unidecode import unidecode -''' +""" Text clean time -''' +""" # Regular expression matching Japanese without punctuation marks: _japanese_characters = re.compile( - r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') + r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" +) # Regular expression matching non-Japanese characters or punctuation marks: _japanese_marks = re.compile( - r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') + r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" +) # List of (symbol, Japanese) pairs for marks: -_symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ - ('%', 'パーセント') -]] +_symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("%", "パーセント")]] # List of (romaji, ipa2) pairs for marks: -_romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ - ('u', 'ɯ'), - ('ʧ', 'tʃ'), - ('j', 'dʑ'), - ('y', 'j'), - ('ni', 'n^i'), - ('nj', 'n^'), - ('hi', 'çi'), - ('hj', 'ç'), - ('f', 'ɸ'), - ('I', 'i*'), - ('U', 'ɯ*'), - ('r', 'ɾ') -]] +_romaji_to_ipa2 = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("u", "ɯ"), + ("ʧ", "tʃ"), + ("j", "dʑ"), + ("y", "j"), + ("ni", "n^i"), + ("nj", "n^"), + ("hi", "çi"), + ("hj", "ç"), + ("f", "ɸ"), + ("I", "i*"), + ("U", "ɯ*"), + ("r", "ɾ"), + ] +] # List of (consonant, sokuon) pairs: -_real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ - (r'Q([↑↓]*[kg])', r'k#\1'), - (r'Q([↑↓]*[tdjʧ])', r't#\1'), - (r'Q([↑↓]*[sʃ])', r's\1'), - (r'Q([↑↓]*[pb])', r'p#\1') -]] +_real_sokuon = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + (r"Q([↑↓]*[kg])", r"k#\1"), + (r"Q([↑↓]*[tdjʧ])", r"t#\1"), + (r"Q([↑↓]*[sʃ])", r"s\1"), + (r"Q([↑↓]*[pb])", r"p#\1"), + ] +] # List of (consonant, hatsuon) pairs: -_real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ - (r'N([↑↓]*[pbm])', r'm\1'), - (r'N([↑↓]*[ʧʥj])', r'n^\1'), - (r'N([↑↓]*[tdn])', r'n\1'), - (r'N([↑↓]*[kg])', r'ŋ\1') -]] +_real_hatsuon = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + (r"N([↑↓]*[pbm])", r"m\1"), + (r"N([↑↓]*[ʧʥj])", r"n^\1"), + (r"N([↑↓]*[tdn])", r"n\1"), + (r"N([↑↓]*[kg])", r"ŋ\1"), + ] +] def symbols_to_japanese(text): @@ -59,44 +69,45 @@ def symbols_to_japanese(text): def japanese_to_romaji_with_accent(text): - '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' + """Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html""" import pyopenjtalk + text = symbols_to_japanese(text) sentences = re.split(_japanese_marks, text) marks = re.findall(_japanese_marks, text) - text = '' + text = "" for i, sentence in enumerate(sentences): if re.match(_japanese_characters, sentence): - if text != '': - text += ' ' + if text != "": + text += " " labels = pyopenjtalk.extract_fullcontext(sentence) for n, label in enumerate(labels): - phoneme = re.search(r'\-([^\+]*)\+', label).group(1) - if phoneme not in ['sil', 'pau']: - text += phoneme.replace('ch', 'ʧ').replace('sh', - 'ʃ').replace('cl', 'Q') + phoneme = re.search(r"\-([^\+]*)\+", label).group(1) + if phoneme not in ["sil", "pau"]: + text += ( + phoneme.replace("ch", "ʧ").replace("sh", "ʃ").replace("cl", "Q") + ) else: continue # n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) a2 = int(re.search(r"\+(\d+)\+", label).group(1)) a3 = int(re.search(r"\+(\d+)/", label).group(1)) - if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']: + if re.search(r"\-([^\+]*)\+", labels[n + 1]).group(1) in ["sil", "pau"]: a2_next = -1 else: - a2_next = int( - re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) + a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) # Accent phrase boundary if a3 == 1 and a2_next == 1: - text += ' ' + text += " " # Falling elif a1 == 0 and a2_next == a2 + 1: - text += '↓' + text += "↓" # Rising elif a2 == 1 and a2_next == 2: - text += '↑' + text += "↑" if i < len(marks): - text += unidecode(marks[i]).replace(' ', '') + text += unidecode(marks[i]).replace(" ", "") return text @@ -111,17 +122,21 @@ def get_real_hatsuon(text): text = re.sub(regex, replacement, text) return text + def japanese_to_ipa(text): - text = japanese_to_romaji_with_accent(text).replace('...', '…') + text = japanese_to_romaji_with_accent(text).replace("...", "…") text = get_real_sokuon(text) text = get_real_hatsuon(text) for regex, replacement in _romaji_to_ipa2: text = re.sub(regex, replacement, text) return text -''' + +""" Phoneme merge time -''' +""" + + def _connect_tone(phoneme_tokens, vocab): tone_list = ["→", "↑", "↓↑", "↓"] @@ -142,13 +157,19 @@ def _connect_tone(phoneme_tokens, vocab): for t in phoneme_tokens: cur_token = t if t in tone_token: - cur_token = last_single_token + (pre_token - base) * len(tone_list) + tone_token.index(t) + 1 + cur_token = ( + last_single_token + + (pre_token - base) * len(tone_list) + + tone_token.index(t) + + 1 + ) res_token.pop() res_token.append(cur_token) pre_token = t return res_token + def japanese_merge_phoneme(phoneme_tokens, vocab): phoneme_tokens = _connect_tone(phoneme_tokens, vocab) - return phoneme_tokens \ No newline at end of file + return phoneme_tokens diff --git a/models/tts/debatts/utils/g2p_new/korean.py b/models/tts/debatts/utils/g2p_new/korean.py index ebf9ecf5..60ba2e13 100644 --- a/models/tts/debatts/utils/g2p_new/korean.py +++ b/models/tts/debatts/utils/g2p_new/korean.py @@ -1,5 +1,7 @@ -'''https://github.com/bootphon/phonemizer''' +"""https://github.com/bootphon/phonemizer""" + import re + # from g2pkk import G2p # from jamo import hangul_to_jamo @@ -44,64 +46,74 @@ } # List of (jamo, ipa) pairs: (need to update) -_jamo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ - ('ㅏ', 'ɐ'), - ('ㅑ', 'jɐ'), - ('ㅓ', 'ʌ'), - ('ㅕ', 'jʌ'), - ('ㅗ', 'o'), - ('ㅛ', 'jo'), - ('ᅮ', 'u'), - ('ㅠ', 'ju'), - ('ᅳ', 'ɯ'), - ('ㅣ', 'i'), - ('ㅔ', 'e'), - ('ㅐ', 'ɛ'), - ('ㅖ', 'je'), - ('ㅒ', 'jɛ'), # lost - ('ㅚ', 'we'), - ('ㅟ', 'wi'), - ('ㅢ', 'ɯj'), - ('ㅘ', 'wɐ'), - ('ㅙ', 'wɛ'), # lost - ('ㅝ', 'wʌ'), - ('ㅞ', 'wɛ'), # lost - ('ㄱ', 'q'), # 'ɡ' or 'k' - ('ㄴ', 'n'), - ('ㄷ', 't'), # d - ('ㄹ', 'ɫ'), # 'ᄅ' is 'r', 'ᆯ' is 'ɫ' - ('ㅁ', 'm'), - ('ㅂ', 'p'), - ('ㅅ', 's'), # 'ᄉ'is 't', 'ᆺ'is 's' - ('ㅇ', 'ŋ'), # 'ᄋ' is None, 'ᆼ' is 'ŋ' - ('ㅈ', 'tɕ'), - ('ㅊ', 'tɕʰ'), # tʃh - ('ㅋ', 'kʰ'), # kh - ('ㅌ', 'tʰ'), # th - ('ㅍ', 'pʰ'), # ph - ('ㅎ', 'h'), - ('ㄲ', 'k*'), # q - ('ㄸ', 't*'), # t - ('ㅃ', 'p*'), # p - ('ㅆ', 's*'), # 'ᄊ' is 's', 'ᆻ' is 't' - ('ㅉ', 'tɕ*'), # tɕ ? -]] - -_special_map = [(re.compile('%s' % x[0]), x[1]) for x in [ - ('ʃ', 'ɕ'), - ('tɕh', 'tɕʰ'), - ('kh', 'kʰ'), - ('th', 'tʰ'), - ('ph', 'pʰ'), -]] +_jamo_to_ipa = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("ㅏ", "ɐ"), + ("ㅑ", "jɐ"), + ("ㅓ", "ʌ"), + ("ㅕ", "jʌ"), + ("ㅗ", "o"), + ("ㅛ", "jo"), + ("ᅮ", "u"), + ("ㅠ", "ju"), + ("ᅳ", "ɯ"), + ("ㅣ", "i"), + ("ㅔ", "e"), + ("ㅐ", "ɛ"), + ("ㅖ", "je"), + ("ㅒ", "jɛ"), # lost + ("ㅚ", "we"), + ("ㅟ", "wi"), + ("ㅢ", "ɯj"), + ("ㅘ", "wɐ"), + ("ㅙ", "wɛ"), # lost + ("ㅝ", "wʌ"), + ("ㅞ", "wɛ"), # lost + ("ㄱ", "q"), # 'ɡ' or 'k' + ("ㄴ", "n"), + ("ㄷ", "t"), # d + ("ㄹ", "ɫ"), # 'ᄅ' is 'r', 'ᆯ' is 'ɫ' + ("ㅁ", "m"), + ("ㅂ", "p"), + ("ㅅ", "s"), # 'ᄉ'is 't', 'ᆺ'is 's' + ("ㅇ", "ŋ"), # 'ᄋ' is None, 'ᆼ' is 'ŋ' + ("ㅈ", "tɕ"), + ("ㅊ", "tɕʰ"), # tʃh + ("ㅋ", "kʰ"), # kh + ("ㅌ", "tʰ"), # th + ("ㅍ", "pʰ"), # ph + ("ㅎ", "h"), + ("ㄲ", "k*"), # q + ("ㄸ", "t*"), # t + ("ㅃ", "p*"), # p + ("ㅆ", "s*"), # 'ᄊ' is 's', 'ᆻ' is 't' + ("ㅉ", "tɕ*"), # tɕ ? + ] +] + +_special_map = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("ʃ", "ɕ"), + ("tɕh", "tɕʰ"), + ("kh", "kʰ"), + ("th", "tʰ"), + ("ph", "pʰ"), + ] +] + def normalize(text): text = text.strip() - text = re.sub("[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text) + text = re.sub( + "[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text + ) text = normalize_english(text) text = text.lower() return text + def normalize_english(text): def fn(m): word = m.group() @@ -112,6 +124,7 @@ def fn(m): text = re.sub("([A-Za-z]+)", fn, text) return text + # Convert jamo to IPA def jamo_to_ipa(text): res = "" @@ -121,25 +134,30 @@ def jamo_to_ipa(text): res += t return res + # special map def special_map(text): for regex, replacement in _special_map: text = re.sub(regex, replacement, text) return text + def korean_to_ipa(text): text = normalize(text) # espeak-ng from phonemizer import phonemize from phonemizer.separator import Separator - ipa = phonemize(text, - language="ko", - backend="espeak", - separator=Separator(phone=None, word=' ', syllable='|'), - strip=True, - preserve_punctuation=True, - njobs=4) + + ipa = phonemize( + text, + language="ko", + backend="espeak", + separator=Separator(phone=None, word=" ", syllable="|"), + strip=True, + preserve_punctuation=True, + njobs=4, + ) ipa = special_map(ipa) # # hangul charactier # g2p = G2p() diff --git a/models/tts/debatts/utils/g2p_new/mandarin.py b/models/tts/debatts/utils/g2p_new/mandarin.py index 075cddf7..9c6a229e 100644 --- a/models/tts/debatts/utils/g2p_new/mandarin.py +++ b/models/tts/debatts/utils/g2p_new/mandarin.py @@ -2,110 +2,117 @@ import jieba import cn2an -''' +""" Text clean time -''' +""" # List of (Latin alphabet, bopomofo) pairs: -_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ - ('a', 'ㄟˉ'), - ('b', 'ㄅㄧˋ'), - ('c', 'ㄙㄧˉ'), - ('d', 'ㄉㄧˋ'), - ('e', 'ㄧˋ'), - ('f', 'ㄝˊㄈㄨˋ'), - ('g', 'ㄐㄧˋ'), - ('h', 'ㄝˇㄑㄩˋ'), - ('i', 'ㄞˋ'), - ('j', 'ㄐㄟˋ'), - ('k', 'ㄎㄟˋ'), - ('l', 'ㄝˊㄛˋ'), - ('m', 'ㄝˊㄇㄨˋ'), - ('n', 'ㄣˉ'), - ('o', 'ㄡˉ'), - ('p', 'ㄆㄧˉ'), - ('q', 'ㄎㄧㄡˉ'), - ('r', 'ㄚˋ'), - ('s', 'ㄝˊㄙˋ'), - ('t', 'ㄊㄧˋ'), - ('u', 'ㄧㄡˉ'), - ('v', 'ㄨㄧˉ'), - ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), - ('x', 'ㄝˉㄎㄨˋㄙˋ'), - ('y', 'ㄨㄞˋ'), - ('z', 'ㄗㄟˋ') -]] +_latin_to_bopomofo = [ + (re.compile("%s" % x[0], re.IGNORECASE), x[1]) + for x in [ + ("a", "ㄟˉ"), + ("b", "ㄅㄧˋ"), + ("c", "ㄙㄧˉ"), + ("d", "ㄉㄧˋ"), + ("e", "ㄧˋ"), + ("f", "ㄝˊㄈㄨˋ"), + ("g", "ㄐㄧˋ"), + ("h", "ㄝˇㄑㄩˋ"), + ("i", "ㄞˋ"), + ("j", "ㄐㄟˋ"), + ("k", "ㄎㄟˋ"), + ("l", "ㄝˊㄛˋ"), + ("m", "ㄝˊㄇㄨˋ"), + ("n", "ㄣˉ"), + ("o", "ㄡˉ"), + ("p", "ㄆㄧˉ"), + ("q", "ㄎㄧㄡˉ"), + ("r", "ㄚˋ"), + ("s", "ㄝˊㄙˋ"), + ("t", "ㄊㄧˋ"), + ("u", "ㄧㄡˉ"), + ("v", "ㄨㄧˉ"), + ("w", "ㄉㄚˋㄅㄨˋㄌㄧㄡˋ"), + ("x", "ㄝˉㄎㄨˋㄙˋ"), + ("y", "ㄨㄞˋ"), + ("z", "ㄗㄟˋ"), + ] +] # List of (bopomofo, ipa) pairs: -_bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ - ('ㄅㄛ', 'p⁼wo'), - ('ㄆㄛ', 'pʰwo'), - ('ㄇㄛ', 'mwo'), - ('ㄈㄛ', 'fwo'), - ('ㄧㄢ', '|jɛn'), - ('ㄩㄢ', '|ɥæn'), - ('ㄧㄣ', '|in'), - ('ㄩㄣ', '|ɥn'), - ('ㄧㄥ', '|iŋ'), - ('ㄨㄥ', '|ʊŋ'), - ('ㄩㄥ', '|jʊŋ'), - # Add - ('ㄧㄚ', '|ia'), - ('ㄧㄝ', '|iɛ'), - ('ㄧㄠ', '|iɑʊ'), - ('ㄧㄡ', '|ioʊ'), - ('ㄧㄤ', '|iɑŋ'), - ('ㄨㄚ', '|ua'), - ('ㄨㄛ', '|uo'), - ('ㄨㄞ', '|uaɪ'), - ('ㄨㄟ', '|ueɪ'), - ('ㄨㄢ', '|uan'), - ('ㄨㄣ', '|uən'), - ('ㄨㄤ', '|uɑŋ'), - ('ㄩㄝ', '|ɥɛ'), - # End - ('ㄅ', 'p⁼'), - ('ㄆ', 'pʰ'), - ('ㄇ', 'm'), - ('ㄈ', 'f'), - ('ㄉ', 't⁼'), - ('ㄊ', 'tʰ'), - ('ㄋ', 'n'), - ('ㄌ', 'l'), - ('ㄍ', 'k⁼'), - ('ㄎ', 'kʰ'), - ('ㄏ', 'x'), - ('ㄐ', 'tʃ⁼'), - ('ㄑ', 'tʃʰ'), - ('ㄒ', 'ʃ'), - ('ㄓ', 'ts`⁼'), - ('ㄔ', 'ts`ʰ'), - ('ㄕ', 's`'), - ('ㄖ', 'ɹ`'), - ('ㄗ', 'ts⁼'), - ('ㄘ', 'tsʰ'), - ('ㄙ', '|s'), - ('ㄚ', '|a'), - ('ㄛ', '|o'), - ('ㄜ', '|ə'), - ('ㄝ', '|ɛ'), - ('ㄞ', '|aɪ'), - ('ㄟ', '|eɪ'), - ('ㄠ', '|ɑʊ'), - ('ㄡ', '|oʊ'), - ('ㄢ', '|an'), - ('ㄣ', '|ən'), - ('ㄤ', '|ɑŋ'), - ('ㄥ', '|əŋ'), - ('ㄦ', 'əɹ'), - ('ㄧ', '|i'), - ('ㄨ', '|u'), - ('ㄩ', '|ɥ'), - ('ˉ', '→|'), - ('ˊ', '↑|'), - ('ˇ', '↓↑|'), - ('ˋ', '↓|'), - ('˙', '|'), -]] +_bopomofo_to_ipa = [ + (re.compile("%s" % x[0]), x[1]) + for x in [ + ("ㄅㄛ", "p⁼wo"), + ("ㄆㄛ", "pʰwo"), + ("ㄇㄛ", "mwo"), + ("ㄈㄛ", "fwo"), + ("ㄧㄢ", "|jɛn"), + ("ㄩㄢ", "|ɥæn"), + ("ㄧㄣ", "|in"), + ("ㄩㄣ", "|ɥn"), + ("ㄧㄥ", "|iŋ"), + ("ㄨㄥ", "|ʊŋ"), + ("ㄩㄥ", "|jʊŋ"), + # Add + ("ㄧㄚ", "|ia"), + ("ㄧㄝ", "|iɛ"), + ("ㄧㄠ", "|iɑʊ"), + ("ㄧㄡ", "|ioʊ"), + ("ㄧㄤ", "|iɑŋ"), + ("ㄨㄚ", "|ua"), + ("ㄨㄛ", "|uo"), + ("ㄨㄞ", "|uaɪ"), + ("ㄨㄟ", "|ueɪ"), + ("ㄨㄢ", "|uan"), + ("ㄨㄣ", "|uən"), + ("ㄨㄤ", "|uɑŋ"), + ("ㄩㄝ", "|ɥɛ"), + # End + ("ㄅ", "p⁼"), + ("ㄆ", "pʰ"), + ("ㄇ", "m"), + ("ㄈ", "f"), + ("ㄉ", "t⁼"), + ("ㄊ", "tʰ"), + ("ㄋ", "n"), + ("ㄌ", "l"), + ("ㄍ", "k⁼"), + ("ㄎ", "kʰ"), + ("ㄏ", "x"), + ("ㄐ", "tʃ⁼"), + ("ㄑ", "tʃʰ"), + ("ㄒ", "ʃ"), + ("ㄓ", "ts`⁼"), + ("ㄔ", "ts`ʰ"), + ("ㄕ", "s`"), + ("ㄖ", "ɹ`"), + ("ㄗ", "ts⁼"), + ("ㄘ", "tsʰ"), + ("ㄙ", "|s"), + ("ㄚ", "|a"), + ("ㄛ", "|o"), + ("ㄜ", "|ə"), + ("ㄝ", "|ɛ"), + ("ㄞ", "|aɪ"), + ("ㄟ", "|eɪ"), + ("ㄠ", "|ɑʊ"), + ("ㄡ", "|oʊ"), + ("ㄢ", "|an"), + ("ㄣ", "|ən"), + ("ㄤ", "|ɑŋ"), + ("ㄥ", "|əŋ"), + ("ㄦ", "əɹ"), + ("ㄧ", "|i"), + ("ㄨ", "|u"), + ("ㄩ", "|ɥ"), + ("ˉ", "→|"), + ("ˊ", "↑|"), + ("ˇ", "↓↑|"), + ("ˋ", "↓|"), + ("˙", "|"), + ] +] + # Convert numbers to Chinese pronunciation def number_to_chinese(text): @@ -115,6 +122,7 @@ def number_to_chinese(text): text = cn2an.transform(text, "an2cn") return text + def normalization(text): text = text.replace(",", ",") text = text.replace("。", ".") @@ -130,39 +138,44 @@ def normalization(text): text = text.replace("・・・", "…") text = text.replace("...", "…") text = re.sub(r"\s+", "", text) - text = re.sub(r'[^\u4e00-\u9fff\s_,\.\?!;:\'…]', '', text) - text = re.sub(r'\s*([,\.\?!;:\'…])\s*', r'\1', text) + text = re.sub(r"[^\u4e00-\u9fff\s_,\.\?!;:\'…]", "", text) + text = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", text) return text + # Word Segmentation, and convert Chinese pronunciation to pinyin (bopomofo) def chinese_to_bopomofo(text): from pypinyin import lazy_pinyin, BOPOMOFO + words = jieba.lcut(text, cut_all=False) - text = '' + text = "" for word in words: bopomofos = lazy_pinyin(word, BOPOMOFO) - if not re.search('[\u4e00-\u9fff]', word): + if not re.search("[\u4e00-\u9fff]", word): text += word continue for i in range(len(bopomofos)): - bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i]) - if text != '': - text += '|' - text += '|'.join(bopomofos) + bopomofos[i] = re.sub(r"([\u3105-\u3129])$", r"\1ˉ", bopomofos[i]) + if text != "": + text += "|" + text += "|".join(bopomofos) return text + # Convert latin pronunciation to pinyin (bopomofo) def latin_to_bopomofo(text): for regex, replacement in _latin_to_bopomofo: text = re.sub(regex, replacement, text) return text + # Convert pinyin (bopomofo) to IPA def bopomofo_to_ipa(text): for regex, replacement in _bopomofo_to_ipa: text = re.sub(regex, replacement, text) return text + def _chinese_to_ipa(text): text = number_to_chinese(text.strip()) text = normalization(text) @@ -170,15 +183,15 @@ def _chinese_to_ipa(text): text = chinese_to_bopomofo(text) text = latin_to_bopomofo(text) text = bopomofo_to_ipa(text) - text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', - r'\1ɹ\2', text) - text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) - text = re.sub(r'^\||[^\w\s_,\.\?!;:\'…\|→↓↑⁼ʰ`]', '', text) - text = re.sub(r'([,\.\?!;:\'…])', r'|\1|', text) - text = re.sub(r'\|+', '|', text) - text = text.rstrip('|') + text = re.sub("([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text) + text = re.sub("([s][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text) + text = re.sub(r"^\||[^\w\s_,\.\?!;:\'…\|→↓↑⁼ʰ`]", "", text) + text = re.sub(r"([,\.\?!;:\'…])", r"|\1|", text) + text = re.sub(r"\|+", "|", text) + text = text.rstrip("|") return text + # Convert Chinese to IPA def chinese_to_ipa(text, text_tokenizer): # phonemes = text_tokenizer(text.strip()) diff --git a/models/tts/debatts/utils/g2p_new/text_tokenizers.py b/models/tts/debatts/utils/g2p_new/text_tokenizers.py index 943632df..4bdd5103 100644 --- a/models/tts/debatts/utils/g2p_new/text_tokenizers.py +++ b/models/tts/debatts/utils/g2p_new/text_tokenizers.py @@ -9,7 +9,6 @@ from phonemizer.separator import Separator - class TextTokenizer: """Phonemize Text.""" @@ -34,7 +33,7 @@ def __init__( language_switch=language_switch, words_mismatch=words_mismatch, ) - + self.separator = separator # convert chinese punctuation to english punctuation @@ -60,8 +59,8 @@ def __call__(self, text, strip=True) -> List[str]: normalized_text = [] for line in str2list(text): line = self.convert_chinese_punctuation(line.strip()) - line = re.sub(r'[^\w\s_,\.\?!;:\'…]', '', line) - line = re.sub(r'\s*([,\.\?!;:\'…])\s*', r'\1', line) + line = re.sub(r"[^\w\s_,\.\?!;:\'…]", "", line) + line = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", line) line = re.sub(r"\s+", " ", line) normalized_text.append(line) # print("Normalized test: ", normalized_text[0]) @@ -69,12 +68,12 @@ def __call__(self, text, strip=True) -> List[str]: normalized_text, separator=self.separator, strip=strip, njobs=1 ) if text_type == str: - phonemized = re.sub(r'([,\.\?!;:\'…])', r'|\1|', list2str(phonemized)) - phonemized = re.sub(r'\|+', '|', phonemized) - phonemized = phonemized.rstrip('|') + phonemized = re.sub(r"([,\.\?!;:\'…])", r"|\1|", list2str(phonemized)) + phonemized = re.sub(r"\|+", "|", phonemized) + phonemized = phonemized.rstrip("|") else: for i in range(len(phonemized)): - phonemized[i] = re.sub(r'([,\.\?!;:\'…])', r'|\1|', phonemized[i]) - phonemized[i] = re.sub(r'\|+', '|', phonemized[i]) - phonemized[i] = phonemized[i].rstrip('|') - return phonemized \ No newline at end of file + phonemized[i] = re.sub(r"([,\.\?!;:\'…])", r"|\1|", phonemized[i]) + phonemized[i] = re.sub(r"\|+", "|", phonemized[i]) + phonemized[i] = phonemized[i].rstrip("|") + return phonemized diff --git a/models/tts/debatts/utils/topk_sampling.py b/models/tts/debatts/utils/topk_sampling.py index 0a15fd33..58c36036 100644 --- a/models/tts/debatts/utils/topk_sampling.py +++ b/models/tts/debatts/utils/topk_sampling.py @@ -8,15 +8,15 @@ import torch.nn.functional as F -def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): - """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (vocabulary size) - top_k >0: keep only top k tokens with highest probability (top-k filtering). - top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - - Basic outline taken from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 +def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (vocabulary size) + top_k >0: keep only top k tokens with highest probability (top-k filtering). + top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + + Basic outline taken from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ assert logits.dim() == 2 # [BATCH_SIZE, VOCAB_SIZE] top_k = min(top_k, logits.size(-1)) # Safety check @@ -24,21 +24,21 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k, dim=1)[0][..., -1, None] logits[indices_to_remove] = filter_value - + sorted_logits, sorted_indices = torch.sort(logits, descending=True) - + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 - + # Replace logits to be removed with -inf in the sorted_logits sorted_logits[sorted_indices_to_remove] = filter_value # Then reverse the sorting process by mapping back sorted_logits to their original position logits = torch.gather(sorted_logits, 1, sorted_indices.argsort(-1)) - + # pred_token = torch.multinomial(F.softmax(logits, -1), 1) # [BATCH_SIZE, 1] return logits From ee1cdc197f2353bf5e73a14807673629d23db847 Mon Sep 17 00:00:00 2001 From: hehaorui Date: Mon, 28 Oct 2024 21:06:04 +0800 Subject: [PATCH 6/8] adding README and environment.yml, also modified the comments --- models/tts/debatts/README.md | 13 + models/tts/debatts/environment.yml | 337 ++++++++++++++++++ models/tts/debatts/t2s_model_new.py | 2 +- models/tts/debatts/t2s_sft_dataset_new.py | 1 - models/tts/debatts/utils/__init__.py | 4 + models/tts/debatts/utils/g2p/__init__.py | 7 +- models/tts/debatts/utils/g2p/cleaners.py | 5 + models/tts/debatts/utils/g2p_new/__init__.py | 10 +- models/tts/debatts/utils/g2p_new/cleaners.py | 27 +- models/tts/debatts/utils/g2p_new/english.py | 193 ---------- models/tts/debatts/utils/g2p_new/french.py | 178 --------- models/tts/debatts/utils/g2p_new/g2p_new.py | 5 + models/tts/debatts/utils/g2p_new/german.py | 122 ------- models/tts/debatts/utils/g2p_new/japanese.py | 175 --------- models/tts/debatts/utils/g2p_new/korean.py | 167 --------- models/tts/debatts/utils/g2p_new/mandarin.py | 5 + .../debatts/utils/g2p_new/text_tokenizers.py | 5 + models/tts/debatts/utils/logger.py | 1 - models/tts/debatts/utils/tool.py | 89 ----- 19 files changed, 396 insertions(+), 950 deletions(-) create mode 100644 models/tts/debatts/README.md create mode 100644 models/tts/debatts/environment.yml delete mode 100644 models/tts/debatts/utils/g2p_new/english.py delete mode 100644 models/tts/debatts/utils/g2p_new/french.py delete mode 100644 models/tts/debatts/utils/g2p_new/german.py delete mode 100644 models/tts/debatts/utils/g2p_new/japanese.py delete mode 100644 models/tts/debatts/utils/g2p_new/korean.py delete mode 100644 models/tts/debatts/utils/tool.py diff --git a/models/tts/debatts/README.md b/models/tts/debatts/README.md new file mode 100644 index 00000000..9f3fcb55 --- /dev/null +++ b/models/tts/debatts/README.md @@ -0,0 +1,13 @@ +# Debatts - Mandarin Debate TTS Model + +## Introduction +Debatts is an advanced text-to-speech (TTS) model specifically designed for Mandarin debate contexts. This innovative model leverages short audio prompts to learn and replicate speaker characteristics while dynamically adjusting speaking style by analyzing the audio of debate opponents. This capability allows Debatts to integrate seamlessly into debate scenarios, offering not just speech synthesis but a responsive adaptation to the changing dynamics of debate interactions. + +## Environment Setup +To set up the necessary environment to run Debatts, please use the provided `environment.yml` file. This file contains all the required dependencies and can be easily set up with the following Conda command: + +```bash +conda env create -f environment.yml + +## Continuous Updates +The Debatts project is actively being developed, with continuous updates aimed at enhancing model performance and expanding features. We encourage users to regularly check our repository for the latest updates and improvements to ensure optimal functionality and to take advantage of new capabilities as they become available. \ No newline at end of file diff --git a/models/tts/debatts/environment.yml b/models/tts/debatts/environment.yml new file mode 100644 index 00000000..4d358104 --- /dev/null +++ b/models/tts/debatts/environment.yml @@ -0,0 +1,337 @@ +name: debatts +channels: + - pytorch + - nvidia + - https://repo.anaconda.com/pkgs/main + - conda-forge + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2 + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - asttokens=2.4.1=pyhd8ed1ab_0 + - backcall=0.2.0=pyh9f0ad1d_0 + - blas=1.0=mkl + - brotli-python=1.0.9=py39h6a678d5_8 + - bzip2=1.0.8=h7f98852_4 + - ca-certificates=2024.7.4=hbcca054_0 + - certifi=2024.7.4=pyhd8ed1ab_0 + - charset-normalizer=3.3.2=pyhd8ed1ab_0 + - comm=0.2.2=pyhd8ed1ab_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.5.39=0 + - cuda-runtime=12.1.0=0 + - cuda-version=12.5=3 + - debugpy=1.6.7=py39h6a678d5_0 + - decorator=5.1.1=pyhd8ed1ab_0 + - entrypoints=0.4=pyhd8ed1ab_0 + - executing=2.0.1=pyhd8ed1ab_0 + - ffmpeg=4.2.2=h20bf706_0 + - filelock=3.15.4=pyhd8ed1ab_0 + - freetype=2.10.4=h0708190_1 + - gmp=6.1.2=hf484d3e_1000 + - gmpy2=2.1.2=py39heeb90bb_0 + - gnutls=3.6.15=he1e5248_0 + - idna=3.7=py39h06a4308_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - ipykernel=6.29.4=pyh3099207_0 + - ipython=8.12.0=pyh41d4057_0 + - jedi=0.19.1=pyhd8ed1ab_0 + - jinja2=3.1.4=py39h06a4308_0 + - jpeg=9e=h5eee18b_1 + - jupyter_client=7.3.4=pyhd8ed1ab_0 + - jupyter_core=5.7.2=py39hf3d152e_0 + - lame=3.100=h7f98852_1001 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libcublas=12.1.0.26=0 + - libcufft=11.0.2.4=0 + - libcufile=1.10.0.4=0 + - libcurand=10.3.6.39=0 + - libcusolver=11.4.4.55=0 + - libcusparse=12.0.2.55=0 + - libdeflate=1.17=h5eee18b_1 + - libffi=3.4.4=h6a678d5_1 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libidn2=2.3.4=h5eee18b_0 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - libnpp=12.0.2.50=0 + - libnvjitlink=12.1.105=0 + - libnvjpeg=12.1.1.14=0 + - libopus=1.3.1=h7f98852_1 + - libpng=1.6.39=h5eee18b_0 + - libsodium=1.0.18=h36c2ea0_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libunistring=0.9.10=h7f98852_0 + - libvpx=1.7.0=h439df22_0 + - libwebp-base=1.3.2=h5eee18b_0 + - llvm-openmp=14.0.6=h9e868ea_0 + - lz4-c=1.9.4=h6a678d5_1 + - markupsafe=2.1.1=py39hb9d737c_1 + - matplotlib-inline=0.1.7=pyhd8ed1ab_0 + - mkl=2023.1.0=h213fc3f_46344 + - mkl-service=2.4.0=py39h5eee18b_1 + - mkl_fft=1.3.8=py39h5eee18b_0 + - mkl_random=1.2.4=py39hdb19cb5_0 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - mpmath=1.3.0=py39h06a4308_0 + - ncurses=6.4=h6a678d5_0 + - nest-asyncio=1.6.0=pyhd8ed1ab_0 + - nettle=3.7.3=hbbd107a_1 + - networkx=3.2.1=py39h06a4308_0 + - numpy-base=1.26.4=py39hb5e798b_0 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.4.0=h3ad879b_0 + - openssl=1.1.1w=h7f8727e_0 + - packaging=24.1=pyhd8ed1ab_0 + - parso=0.8.4=pyhd8ed1ab_0 + - pexpect=4.9.0=pyhd8ed1ab_0 + - pickleshare=0.7.5=py_1003 + - pillow=10.3.0=py39h5eee18b_0 + - platformdirs=4.2.2=pyhd8ed1ab_0 + - prompt-toolkit=3.0.47=pyha770c72_0 + - prompt_toolkit=3.0.47=hd8ed1ab_0 + - psutil=5.9.1=py39hb9d737c_0 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pure_eval=0.2.2=pyhd8ed1ab_0 + - pygments=2.18.0=pyhd8ed1ab_0 + - pysocks=1.7.1=py39h06a4308_0 + - python=3.9.15=h7a1cb2a_2 + - python-dateutil=2.9.0=pyhd8ed1ab_0 + - python_abi=3.9=2_cp39 + - pytorch=2.3.1=py3.9_cuda12.1_cudnn8.9.2_0 + - pytorch-cuda=12.1=ha16c6d3_5 + - pytorch-mutex=1.0=cuda + - pyyaml=6.0.1=py39h5eee18b_0 + - pyzmq=25.1.2=py39h6a678d5_0 + - readline=8.2=h5eee18b_0 + - requests=2.32.3=pyhd8ed1ab_0 + - six=1.16.0=pyh6c4a22f_0 + - sqlite=3.45.3=h5eee18b_0 + - stack_data=0.6.2=pyhd8ed1ab_0 + - sympy=1.12.1=pyh04b8f61_3 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.14=h39e8969_0 + - torchtriton=2.3.1=py39 + - torchvision=0.18.1=py39_cu121 + - tornado=6.1=py39hb9d737c_3 + - traitlets=5.14.3=pyhd8ed1ab_0 + - typing_extensions=4.11.0=py39h06a4308_0 + - urllib3=2.2.2=py39h06a4308_0 + - wcwidth=0.2.13=pyhd8ed1ab_0 + - wheel=0.43.0=py39h06a4308_0 + - x264=1!157.20191217=h7b6447c_0 + - xz=5.4.6=h5eee18b_1 + - yaml=0.2.5=h7b6447c_0 + - zeromq=4.3.5=h6a678d5_0 + - zlib=1.2.13=h5eee18b_1 + - zstd=1.5.5=hc292b87_2 + - pip: + - absl-py==2.1.0 + - accelerate==0.24.1 + - aiohttp==3.9.5 + - aiosignal==1.3.1 + - aliyun-python-sdk-core==2.15.1 + - aliyun-python-sdk-kms==2.16.3 + - antlr4-python3-runtime==4.9.3 + - argparse==1.4.0 + - asteroid==0.7.0 + - asteroid-filterbanks==0.4.0 + - async-timeout==4.0.3 + - attrs==23.2.0 + - audiomentations==0.36.0 + - babel==2.15.0 + - bitarray==2.9.2 + - black==24.1.1 + - braceexpand==0.1.7 + - cached-property==1.5.2 + - cffi==1.16.0 + - click==8.1.7 + - cn2an==0.5.22 + - colorama==0.4.6 + - coloredlogs==15.0.1 + - contourpy==1.2.1 + - crcmod==1.7 + - cryptography==43.0.0 + - cycler==0.12.1 + - cython==3.0.10 + - cytoolz==0.12.3 + - datasets==2.20.0 + - diffsptk==2.1.0 + - diffusers==0.29.2 + - dill==0.3.8 + - distance==0.1.3 + - docker-pycreds==0.4.0 + - easydict==1.13 + - editdistance==0.8.1 + - einops==0.8.0 + - encodec==0.1.1 + - evaluate==0.4.2 + - fairseq==0.12.2 + - fastdtw==0.3.4 + - ffmpeg-python==0.2.0 + - flatbuffers==24.3.25 + - fonttools==4.53.1 + - frechet-audio-distance==0.3.1 + - frozenlist==1.4.1 + - fsspec==2024.5.0 + - ftfy==6.2.0 + - funasr==1.1.4 + - future==1.0.0 + - g2p-en==2.1.0 + - gitdb==4.0.11 + - gitpython==3.1.43 + - grpcio==1.64.1 + - h5py==3.11.0 + - huggingface-hub==0.23.4 + - humanfriendly==10.0 + - hydra-core==1.3.2 + - importlib-metadata==8.0.0 + - importlib-resources==6.4.0 + - inflect==7.3.1 + - intervaltree==3.1.0 + - jaconv==0.4.0 + - jamo==0.4.1 + - jieba==0.42.1 + - jiwer==3.0.4 + - jmespath==0.10.0 + - joblib==1.4.2 + - json5==0.9.25 + - jsonschema==4.22.0 + - jsonschema-specifications==2023.12.1 + - julius==0.2.7 + - kaldiio==2.18.0 + - kiwisolver==1.4.5 + - laion-clap==1.1.2 + - lazy-loader==0.4 + - lhotse==1.25.0.dev0+git.da4d70d.clean + - librosa==0.10.2.post1 + - lightning-utilities==0.11.3.post0 + - lilcom==1.8.0 + - llvmlite==0.43.0 + - loguru==0.7.2 + - lxml==5.2.2 + - markdown==3.6 + - matplotlib==3.9.1 + - mir-eval==0.7 + - modelscope==1.17.1 + - modules==1.0.0 + - more-itertools==10.3.0 + - msgpack==1.0.8 + - multidict==6.0.5 + - multiprocess==0.70.16 + - mypy-extensions==1.0.0 + - nltk==3.8.1 + - nnaudio==0.3.3 + - noisereduce==3.0.2 + - npy-append-array==0.9.16 + - numba==0.60.0 + - numpy==1.23.4 + - omegaconf==2.3.0 + - onnxruntime==1.19.0 + - openai-whisper==20231117 + - oss2==2.18.6 + - pandas==2.2.2 + - pathspec==0.12.1 + - pb-bss-eval==0.0.2 + - pedalboard==0.9.9 + - pesq==0.0.4 + - pip==24.2 + - pooch==1.8.2 + - portalocker==2.10.0 + - praat-parselmouth==0.4.3 + - proces==0.1.7 + - progressbar==2.5 + - protobuf==4.25.3 + - ptwt==0.1.9 + - pyarrow==16.1.0 + - pyarrow-hotfix==0.6 + - pycparser==2.22 + - pycryptodome==3.20.0 + - pydub==0.25.1 + - pymcd==0.2.1 + - pynndescent==0.5.13 + - pyparsing==3.1.2 + - pypesq==1.2.4 + - pypinyin==0.48.0 + - pysptk==1.0.1 + - pystoi==0.4.1 + - pytorch-lightning==2.3.2 + - pytorch-ranger==0.1.1 + - pytorch-wpe==0.0.1 + - pytz==2024.1 + - pywavelets==1.6.0 + - pyworld==0.3.4 + - rapidfuzz==3.9.6 + - referencing==0.35.1 + - regex==2024.5.15 + - resampy==0.4.3 + - resemblyzer==0.1.4 + - rir-generator==0.2.0 + - rpds-py==0.18.1 + - ruamel-yaml==0.18.6 + - ruamel-yaml-clib==0.2.8 + - sacrebleu==2.4.2 + - safetensors==0.4.3 + - scikit-learn==1.5.1 + - scipy==1.10.1 + - semantic-version==2.10.0 + - sentencepiece==0.2.0 + - sentry-sdk==2.8.0 + - setproctitle==1.3.3 + - setuptools==70.3.0 + - setuptools-rust==1.9.0 + - smmap==5.0.1 + - sortedcontainers==2.4.0 + - soundfile==0.12.1 + - soxr==0.3.7 + - tabulate==0.9.0 + - tensorboard==2.17.0 + - tensorboard-data-server==0.7.2 + - tensorboardx==2.6.2.2 + - tgt==1.5 + - threadpoolctl==3.5.0 + - tiktoken==0.7.0 + - timm==1.0.8 + - tokenizers==0.19.1 + - tomli==2.0.1 + - toolz==0.12.1 + - torch-complex==0.4.4 + - torch-optimizer==0.1.0 + - torch-stoi==0.2.1 + - torchaudio==2.3.1 + - torchcomp==0.1.1 + - torchcrepe==0.0.23 + - torchlibrosa==0.1.0 + - torchlpc==0.4 + - torchmetrics==0.11.4 + - tqdm==4.66.4 + - transformers==4.44.0 + - trash-cli==0.24.5.26 + - typeguard==4.3.0 + - typing==3.7.4.3 + - tzdata==2024.1 + - umap-learn==0.5.6 + - unidecode==1.3.8 + - vector-quantize-pytorch==1.12.5 + - wandb==0.17.4 + - webdataset==0.2.86 + - webrtcvad==2.0.10 + - werkzeug==3.0.3 + - wget==3.2 + - xxhash==3.4.1 + - yarl==1.9.4 + - zhconv==1.4.3 + - zhon==2.0.2 + - zipp==3.19.2 diff --git a/models/tts/debatts/t2s_model_new.py b/models/tts/debatts/t2s_model_new.py index f00ecaf7..279d0a5a 100644 --- a/models/tts/debatts/t2s_model_new.py +++ b/models/tts/debatts/t2s_model_new.py @@ -310,7 +310,7 @@ def add_phone_middle_label( @torch.no_grad() def sample_hf( self, - phone_ids, # the phones of prompt and target should be concatenated together。在实际使用中,phone_ids是文本的token输入 + phone_ids, # the phones of prompt and target should be concatenated together prompt_ids, prompt0_ids=None, max_length=100000, diff --git a/models/tts/debatts/t2s_sft_dataset_new.py b/models/tts/debatts/t2s_sft_dataset_new.py index 11328e7c..7c5abcba 100644 --- a/models/tts/debatts/t2s_sft_dataset_new.py +++ b/models/tts/debatts/t2s_sft_dataset_new.py @@ -28,7 +28,6 @@ import sys sys.path.append("./models/tts/debatts") -from utils.g2p_new.g2p import phonemizer_g2p from utils.g2p_new.g2p_new import new_g2p from torch.nn.utils.rnn import pad_sequence diff --git a/models/tts/debatts/utils/__init__.py b/models/tts/debatts/utils/__init__.py index e69de29b..5cc79194 100644 --- a/models/tts/debatts/utils/__init__.py +++ b/models/tts/debatts/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/models/tts/debatts/utils/g2p/__init__.py b/models/tts/debatts/utils/g2p/__init__.py index 3a2b3988..915c81e5 100644 --- a/models/tts/debatts/utils/g2p/__init__.py +++ b/models/tts/debatts/utils/g2p/__init__.py @@ -1,4 +1,9 @@ -""" from https://github.com/keithito/tacotron """ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# from https://github.com/keithito/tacotron import sys import utils.g2p.cleaners diff --git a/models/tts/debatts/utils/g2p/cleaners.py b/models/tts/debatts/utils/g2p/cleaners.py index 7d96e84a..8edc06ac 100644 --- a/models/tts/debatts/utils/g2p/cleaners.py +++ b/models/tts/debatts/utils/g2p/cleaners.py @@ -1,3 +1,8 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import re from utils.g2p.japanese import japanese_to_ipa from utils.g2p.mandarin import chinese_to_ipa diff --git a/models/tts/debatts/utils/g2p_new/__init__.py b/models/tts/debatts/utils/g2p_new/__init__.py index 80b61797..fd6f360e 100644 --- a/models/tts/debatts/utils/g2p_new/__init__.py +++ b/models/tts/debatts/utils/g2p_new/__init__.py @@ -1,3 +1,8 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + from utils.g2p_new import cleaners from tokenizers import Tokenizer from utils.g2p_new.text_tokenizers import TextTokenizer @@ -18,8 +23,7 @@ def __init__(self, vacab_path="./utils/g2p_new/vacab.json"): } self.text_tokenizers = {} self.int_text_tokenizers() - # TODO - vacab_path = "/mntcephfs/lab_data/lijiaqi/Speech/utils/g2p_new/vacab.json" + vacab_path = "./g2p_new/vacab.json" with open(vacab_path, "rb") as f: json_data = f.read() data = json.loads(json_data) @@ -55,7 +59,7 @@ def _clean_text(self, text, language, cleaner_names): return text def phoneme2token(self, phonemes): - # 使用的是国际音标,可以将音素转化成token。实际上输入的phone id也是将音频先asr成文本再转成token的,使用的是同一套vocab体系 + # converts phonemes into tokens. In fact, the input phone id is also the first asr audio into text and then converted into token, using the same set of vocab system tokens = [] if isinstance(phonemes, list): for phone in phonemes: diff --git a/models/tts/debatts/utils/g2p_new/cleaners.py b/models/tts/debatts/utils/g2p_new/cleaners.py index 4a9509a4..110851c4 100644 --- a/models/tts/debatts/utils/g2p_new/cleaners.py +++ b/models/tts/debatts/utils/g2p_new/cleaners.py @@ -1,26 +1,15 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import re -from utils.g2p_new.japanese import japanese_to_ipa from utils.g2p_new.mandarin import chinese_to_ipa -from utils.g2p_new.english import english_to_ipa -from utils.g2p_new.french import french_to_ipa -from utils.g2p_new.korean import korean_to_ipa -from utils.g2p_new.german import german_to_ipa - def cjekfd_cleaners(text, language, text_tokenizers): - if language == "zh": - return chinese_to_ipa(text, text_tokenizers["zh"]) - elif language == "ja": - return japanese_to_ipa(text, text_tokenizers["ja"]) - elif language == "en": - return english_to_ipa(text, text_tokenizers["en"]) - elif language == "fr": - return french_to_ipa(text, text_tokenizers["fr"]) - elif language == "ko": - return korean_to_ipa(text, text_tokenizers["ko"]) - elif language == "de": - return german_to_ipa(text, text_tokenizers["de"]) + if language == 'zh': + return chinese_to_ipa(text, text_tokenizers['zh']) else: - raise Exception("Unknown language: %s" % language) + raise Exception('Unknown or Not supported yet language: %s' % language) return None diff --git a/models/tts/debatts/utils/g2p_new/english.py b/models/tts/debatts/utils/g2p_new/english.py deleted file mode 100644 index 64442c4c..00000000 --- a/models/tts/debatts/utils/g2p_new/english.py +++ /dev/null @@ -1,193 +0,0 @@ -import re -from unidecode import unidecode -import inflect - -""" - Text clean time -""" -_inflect = inflect.engine() -_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") -_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") -_percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)") -_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") -_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") -_fraction_re = re.compile(r"([0-9]+)/([0-9]+)") -_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") -_number_re = re.compile(r"[0-9]+") - -# List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [ - (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1]) - for x in [ - ("mrs", "misess"), - ("mr", "mister"), - ("dr", "doctor"), - ("st", "saint"), - ("co", "company"), - ("jr", "junior"), - ("maj", "major"), - ("gen", "general"), - ("drs", "doctors"), - ("rev", "reverend"), - ("lt", "lieutenant"), - ("hon", "honorable"), - ("sgt", "sergeant"), - ("capt", "captain"), - ("esq", "esquire"), - ("ltd", "limited"), - ("col", "colonel"), - ("ft", "fort"), - ("etc", "et cetera"), - ("btw", "by the way"), - ] -] - -_special_map = [ - ("t|ɹ", "tɹ"), - ("d|ɹ", "dɹ"), - ("t|s", "ts"), - ("d|z", "dz"), - ("ɪ|ɹ", "ɪɹ"), - ("ɐ", "ɚ"), - ("ᵻ", "ɪ"), - ("əl", "l"), - ("x", "k"), - ("ɬ", "l"), - ("ʔ", "t"), - ("n̩", "n"), - ("oː|ɹ", "oːɹ"), -] - - -def expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text - - -def _remove_commas(m): - return m.group(1).replace(",", "") - - -def _expand_decimal_point(m): - return m.group(1).replace(".", " point ") - - -def _expand_percent(m): - return m.group(1).replace("%", " percent ") - - -def _expand_dollars(m): - match = m.group(1) - parts = match.split(".") - if len(parts) > 2: - return match + " dollars" # Unexpected format - dollars = int(parts[0]) if parts[0] else 0 - cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 - if dollars and cents: - dollar_unit = "dollar" if dollars == 1 else "dollars" - cent_unit = "cent" if cents == 1 else "cents" - return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) - elif dollars: - dollar_unit = "dollar" if dollars == 1 else "dollars" - return "%s %s" % (dollars, dollar_unit) - elif cents: - cent_unit = "cent" if cents == 1 else "cents" - return "%s %s" % (cents, cent_unit) - else: - return "zero dollars" - - -def fraction_to_words(numerator, denominator): - if numerator == 1 and denominator == 2: - return "one half" - if numerator == 1 and denominator == 4: - return "one quarter" - if denominator == 2: - return _inflect.number_to_words(numerator) + " halves" - if denominator == 4: - return _inflect.number_to_words(numerator) + " quarters" - return ( - _inflect.number_to_words(numerator) - + " " - + _inflect.ordinal(_inflect.number_to_words(denominator)) - ) - - -def _expand_fraction(m): - numerator = int(m.group(1)) - denominator = int(m.group(2)) - return fraction_to_words(numerator, denominator) - - -def _expand_ordinal(m): - return _inflect.number_to_words(m.group(0)) - - -def _expand_number(m): - num = int(m.group(0)) - if num > 1000 and num < 3000: - if num == 2000: - return " two thousand " - elif num > 2000 and num < 2010: - return " two thousand " + _inflect.number_to_words(num % 100) + " " - elif num % 100 == 0: - return " " + _inflect.number_to_words(num // 100) + " hundred " - else: - return ( - " " - + _inflect.number_to_words(num, andword="", zero="oh", group=2).replace( - ", ", " " - ) - + " " - ) - else: - return " " + _inflect.number_to_words(num, andword="") + " " - - -# Normalize numbers pronunciation -def normalize_numbers(text): - text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_pounds_re, r"\1 pounds", text) - text = re.sub(_dollars_re, _expand_dollars, text) - text = re.sub(_fraction_re, _expand_fraction, text) - text = re.sub(_decimal_number_re, _expand_decimal_point, text) - text = re.sub(_percent_number_re, _expand_percent, text) - text = re.sub(_ordinal_re, _expand_ordinal, text) - text = re.sub(_number_re, _expand_number, text) - return text - - -def _english_to_ipa(text): - # text = unidecode(text).lower() - text = expand_abbreviations(text) - text = normalize_numbers(text) - return text - - -# special map -def special_map(text): - for regex, replacement in _special_map: - regex = regex.replace("|", "\|") - while re.search(r"(^|[_|]){}([_|]|$)".format(regex), text): - text = re.sub( - r"(^|[_|]){}([_|]|$)".format(regex), r"\1{}\2".format(replacement), text - ) - # text = re.sub(r'([,.!?])', r'|\1', text) - return text - - -# Add some special operation -def english_to_ipa(text, text_tokenizer): - if type(text) == str: - text = _english_to_ipa(text) - else: - text = [_english_to_ipa(t) for t in text] - phonemes = text_tokenizer(text) - if type(text) == str: - return special_map(phonemes) - else: - result_ph = [] - for phone in phonemes: - result_ph.append(special_map(phone)) - return result_ph diff --git a/models/tts/debatts/utils/g2p_new/french.py b/models/tts/debatts/utils/g2p_new/french.py deleted file mode 100644 index 9c059e08..00000000 --- a/models/tts/debatts/utils/g2p_new/french.py +++ /dev/null @@ -1,178 +0,0 @@ -"""https://github.com/bootphon/phonemizer""" - -import re -from phonemizer import phonemize -from phonemizer.separator import Separator - -# List of (regular expression, replacement) pairs for abbreviations in french: -_abbreviations = [ - (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) - for x in [ - ("M", "monsieur"), - ("Mlle", "mademoiselle"), - ("Mlles", "mesdemoiselles"), - ("Mme", "Madame"), - ("Mmes", "Mesdames"), - ("N.B", "nota bene"), - ("M", "monsieur"), - ("p.c.q", "parce que"), - ("Pr", "professeur"), - ("qqch", "quelque chose"), - ("rdv", "rendez-vous"), - ("max", "maximum"), - ("min", "minimum"), - ("no", "numéro"), - ("adr", "adresse"), - ("dr", "docteur"), - ("st", "saint"), - ("co", "companie"), - ("jr", "junior"), - ("sgt", "sergent"), - ("capt", "capitain"), - ("col", "colonel"), - ("av", "avenue"), - ("av. J.-C", "avant Jésus-Christ"), - ("apr. J.-C", "après Jésus-Christ"), - ("art", "article"), - ("boul", "boulevard"), - ("c.-à-d", "c’est-à-dire"), - ("etc", "et cetera"), - ("ex", "exemple"), - ("excl", "exclusivement"), - ("boul", "boulevard"), - ] -] + [ - (re.compile("\\b%s" % x[0]), x[1]) - for x in [ - ("Mlle", "mademoiselle"), - ("Mlles", "mesdemoiselles"), - ("Mme", "Madame"), - ("Mmes", "Mesdames"), - ] -] - -rep_map = { - ":": ",", - ";": ",", - ",": ",", - "。": ".", - "!": "!", - "?": "?", - "\n": ".", - "·": ",", - "、": ",", - "...": ".", - "…": ".", - "$": ".", - "“": "", - "”": "", - "‘": "", - "’": "", - "(": "", - ")": "", - "(": "", - ")": "", - "《": "", - "》": "", - "【": "", - "】": "", - "[": "", - "]": "", - "—": "", - "~": "-", - "~": "-", - "「": "", - "」": "", - "¿": "", - "¡": "", -} - -_special_map = [ - (re.compile("%s" % x[0]), x[1]) - for x in [ - ("ø", "ɸ"), - ("ː", ":"), - ("j", "jˈ"), # To avoid incorrect connect - ("n", "ˈn"), # To avoid incorrect connect - ("w", "wˈ"), # To avoid incorrect connect - ("ã", "a~"), - ("ɑ̃", "ɑ~"), - ("ɔ̃", "ɔ~"), - ("ɛ̃", "ɛ~"), - ("œ̃", "œ~"), - ] -] - - -def collapse_whitespace(text): - # Regular expression matching whitespace: - _whitespace_re = re.compile(r"\s+") - return re.sub(_whitespace_re, " ", text).strip() - - -def remove_punctuation_at_begin(text): - return re.sub(r"^[,.!?]+", "", text) - - -def remove_aux_symbols(text): - text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) - return text - - -def replace_symbols(text): - text = text.replace(";", ",") - text = text.replace("-", " ") - text = text.replace(":", ",") - text = text.replace("&", " et ") - return text - - -def expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text - - -def replace_punctuation(text): - pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) - replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) - return replaced_text - - -def text_normalize(text): - text = expand_abbreviations(text) - text = replace_punctuation(text) - text = replace_symbols(text) - text = remove_aux_symbols(text) - text = remove_punctuation_at_begin(text) - text = collapse_whitespace(text) - text = re.sub(r"([^\.,!\?\-…])$", r"\1.", text) - return text - - -# special map -def special_map(text): - for regex, replacement in _special_map: - text = re.sub(regex, replacement, text) - return text - - -def french_to_ipa(text): - text = text_normalize(text) - - ipa = phonemize( - text.strip(), - language="fr-fr", - backend="espeak", - separator=Separator(phone=None, word=" ", syllable="|"), - strip=True, - preserve_punctuation=True, - njobs=4, - ) - - # remove "(en)" and "(fr)" tag - ipa = ipa.replace("(en)", "").replace("(fr)", "") - - ipa = special_map(ipa) - - return ipa diff --git a/models/tts/debatts/utils/g2p_new/g2p_new.py b/models/tts/debatts/utils/g2p_new/g2p_new.py index eaa6ebbd..ae5b72fc 100644 --- a/models/tts/debatts/utils/g2p_new/g2p_new.py +++ b/models/tts/debatts/utils/g2p_new/g2p_new.py @@ -1,3 +1,8 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + from utils.g2p_new import PhonemeBpeTokenizer import tqdm diff --git a/models/tts/debatts/utils/g2p_new/german.py b/models/tts/debatts/utils/g2p_new/german.py deleted file mode 100644 index 3f9259a3..00000000 --- a/models/tts/debatts/utils/g2p_new/german.py +++ /dev/null @@ -1,122 +0,0 @@ -"""https://github.com/bootphon/phonemizer""" - -import re -from phonemizer import phonemize -from phonemizer.separator import Separator - -rep_map = { - ":": ",", - ";": ",", - ",": ",", - "。": ".", - "!": "!", - "?": "?", - "\n": ".", - "·": ",", - "、": ",", - "...": ".", - "…": ".", - "$": ".", - "“": "", - "”": "", - "‘": "", - "’": "", - "(": "", - ")": "", - "(": "", - ")": "", - "《": "", - "》": "", - "【": "", - "】": "", - "[": "", - "]": "", - "—": "", - "~": "-", - "~": "-", - "「": "", - "」": "", - "¿": "", - "¡": "", -} - -_special_map = [ - (re.compile("%s" % x[0]), x[1]) - for x in [ - ("ø", "ɸ"), - ("ː", ":"), - ("ɜ", "ʒ"), - ("ɑ̃", "ɑ~"), - ("j", "jˈ"), # To avoid incorrect connect - ("n", "ˈn"), # To avoid incorrect connect - ("t", "tˈ"), # To avoid incorrect connect - ("ŋ", "ˈŋ"), # To avoid incorrect connect - ("ɪ", "ˈɪ"), # To avoid incorrect connect - ] -] - - -def collapse_whitespace(text): - # Regular expression matching whitespace: - _whitespace_re = re.compile(r"\s+") - return re.sub(_whitespace_re, " ", text).strip() - - -def remove_punctuation_at_begin(text): - return re.sub(r"^[,.!?]+", "", text) - - -def remove_aux_symbols(text): - text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) - return text - - -def replace_symbols(text): - text = text.replace(";", ",") - text = text.replace("-", " ") - text = text.replace(":", ",") - return text - - -def replace_punctuation(text): - pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) - replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) - return replaced_text - - -def text_normalize(text): - text = replace_punctuation(text) - text = replace_symbols(text) - text = remove_aux_symbols(text) - text = remove_punctuation_at_begin(text) - text = collapse_whitespace(text) - text = re.sub(r"([^\.,!\?\-…])$", r"\1.", text) - return text - - -# special map -def special_map(text): - for regex, replacement in _special_map: - text = re.sub(regex, replacement, text) - return text - - -def german_to_ipa(text): - text = text_normalize(text) - - ipa = phonemize( - text.strip(), - language="de", - backend="espeak", - separator=Separator(phone=None, word=" ", syllable="|"), - strip=True, - preserve_punctuation=True, - njobs=4, - ) - - # remove "(en)" and "(fr)" tag - ipa = ipa.replace("(en)", "").replace("(de)", "") - - ipa = special_map(ipa) - - return ipa diff --git a/models/tts/debatts/utils/g2p_new/japanese.py b/models/tts/debatts/utils/g2p_new/japanese.py deleted file mode 100644 index ed1dc2e3..00000000 --- a/models/tts/debatts/utils/g2p_new/japanese.py +++ /dev/null @@ -1,175 +0,0 @@ -"""from https://github.com/Plachtaa/VALL-E-X/g2p""" - -import re -from unidecode import unidecode - -""" - Text clean time -""" - -# Regular expression matching Japanese without punctuation marks: -_japanese_characters = re.compile( - r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" -) - -# Regular expression matching non-Japanese characters or punctuation marks: -_japanese_marks = re.compile( - r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" -) - -# List of (symbol, Japanese) pairs for marks: -_symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("%", "パーセント")]] - -# List of (romaji, ipa2) pairs for marks: -_romaji_to_ipa2 = [ - (re.compile("%s" % x[0]), x[1]) - for x in [ - ("u", "ɯ"), - ("ʧ", "tʃ"), - ("j", "dʑ"), - ("y", "j"), - ("ni", "n^i"), - ("nj", "n^"), - ("hi", "çi"), - ("hj", "ç"), - ("f", "ɸ"), - ("I", "i*"), - ("U", "ɯ*"), - ("r", "ɾ"), - ] -] - -# List of (consonant, sokuon) pairs: -_real_sokuon = [ - (re.compile("%s" % x[0]), x[1]) - for x in [ - (r"Q([↑↓]*[kg])", r"k#\1"), - (r"Q([↑↓]*[tdjʧ])", r"t#\1"), - (r"Q([↑↓]*[sʃ])", r"s\1"), - (r"Q([↑↓]*[pb])", r"p#\1"), - ] -] - -# List of (consonant, hatsuon) pairs: -_real_hatsuon = [ - (re.compile("%s" % x[0]), x[1]) - for x in [ - (r"N([↑↓]*[pbm])", r"m\1"), - (r"N([↑↓]*[ʧʥj])", r"n^\1"), - (r"N([↑↓]*[tdn])", r"n\1"), - (r"N([↑↓]*[kg])", r"ŋ\1"), - ] -] - - -def symbols_to_japanese(text): - for regex, replacement in _symbols_to_japanese: - text = re.sub(regex, replacement, text) - return text - - -def japanese_to_romaji_with_accent(text): - """Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html""" - import pyopenjtalk - - text = symbols_to_japanese(text) - sentences = re.split(_japanese_marks, text) - marks = re.findall(_japanese_marks, text) - text = "" - for i, sentence in enumerate(sentences): - if re.match(_japanese_characters, sentence): - if text != "": - text += " " - labels = pyopenjtalk.extract_fullcontext(sentence) - for n, label in enumerate(labels): - phoneme = re.search(r"\-([^\+]*)\+", label).group(1) - if phoneme not in ["sil", "pau"]: - text += ( - phoneme.replace("ch", "ʧ").replace("sh", "ʃ").replace("cl", "Q") - ) - else: - continue - # n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) - a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) - a2 = int(re.search(r"\+(\d+)\+", label).group(1)) - a3 = int(re.search(r"\+(\d+)/", label).group(1)) - if re.search(r"\-([^\+]*)\+", labels[n + 1]).group(1) in ["sil", "pau"]: - a2_next = -1 - else: - a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) - # Accent phrase boundary - if a3 == 1 and a2_next == 1: - text += " " - # Falling - elif a1 == 0 and a2_next == a2 + 1: - text += "↓" - # Rising - elif a2 == 1 and a2_next == 2: - text += "↑" - if i < len(marks): - text += unidecode(marks[i]).replace(" ", "") - return text - - -def get_real_sokuon(text): - for regex, replacement in _real_sokuon: - text = re.sub(regex, replacement, text) - return text - - -def get_real_hatsuon(text): - for regex, replacement in _real_hatsuon: - text = re.sub(regex, replacement, text) - return text - - -def japanese_to_ipa(text): - text = japanese_to_romaji_with_accent(text).replace("...", "…") - text = get_real_sokuon(text) - text = get_real_hatsuon(text) - for regex, replacement in _romaji_to_ipa2: - text = re.sub(regex, replacement, text) - return text - - -""" - Phoneme merge time -""" - - -def _connect_tone(phoneme_tokens, vocab): - - tone_list = ["→", "↑", "↓↑", "↓"] - tone_token = [] - last_single_token = 0 - base = 0 - pattern = r"\[[^\[\]]*\]" # Exclude "[" and "]" - for tone, idx in vocab.items(): - if re.match(pattern, tone): - base = idx + 1 - if tone in tone_list: - tone_token.append(idx) - last_single_token = idx - - pre_token = None - cur_token = None - res_token = [] - for t in phoneme_tokens: - cur_token = t - if t in tone_token: - cur_token = ( - last_single_token - + (pre_token - base) * len(tone_list) - + tone_token.index(t) - + 1 - ) - res_token.pop() - res_token.append(cur_token) - pre_token = t - - return res_token - - -def japanese_merge_phoneme(phoneme_tokens, vocab): - phoneme_tokens = _connect_tone(phoneme_tokens, vocab) - return phoneme_tokens diff --git a/models/tts/debatts/utils/g2p_new/korean.py b/models/tts/debatts/utils/g2p_new/korean.py deleted file mode 100644 index 60ba2e13..00000000 --- a/models/tts/debatts/utils/g2p_new/korean.py +++ /dev/null @@ -1,167 +0,0 @@ -"""https://github.com/bootphon/phonemizer""" - -import re - -# from g2pkk import G2p -# from jamo import hangul_to_jamo - -english_dictionary = { - "KOREA": "코리아", - "IDOL": "아이돌", - "IT": "아이티", - "IQ": "아이큐", - "UP": "업", - "DOWN": "다운", - "PC": "피씨", - "CCTV": "씨씨티비", - "SNS": "에스엔에스", - "AI": "에이아이", - "CEO": "씨이오", - "A": "에이", - "B": "비", - "C": "씨", - "D": "디", - "E": "이", - "F": "에프", - "G": "지", - "H": "에이치", - "I": "아이", - "J": "제이", - "K": "케이", - "L": "엘", - "M": "엠", - "N": "엔", - "O": "오", - "P": "피", - "Q": "큐", - "R": "알", - "S": "에스", - "T": "티", - "U": "유", - "V": "브이", - "W": "더블유", - "X": "엑스", - "Y": "와이", - "Z": "제트", -} - -# List of (jamo, ipa) pairs: (need to update) -_jamo_to_ipa = [ - (re.compile("%s" % x[0]), x[1]) - for x in [ - ("ㅏ", "ɐ"), - ("ㅑ", "jɐ"), - ("ㅓ", "ʌ"), - ("ㅕ", "jʌ"), - ("ㅗ", "o"), - ("ㅛ", "jo"), - ("ᅮ", "u"), - ("ㅠ", "ju"), - ("ᅳ", "ɯ"), - ("ㅣ", "i"), - ("ㅔ", "e"), - ("ㅐ", "ɛ"), - ("ㅖ", "je"), - ("ㅒ", "jɛ"), # lost - ("ㅚ", "we"), - ("ㅟ", "wi"), - ("ㅢ", "ɯj"), - ("ㅘ", "wɐ"), - ("ㅙ", "wɛ"), # lost - ("ㅝ", "wʌ"), - ("ㅞ", "wɛ"), # lost - ("ㄱ", "q"), # 'ɡ' or 'k' - ("ㄴ", "n"), - ("ㄷ", "t"), # d - ("ㄹ", "ɫ"), # 'ᄅ' is 'r', 'ᆯ' is 'ɫ' - ("ㅁ", "m"), - ("ㅂ", "p"), - ("ㅅ", "s"), # 'ᄉ'is 't', 'ᆺ'is 's' - ("ㅇ", "ŋ"), # 'ᄋ' is None, 'ᆼ' is 'ŋ' - ("ㅈ", "tɕ"), - ("ㅊ", "tɕʰ"), # tʃh - ("ㅋ", "kʰ"), # kh - ("ㅌ", "tʰ"), # th - ("ㅍ", "pʰ"), # ph - ("ㅎ", "h"), - ("ㄲ", "k*"), # q - ("ㄸ", "t*"), # t - ("ㅃ", "p*"), # p - ("ㅆ", "s*"), # 'ᄊ' is 's', 'ᆻ' is 't' - ("ㅉ", "tɕ*"), # tɕ ? - ] -] - -_special_map = [ - (re.compile("%s" % x[0]), x[1]) - for x in [ - ("ʃ", "ɕ"), - ("tɕh", "tɕʰ"), - ("kh", "kʰ"), - ("th", "tʰ"), - ("ph", "pʰ"), - ] -] - - -def normalize(text): - text = text.strip() - text = re.sub( - "[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text - ) - text = normalize_english(text) - text = text.lower() - return text - - -def normalize_english(text): - def fn(m): - word = m.group() - if word in english_dictionary: - return english_dictionary.get(word) - return word - - text = re.sub("([A-Za-z]+)", fn, text) - return text - - -# Convert jamo to IPA -def jamo_to_ipa(text): - res = "" - for t in text: - for regex, replacement in _jamo_to_ipa: - t = re.sub(regex, replacement, t) - res += t - return res - - -# special map -def special_map(text): - for regex, replacement in _special_map: - text = re.sub(regex, replacement, text) - return text - - -def korean_to_ipa(text): - text = normalize(text) - - # espeak-ng - from phonemizer import phonemize - from phonemizer.separator import Separator - - ipa = phonemize( - text, - language="ko", - backend="espeak", - separator=Separator(phone=None, word=" ", syllable="|"), - strip=True, - preserve_punctuation=True, - njobs=4, - ) - ipa = special_map(ipa) - # # hangul charactier - # g2p = G2p() - # text = g2p(text) - # text = list(hangul_to_jamo(text)) # '하늘' --> ['ᄒ', 'ᅡ', 'ᄂ', 'ᅳ', 'ᆯ'] - # ipa = jamo_to_ipa(text) - return ipa diff --git a/models/tts/debatts/utils/g2p_new/mandarin.py b/models/tts/debatts/utils/g2p_new/mandarin.py index 9c6a229e..13ac24c7 100644 --- a/models/tts/debatts/utils/g2p_new/mandarin.py +++ b/models/tts/debatts/utils/g2p_new/mandarin.py @@ -1,3 +1,8 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import re import jieba import cn2an diff --git a/models/tts/debatts/utils/g2p_new/text_tokenizers.py b/models/tts/debatts/utils/g2p_new/text_tokenizers.py index 4bdd5103..45cb3481 100644 --- a/models/tts/debatts/utils/g2p_new/text_tokenizers.py +++ b/models/tts/debatts/utils/g2p_new/text_tokenizers.py @@ -1,3 +1,8 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import re import os from typing import List, Pattern, Union diff --git a/models/tts/debatts/utils/logger.py b/models/tts/debatts/utils/logger.py index df811ebc..2dd8e34f 100644 --- a/models/tts/debatts/utils/logger.py +++ b/models/tts/debatts/utils/logger.py @@ -25,7 +25,6 @@ def init_logger(name): fh.setFormatter(formatter) logger.addHandler(fh) - # 创建一个自定义的日志格式器,将特定级别的日志设置为红色 class ColorFormatter(logging.Formatter): def format(self, record): if record.levelno >= logging.ERROR: diff --git a/models/tts/debatts/utils/tool.py b/models/tts/debatts/utils/tool.py deleted file mode 100644 index ac680ea4..00000000 --- a/models/tts/debatts/utils/tool.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) 2024 Amphion. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os -import io -import scipy -import os -import shutil -from pydub import AudioSegment -import soundfile as sf - - -def load_cfg(cfg_path): - if not os.path.exists("config.json"): - raise FileNotFoundError( - "config.json not found. Please: copy, config, and rename `config.json.example` to `config.json`" - ) - with open(cfg_path, "r") as f: - cfg = json.load(f) - return cfg - - -def write_wav(path, sr, x): - """numpy array to WAV""" - sf.write(path, x, sr) - - -def write_mp3(path, sr, x): - """numpy array to MP3""" - wav_io = io.BytesIO() - scipy.io.wavfile.write(wav_io, sr, x) - wav_io.seek(0) - sound = AudioSegment.from_wav(wav_io) - with open(path, "wb") as af: - sound.export( - af, - format="mp3", - codec="mp3", - bitrate="160000", - ) - - -# 读取文件夹内所有音频文件 -def get_audio_files(folder_path): - audio_files = [] - for root, _, files in os.walk(folder_path): - if "_processed" in root: - continue - for file in files: - if ".temp" in file: - continue - if file.endswith((".mp3", ".wav", ".flac", ".m4a")): - audio_files.append(os.path.join(root, file)) - return audio_files - - -def get_specific_files(folder_path, ext): - audio_files = [] - for root, _, files in os.walk(folder_path): - if "_processed" in root: - continue - for file in files: - if ".temp" in file: - continue - if file.endswith(ext): - audio_files.append(os.path.join(root, file)) - return audio_files - - -def move_vocals(src_directory): - # 遍历根目录下的所有文件和文件夹 - for root, _, files in os.walk(src_directory): - for file in files: - # 检查文件名是否为'vocals.mp3' - if file == "vocals.mp3": - # 构建源文件的完整路径 - src_path = os.path.join(root, file) - # 获取父级目录的名称 - parent_dir_name = os.path.basename(root) - # 构建目标文件的完整路径 - dest_path = os.path.join(src_directory, parent_dir_name + ".mp3") - # 复制文件 - shutil.copy(src_path, dest_path) - - # 删除源文件夹 - shutil.rmtree(src_directory + "/htdemucs") From 2f3caed691548aaff0e7220cf6a651f9a63031f4 Mon Sep 17 00:00:00 2001 From: hehaorui Date: Mon, 28 Oct 2024 21:41:15 +0800 Subject: [PATCH 7/8] add env.sh and remove new in filename --- models/tts/debatts/env.sh | 45 +++ models/tts/debatts/environment.yml | 337 ------------------ models/tts/debatts/requirements.txt | 286 +++++++++++++++ .../{t2s_model_new.py => t2s_model.py} | 0 ..._sft_dataset_new.py => t2s_sft_dataset.py} | 0 5 files changed, 331 insertions(+), 337 deletions(-) create mode 100644 models/tts/debatts/env.sh delete mode 100644 models/tts/debatts/environment.yml create mode 100644 models/tts/debatts/requirements.txt rename models/tts/debatts/{t2s_model_new.py => t2s_model.py} (100%) rename models/tts/debatts/{t2s_sft_dataset_new.py => t2s_sft_dataset.py} (100%) diff --git a/models/tts/debatts/env.sh b/models/tts/debatts/env.sh new file mode 100644 index 00000000..bf3d8f55 --- /dev/null +++ b/models/tts/debatts/env.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +sudo apt-get update +sudo apt-get install -y espeak-ng + +pip install accelerate==0.24.1 +pip install cn2an +pip install -U cos-python-sdk-v5 +pip install datasets +pip install ffmpeg-python +pip install setuptools ruamel.yaml tqdm +pip install tensorboard tensorboardX torch==2.3.1 +pip install transformers===4.41.1 +pip install -U encodec +pip install black==24.1.1 +pip install -U funasr +pip install g2p-en +pip install jieba +pip install json5 +pip install librosa +pip install matplotlib +pip install modelscope +pip install numba==0.60.0 +pip install numpy +pip install omegaconf +pip install onnxruntime +pip install -U openai-whisper +pip install openpyxl +pip install pandas +pip install phonemizer +pip install protobuf +pip install pydub +pip install pypinyin +pip install pyworld +pip install ruamel.yaml +pip install scikit-learn scipy +pip install soundfile +pip install timm tokenizers +pip install torchaudio==2.3.1 +pip install torchvision==0.18.1 +pip install tqdm==4.66.4 +pip install transformers==4.44.0 +pip install unidecode +pip install zhconv zhon wandb + diff --git a/models/tts/debatts/environment.yml b/models/tts/debatts/environment.yml deleted file mode 100644 index 4d358104..00000000 --- a/models/tts/debatts/environment.yml +++ /dev/null @@ -1,337 +0,0 @@ -name: debatts -channels: - - pytorch - - nvidia - - https://repo.anaconda.com/pkgs/main - - conda-forge - - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2 - - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r - - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main - - defaults -dependencies: - - _libgcc_mutex=0.1=main - - _openmp_mutex=5.1=1_gnu - - asttokens=2.4.1=pyhd8ed1ab_0 - - backcall=0.2.0=pyh9f0ad1d_0 - - blas=1.0=mkl - - brotli-python=1.0.9=py39h6a678d5_8 - - bzip2=1.0.8=h7f98852_4 - - ca-certificates=2024.7.4=hbcca054_0 - - certifi=2024.7.4=pyhd8ed1ab_0 - - charset-normalizer=3.3.2=pyhd8ed1ab_0 - - comm=0.2.2=pyhd8ed1ab_0 - - cuda-cudart=12.1.105=0 - - cuda-cupti=12.1.105=0 - - cuda-libraries=12.1.0=0 - - cuda-nvrtc=12.1.105=0 - - cuda-nvtx=12.1.105=0 - - cuda-opencl=12.5.39=0 - - cuda-runtime=12.1.0=0 - - cuda-version=12.5=3 - - debugpy=1.6.7=py39h6a678d5_0 - - decorator=5.1.1=pyhd8ed1ab_0 - - entrypoints=0.4=pyhd8ed1ab_0 - - executing=2.0.1=pyhd8ed1ab_0 - - ffmpeg=4.2.2=h20bf706_0 - - filelock=3.15.4=pyhd8ed1ab_0 - - freetype=2.10.4=h0708190_1 - - gmp=6.1.2=hf484d3e_1000 - - gmpy2=2.1.2=py39heeb90bb_0 - - gnutls=3.6.15=he1e5248_0 - - idna=3.7=py39h06a4308_0 - - intel-openmp=2023.1.0=hdb19cb5_46306 - - ipykernel=6.29.4=pyh3099207_0 - - ipython=8.12.0=pyh41d4057_0 - - jedi=0.19.1=pyhd8ed1ab_0 - - jinja2=3.1.4=py39h06a4308_0 - - jpeg=9e=h5eee18b_1 - - jupyter_client=7.3.4=pyhd8ed1ab_0 - - jupyter_core=5.7.2=py39hf3d152e_0 - - lame=3.100=h7f98852_1001 - - lcms2=2.12=h3be6417_0 - - ld_impl_linux-64=2.38=h1181459_1 - - lerc=3.0=h295c915_0 - - libcublas=12.1.0.26=0 - - libcufft=11.0.2.4=0 - - libcufile=1.10.0.4=0 - - libcurand=10.3.6.39=0 - - libcusolver=11.4.4.55=0 - - libcusparse=12.0.2.55=0 - - libdeflate=1.17=h5eee18b_1 - - libffi=3.4.4=h6a678d5_1 - - libgcc-ng=11.2.0=h1234567_1 - - libgomp=11.2.0=h1234567_1 - - libidn2=2.3.4=h5eee18b_0 - - libjpeg-turbo=2.0.0=h9bf148f_0 - - libnpp=12.0.2.50=0 - - libnvjitlink=12.1.105=0 - - libnvjpeg=12.1.1.14=0 - - libopus=1.3.1=h7f98852_1 - - libpng=1.6.39=h5eee18b_0 - - libsodium=1.0.18=h36c2ea0_1 - - libstdcxx-ng=11.2.0=h1234567_1 - - libtasn1=4.19.0=h5eee18b_0 - - libtiff=4.5.1=h6a678d5_0 - - libunistring=0.9.10=h7f98852_0 - - libvpx=1.7.0=h439df22_0 - - libwebp-base=1.3.2=h5eee18b_0 - - llvm-openmp=14.0.6=h9e868ea_0 - - lz4-c=1.9.4=h6a678d5_1 - - markupsafe=2.1.1=py39hb9d737c_1 - - matplotlib-inline=0.1.7=pyhd8ed1ab_0 - - mkl=2023.1.0=h213fc3f_46344 - - mkl-service=2.4.0=py39h5eee18b_1 - - mkl_fft=1.3.8=py39h5eee18b_0 - - mkl_random=1.2.4=py39hdb19cb5_0 - - mpc=1.1.0=h10f8cd9_1 - - mpfr=4.0.2=hb69a4c5_1 - - mpmath=1.3.0=py39h06a4308_0 - - ncurses=6.4=h6a678d5_0 - - nest-asyncio=1.6.0=pyhd8ed1ab_0 - - nettle=3.7.3=hbbd107a_1 - - networkx=3.2.1=py39h06a4308_0 - - numpy-base=1.26.4=py39hb5e798b_0 - - openh264=2.1.1=h4ff587b_0 - - openjpeg=2.4.0=h3ad879b_0 - - openssl=1.1.1w=h7f8727e_0 - - packaging=24.1=pyhd8ed1ab_0 - - parso=0.8.4=pyhd8ed1ab_0 - - pexpect=4.9.0=pyhd8ed1ab_0 - - pickleshare=0.7.5=py_1003 - - pillow=10.3.0=py39h5eee18b_0 - - platformdirs=4.2.2=pyhd8ed1ab_0 - - prompt-toolkit=3.0.47=pyha770c72_0 - - prompt_toolkit=3.0.47=hd8ed1ab_0 - - psutil=5.9.1=py39hb9d737c_0 - - ptyprocess=0.7.0=pyhd3deb0d_0 - - pure_eval=0.2.2=pyhd8ed1ab_0 - - pygments=2.18.0=pyhd8ed1ab_0 - - pysocks=1.7.1=py39h06a4308_0 - - python=3.9.15=h7a1cb2a_2 - - python-dateutil=2.9.0=pyhd8ed1ab_0 - - python_abi=3.9=2_cp39 - - pytorch=2.3.1=py3.9_cuda12.1_cudnn8.9.2_0 - - pytorch-cuda=12.1=ha16c6d3_5 - - pytorch-mutex=1.0=cuda - - pyyaml=6.0.1=py39h5eee18b_0 - - pyzmq=25.1.2=py39h6a678d5_0 - - readline=8.2=h5eee18b_0 - - requests=2.32.3=pyhd8ed1ab_0 - - six=1.16.0=pyh6c4a22f_0 - - sqlite=3.45.3=h5eee18b_0 - - stack_data=0.6.2=pyhd8ed1ab_0 - - sympy=1.12.1=pyh04b8f61_3 - - tbb=2021.8.0=hdb19cb5_0 - - tk=8.6.14=h39e8969_0 - - torchtriton=2.3.1=py39 - - torchvision=0.18.1=py39_cu121 - - tornado=6.1=py39hb9d737c_3 - - traitlets=5.14.3=pyhd8ed1ab_0 - - typing_extensions=4.11.0=py39h06a4308_0 - - urllib3=2.2.2=py39h06a4308_0 - - wcwidth=0.2.13=pyhd8ed1ab_0 - - wheel=0.43.0=py39h06a4308_0 - - x264=1!157.20191217=h7b6447c_0 - - xz=5.4.6=h5eee18b_1 - - yaml=0.2.5=h7b6447c_0 - - zeromq=4.3.5=h6a678d5_0 - - zlib=1.2.13=h5eee18b_1 - - zstd=1.5.5=hc292b87_2 - - pip: - - absl-py==2.1.0 - - accelerate==0.24.1 - - aiohttp==3.9.5 - - aiosignal==1.3.1 - - aliyun-python-sdk-core==2.15.1 - - aliyun-python-sdk-kms==2.16.3 - - antlr4-python3-runtime==4.9.3 - - argparse==1.4.0 - - asteroid==0.7.0 - - asteroid-filterbanks==0.4.0 - - async-timeout==4.0.3 - - attrs==23.2.0 - - audiomentations==0.36.0 - - babel==2.15.0 - - bitarray==2.9.2 - - black==24.1.1 - - braceexpand==0.1.7 - - cached-property==1.5.2 - - cffi==1.16.0 - - click==8.1.7 - - cn2an==0.5.22 - - colorama==0.4.6 - - coloredlogs==15.0.1 - - contourpy==1.2.1 - - crcmod==1.7 - - cryptography==43.0.0 - - cycler==0.12.1 - - cython==3.0.10 - - cytoolz==0.12.3 - - datasets==2.20.0 - - diffsptk==2.1.0 - - diffusers==0.29.2 - - dill==0.3.8 - - distance==0.1.3 - - docker-pycreds==0.4.0 - - easydict==1.13 - - editdistance==0.8.1 - - einops==0.8.0 - - encodec==0.1.1 - - evaluate==0.4.2 - - fairseq==0.12.2 - - fastdtw==0.3.4 - - ffmpeg-python==0.2.0 - - flatbuffers==24.3.25 - - fonttools==4.53.1 - - frechet-audio-distance==0.3.1 - - frozenlist==1.4.1 - - fsspec==2024.5.0 - - ftfy==6.2.0 - - funasr==1.1.4 - - future==1.0.0 - - g2p-en==2.1.0 - - gitdb==4.0.11 - - gitpython==3.1.43 - - grpcio==1.64.1 - - h5py==3.11.0 - - huggingface-hub==0.23.4 - - humanfriendly==10.0 - - hydra-core==1.3.2 - - importlib-metadata==8.0.0 - - importlib-resources==6.4.0 - - inflect==7.3.1 - - intervaltree==3.1.0 - - jaconv==0.4.0 - - jamo==0.4.1 - - jieba==0.42.1 - - jiwer==3.0.4 - - jmespath==0.10.0 - - joblib==1.4.2 - - json5==0.9.25 - - jsonschema==4.22.0 - - jsonschema-specifications==2023.12.1 - - julius==0.2.7 - - kaldiio==2.18.0 - - kiwisolver==1.4.5 - - laion-clap==1.1.2 - - lazy-loader==0.4 - - lhotse==1.25.0.dev0+git.da4d70d.clean - - librosa==0.10.2.post1 - - lightning-utilities==0.11.3.post0 - - lilcom==1.8.0 - - llvmlite==0.43.0 - - loguru==0.7.2 - - lxml==5.2.2 - - markdown==3.6 - - matplotlib==3.9.1 - - mir-eval==0.7 - - modelscope==1.17.1 - - modules==1.0.0 - - more-itertools==10.3.0 - - msgpack==1.0.8 - - multidict==6.0.5 - - multiprocess==0.70.16 - - mypy-extensions==1.0.0 - - nltk==3.8.1 - - nnaudio==0.3.3 - - noisereduce==3.0.2 - - npy-append-array==0.9.16 - - numba==0.60.0 - - numpy==1.23.4 - - omegaconf==2.3.0 - - onnxruntime==1.19.0 - - openai-whisper==20231117 - - oss2==2.18.6 - - pandas==2.2.2 - - pathspec==0.12.1 - - pb-bss-eval==0.0.2 - - pedalboard==0.9.9 - - pesq==0.0.4 - - pip==24.2 - - pooch==1.8.2 - - portalocker==2.10.0 - - praat-parselmouth==0.4.3 - - proces==0.1.7 - - progressbar==2.5 - - protobuf==4.25.3 - - ptwt==0.1.9 - - pyarrow==16.1.0 - - pyarrow-hotfix==0.6 - - pycparser==2.22 - - pycryptodome==3.20.0 - - pydub==0.25.1 - - pymcd==0.2.1 - - pynndescent==0.5.13 - - pyparsing==3.1.2 - - pypesq==1.2.4 - - pypinyin==0.48.0 - - pysptk==1.0.1 - - pystoi==0.4.1 - - pytorch-lightning==2.3.2 - - pytorch-ranger==0.1.1 - - pytorch-wpe==0.0.1 - - pytz==2024.1 - - pywavelets==1.6.0 - - pyworld==0.3.4 - - rapidfuzz==3.9.6 - - referencing==0.35.1 - - regex==2024.5.15 - - resampy==0.4.3 - - resemblyzer==0.1.4 - - rir-generator==0.2.0 - - rpds-py==0.18.1 - - ruamel-yaml==0.18.6 - - ruamel-yaml-clib==0.2.8 - - sacrebleu==2.4.2 - - safetensors==0.4.3 - - scikit-learn==1.5.1 - - scipy==1.10.1 - - semantic-version==2.10.0 - - sentencepiece==0.2.0 - - sentry-sdk==2.8.0 - - setproctitle==1.3.3 - - setuptools==70.3.0 - - setuptools-rust==1.9.0 - - smmap==5.0.1 - - sortedcontainers==2.4.0 - - soundfile==0.12.1 - - soxr==0.3.7 - - tabulate==0.9.0 - - tensorboard==2.17.0 - - tensorboard-data-server==0.7.2 - - tensorboardx==2.6.2.2 - - tgt==1.5 - - threadpoolctl==3.5.0 - - tiktoken==0.7.0 - - timm==1.0.8 - - tokenizers==0.19.1 - - tomli==2.0.1 - - toolz==0.12.1 - - torch-complex==0.4.4 - - torch-optimizer==0.1.0 - - torch-stoi==0.2.1 - - torchaudio==2.3.1 - - torchcomp==0.1.1 - - torchcrepe==0.0.23 - - torchlibrosa==0.1.0 - - torchlpc==0.4 - - torchmetrics==0.11.4 - - tqdm==4.66.4 - - transformers==4.44.0 - - trash-cli==0.24.5.26 - - typeguard==4.3.0 - - typing==3.7.4.3 - - tzdata==2024.1 - - umap-learn==0.5.6 - - unidecode==1.3.8 - - vector-quantize-pytorch==1.12.5 - - wandb==0.17.4 - - webdataset==0.2.86 - - webrtcvad==2.0.10 - - werkzeug==3.0.3 - - wget==3.2 - - xxhash==3.4.1 - - yarl==1.9.4 - - zhconv==1.4.3 - - zhon==2.0.2 - - zipp==3.19.2 diff --git a/models/tts/debatts/requirements.txt b/models/tts/debatts/requirements.txt new file mode 100644 index 00000000..07e07419 --- /dev/null +++ b/models/tts/debatts/requirements.txt @@ -0,0 +1,286 @@ +absl-py==2.1.0 +accelerate==0.24.1 +addict==2.4.0 +aiofiles==23.2.1 +aiohttp==3.9.5 +aiosignal==1.3.1 +aliyun-python-sdk-core==2.15.1 +aliyun-python-sdk-kms==2.16.3 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +asteroid==0.7.0 +asteroid-filterbanks==0.4.0 +asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work +async-timeout==4.0.3 +attrs==23.2.0 +audiomentations==0.36.0 +audioread==3.0.1 +Babel==2.15.0 +backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work +bitarray==2.9.2 +black==24.1.1 +braceexpand==0.1.7 +Brotli @ file:///croot/brotli-split_1714483155106/work +bypy==1.8.5 +cached-property==1.5.2 +certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1720457958366/work/certifi +cffi==1.16.0 +charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1698833585322/work +click==8.1.7 +cn2an==0.5.22 +colorama==0.4.6 +coloredlogs==15.0.1 +comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work +contourpy==1.3.0 +crcmod==1.7 +cryptography==43.0.0 +cycler==0.12.1 +Cython==3.0.10 +cytoolz==0.12.3 +datasets==2.20.0 +debugpy @ file:///croot/debugpy_1690905042057/work +decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work +decord==0.6.0 +diffsptk==2.1.0 +diffusers==0.29.2 +dill==0.3.8 +Distance==0.1.3 +docker-pycreds==0.4.0 +easydict==1.13 +editdistance==0.6.2 +einops==0.8.0 +encodec==0.1.1 +entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work +evaluate==0.4.2 +executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work +fairscale==0.4.0 +# Editable Git install with no remote (fairseq==0.12.2) +-e /mntnfs/lee_data1/qjw/fairseq +fastapi==0.115.2 +fastdtw==0.3.4 +ffmpeg-python==0.2.0 +ffmpy==0.4.0 +filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1719088281970/work +flatbuffers==24.3.25 +fonttools==4.53.1 +frechet_audio_distance==0.3.1 +frozenlist==1.4.1 +fsspec==2024.5.0 +ftfy==6.2.0 +funasr==1.1.4 +future==1.0.0 +g2p-en==2.1.0 +gitdb==4.0.11 +GitPython==3.1.43 +gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645438755360/work +gradio==4.41.0 +gradio_client==1.3.0 +grpcio==1.64.1 +h11==0.14.0 +h5py==3.11.0 +httpcore==1.0.6 +httpx==0.27.2 +huggingface-hub==0.26.1 +humanfriendly==10.0 +hydra-core==1.3.2 +idna @ file:///croot/idna_1714398848350/work +importlib_metadata==8.0.0 +importlib_resources==6.4.5 +inflect==7.3.1 +intervaltree==3.1.0 +ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1717717528849/work +ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1680185408135/work +jaconv==0.4.0 +jamo==0.4.1 +jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work +jieba==0.42.1 +Jinja2 @ file:///croot/jinja2_1716993405101/work +jiwer==3.0.4 +jmespath==0.10.0 +joblib==1.4.2 +json5==0.9.25 +jsonlines==4.0.0 +jsonschema==4.22.0 +jsonschema-specifications==2023.12.1 +julius==0.2.7 +jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1654730843242/work +jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257447442/work +kaldiio==2.18.0 +kiwisolver==1.4.5 +laion-clap==1.1.2 +lazy_loader==0.4 +lhotse @ git+https://github.com/lhotse-speech/lhotse@da4d70d7affc477eb8dc3a51f9b13d387817059a +librosa==0.10.2.post1 +lightning-utilities==0.11.3.post0 +lilcom==1.8.0 +llvmlite==0.43.0 +loguru==0.7.2 +lxml==5.2.2 +Markdown==3.6 +markdown-it-py==3.0.0 +markdown2==2.4.10 +MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1648737556467/work +matplotlib==3.7.4 +matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work +mdurl==0.1.2 +mir_eval==0.7 +mkl-fft @ file:///croot/mkl_fft_1695058164594/work +mkl-random @ file:///croot/mkl_random_1695059800811/work +mkl-service==2.4.0 +modelscope==1.17.1 +modelscope_studio @ http://thunlp.oss-cn-qingdao.aliyuncs.com/multi_modal/never_delete/modelscope_studio-0.4.0.9-py3-none-any.whl +modules==1.0.0 +more-itertools==10.1.0 +mpmath @ file:///croot/mpmath_1690848262763/work +msgpack==1.0.8 +multidict==6.0.5 +multiprocess==0.70.16 +mypy-extensions==1.0.0 +nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work +networkx @ file:///croot/networkx_1717597493534/work +nltk==3.8.1 +nnAudio==0.3.3 +noisereduce==3.0.2 +npy-append-array==0.9.16 +numba==0.60.0 +numpy==1.23.4 +omegaconf==2.3.0 +onnxruntime==1.19.0 +openai-whisper==20231117 +opencv-python-headless==4.5.5.64 +openpyxl==3.1.2 +orjson==3.10.9 +oss2==2.18.6 +packaging==23.2 +pandas==2.2.2 +parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work +pathspec==0.12.1 +pb-bss-eval==0.0.2 +pedalboard==0.9.9 +pesq==0.0.4 +pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work +pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work +Pillow==10.1.0 +platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work +pooch==1.8.2 +portalocker==2.10.1 +praat-parselmouth==0.4.3 +proces==0.1.7 +progressbar==2.5 +prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1718047967974/work +protobuf==4.25.3 +psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1653089170447/work +ptwt==0.1.9 +ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl +pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work +pyarrow==16.1.0 +pyarrow-hotfix==0.6 +pycparser==2.22 +pycryptodome==3.20.0 +pydantic==2.9.2 +pydantic_core==2.23.4 +pydub==0.25.1 +Pygments==2.18.0 +pymcd==0.2.1 +pynndescent==0.5.13 +pyparsing==3.1.2 +pypesq @ https://github.com/vBaiCai/python-pesq/archive/master.zip#sha256=fba27c3d95e8f72fed7c55f675ce6057a64b26a1a67a2e469df2804cca69b8cc +pypinyin==0.48.0 +PySocks @ file:///tmp/build/80754af9/pysocks_1605305812635/work +pysptk==1.0.1 +pystoi==0.4.1 +python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work +python-multipart==0.0.12 +pytorch-lightning==2.3.2 +pytorch-ranger==0.1.1 +pytorch-wpe==0.0.1 +pytz==2024.1 +PyWavelets==1.6.0 +pyworld==0.3.4 +PyYAML @ file:///croot/pyyaml_1698096049011/work +pyzmq @ file:///croot/pyzmq_1705605076900/work +rapidfuzz==3.9.6 +referencing==0.35.1 +regex==2024.5.15 +requests==2.32.3 +requests-toolbelt==1.0.0 +resampy==0.4.3 +Resemblyzer==0.1.4 +rich==13.9.2 +rir-generator==0.2.0 +rpds-py==0.18.1 +ruamel.yaml==0.18.6 +ruamel.yaml.clib==0.2.8 +ruff==0.7.0 +sacrebleu==2.3.2 +safetensors==0.4.5 +scikit-learn==1.5.1 +scipy==1.10.1 +seaborn==0.13.0 +semantic-version==2.10.0 +sentencepiece==0.2.0 +sentry-sdk==2.8.0 +setproctitle==1.3.3 +setuptools-rust==1.9.0 +shellingham==1.5.4 +shortuuid==1.0.11 +six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work +smmap==5.0.1 +socksio==1.0.0 +sortedcontainers==2.4.0 +soundfile==0.12.1 +soxr==0.3.7 +stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work +starlette==0.40.0 +sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1718625546171/work +tabulate==0.9.0 +tensorboard==2.17.0 +tensorboard-data-server==0.7.2 +tensorboardX==2.6.2.2 +tgt==1.5 +threadpoolctl==3.5.0 +tiktoken==0.7.0 +timm==0.9.10 +tokenizers==0.19.1 +tomli==2.0.1 +tomlkit==0.12.0 +toolz==0.12.1 +torch==2.3.1 +torch-complex==0.4.4 +torch-optimizer==0.1.0 +torch-stoi==0.2.1 +torchaudio==2.3.1 +torchcomp==0.1.1 +torchcrepe==0.0.23 +torchlibrosa==0.1.0 +torchlpc==0.4 +torchmetrics==0.11.4 +torchvision==0.18.1 +tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827245914/work +tqdm==4.66.4 +traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work +transformers==4.44.0 +trash-cli==0.24.5.26 +triton==2.3.1 +typeguard==4.3.0 +typer==0.12.5 +typing==3.7.4.3 +typing_extensions @ file:///croot/typing_extensions_1715268824938/work +tzdata==2024.1 +umap-learn==0.5.6 +Unidecode==1.3.8 +urllib3==2.2.3 +uvicorn==0.24.0.post1 +vector-quantize-pytorch==1.12.5 +wandb==0.17.4 +wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work +webdataset==0.2.86 +webrtcvad==2.0.10 +websockets==12.0 +Werkzeug==3.0.3 +wget==3.2 +xxhash==3.4.1 +yarl==1.9.4 +zhconv==1.4.3 +zhon==2.0.2 +zipp==3.19.2 diff --git a/models/tts/debatts/t2s_model_new.py b/models/tts/debatts/t2s_model.py similarity index 100% rename from models/tts/debatts/t2s_model_new.py rename to models/tts/debatts/t2s_model.py diff --git a/models/tts/debatts/t2s_sft_dataset_new.py b/models/tts/debatts/t2s_sft_dataset.py similarity index 100% rename from models/tts/debatts/t2s_sft_dataset_new.py rename to models/tts/debatts/t2s_sft_dataset.py From 102d1159ba51fbd09b2136ce8eb29e6508420f5a Mon Sep 17 00:00:00 2001 From: hehaorui Date: Mon, 18 Nov 2024 23:04:17 +0800 Subject: [PATCH 8/8] commit changes of filenames and README and arxiv --- models/tts/debatts/README.md | 29 +++++++++++++++++-- ...layer_24k.json => s2a_debatts_1layer.json} | 0 ...8192_1q_24k.json => s2a_debatts_full.json} | 0 .../87_SPEAKER00_7_part11_212_prompt.json | 3 ++ ...rge_101k_fix_new.json => t2s_debatts.json} | 0 models/tts/debatts/t2s_model.py | 4 --- .../debatts/try_inference_small_samples.py | 16 +++++----- 7 files changed, 37 insertions(+), 15 deletions(-) rename models/tts/debatts/s2a_egs/{exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json => s2a_debatts_1layer.json} (100%) rename models/tts/debatts/s2a_egs/{exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json => s2a_debatts_full.json} (100%) create mode 100644 models/tts/debatts/speech_examples/87_SPEAKER00_7_part11_212_prompt.json rename models/tts/debatts/t2s_egs/{exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json => t2s_debatts.json} (100%) diff --git a/models/tts/debatts/README.md b/models/tts/debatts/README.md index 9f3fcb55..ed2e6596 100644 --- a/models/tts/debatts/README.md +++ b/models/tts/debatts/README.md @@ -4,10 +4,33 @@ Debatts is an advanced text-to-speech (TTS) model specifically designed for Mandarin debate contexts. This innovative model leverages short audio prompts to learn and replicate speaker characteristics while dynamically adjusting speaking style by analyzing the audio of debate opponents. This capability allows Debatts to integrate seamlessly into debate scenarios, offering not just speech synthesis but a responsive adaptation to the changing dynamics of debate interactions. ## Environment Setup -To set up the necessary environment to run Debatts, please use the provided `environment.yml` file. This file contains all the required dependencies and can be easily set up with the following Conda command: +To set up the necessary environment to run Debatts, please use the provided `env.sh` file. This file contains all the required dependencies and can be easily set up with the following Conda command: + +**Clone and install** ```bash -conda env create -f environment.yml +git clone https://github.com/open-mmlab/Amphion.git +# create env +bash ./models/tts/debatts/env.sh +``` + +**Application** +We provide model application within the try_inference python code, with the supported example speeches. For more debating speech samples, users can refer to huggingface [Debatts-Data](https://huggingface.co/datasets/amphion/Debatts-Data). Modify the corresponding speech path in inference code. ## Continuous Updates -The Debatts project is actively being developed, with continuous updates aimed at enhancing model performance and expanding features. We encourage users to regularly check our repository for the latest updates and improvements to ensure optimal functionality and to take advantage of new capabilities as they become available. \ No newline at end of file +The Debatts project is actively being developed, with continuous updates aimed at enhancing model performance and expanding features. We encourage users to regularly check our repository for the latest updates and improvements to ensure optimal functionality and to take advantage of new capabilities as they become available. + +## Citations +If you use MaskGCT in your research, please cite the following paper: + +```bibtex +@misc{huang2024debattszeroshotdebatingtexttospeech, + title={Debatts: Zero-Shot Debating Text-to-Speech Synthesis}, + author={Yiqiao Huang and Yuancheng Wang and Jiaqi Li and Haotian Guo and Haorui He and Shunsi Zhang and Zhizheng Wu}, + year={2024}, + eprint={2411.06540}, + archivePrefix={arXiv}, + primaryClass={eess.AS}, + url={https://arxiv.org/abs/2411.06540}, +} +``` diff --git a/models/tts/debatts/s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json b/models/tts/debatts/s2a_egs/s2a_debatts_1layer.json similarity index 100% rename from models/tts/debatts/s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json rename to models/tts/debatts/s2a_egs/s2a_debatts_1layer.json diff --git a/models/tts/debatts/s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json b/models/tts/debatts/s2a_egs/s2a_debatts_full.json similarity index 100% rename from models/tts/debatts/s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json rename to models/tts/debatts/s2a_egs/s2a_debatts_full.json diff --git a/models/tts/debatts/speech_examples/87_SPEAKER00_7_part11_212_prompt.json b/models/tts/debatts/speech_examples/87_SPEAKER00_7_part11_212_prompt.json new file mode 100644 index 00000000..d7d7a2ce --- /dev/null +++ b/models/tts/debatts/speech_examples/87_SPEAKER00_7_part11_212_prompt.json @@ -0,0 +1,3 @@ +{ + "text": "对着我就是莽夫, 我不看这个公司是不是套牌, 是不是垃圾, 是不是龙头, 我就随便投, 代表是莽夫, 肯定也不是这样呀。" +} \ No newline at end of file diff --git a/models/tts/debatts/t2s_egs/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json b/models/tts/debatts/t2s_egs/t2s_debatts.json similarity index 100% rename from models/tts/debatts/t2s_egs/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json rename to models/tts/debatts/t2s_egs/t2s_debatts.json diff --git a/models/tts/debatts/t2s_model.py b/models/tts/debatts/t2s_model.py index 279d0a5a..60c4d20f 100644 --- a/models/tts/debatts/t2s_model.py +++ b/models/tts/debatts/t2s_model.py @@ -513,9 +513,7 @@ def sample_hf( repetition_penalty=repeat_penalty, min_new_tokens=50, ) - # assert generated_ids.size(1) > input_length, f"Generated tokens length {generated_ids.size(1)} is less than input length {input_length}, generated ids is {generated_ids}" - gen_tokens = generated_ids[:, input_length:-1] return gen_tokens @@ -526,8 +524,6 @@ def __init__(self, downsample_factor=2): def forward(self, x, mask): # x shape: (batch_size, seq_len) - # mask shape: (batch_size, seq_len) - x = x.float() x = x.unsqueeze(1) # add channel dimension: (batch_size, 1, seq_len) x = F.avg_pool1d( diff --git a/models/tts/debatts/try_inference_small_samples.py b/models/tts/debatts/try_inference_small_samples.py index 6b8a845c..d234d609 100644 --- a/models/tts/debatts/try_inference_small_samples.py +++ b/models/tts/debatts/try_inference_small_samples.py @@ -307,10 +307,10 @@ def semantic2acoustic(combine_semantic_code, acoustic_code): device = torch.device("cuda:0") cfg_soundstorm_1layer = load_config( - "./s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_1layer_24k.json" + "./s2a_egs/s2a_debatts_1layer.json" ) cfg_soundstorm_full = load_config( - "./s2a_egs/exp_config_16k_emilia_llama_new_semantic_repcodec_8192_1q_24k.json" + "./s2a_egs/s2a_debatts_full.json" ) soundstorm_1layer = build_soundstorm(cfg_soundstorm_1layer, device) @@ -328,14 +328,13 @@ def semantic2acoustic(combine_semantic_code, acoustic_code): ) kmeans_model = build_kmeans_model(cfg_soundstorm_full, device) - -soundstorm_1layer_path = "./s2a_model/emilia_50k_8192_331k_model.safetensors" -soundstorm_full_path = "./s2a_model/emilia_50k_8192_519k_model.safetensors" +soundstorm_1layer_path = "./s2a_model/s2a_model_1layer/onelayer_model.safetensors" +soundstorm_full_path = "./s2a_model/s2a_model_full/full_model.safetensors" safetensors.torch.load_model(soundstorm_1layer, soundstorm_1layer_path) safetensors.torch.load_model(soundstorm_full, soundstorm_full_path) t2s_cfg = load_config( - "./t2s_egs/exp_config_16k_emilia_new_semantic_repcodec_8192_1q_large_101k_fix_new.json" + "./t2s_egs/t2s_debatts.json" ) t2s_model_new = build_t2s_model_new(t2s_cfg, device) t2s_model_new_ckpt_path = "./t2s_model/model.safetensors" @@ -607,12 +606,13 @@ def infer_small( from models.tts.soundstorm.try_inference_new import evaluation_new from models.tts.soundstorm.try_inference_new import extract_emotion_similarity -prompt0_wav_path = "./debatts/speech_examples/87_SPEAKER01_2_part03_213.wav" +prompt0_wav_path = "./speech_examples/87_SPEAKER01_2_part03_213.wav" prompt0_text = generate_text_data(prompt0_wav_path)[1] -spk_prompt_wav_path = "The Speaker Identity Path" +spk_prompt_wav_path = "./speech_examples/87_SPEAKER00_7_part11_212_prompt.wav" spk_prompt_text = generate_text_data(spk_prompt_wav_path)[1] +# TODO save_path_dir = "The Path to Save Generated Speech" wav_filename = "The Filename of Generated Speech" save_path = os.path.join(save_path_infer_dir, wav_filename)