Skip to content

Commit f4ac0c7

Browse files
authored
Merge pull request #2143 from lym0302/mix_front
[tts] add mix frontend
2 parents ae7a73b + 207bb5d commit f4ac0c7

File tree

3 files changed

+201
-6
lines changed

3 files changed

+201
-6
lines changed

paddlespeech/t2s/exps/syn_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from paddlespeech.t2s.datasets.data_table import DataTable
3131
from paddlespeech.t2s.frontend import English
32+
from paddlespeech.t2s.frontend.mix_frontend import MixFrontend
3233
from paddlespeech.t2s.frontend.zh_frontend import Frontend
3334
from paddlespeech.t2s.modules.normalizer import ZScore
3435
from paddlespeech.utils.dynamic_import import dynamic_import
@@ -98,6 +99,8 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
9899
sentence = "".join(items[1:])
99100
elif lang == 'en':
100101
sentence = " ".join(items[1:])
102+
elif lang == 'mix':
103+
sentence = " ".join(items[1:])
101104
sentences.append((utt_id, sentence))
102105
return sentences
103106

@@ -111,7 +114,8 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]],
111114
am_dataset = am[am.rindex('_') + 1:]
112115
if am_name == 'fastspeech2':
113116
fields = ["utt_id", "text"]
114-
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
117+
if am_dataset in {"aishell3", "vctk",
118+
"mix"} and speaker_dict is not None:
115119
print("multiple speaker fastspeech2!")
116120
fields += ["spk_id"]
117121
elif voice_cloning:
@@ -140,6 +144,10 @@ def get_frontend(lang: str='zh',
140144
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
141145
elif lang == 'en':
142146
frontend = English(phone_vocab_path=phones_dict)
147+
elif lang == 'mix':
148+
frontend = MixFrontend(
149+
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
150+
143151
else:
144152
print("wrong lang!")
145153
print("frontend done!")
@@ -341,8 +349,12 @@ def get_am_output(
341349
input_ids = frontend.get_input_ids(
342350
input, merge_sentences=merge_sentences)
343351
phone_ids = input_ids["phone_ids"]
352+
elif lang == 'mix':
353+
input_ids = frontend.get_input_ids(
354+
input, merge_sentences=merge_sentences)
355+
phone_ids = input_ids["phone_ids"]
344356
else:
345-
print("lang should in {'zh', 'en'}!")
357+
print("lang should in {'zh', 'en', 'mix'}!")
346358

347359
if get_tone_ids:
348360
tone_ids = input_ids["tone_ids"]

paddlespeech/t2s/exps/synthesize_e2e.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,20 @@ def evaluate(args):
113113
input_ids = frontend.get_input_ids(
114114
sentence, merge_sentences=merge_sentences)
115115
phone_ids = input_ids["phone_ids"]
116+
elif args.lang == 'mix':
117+
input_ids = frontend.get_input_ids(
118+
sentence, merge_sentences=merge_sentences)
119+
phone_ids = input_ids["phone_ids"]
116120
else:
117-
print("lang should in {'zh', 'en'}!")
121+
print("lang should in {'zh', 'en', 'mix'}!")
118122
with paddle.no_grad():
119123
flags = 0
120124
for i in range(len(phone_ids)):
121125
part_phone_ids = phone_ids[i]
122126
# acoustic model
123127
if am_name == 'fastspeech2':
124128
# multi speaker
125-
if am_dataset in {"aishell3", "vctk"}:
129+
if am_dataset in {"aishell3", "vctk", "mix"}:
126130
spk_id = paddle.to_tensor(args.spk_id)
127131
mel = am_inference(part_phone_ids, spk_id)
128132
else:
@@ -170,7 +174,7 @@ def parse_args():
170174
choices=[
171175
'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc',
172176
'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk',
173-
'tacotron2_csmsc', 'tacotron2_ljspeech'
177+
'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix'
174178
],
175179
help='Choose acoustic model type of tts task.')
176180
parser.add_argument(
@@ -231,7 +235,7 @@ def parse_args():
231235
'--lang',
232236
type=str,
233237
default='zh',
234-
help='Choose model language. zh or en')
238+
help='Choose model language. zh or en or mix')
235239

236240
parser.add_argument(
237241
"--inference_dir",
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import re
15+
from typing import Dict
16+
from typing import List
17+
18+
import paddle
19+
20+
from paddlespeech.t2s.frontend import English
21+
from paddlespeech.t2s.frontend.zh_frontend import Frontend
22+
23+
24+
class MixFrontend():
25+
def __init__(self,
26+
g2p_model="pypinyin",
27+
phone_vocab_path=None,
28+
tone_vocab_path=None):
29+
30+
self.zh_frontend = Frontend(
31+
phone_vocab_path=phone_vocab_path, tone_vocab_path=tone_vocab_path)
32+
self.en_frontend = English(phone_vocab_path=phone_vocab_path)
33+
self.SENTENCE_SPLITOR = re.compile(r'([:、,;。?!,;?!][”’]?)')
34+
self.sp_id = self.zh_frontend.vocab_phones["sp"]
35+
self.sp_id_tensor = paddle.to_tensor([self.sp_id])
36+
37+
def is_chinese(self, char):
38+
if char >= '\u4e00' and char <= '\u9fa5':
39+
return True
40+
else:
41+
return False
42+
43+
def is_alphabet(self, char):
44+
if (char >= '\u0041' and char <= '\u005a') or (char >= '\u0061' and
45+
char <= '\u007a'):
46+
return True
47+
else:
48+
return False
49+
50+
def is_number(self, char):
51+
if char >= '\u0030' and char <= '\u0039':
52+
return True
53+
else:
54+
return False
55+
56+
def is_other(self, char):
57+
if not (self.is_chinese(char) or self.is_number(char) or
58+
self.is_alphabet(char)):
59+
return True
60+
else:
61+
return False
62+
63+
def _split(self, text: str) -> List[str]:
64+
text = re.sub(r'[《》【】<=>{}()()#&@“”^_|…\\]', '', text)
65+
text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
66+
text = text.strip()
67+
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
68+
return sentences
69+
70+
def _distinguish(self, text: str) -> List[str]:
71+
# sentence --> [ch_part, en_part, ch_part, ...]
72+
73+
segments = []
74+
types = []
75+
76+
flag = 0
77+
temp_seg = ""
78+
temp_lang = ""
79+
80+
# Determine the type of each character. type: blank, chinese, alphabet, number, unk.
81+
for ch in text:
82+
if self.is_chinese(ch):
83+
types.append("zh")
84+
elif self.is_alphabet(ch):
85+
types.append("en")
86+
elif ch == " ":
87+
types.append("blank")
88+
elif self.is_number(ch):
89+
types.append("num")
90+
else:
91+
types.append("unk")
92+
93+
assert len(types) == len(text)
94+
95+
for i in range(len(types)):
96+
97+
# find the first char of the seg
98+
if flag == 0:
99+
if types[i] != "unk" and types[i] != "blank":
100+
temp_seg += text[i]
101+
temp_lang = types[i]
102+
flag = 1
103+
104+
else:
105+
if types[i] == temp_lang or types[i] == "num":
106+
temp_seg += text[i]
107+
108+
elif temp_lang == "num" and types[i] != "unk":
109+
temp_seg += text[i]
110+
if types[i] == "zh" or types[i] == "en":
111+
temp_lang = types[i]
112+
113+
elif temp_lang == "en" and types[i] == "blank":
114+
temp_seg += text[i]
115+
116+
elif types[i] == "unk":
117+
pass
118+
119+
else:
120+
segments.append((temp_seg, temp_lang))
121+
122+
if types[i] != "unk" and types[i] != "blank":
123+
temp_seg = text[i]
124+
temp_lang = types[i]
125+
flag = 1
126+
else:
127+
flag = 0
128+
temp_seg = ""
129+
temp_lang = ""
130+
131+
segments.append((temp_seg, temp_lang))
132+
133+
return segments
134+
135+
def get_input_ids(self,
136+
sentence: str,
137+
merge_sentences: bool=True,
138+
get_tone_ids: bool=False,
139+
add_sp: bool=True) -> Dict[str, List[paddle.Tensor]]:
140+
141+
sentences = self._split(sentence)
142+
phones_list = []
143+
result = {}
144+
145+
for text in sentences:
146+
phones_seg = []
147+
segments = self._distinguish(text)
148+
for seg in segments:
149+
content = seg[0]
150+
lang = seg[1]
151+
if lang == "zh":
152+
input_ids = self.zh_frontend.get_input_ids(
153+
content,
154+
merge_sentences=True,
155+
get_tone_ids=get_tone_ids)
156+
157+
elif lang == "en":
158+
input_ids = self.en_frontend.get_input_ids(
159+
content, merge_sentences=True)
160+
161+
phones_seg.append(input_ids["phone_ids"][0])
162+
if add_sp:
163+
phones_seg.append(self.sp_id_tensor)
164+
165+
phones = paddle.concat(phones_seg)
166+
phones_list.append(phones)
167+
168+
if merge_sentences:
169+
merge_list = paddle.concat(phones_list)
170+
# rm the last 'sp' to avoid the noise at the end
171+
# cause in the training data, no 'sp' in the end
172+
if merge_list[-1] == self.sp_id_tensor:
173+
merge_list = merge_list[:-1]
174+
phones_list = []
175+
phones_list.append(merge_list)
176+
177+
result["phone_ids"] = phones_list
178+
179+
return result

0 commit comments

Comments
 (0)