@@ -671,21 +671,41 @@ def write_calibration_table(calibration_cache, dir="."):
671
671
import json
672
672
673
673
import flatbuffers
674
+ import numpy as np
674
675
675
676
import onnxruntime .quantization .CalTableFlatBuffers .KeyValue as KeyValue
676
677
import onnxruntime .quantization .CalTableFlatBuffers .TrtTable as TrtTable
678
+ from onnxruntime .quantization .calibrate import CalibrationMethod , TensorData , TensorsData
677
679
678
680
logging .info (f"calibration cache: { calibration_cache } " )
679
681
682
+ class MyEncoder (json .JSONEncoder ):
683
+ def default (self , obj ):
684
+ if isinstance (obj , (TensorData , TensorsData )):
685
+ return obj .to_dict ()
686
+ if isinstance (obj , np .ndarray ):
687
+ return {"data" : obj .tolist (), "dtype" : str (obj .dtype ), "CLS" : "numpy.array" }
688
+ if isinstance (obj , CalibrationMethod ):
689
+ return {"CLS" : obj .__class__ .__name__ , "value" : str (obj )}
690
+ return json .JSONEncoder .default (self , obj )
691
+
692
+ json_data = json .dumps (calibration_cache , cls = MyEncoder )
693
+
680
694
with open (os .path .join (dir , "calibration.json" ), "w" ) as file :
681
- file .write (json . dumps ( calibration_cache ) ) # use `json.loads` to do the reverse
695
+ file .write (json_data ) # use `json.loads` to do the reverse
682
696
683
697
# Serialize data using FlatBuffers
698
+ zero = np .array (0 )
684
699
builder = flatbuffers .Builder (1024 )
685
700
key_value_list = []
686
701
for key in sorted (calibration_cache .keys ()):
687
702
values = calibration_cache [key ]
688
- value = str (max (abs (values [0 ]), abs (values [1 ])))
703
+ d_values = values .to_dict ()
704
+ floats = [
705
+ float (d_values .get ("highest" , zero ).item ()),
706
+ float (d_values .get ("lowest" , zero ).item ()),
707
+ ]
708
+ value = str (max (floats ))
689
709
690
710
flat_key = builder .CreateString (key )
691
711
flat_value = builder .CreateString (value )
@@ -724,9 +744,14 @@ def write_calibration_table(calibration_cache, dir="."):
724
744
# write plain text
725
745
with open (os .path .join (dir , "calibration.cache" ), "w" ) as file :
726
746
for key in sorted (calibration_cache .keys ()):
727
- value = calibration_cache [key ]
728
- s = key + " " + str (max (abs (value [0 ]), abs (value [1 ])))
729
- file .write (s )
747
+ values = calibration_cache [key ]
748
+ d_values = values .to_dict ()
749
+ floats = [
750
+ float (d_values .get ("highest" , zero ).item ()),
751
+ float (d_values .get ("lowest" , zero ).item ()),
752
+ ]
753
+ value = key + " " + str (max (floats ))
754
+ file .write (value )
730
755
file .write ("\n " )
731
756
732
757
0 commit comments