forked from mlcommons/inference
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpytorch_SUT.py
124 lines (107 loc) · 5.02 KB
/
pytorch_SUT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) 2020, Cerebras Systems, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd(), "pytorch"))
import array
import torch
import numpy as np
import toml
import mlperf_loadgen as lg
from tqdm import tqdm
from QSL import AudioQSL, AudioQSLInMemory
from decoders import ScriptGreedyDecoder
from helpers import add_blank_label
from preprocessing import AudioPreprocessing
from model_separable_rnnt import RNNT
def load_and_migrate_checkpoint(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location="cpu")
migrated_state_dict = {}
for key, value in checkpoint['state_dict'].items():
key = key.replace("joint_net", "joint.net")
migrated_state_dict[key] = value
del migrated_state_dict["audio_preprocessor.featurizer.fb"]
del migrated_state_dict["audio_preprocessor.featurizer.window"]
return migrated_state_dict
class PytorchSUT:
def __init__(self, config_toml, checkpoint_path, dataset_dir,
manifest_filepath, perf_count):
config = toml.load(config_toml)
dataset_vocab = config['labels']['labels']
rnnt_vocab = add_blank_label(dataset_vocab)
featurizer_config = config['input_eval']
self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries,
self.process_latencies)
self.qsl = AudioQSLInMemory(dataset_dir,
manifest_filepath,
dataset_vocab,
featurizer_config["sample_rate"],
perf_count)
self.audio_preprocessor = AudioPreprocessing(**featurizer_config)
self.audio_preprocessor.eval()
self.audio_preprocessor = torch.jit.script(self.audio_preprocessor)
self.audio_preprocessor = torch.jit._recursive.wrap_cpp_module(
torch._C._freeze_module(self.audio_preprocessor._c))
model = RNNT(
feature_config=featurizer_config,
rnnt=config['rnnt'],
num_classes=len(rnnt_vocab)
)
model.load_state_dict(load_and_migrate_checkpoint(checkpoint_path),
strict=True)
model.eval()
model.encoder = torch.jit.script(model.encoder)
model.encoder = torch.jit._recursive.wrap_cpp_module(
torch._C._freeze_module(model.encoder._c))
model.prediction = torch.jit.script(model.prediction)
model.prediction = torch.jit._recursive.wrap_cpp_module(
torch._C._freeze_module(model.prediction._c))
model.joint = torch.jit.script(model.joint)
model.joint = torch.jit._recursive.wrap_cpp_module(
torch._C._freeze_module(model.joint._c))
model = torch.jit.script(model)
self.greedy_decoder = ScriptGreedyDecoder(len(rnnt_vocab) - 1, model)
def issue_queries(self, query_samples):
for query_sample in query_samples:
waveform = self.qsl[query_sample.index]
assert waveform.ndim == 1
waveform_length = np.array(waveform.shape[0], dtype=np.int64)
waveform = np.expand_dims(waveform, 0)
waveform_length = np.expand_dims(waveform_length, 0)
with torch.no_grad():
waveform = torch.from_numpy(waveform)
waveform_length = torch.from_numpy(waveform_length)
feature, feature_length = self.audio_preprocessor.forward((waveform, waveform_length))
assert feature.ndim == 3
assert feature_length.ndim == 1
feature = feature.permute(2, 0, 1)
_, _, transcript = self.greedy_decoder.forward(feature, feature_length)
assert len(transcript) == 1
response_array = array.array('q', transcript[0])
bi = response_array.buffer_info()
response = lg.QuerySampleResponse(query_sample.id, bi[0],
bi[1] * response_array.itemsize)
lg.QuerySamplesComplete([response])
def flush_queries(self):
pass
def process_latencies(self, latencies_ns):
print("Average latency (ms) per query:")
print(np.mean(latencies_ns)/1000000.0)
print("Median latency (ms): ")
print(np.percentile(latencies_ns, 50)/1000000.0)
print("90 percentile latency (ms): ")
print(np.percentile(latencies_ns, 90)/1000000.0)
def __del__(self):
lg.DestroySUT(self.sut)
print("Finished destroying SUT.")