-
Notifications
You must be signed in to change notification settings - Fork 1k
/
calibrator.py
100 lines (84 loc) · 3.51 KB
/
calibrator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
import cv2
import glob
from Processor import letterbox
import ctypes
import logging
logger = logging.getLogger(__name__)
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_char_p
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p]
"""
There are 4 types calibrator in TensorRT.
trt.IInt8LegacyCalibrator
trt.IInt8EntropyCalibrator
trt.IInt8EntropyCalibrator2
trt.IInt8MinMaxCalibrator
"""
class Calibrator(trt.IInt8MinMaxCalibrator):
def __init__(self, stream, cache_file=""):
trt.IInt8MinMaxCalibrator.__init__(self)
self.stream = stream
self.d_input = cuda.mem_alloc(self.stream.calibration_data.nbytes)
self.cache_file = cache_file
stream.reset()
def get_batch_size(self):
return self.stream.batch_size
def get_batch(self, names):
print("######################")
print(names)
print("######################")
batch = self.stream.next_batch()
if not batch.size:
return None
cuda.memcpy_htod(self.d_input, batch)
return [int(self.d_input)]
def read_calibration_cache(self):
# If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
if os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
logger.info("Using calibration cache to save time: {:}".format(self.cache_file))
return f.read()
def write_calibration_cache(self, cache):
with open(self.cache_file, "wb") as f:
logger.info("Caching calibration data for future use: {:}".format(self.cache_file))
f.write(cache)
def precess_image(img_src, img_size, stride):
'''Process image before image inference.'''
image = letterbox(img_src, img_size, auto=False, return_int=True)[0]
# Convert
image = image.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
image = np.ascontiguousarray(image).astype(np.float32)
image /= 255. # 0 - 255 to 0.0 - 1.0
return image
class DataLoader:
def __init__(self, batch_size, batch_num, calib_img_dir, input_w, input_h):
self.index = 0
self.length = batch_num
self.batch_size = batch_size
self.input_h = input_h
self.input_w = input_w
# self.img_list = [i.strip() for i in open('calib.txt').readlines()]
self.img_list = glob.glob(os.path.join(calib_img_dir, "*.jpg"))
assert len(self.img_list) > self.batch_size * self.length, \
'{} must contains more than '.format(calib_img_dir) + str(self.batch_size * self.length) + ' images to calib'
print('found all {} images to calib.'.format(len(self.img_list)))
self.calibration_data = np.zeros((self.batch_size, 3, input_h, input_w), dtype=np.float32)
def reset(self):
self.index = 0
def next_batch(self):
if self.index < self.length:
for i in range(self.batch_size):
assert os.path.exists(self.img_list[i + self.index * self.batch_size]), 'not found!!'
img = cv2.imread(self.img_list[i + self.index * self.batch_size])
img = precess_image(img, self.input_h, 32)
self.calibration_data[i] = img
self.index += 1
return np.ascontiguousarray(self.calibration_data, dtype=np.float32)
else:
return np.array([])
def __len__(self):
return self.length