Skip to content

Commit 407c1ab

Browse files
authored
Fix conversion of TensorData, TensorsData to json (microsoft#22166)
### Description Fix write_calibration_table to support TensorData, TensorsData
1 parent 280c013 commit 407c1ab

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

onnxruntime/python/tools/quantization/calibrate.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class TensorData:
6969
_floats = frozenset(["avg", "std", "lowest", "highest", "hist_edges"])
7070

7171
def __init__(self, **kwargs):
72+
self._attrs = list(kwargs.keys())
7273
for k, v in kwargs.items():
7374
if k not in TensorData._allowed:
7475
raise ValueError(f"Unexpected value {k!r} not in {TensorData._allowed}.")
@@ -91,6 +92,12 @@ def avg_std(self):
9192
raise AttributeError(f"Attributes 'avg' and/or 'std' missing in {dir(self)}.")
9293
return (self.avg, self.std)
9394

95+
def to_dict(self):
96+
# This is needed to serialize the data into JSON.
97+
data = {k: getattr(self, k) for k in self._attrs}
98+
data["CLS"] = self.__class__.__name__
99+
return data
100+
94101

95102
class TensorsData:
96103
def __init__(self, calibration_method, data: Dict[str, Union[TensorData, Tuple]]):
@@ -125,12 +132,24 @@ def __setitem__(self, key, value):
125132
raise RuntimeError(f"Only an existing tensor can be modified, {key!r} is not.")
126133
self.data[key] = value
127134

135+
def keys(self):
136+
return self.data.keys()
137+
128138
def values(self):
129139
return self.data.values()
130140

131141
def items(self):
132142
return self.data.items()
133143

144+
def to_dict(self):
145+
# This is needed to serialize the data into JSON.
146+
data = {
147+
"CLS": self.__class__.__name__,
148+
"data": self.data,
149+
"calibration_method": self.calibration_method,
150+
}
151+
return data
152+
134153

135154
class CalibrationMethod(Enum):
136155
MinMax = 0

onnxruntime/python/tools/quantization/quant_utils.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -671,21 +671,41 @@ def write_calibration_table(calibration_cache, dir="."):
671671
import json
672672

673673
import flatbuffers
674+
import numpy as np
674675

675676
import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue
676677
import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable
678+
from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData
677679

678680
logging.info(f"calibration cache: {calibration_cache}")
679681

682+
class MyEncoder(json.JSONEncoder):
683+
def default(self, obj):
684+
if isinstance(obj, (TensorData, TensorsData)):
685+
return obj.to_dict()
686+
if isinstance(obj, np.ndarray):
687+
return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"}
688+
if isinstance(obj, CalibrationMethod):
689+
return {"CLS": obj.__class__.__name__, "value": str(obj)}
690+
return json.JSONEncoder.default(self, obj)
691+
692+
json_data = json.dumps(calibration_cache, cls=MyEncoder)
693+
680694
with open(os.path.join(dir, "calibration.json"), "w") as file:
681-
file.write(json.dumps(calibration_cache)) # use `json.loads` to do the reverse
695+
file.write(json_data) # use `json.loads` to do the reverse
682696

683697
# Serialize data using FlatBuffers
698+
zero = np.array(0)
684699
builder = flatbuffers.Builder(1024)
685700
key_value_list = []
686701
for key in sorted(calibration_cache.keys()):
687702
values = calibration_cache[key]
688-
value = str(max(abs(values[0]), abs(values[1])))
703+
d_values = values.to_dict()
704+
floats = [
705+
float(d_values.get("highest", zero).item()),
706+
float(d_values.get("lowest", zero).item()),
707+
]
708+
value = str(max(floats))
689709

690710
flat_key = builder.CreateString(key)
691711
flat_value = builder.CreateString(value)
@@ -724,9 +744,14 @@ def write_calibration_table(calibration_cache, dir="."):
724744
# write plain text
725745
with open(os.path.join(dir, "calibration.cache"), "w") as file:
726746
for key in sorted(calibration_cache.keys()):
727-
value = calibration_cache[key]
728-
s = key + " " + str(max(abs(value[0]), abs(value[1])))
729-
file.write(s)
747+
values = calibration_cache[key]
748+
d_values = values.to_dict()
749+
floats = [
750+
float(d_values.get("highest", zero).item()),
751+
float(d_values.get("lowest", zero).item()),
752+
]
753+
value = key + " " + str(max(floats))
754+
file.write(value)
730755
file.write("\n")
731756

732757

onnxruntime/test/python/quantization/test_qdq.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
create_clip_node,
2323
)
2424

25-
from onnxruntime.quantization import QDQQuantizer, QuantFormat, QuantType, quantize_static
26-
from onnxruntime.quantization.calibrate import TensorData
25+
from onnxruntime.quantization import QDQQuantizer, QuantFormat, QuantType, quantize_static, write_calibration_table
26+
from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData
2727

2828

2929
class TestQDQFormat(unittest.TestCase):
@@ -1720,6 +1720,11 @@ def test_int4_qdq_per_channel_conv(self):
17201720
size_ratio = weight_quant_init.ByteSize() / unpacked_size
17211721
self.assertLess(size_ratio, 0.55)
17221722

1723+
def test_json_serialization(self):
1724+
td = TensorData(lowest=np.array([0.1], dtype=np.float32), highest=np.array([1.1], dtype=np.float32))
1725+
new_calibrate_tensors_range = TensorsData(CalibrationMethod.MinMax, {"td": td})
1726+
write_calibration_table(new_calibrate_tensors_range)
1727+
17231728

17241729
if __name__ == "__main__":
17251730
unittest.main()

0 commit comments

Comments
 (0)