forked from onnx/onnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
139 lines (112 loc) · 4.58 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from collections import namedtuple
from typing import Any, NewType, Sequence
import numpy
import onnx.checker
import onnx.onnx_cpp2py_export.checker as c_checker
from onnx import IR_VERSION, ModelProto, NodeProto
class DeviceType:
"""Describes device type."""
_Type = NewType("_Type", int)
CPU: _Type = _Type(0)
CUDA: _Type = _Type(1)
class Device:
"""Describes device type and device id
syntax: device_type:device_id(optional)
example: 'CPU', 'CUDA', 'CUDA:1'
"""
def __init__(self, device: str) -> None:
options = device.split(":")
self.type = getattr(DeviceType, options[0])
self.device_id = 0
if len(options) > 1:
self.device_id = int(options[1])
def namedtupledict(
typename: str, field_names: Sequence[str], *args: Any, **kwargs: Any
) -> type[tuple[Any, ...]]:
field_names_map = {n: i for i, n in enumerate(field_names)}
# Some output names are invalid python identifier, e.g. "0"
kwargs.setdefault("rename", True)
data = namedtuple(typename, field_names, *args, **kwargs) # type: ignore # noqa: PYI024
def getitem(self: Any, key: Any) -> Any:
if isinstance(key, str):
key = field_names_map[key]
return super(type(self), self).__getitem__(key) # type: ignore
data.__getitem__ = getitem # type: ignore[assignment]
return data
class BackendRep:
"""BackendRep is the handle that a Backend returns after preparing to execute
a model repeatedly. Users will then pass inputs to the run function of
BackendRep to retrieve the corresponding results.
"""
def run(self, inputs: Any, **kwargs: Any) -> tuple[Any, ...]: # noqa: ARG002
"""Abstract function."""
return (None,)
class Backend:
"""Backend is the entity that will take an ONNX model with inputs,
perform a computation, and then return the output.
For one-off execution, users can use run_node and run_model to obtain results quickly.
For repeated execution, users should use prepare, in which the Backend
does all of the preparation work for executing the model repeatedly
(e.g., loading initializers), and returns a BackendRep handle.
"""
@classmethod
def is_compatible(
cls, model: ModelProto, device: str = "CPU", **kwargs: Any # noqa: ARG003
) -> bool:
# Return whether the model is compatible with the backend.
return True
@classmethod
def prepare(
cls, model: ModelProto, device: str = "CPU", **kwargs: Any # noqa: ARG003
) -> BackendRep | None:
# TODO Remove Optional from return type
onnx.checker.check_model(model)
return None
@classmethod
def run_model(
cls, model: ModelProto, inputs: Any, device: str = "CPU", **kwargs: Any
) -> tuple[Any, ...]:
backend = cls.prepare(model, device, **kwargs)
assert backend is not None
return backend.run(inputs)
@classmethod
def run_node(
cls,
node: NodeProto,
inputs: Any, # noqa: ARG003
device: str = "CPU", # noqa: ARG003
outputs_info: ( # noqa: ARG003
Sequence[tuple[numpy.dtype, tuple[int, ...]]] | None
) = None,
**kwargs: dict[str, Any],
) -> tuple[Any, ...] | None:
"""Simple run one operator and return the results.
Args:
node: The node proto.
inputs: Inputs to the node.
device: The device to run on.
outputs_info: a list of tuples, which contains the element type and
shape of each output. First element of the tuple is the dtype, and
the second element is the shape. More use case can be found in
https://github.com/onnx/onnx/blob/main/onnx/backend/test/runner/__init__.py
kwargs: Other keyword arguments.
"""
# TODO Remove Optional from return type
if "opset_version" in kwargs:
special_context = c_checker.CheckerContext()
special_context.ir_version = IR_VERSION
special_context.opset_imports = {"": kwargs["opset_version"]} # type: ignore
onnx.checker.check_node(node, special_context)
else:
onnx.checker.check_node(node)
return None
@classmethod
def supports_device(cls, device: str) -> bool: # noqa: ARG003
"""Checks whether the backend is compiled with particular device support.
In particular it's used in the testing suite.
"""
return True