Skip to content

Commit 3de3c5d

Browse files
authored
Supports for local functions in translator (#96)
* fix suffix * one fix * fix * fix ut * fix ir_version * doc
1 parent 664e084 commit 3de3c5d

File tree

6 files changed

+257
-32
lines changed

6 files changed

+257
-32
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.3.1
55
+++++
66

7+
* :pr:`96`: supports local functions in translator
78
* :pr:`95`: improves translation to GraphBuilder
89

910
0.3.0

_unittests/ut_translate_api/test_translate_builder.py

Lines changed: 125 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
from textwrap import dedent
33
import numpy as np
4+
import onnx.helper as oh
45
from onnx import ModelProto, TensorProto
56
from onnx.checker import check_model
67
from onnx.defs import onnx_opset_version
@@ -29,37 +30,43 @@ def test_exp(self):
2930
self.assertEqualArray(np.exp(a), got)
3031

3132
code = translate(onx, api="builder")
32-
expected = dedent(
33-
"""
33+
expected = (
34+
dedent(
35+
"""
3436
def light_api(
3537
op: "GraphBuilder",
3638
X: "FLOAT[]",
3739
):
38-
Y = op.Exp(X)
40+
Y = op.Exp(X, outputs=['Y'])
3941
op.Identity(Y, outputs=["Y"])
4042
return Y
4143
4244
g = GraphBuilder({'': 19}, ir_version=10)
4345
g.make_tensor_input("X", TensorProto.FLOAT, ())
4446
light_api(g.op, "X")
45-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
47+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
4648
model = g.to_onnx()
4749
"""
48-
).strip("\n")
50+
)
51+
.strip("\n")
52+
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
53+
)
4954
self.assertEqual(expected, code.strip("\n"))
5055

5156
def light_api(
5257
op: "GraphBuilder",
5358
X: "FLOAT[]", # noqa: F722
5459
):
55-
Y = op.Exp(X)
60+
Y = op.Exp(X, outputs=["Y"])
5661
op.Identity(Y, outputs=["Y"])
5762
return Y
5863

5964
g2 = GraphBuilder({"": 19})
6065
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
6166
light_api(g2.op, "X")
62-
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
67+
g2.make_tensor_output(
68+
"Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
69+
)
6370
onx2 = g2.to_onnx()
6471

6572
ref = ReferenceEvaluator(onx2)
@@ -78,25 +85,29 @@ def test_zdoc(self):
7885
.to_onnx()
7986
)
8087
code = translate(onx, api="builder")
81-
expected = dedent(
82-
"""
88+
expected = (
89+
dedent(
90+
"""
8391
def light_api(
8492
op: "GraphBuilder",
8593
X: "FLOAT[]",
8694
):
8795
r = np.array([-1, 1], dtype=np.int64)
88-
r0_0 = op.Reshape(X, r)
89-
Y = op.Transpose(r0_0, perm=[1, 0])
96+
r0_0 = op.Reshape(X, r, outputs=['r0_0'])
97+
Y = op.Transpose(r0_0, perm=[1, 0], outputs=['Y'])
9098
op.Identity(Y, outputs=["Y"])
9199
return Y
92100
93101
g = GraphBuilder({'': 19}, ir_version=10)
94102
g.make_tensor_input("X", TensorProto.FLOAT, ())
95103
light_api(g.op, "X")
96-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
104+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
97105
model = g.to_onnx()
98106
"""
99-
).strip("\n")
107+
)
108+
.strip("\n")
109+
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
110+
)
100111
self.maxDiff = None
101112
self.assertEqual(expected, code.strip("\n"))
102113

@@ -130,13 +141,14 @@ def test_exp_f(self):
130141
tr = Translater(onx, emitter=BuilderEmitter("mm"))
131142
code = tr.export(as_str=True)
132143

133-
expected = dedent(
134-
"""
144+
expected = (
145+
dedent(
146+
"""
135147
def light_api(
136148
op: "GraphBuilder",
137149
X: "FLOAT[]",
138150
):
139-
Y = op.Exp(X)
151+
Y = op.Exp(X, outputs=['Y'])
140152
op.Identity(Y, outputs=["Y"])
141153
return Y
142154
@@ -145,14 +157,17 @@ def mm() -> "ModelProto":
145157
g = GraphBuilder({'': 19}, ir_version=10)
146158
g.make_tensor_input("X", TensorProto.FLOAT, ())
147159
light_api(g.op, "X")
148-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
160+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
149161
model = g.to_onnx()
150162
return model
151163
152164
153165
model = mm()
154166
"""
155-
).strip("\n")
167+
)
168+
.strip("\n")
169+
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
170+
)
156171
self.assertEqual(expected, code.strip("\n"))
157172

158173
def light_api(
@@ -166,14 +181,105 @@ def light_api(
166181
g2 = GraphBuilder({"": 19})
167182
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
168183
light_api(g2.op, "X")
169-
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
184+
g2.make_tensor_output(
185+
"Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
186+
)
170187
onx2 = g2.to_onnx()
171188

172189
ref = ReferenceEvaluator(onx2)
173190
a = np.arange(10).astype(np.float32)
174191
got = ref.run(None, {"X": a})[0]
175192
self.assertEqualArray(np.exp(a), got)
176193

194+
def test_local_function(self):
195+
new_domain = "custom"
196+
197+
linear_regression = oh.make_function(
198+
new_domain,
199+
"LinearRegression",
200+
["x", "a", "b"],
201+
["y"],
202+
[
203+
oh.make_node("MatMul", ["x", "a"], ["xa"]),
204+
oh.make_node("Add", ["xa", "b"], ["y"]),
205+
],
206+
[oh.make_opsetid("", 14)],
207+
[],
208+
)
209+
210+
graph = oh.make_graph(
211+
[
212+
oh.make_node(
213+
"LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain
214+
),
215+
oh.make_node("Abs", ["Y1"], ["Y"]),
216+
],
217+
"example",
218+
[
219+
oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]),
220+
oh.make_tensor_value_info("A", TensorProto.FLOAT, [None, None]),
221+
oh.make_tensor_value_info("B", TensorProto.FLOAT, [None, None]),
222+
],
223+
[oh.make_tensor_value_info("Y", TensorProto.FLOAT, None)],
224+
)
225+
226+
onnx_model = oh.make_model(
227+
graph,
228+
opset_imports=[oh.make_opsetid("", 14), oh.make_opsetid(new_domain, 1)],
229+
functions=[linear_regression],
230+
ir_version=10,
231+
)
232+
tr = Translater(onnx_model, emitter=BuilderEmitter("mm"))
233+
code = tr.export(as_str=True)
234+
235+
expected = (
236+
dedent(
237+
"""
238+
def example(
239+
op: "GraphBuilder",
240+
X: "FLOAT[, ]",
241+
A: "FLOAT[, ]",
242+
B: "FLOAT[, ]",
243+
):
244+
Y1 = op.LinearRegression(X, A, B, domain='custom', outputs=['Y1'])
245+
Y = op.Abs(Y1, outputs=['Y'])
246+
op.Identity(Y, outputs=["Y"])
247+
return Y
248+
249+
250+
def make_custom_LinearRegression(g: "GraphBuilder"):
251+
gr = GraphBuilder({'': 14}, as_function=True)
252+
x = gr.make_tensor_input('x')
253+
a = gr.make_tensor_input('a')
254+
b = gr.make_tensor_input('b')
255+
op = gr.op
256+
xa = op.MatMul(x, a, outputs=['xa'])
257+
y = op.Add(xa, b, outputs=['y'])
258+
gr.make_tensor_output(y)
259+
g.add_function(builder=gr)
260+
return gr
261+
262+
263+
def mm() -> "ModelProto":
264+
g = GraphBuilder({'': 14, 'custom': 1}, ir_version=10)
265+
g.make_tensor_input("X", TensorProto.FLOAT, ('', ''))
266+
g.make_tensor_input("A", TensorProto.FLOAT, ('', ''))
267+
g.make_tensor_input("B", TensorProto.FLOAT, ('', ''))
268+
example(g.op, "X", "A", "B")
269+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
270+
make_custom_LinearRegression(g)
271+
model = g.to_onnx()
272+
return model
273+
274+
275+
model = mm()
276+
"""
277+
)
278+
.strip("\n")
279+
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
280+
)
281+
self.assertEqual(expected, code.strip("\n"))
282+
177283

178284
if __name__ == "__main__":
179285
unittest.main(verbosity=2)

onnx_array_api/graph_api/graph_builder.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __init__(
194194
self._known_shapes = {}
195195
self._known_types = {}
196196
self.constants_ = {}
197+
self.functions_ = {}
197198
elif isinstance(target_opset_or_existing_proto, ModelProto):
198199
assert (
199200
not input_names
@@ -223,6 +224,8 @@ def __init__(
223224
self.constants_[node.output[0]] = node
224225
self.set_shape(node.output[0], self._get_tensor_shape(node))
225226
self.set_type(node.output[0], self._get_tensor_type(node))
227+
for f in proto.functions:
228+
self.add_function(f)
226229
else:
227230
raise NotImplementedError(
228231
f"{type(target_opset_or_existing_proto)} is not supported."
@@ -231,6 +234,14 @@ def __init__(
231234
self.op = Opset(self, self.opsets[""]) if "" in self.opsets else None
232235
self._cache_array = []
233236

237+
def add_local_function(self, domain: str, name: str, gr: "GraphBuilder"):
238+
"Adds a local function."
239+
assert (
240+
domain,
241+
name,
242+
) not in self.functions_, f"Function {(domain, name)} was already added."
243+
self.functions_[domain, name] = gr
244+
234245
def _get_tensor_shape(
235246
self, proto: Union[NodeProto, TensorProto]
236247
) -> Tuple[int, ...]:
@@ -417,6 +428,8 @@ def make_tensor_output(
417428
name: Union[str, List[str]],
418429
elem_type: Optional[int] = None,
419430
shape: Optional[Tuple[int, ...]] = None,
431+
is_dimension: bool = False,
432+
indexed: bool = False,
420433
) -> Union[str, List[str]]:
421434
if isinstance(name, list):
422435
res = []

onnx_array_api/translate_api/base_emitter.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class EventType(IntEnum):
2525
END_SIGNATURE = 16
2626
BEGIN_RETURN = 17
2727
END_RETURN = 18
28+
BEGIN_FUNCTION_SIGNATURE = 19
29+
END_FUNCTION_SIGNATURE = 20
30+
BEGIN_FUNCTION_RETURN = 21
31+
END_FUNCTION_RETURN = 22
2832

2933
@classmethod
3034
def to_str(cls, self) -> str:
@@ -76,6 +80,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
7680
if event == EventType.BEGIN_FUNCTION:
7781
return self._emit_begin_function(**kwargs)
7882

83+
if event == EventType.BEGIN_FUNCTION_SIGNATURE:
84+
return self._emit_begin_function_signature(**kwargs)
85+
86+
if event == EventType.END_FUNCTION_SIGNATURE:
87+
return self._emit_end_function_signature(**kwargs)
88+
7989
if event == EventType.END_FUNCTION:
8090
return self._emit_end_function(**kwargs)
8191

@@ -100,6 +110,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
100110
if event == EventType.END_RETURN:
101111
return self._emit_end_return(**kwargs)
102112

113+
if event == EventType.BEGIN_FUNCTION_RETURN:
114+
return self._emit_begin_function_return(**kwargs)
115+
116+
if event == EventType.END_FUNCTION_RETURN:
117+
return self._emit_end_function_return(**kwargs)
118+
103119
raise ValueError(f"Unexpected event {EventType.to_str(event)}.")
104120

105121
def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
@@ -224,6 +240,12 @@ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
224240
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
225241
)
226242

243+
def _emit_begin_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
244+
return []
245+
246+
def _emit_end_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
247+
return []
248+
227249
def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]:
228250
raise NotImplementedError(
229251
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
@@ -250,3 +272,9 @@ def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
250272

251273
def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
252274
return []
275+
276+
def _emit_begin_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
277+
return []
278+
279+
def _emit_end_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
280+
return []

0 commit comments

Comments
 (0)