3
3
from enum import IntEnum
4
4
import numpy as np
5
5
from onnx import ModelProto , TensorProto , ValueInfoProto , load
6
+ from onnx .reference import ReferenceEvaluator
6
7
from onnx .helper import tensor_dtype_to_np_dtype
7
8
from onnx .shape_inference import infer_shapes
8
9
from . import to_array_extended
@@ -138,17 +139,23 @@ class YieldEvaluator:
138
139
139
140
:param onnx_model: model to run
140
141
:param recursive: dig into subgraph and functions as well
142
+ :param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator
143
+ <onnx_array_api.reference.ExtendedReferenceEvaluator>`
141
144
"""
142
145
143
146
def __init__ (
144
147
self ,
145
148
onnx_model : ModelProto ,
146
149
recursive : bool = False ,
147
- cls = ExtendedReferenceEvaluator ,
150
+ cls : Optional [ type [ ExtendedReferenceEvaluator ]] = None ,
148
151
):
149
152
assert not recursive , "recursive=True is not yet implemented"
150
153
self .onnx_model = onnx_model
151
- self .evaluator = cls (onnx_model ) if cls is not None else None
154
+ self .evaluator = (
155
+ cls (onnx_model )
156
+ if cls is not None
157
+ else ExtendedReferenceEvaluator (onnx_model )
158
+ )
152
159
153
160
def enumerate_results (
154
161
self ,
@@ -166,9 +173,9 @@ def enumerate_results(
166
173
Returns:
167
174
iterator on tuple(result kind, name, value, node.op_type or None)
168
175
"""
169
- assert isinstance (self .evaluator , ExtendedReferenceEvaluator ), (
176
+ assert isinstance (self .evaluator , ReferenceEvaluator ), (
170
177
f"This implementation only works with "
171
- f"ExtendedReferenceEvaluator not { type (self .evaluator )} "
178
+ f"ReferenceEvaluator not { type (self .evaluator )} "
172
179
)
173
180
attributes = {}
174
181
if output_names is None :
@@ -595,6 +602,7 @@ def compare_onnx_execution(
595
602
raise_exc : bool = True ,
596
603
mode : str = "execute" ,
597
604
keep_tensor : bool = False ,
605
+ cls : Optional [type [ReferenceEvaluator ]] = None ,
598
606
) -> Tuple [List [ResultExecution ], List [ResultExecution ], List [Tuple [int , int ]]]:
599
607
"""
600
608
Compares the execution of two onnx models.
@@ -611,6 +619,7 @@ def compare_onnx_execution(
611
619
:param mode: the model should be executed but the function can be executed
612
620
but the comparison may append on nodes only
613
621
:param keep_tensor: keeps the tensor in order to compute a precise distance
622
+ :param cls: evaluator class to use
614
623
:return: four results, a sequence of results
615
624
for the first model and the second model,
616
625
the alignment between the two, DistanceExecution
@@ -634,15 +643,15 @@ def compare_onnx_execution(
634
643
print (f"[compare_onnx_execution] execute with { len (inputs )} inputs" )
635
644
print ("[compare_onnx_execution] execute first model" )
636
645
res1 = list (
637
- YieldEvaluator (model1 ).enumerate_summarized (
646
+ YieldEvaluator (model1 , cls = cls ).enumerate_summarized (
638
647
None , feeds1 , raise_exc = raise_exc , keep_tensor = keep_tensor
639
648
)
640
649
)
641
650
if verbose :
642
651
print (f"[compare_onnx_execution] got { len (res1 )} results" )
643
652
print ("[compare_onnx_execution] execute second model" )
644
653
res2 = list (
645
- YieldEvaluator (model2 ).enumerate_summarized (
654
+ YieldEvaluator (model2 , cls = cls ).enumerate_summarized (
646
655
None , feeds2 , raise_exc = raise_exc , keep_tensor = keep_tensor
647
656
)
648
657
)
0 commit comments