Skip to content

Commit

Permalink
format g2p
Browse files Browse the repository at this point in the history
  • Loading branch information
HeCheng0625 committed Oct 18, 2024
1 parent 0eb0465 commit 40190c5
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 131 deletions.
5 changes: 0 additions & 5 deletions models/tts/maskgct/g2p/g2p/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def tokenize(self, text, sentence, language):
# 1. convert text to phoneme
phonemes = []
if language == "auto":
# 自动切分语种片段
seglist = LangSegment.getTexts(text)
tmp_ph = []
for seg in seglist:
Expand All @@ -65,27 +64,23 @@ def tokenize(self, text, sentence, language):
return phonemes, phoneme_tokens

def _clean_text(self, text, sentence, language, cleaner_names):
# 在这跳转到对应的语种处理函数
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception("Unknown cleaner: %s" % name)
# 获取对应语种的函数来处理
text = cleaner(text, sentence, language, self.text_tokenizers)
return text

def phoneme2token(self, phonemes):
tokens = []
if isinstance(phonemes, list):
for phone in phonemes:
# 由于可能在ipa音标后面添加了对应的常见音标,这里进行修改
phone = phone.split("\t")[0]
phonemes_split = phone.split("|")
tokens.append(
[self.vocab[p] for p in phonemes_split if p in self.vocab]
)
else:
# 由于可能在ipa音标后面添加了对应的常见音标,这里进行修改
phonemes = phonemes.split("\t")[0]
phonemes_split = phonemes.split("|")
tokens = [self.vocab[p] for p in phonemes_split if p in self.vocab]
Expand Down
41 changes: 8 additions & 33 deletions models/tts/maskgct/g2p/g2p/chinese_model_g2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def preprocess(self, origin_sentences, origin_labels):
for token in line:
words.append(token)
word_lens.append(1)
# 变成单个字的列表,开头加上[CLS]
token_start_idxs = 1 + np.cumsum([0] + word_lens[:-1])
sentences.append(((words, token_start_idxs), 0))
###
Expand All @@ -64,12 +63,7 @@ def __len__(self):
return len(self.dataset)

def collate_fn(self, batch):
"""
process batch data, including:
1. padding: 将每个batch的data padding到同一长度(batch中最长的data长度)
2. aligning: 找到每个sentence sequence里面有label项,文本与label对齐
3. tensor:转化为tensor
"""

sentences = [x[0][0] for x in batch]
ori_sents = [x[0][1] for x in batch]
labels = [x[1] for x in batch]
Expand All @@ -85,7 +79,6 @@ def collate_fn(self, batch):
for j in range(batch_len):
cur_len = len(sentences[j][0])
batch_data[j][:cur_len] = sentences[j][0]
# 找到有标签的数据的index([CLS]不算)
label_start_idx = sentences[j][-1]
label_starts = np.zeros(max_len)
label_starts[[idx for idx in label_start_idx if idx < max_len]] = 1
Expand Down Expand Up @@ -118,14 +111,11 @@ def __init__(self, bert_model, jsonr_file, json_file):
with open(json_file, "r", encoding="utf8") as fp:
self.pron_dict_id_2_pinyin = json.load(fp)
self.num_polyphone = len(self.pron_dict)
# 加载训练过的模型
self.device = "cpu"
self.polydataset = PolyDataset
options = SessionOptions() # initialize session options
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
print(os.path.join(bert_model, "poly_bert_model.onnx"))
# 这里的路径传上一节保存的onnx模型地址
# 优先使用CUDA,没有CUDA则使用CPU
self.session = InferenceSession(
os.path.join(bert_model, "poly_bert_model.onnx"),
sess_options=options,
Expand All @@ -134,30 +124,23 @@ def __init__(self, bert_model, jsonr_file, json_file):
"CPUExecutionProvider",
], # CPUExecutionProvider #CUDAExecutionProvider
)
# 设置GPU id
# self.session.set_providers(['CUDAExecutionProvider', "CPUExecutionProvider"], [ {'device_id': 0}])

# disable session.run() fallback mechanism, it prevents for a reset of the execution provider
self.session.disable_fallback()
# print("BERT POLY 初始化结束...")

def predict_process(self, txt_list):
# 对数据进行处理
word_test, label_test, texts_test = self.get_examples_po(txt_list)
data = self.polydataset(word_test, label_test)
# 在这设置batch_size
predict_loader = DataLoader(
data, batch_size=1, shuffle=False, collate_fn=data.collate_fn
)
# 输出预测结果
pred_tags = self.predict_onnx(predict_loader)
# print("BERT预测拼音:{}".format(pred_tags))
return pred_tags

def predict_onnx(self, dev_loader):
pred_tags = []
with torch.no_grad():
# 加入tqdm之后会显示训练过程
for idx, batch_samples in enumerate(dev_loader):
# [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
batch_data, batch_label_starts, batch_labels, batch_pmasks, _ = (
Expand All @@ -176,20 +159,14 @@ def predict_onnx(self, dev_loader):
)[0]
label_masks = batch_pmasks == 1
batch_labels = batch_labels.to("cpu").numpy()
# 这个地方在实际应用中可以仅考虑可选的拼音
for i, indices in enumerate(np.argmax(batch_output, axis=2)):
for j, idx in enumerate(indices):
if label_masks[i][j]:
# pred_tag.append(idx)
pred_tags.append(self.pron_dict_id_2_pinyin[str(idx + 1)])
return pred_tags

# 数据处理
def get_examples_po(self, text_list):
"""
将txt文件每一行中的文本分离出来,存储为words列表
BMES标注法标记文本对应的标签,存储为labels
"""

word_list = []
label_list = []
Expand All @@ -198,25 +175,19 @@ def get_examples_po(self, text_list):
for line in [text_list]:
sentence = line[0]
words = []
# 上面是使用token取代line,防止出现特殊字符干扰
tokens = line[0]
index = line[-1]
front = index
back = len(tokens) - index - 1
labels = [0] * front + [1] + [0] * back
# 然后把输入转换成ids
words = ["[CLS]"] + [item for item in sentence]
# 存放token id
words = self.tokenizer.convert_tokens_to_ids(words)
# 完成
word_list.append(words)
label_list.append(labels)
# 这个地方改为存放原文本
sentence_list.append(sentence)

id += 1
# mask_list.append(masks)
# 验证
assert len(labels) + 1 == len(words), print(
(
poly,
Expand All @@ -229,9 +200,13 @@ def get_examples_po(self, text_list):
len(labels),
)
)
assert len(labels) + 1 == len(words), "labels 数量与 words 不匹配"
assert len(labels) == len(sentence), "labels 数量与 sentence 不匹配"
assert len(labels) + 1 == len(
words
), "Number of labels does not match number of words"
assert len(labels) == len(
sentence
), "Number of labels does not match number of sentences"
assert len(word_list) == len(
label_list
), "label 句子数量与 word 句子数量不匹配"
), "Number of label sentences does not match number of word sentences"
return word_list, label_list, text_list
1 change: 0 additions & 1 deletion models/tts/maskgct/g2p/g2p/english.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def english_to_ipa(text, text_tokenizer):
else:
text = [_english_to_ipa(t) for t in text]
phonemes = text_tokenizer(text)
# 将所有的ipa音素放到字符串里面,如果以ipa音素结尾,则添加blank,如果是非ipa音素结尾(可能是标点符号),则不添加blank
if phonemes[-1] in "p⁼ʰmftnlkxʃs`ɹaoəɛɪeɑʊŋiuɥwæjː":
phonemes += "|_"
if type(text) == str:
Expand Down
Loading

0 comments on commit 40190c5

Please sign in to comment.