forked from yeyupiaoling/Whisper-Finetune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
aishell.py
134 lines (121 loc) · 5.42 KB
/
aishell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import argparse
import json
import os
import functools
import soundfile
from tqdm import tqdm
from utils.utils import download, unpack
from utils.utils import add_arguments, print_arguments
DATA_URL = 'https://openslr.elda.org/resources/33/data_aishell.tgz'
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg("filepath", default=None, type=str, help="压缩包data_aishell.tgz文件路径,不指定会自动下载")
add_arg("target_dir", default="dataset/audio/", type=str, help="存放音频文件的目录")
add_arg("annotation_text", default="dataset/", type=str, help="存放音频标注文件的目录")
add_arg('add_pun', default=False, type=bool, help="是否添加标点符")
args = parser.parse_args()
def create_annotation_text(data_dir, annotation_path):
print('Create Aishell annotation text ...')
if args.add_pun:
import logging
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
inference_pipline = pipeline(task=Tasks.punctuation,
model='damo/punc_ct-transformer_cn-en-common-vocab471067-large',
model_revision="v1.0.0")
if not os.path.exists(annotation_path):
os.makedirs(annotation_path)
f_train = open(os.path.join(annotation_path, 'train.json'), 'w', encoding='utf-8')
f_test = open(os.path.join(annotation_path, 'test.json'), 'w', encoding='utf-8')
transcript_path = os.path.join(data_dir, 'transcript', 'aishell_transcript_v0.8.txt')
transcript_dict = {}
with open(transcript_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in tqdm(lines):
line = line.strip()
if line == '': continue
audio_id, text = line.split(' ', 1)
# remove space
text = ''.join(text.split())
if args.add_pun:
text = inference_pipline(text_in=text)['text']
transcript_dict[audio_id] = text
# 训练集
data_types = ['train', 'dev']
lines = []
for type in data_types:
audio_dir = os.path.join(data_dir, 'wav', type)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname)
audio_id = fname[:-4]
# if no transcription for audio then skipped
if audio_id not in transcript_dict:
continue
text = transcript_dict[audio_id]
line = {"audio": {"path": audio_path}, "sentence": text}
lines.append(line)
# 添加音频时长
for i in tqdm(range(len(lines))):
audio_path = lines[i]['audio']['path']
sample, sr = soundfile.read(audio_path)
duration = round(sample.shape[-1] / float(sr), 2)
lines[i]["duration"] = duration
lines[i]["sentences"] = [{"start": 0, "end": duration, "text": lines[i]["sentence"]}]
for line in lines:
f_train.write(json.dumps(line, ensure_ascii=False) + "\n")
# 测试集
audio_dir = os.path.join(data_dir, 'wav', 'test')
lines = []
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname)
audio_id = fname[:-4]
# if no transcription for audio then skipped
if audio_id not in transcript_dict:
continue
text = transcript_dict[audio_id]
line = {"audio": {"path": audio_path}, "sentence": text}
lines.append(line)
# 添加音频时长
for i in tqdm(range(len(lines))):
audio_path = lines[i]['audio']['path']
sample, sr = soundfile.read(audio_path)
duration = round(sample.shape[-1] / float(sr), 2)
lines[i]["duration"] = duration
lines[i]["sentences"] = [{"start": 0, "end": duration, "text": lines[i]["sentence"]}]
for line in lines:
f_test.write(json.dumps(line, ensure_ascii=False)+"\n")
f_test.close()
f_train.close()
def prepare_dataset(url, md5sum, target_dir, annotation_path, filepath=None):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell')
if not os.path.exists(data_dir):
if filepath is None:
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
# unpack all audio tar files
audio_dir = os.path.join(data_dir, 'wav')
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for ftar in filelist:
unpack(os.path.join(subfolder, ftar), subfolder, True)
os.remove(filepath)
else:
print("Skip downloading and unpacking. Aishell data already exists in %s." % target_dir)
create_annotation_text(data_dir, annotation_path)
def main():
print_arguments(args)
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
annotation_path=args.annotation_text,
filepath=args.filepath)
if __name__ == '__main__':
main()