Skip to content

Commit eb106e2

Browse files
authored
Export evaluator type in compare_onnx_execution (#93)
* Export evaluator type in compare_onnx_execution * doc * doc
1 parent 07c3683 commit eb106e2

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

LICENSE.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Copyright (c) 2023-2024, Xavier Dupré
1+
Copyright (c) 2023-2025, Xavier Dupré
22

33
Permission is hereby granted, free of charge, to any person obtaining a copy
44
of this software and associated documentation files (the "Software"), to deal

onnx_array_api/reference/evaluator_yield.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from enum import IntEnum
44
import numpy as np
55
from onnx import ModelProto, TensorProto, ValueInfoProto, load
6+
from onnx.reference import ReferenceEvaluator
67
from onnx.helper import tensor_dtype_to_np_dtype
78
from onnx.shape_inference import infer_shapes
89
from . import to_array_extended
@@ -138,17 +139,23 @@ class YieldEvaluator:
138139
139140
:param onnx_model: model to run
140141
:param recursive: dig into subgraph and functions as well
142+
:param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator
143+
<onnx_array_api.reference.ExtendedReferenceEvaluator>`
141144
"""
142145

143146
def __init__(
144147
self,
145148
onnx_model: ModelProto,
146149
recursive: bool = False,
147-
cls=ExtendedReferenceEvaluator,
150+
cls: Optional[type[ExtendedReferenceEvaluator]] = None,
148151
):
149152
assert not recursive, "recursive=True is not yet implemented"
150153
self.onnx_model = onnx_model
151-
self.evaluator = cls(onnx_model) if cls is not None else None
154+
self.evaluator = (
155+
cls(onnx_model)
156+
if cls is not None
157+
else ExtendedReferenceEvaluator(onnx_model)
158+
)
152159

153160
def enumerate_results(
154161
self,
@@ -166,9 +173,9 @@ def enumerate_results(
166173
Returns:
167174
iterator on tuple(result kind, name, value, node.op_type or None)
168175
"""
169-
assert isinstance(self.evaluator, ExtendedReferenceEvaluator), (
176+
assert isinstance(self.evaluator, ReferenceEvaluator), (
170177
f"This implementation only works with "
171-
f"ExtendedReferenceEvaluator not {type(self.evaluator)}"
178+
f"ReferenceEvaluator not {type(self.evaluator)}"
172179
)
173180
attributes = {}
174181
if output_names is None:
@@ -595,6 +602,7 @@ def compare_onnx_execution(
595602
raise_exc: bool = True,
596603
mode: str = "execute",
597604
keep_tensor: bool = False,
605+
cls: Optional[type[ReferenceEvaluator]] = None,
598606
) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]:
599607
"""
600608
Compares the execution of two onnx models.
@@ -611,6 +619,7 @@ def compare_onnx_execution(
611619
:param mode: the model should be executed but the function can be executed
612620
but the comparison may append on nodes only
613621
:param keep_tensor: keeps the tensor in order to compute a precise distance
622+
:param cls: evaluator class to use
614623
:return: four results, a sequence of results
615624
for the first model and the second model,
616625
the alignment between the two, DistanceExecution
@@ -634,15 +643,15 @@ def compare_onnx_execution(
634643
print(f"[compare_onnx_execution] execute with {len(inputs)} inputs")
635644
print("[compare_onnx_execution] execute first model")
636645
res1 = list(
637-
YieldEvaluator(model1).enumerate_summarized(
646+
YieldEvaluator(model1, cls=cls).enumerate_summarized(
638647
None, feeds1, raise_exc=raise_exc, keep_tensor=keep_tensor
639648
)
640649
)
641650
if verbose:
642651
print(f"[compare_onnx_execution] got {len(res1)} results")
643652
print("[compare_onnx_execution] execute second model")
644653
res2 = list(
645-
YieldEvaluator(model2).enumerate_summarized(
654+
YieldEvaluator(model2, cls=cls).enumerate_summarized(
646655
None, feeds2, raise_exc=raise_exc, keep_tensor=keep_tensor
647656
)
648657
)

0 commit comments

Comments
 (0)