forked from tensorpack/tensorpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate-lmdb.py
executable file
·136 lines (113 loc) · 4.52 KB
/
create-lmdb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: create-lmdb.py
# Author: Yuxin Wu
import os
import scipy.io.wavfile as wavfile
import string
import numpy as np
import argparse
import bob.ap
from tensorpack.dataflow import DataFlow, LMDBSerializer
from tensorpack.utils.argtools import memoized
from tensorpack.utils.stats import OnlineMoments
from tensorpack.utils import serialize, fs, logger
from tensorpack.utils.utils import get_tqdm
CHARSET = set(string.ascii_lowercase + ' ')
PHONEME_LIST = [
'aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay', 'b', 'bcl', 'ch', 'd', 'dcl', 'dh',
'dx', 'eh', 'el', 'em', 'en', 'eng', 'epi', 'er', 'ey', 'f', 'g', 'gcl', 'h#', 'hh', 'hv', 'ih',
'ix', 'iy', 'jh', 'k', 'kcl', 'l', 'm', 'n', 'ng', 'nx', 'ow', 'oy', 'p', 'pau', 'pcl', 'q', 'r',
's', 'sh', 't', 'tcl', 'th', 'uh', 'uw', 'ux', 'v', 'w', 'y', 'z', 'zh']
PHONEME_DIC = {v: k for k, v in enumerate(PHONEME_LIST)}
WORD_DIC = {v: k for k, v in enumerate(string.ascii_lowercase + ' ')}
def read_timit_txt(f):
f = open(f)
line = f.readlines()[0].strip().split(' ')
line = line[2:]
line = ' '.join(line)
line = line.replace('.', '').lower()
line = filter(lambda c: c in CHARSET, line)
f.close()
ret = []
for c in line:
ret.append(WORD_DIC[c])
return np.asarray(ret)
def read_timit_phoneme(f):
f = open(f)
pho = []
for line in f:
line = line.strip().split(' ')[-1]
pho.append(PHONEME_DIC[line])
f.close()
return np.asarray(pho)
@memoized
def get_bob_extractor(fs, win_length_ms=10, win_shift_ms=5,
n_filters=55, n_ceps=15, f_min=0., f_max=6000,
delta_win=2, pre_emphasis_coef=0.95, dct_norm=True,
mel_scale=True):
ret = bob.ap.Ceps(fs, win_length_ms, win_shift_ms, n_filters, n_ceps, f_min,
f_max, delta_win, pre_emphasis_coef, mel_scale, dct_norm)
return ret
def diff_feature(feat, nd=1):
diff = feat[1:] - feat[:-1]
feat = feat[1:]
if nd == 1:
return np.concatenate((feat, diff), axis=1)
elif nd == 2:
d2 = diff[1:] - diff[:-1]
return np.concatenate((feat[1:], diff[1:], d2), axis=1)
def get_feature(f):
fs, signal = wavfile.read(f)
signal = signal.astype('float64')
feat = get_bob_extractor(fs, n_filters=26, n_ceps=13)(signal)
feat = diff_feature(feat, nd=2)
return feat
class RawTIMIT(DataFlow):
def __init__(self, dirname, label='phoneme'):
self.dirname = dirname
assert os.path.isdir(dirname), dirname
self.filelists = [k for k in fs.recursive_walk(self.dirname)
if k.endswith('.wav')]
logger.info("Found {} wav files ...".format(len(self.filelists)))
assert len(self.filelists), "Found no '.wav' files!"
assert label in ['phoneme', 'letter'], label
self.label = label
def __len__(self):
return len(self.filelists)
def __iter__(self):
for f in self.filelists:
feat = get_feature(f)
if self.label == 'phoneme':
label = read_timit_phoneme(f[:-4] + '.PHN')
elif self.label == 'letter':
label = read_timit_txt(f[:-4] + '.TXT')
yield [feat, label]
def compute_mean_std(db, fname):
ds = LMDBSerializer.load(db, shuffle=False)
ds.reset_state()
o = OnlineMoments()
for dp in get_tqdm(ds):
feat = dp[0] # len x dim
for f in feat:
o.feed(f)
logger.info("Writing to {} ...".format(fname))
with open(fname, 'wb') as f:
f.write(serialize.dumps([o.mean, o.std]))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(title='command', dest='command')
parser_db = subparsers.add_parser('build', help='build a LMDB database')
parser_db.add_argument('--dataset',
help='path to TIMIT TRAIN or TEST directory', required=True)
parser_db.add_argument('--db', help='output lmdb file', required=True)
parser_stat = subparsers.add_parser('stat', help='compute statistics (mean/std) of dataset')
parser_stat.add_argument('--db', help='input lmdb file', required=True)
parser_stat.add_argument('-o', '--output',
help='output statistics file', default='stats.data')
args = parser.parse_args()
if args.command == 'build':
ds = RawTIMIT(args.dataset)
LMDBSerializer.save(ds, args.db)
elif args.command == 'stat':
compute_mean_std(args.db, args.output)