forked from WindQAQ/listen-attend-and-spell
-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer.py
85 lines (59 loc) · 2.84 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import numpy as np
import tensorflow as tf
import utils
from model_helper import las_model_fn
def parse_args():
parser = argparse.ArgumentParser(
description='Listen, Attend and Spell(LAS) implementation based on Tensorflow. '
'The model utilizes input pipeline and estimator API of Tensorflow, '
'which makes the training procedure truly end-to-end.')
parser.add_argument('--data', type=str, required=True,
help='inference data in TFRecord format')
parser.add_argument('--vocab', type=str, required=True,
help='vocabulary table, listing vocabulary line by line')
parser.add_argument('--model_dir', type=str, required=True,
help='path of imported model')
parser.add_argument('--save', type=str, required=True,
help='path of saving inference results')
parser.add_argument('--beam_width', type=int, default=0,
help='number of beams (default 0: using greedy decoding)')
parser.add_argument('--batch_size', type=int, default=8,
help='batch size')
parser.add_argument('--num_channels', type=int, default=39,
help='number of input channels')
return parser.parse_args()
def input_fn(dataset_filename, vocab_filename, num_channels=39, batch_size=8, num_epochs=1):
dataset = utils.read_dataset(dataset_filename, num_channels)
vocab_table = utils.create_vocab_table(vocab_filename)
dataset = utils.process_dataset(
dataset, vocab_table, utils.SOS, utils.EOS, batch_size, num_epochs, is_infer=True)
return dataset
def main(args):
vocab_list = np.array(utils.load_vocab(args.vocab))
vocab_size = len(vocab_list)
config = tf.estimator.RunConfig(model_dir=args.model_dir)
hparams = utils.create_hparams(
args, vocab_size, utils.SOS_ID, utils.EOS_ID)
hparams.decoder.set_hparam('beam_width', args.beam_width)
model = tf.estimator.Estimator(
model_fn=las_model_fn,
config=config,
params=hparams)
predictions = model.predict(
input_fn=lambda: input_fn(
args.data, args.vocab, num_channels=args.num_channels, batch_size=args.batch_size, num_epochs=1),
predict_keys='sample_ids')
if args.beam_width > 0:
predictions = [vocab_list[y['sample_ids'][:, 0]].tolist() + [utils.EOS]
for y in predictions]
else:
predictions = [vocab_list[y['sample_ids']].tolist() + [utils.EOS]
for y in predictions]
predictions = [' '.join(y[:y.index(utils.EOS)]) for y in predictions]
with open(args.save, 'w') as f:
f.write('\n'.join(predictions))
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
args = parse_args()
main(args)