Skip to content

Commit d4cf92d

Browse files
authored
RSDK-4895: Make numpy an extra dependency (#429)
Co-authored-by: hexbabe <[email protected]>
1 parent dc32c5d commit d4cf92d

File tree

12 files changed

+181
-148
lines changed

12 files changed

+181
-148
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,5 @@ package: clean buf better_imports format test
7070
@echo "TODO: Create pip-installable package"
7171

7272
install:
73-
poetry install
73+
poetry install --all-extras
7474
sh etc/postinstall.sh

poetry.lock

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ include = ["LICENSE", "src/viam/rpc/libviam_rust_utils.*"]
1717
typing-extensions = "^4.7.1"
1818
Pillow = "^10.0.0"
1919
protobuf = "^4.23.4"
20-
numpy = ">=1.21"
20+
numpy = { version = ">=1.21", optional = true }
2121

2222
[tool.poetry.group.dev.dependencies]
2323
pytest = "^7.4.0"
@@ -60,3 +60,6 @@ line_length = 140
6060
[build-system]
6161
requires = [ "poetry-core>=1.0.0" ]
6262
build-backend = "poetry.core.masonry.api"
63+
64+
[tool.poetry.extras]
65+
mlmodel = ["numpy"]

src/viam/services/mlmodel/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
try:
2+
import numpy
3+
except ImportError:
4+
import warnings
5+
6+
warnings.warn(
7+
(
8+
"""MLModel support in the Viam Python SDK requires the installation of an
9+
additional dependency: numpy. Update your package using the extra [mlmodel]
10+
e.g. `pip install viam-sdk[mlmodel]` or the equivalent update in your dependency manager."""
11+
),
12+
)
13+
raise
14+
115
from viam.proto.service.mlmodel import File, LabelType, Metadata, TensorInfo
216
from viam.resource.registry import Registry, ResourceRegistration
317

src/viam/services/mlmodel/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Dict, Mapping, Optional
2-
31
from grpclib.client import Channel
42
from numpy.typing import NDArray
3+
from typing import Dict, Mapping, Optional
54

65
from viam.proto.common import DoCommandRequest, DoCommandResponse
76
from viam.proto.service.mlmodel import InferRequest, InferResponse, MetadataRequest, MetadataResponse, MLModelServiceStub
87
from viam.resource.rpc_client_base import ReconfigurableResourceRPCClientBase
9-
from viam.utils import ValueTypes, dict_to_struct, flat_tensors_to_ndarrays, ndarrays_to_flat_tensors, struct_to_dict
8+
from viam.services.mlmodel.utils import flat_tensors_to_ndarrays, ndarrays_to_flat_tensors
9+
from viam.utils import ValueTypes, dict_to_struct, struct_to_dict
1010

1111
from .mlmodel import Metadata, MLModel
1212

src/viam/services/mlmodel/service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from viam.proto.service.mlmodel import InferRequest, InferResponse, MetadataRequest, MetadataResponse, MLModelServiceBase
44
from viam.resource.rpc_service_base import ResourceRPCServiceBase
5-
from viam.utils import flat_tensors_to_ndarrays, ndarrays_to_flat_tensors
5+
from viam.services.mlmodel.utils import flat_tensors_to_ndarrays, ndarrays_to_flat_tensors
66

77
from .mlmodel import MLModel
88

src/viam/services/mlmodel/utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import numpy as np
2+
from numpy.typing import NDArray
3+
from typing import Dict
4+
5+
from viam.proto.service.mlmodel import (
6+
FlatTensors,
7+
FlatTensor,
8+
FlatTensorDataDouble,
9+
FlatTensorDataFloat,
10+
FlatTensorDataInt16,
11+
FlatTensorDataInt32,
12+
FlatTensorDataInt64,
13+
FlatTensorDataInt8,
14+
FlatTensorDataUInt16,
15+
FlatTensorDataUInt32,
16+
FlatTensorDataUInt64,
17+
FlatTensorDataUInt8,
18+
)
19+
20+
21+
def flat_tensors_to_ndarrays(flat_tensors: FlatTensors) -> Dict[str, NDArray]:
22+
property_name_to_dtype = {
23+
"float_tensor": np.float32,
24+
"double_tensor": np.float64,
25+
"int8_tensor": np.int8,
26+
"int16_tensor": np.int16,
27+
"int32_tensor": np.int32,
28+
"int64_tensor": np.int64,
29+
"uint8_tensor": np.uint8,
30+
"uint16_tensor": np.uint16,
31+
"uint32_tensor": np.uint32,
32+
"uint64_tensor": np.uint64,
33+
}
34+
35+
def make_ndarray(flat_data, dtype, shape):
36+
"""Takes flat data (protobuf RepeatedScalarFieldContainer | bytes) to output an ndarray
37+
of appropriate dtype and shape"""
38+
make_array = np.frombuffer if dtype == np.int8 or dtype == np.uint8 else np.array
39+
return make_array(flat_data, dtype).reshape(shape)
40+
41+
ndarrays: Dict[str, NDArray] = dict()
42+
for name, flat_tensor in flat_tensors.tensors.items():
43+
property_name = flat_tensor.WhichOneof("tensor") or flat_tensor.WhichOneof(b"tensor")
44+
if property_name:
45+
tensor_data = getattr(flat_tensor, property_name)
46+
flat_data, dtype, shape = tensor_data.data, property_name_to_dtype[property_name], flat_tensor.shape
47+
ndarrays[name] = make_ndarray(flat_data, dtype, shape)
48+
return ndarrays
49+
50+
51+
def ndarrays_to_flat_tensors(ndarrays: Dict[str, NDArray]) -> FlatTensors:
52+
dtype_name_to_tensor_data_class = {
53+
"float32": FlatTensorDataFloat,
54+
"float64": FlatTensorDataDouble,
55+
"int8": FlatTensorDataInt8,
56+
"int16": FlatTensorDataInt16,
57+
"int32": FlatTensorDataInt32,
58+
"int64": FlatTensorDataInt64,
59+
"uint8": FlatTensorDataUInt8,
60+
"uint16": FlatTensorDataUInt16,
61+
"uint32": FlatTensorDataUInt32,
62+
"uint64": FlatTensorDataUInt64,
63+
}
64+
65+
def get_tensor_data(ndarray: NDArray):
66+
"""Takes an ndarray and returns the corresponding tensor data class instance
67+
e.g. FlatTensorDataInt8, FlatTensorDataUInt8 etc."""
68+
tensor_data_class = dtype_name_to_tensor_data_class[ndarray.dtype.name]
69+
data = ndarray.flatten()
70+
if tensor_data_class == FlatTensorDataInt8 or tensor_data_class == FlatTensorDataUInt8:
71+
data = data.tobytes() # as per the proto, int8 and uint8 are stored as bytes
72+
elif tensor_data_class == FlatTensorDataInt16 or tensor_data_class == FlatTensorDataUInt16:
73+
data = data.astype(np.uint32) # as per the proto, int16 and uint16 are stored as uint32
74+
tensor_data = tensor_data_class(data=data)
75+
return tensor_data
76+
77+
def get_tensor_data_type(ndarray: NDArray):
78+
"""Takes ndarray and returns a FlatTensor datatype property to be set
79+
e.g. "float_tensor", "uint32_tensor" etc."""
80+
if ndarray.dtype == np.float32:
81+
return "float_tensor"
82+
elif ndarray.dtype == np.float64:
83+
return "double_tensor"
84+
return f"{ndarray.dtype.name}_tensor"
85+
86+
tensors_mapping: Dict[str, FlatTensor] = dict()
87+
for name, ndarray in ndarrays.items():
88+
prop_name, prop_value = get_tensor_data_type(ndarray), get_tensor_data(ndarray)
89+
tensors_mapping[name] = FlatTensor(shape=ndarray.shape, **{prop_name: prop_value})
90+
return FlatTensors(tensors=tensors_mapping)

src/viam/utils.py

Lines changed: 2 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,23 @@
11
import asyncio
22
import contextvars
33
import functools
4+
45
import sys
56
import threading
67
from datetime import datetime
78
from typing import Any, Dict, List, Mapping, Optional, SupportsBytes, SupportsFloat, Type, TypeVar, Union
89

9-
import numpy as np
1010
from google.protobuf.json_format import MessageToDict, ParseDict
1111
from google.protobuf.message import Message
1212
from google.protobuf.struct_pb2 import ListValue, Struct, Value
1313
from google.protobuf.timestamp_pb2 import Timestamp
14-
from numpy.typing import NDArray
1514

1615
from viam.proto.common import Geometry, GeoPoint, GetGeometriesRequest, GetGeometriesResponse, Orientation, ResourceName, Vector3
17-
from viam.proto.service.mlmodel import (
18-
FlatTensor,
19-
FlatTensorDataDouble,
20-
FlatTensorDataFloat,
21-
FlatTensorDataInt8,
22-
FlatTensorDataInt16,
23-
FlatTensorDataInt32,
24-
FlatTensorDataInt64,
25-
FlatTensorDataUInt8,
26-
FlatTensorDataUInt16,
27-
FlatTensorDataUInt32,
28-
FlatTensorDataUInt64,
29-
FlatTensors,
30-
)
3116
from viam.resource.base import ResourceBase
3217
from viam.resource.registry import Registry
3318
from viam.resource.types import Subtype, SupportsGetGeometries
3419

20+
3521
if sys.version_info >= (3, 9):
3622
from collections.abc import Callable
3723
else:
@@ -165,78 +151,6 @@ def _convert(v: ValueTypes) -> Any:
165151
return struct
166152

167153

168-
def flat_tensors_to_ndarrays(flat_tensors: FlatTensors) -> Dict[str, NDArray]:
169-
property_name_to_dtype = {
170-
"float_tensor": np.float32,
171-
"double_tensor": np.float64,
172-
"int8_tensor": np.int8,
173-
"int16_tensor": np.int16,
174-
"int32_tensor": np.int32,
175-
"int64_tensor": np.int64,
176-
"uint8_tensor": np.uint8,
177-
"uint16_tensor": np.uint16,
178-
"uint32_tensor": np.uint32,
179-
"uint64_tensor": np.uint64,
180-
}
181-
182-
def make_ndarray(flat_data, dtype, shape):
183-
"""Takes flat data (protobuf RepeatedScalarFieldContainer | bytes) to output an ndarray
184-
of appropriate dtype and shape"""
185-
make_array = np.frombuffer if dtype == np.int8 or dtype == np.uint8 else np.array
186-
return make_array(flat_data, dtype).reshape(shape)
187-
188-
ndarrays: Dict[str, NDArray] = dict()
189-
for name, flat_tensor in flat_tensors.tensors.items():
190-
property_name = flat_tensor.WhichOneof("tensor") or flat_tensor.WhichOneof(b"tensor")
191-
if property_name:
192-
tensor_data = getattr(flat_tensor, property_name)
193-
flat_data, dtype, shape = tensor_data.data, property_name_to_dtype[property_name], flat_tensor.shape
194-
ndarrays[name] = make_ndarray(flat_data, dtype, shape)
195-
return ndarrays
196-
197-
198-
def ndarrays_to_flat_tensors(ndarrays: Dict[str, NDArray]) -> FlatTensors:
199-
dtype_name_to_tensor_data_class = {
200-
"float32": FlatTensorDataFloat,
201-
"float64": FlatTensorDataDouble,
202-
"int8": FlatTensorDataInt8,
203-
"int16": FlatTensorDataInt16,
204-
"int32": FlatTensorDataInt32,
205-
"int64": FlatTensorDataInt64,
206-
"uint8": FlatTensorDataUInt8,
207-
"uint16": FlatTensorDataUInt16,
208-
"uint32": FlatTensorDataUInt32,
209-
"uint64": FlatTensorDataUInt64,
210-
}
211-
212-
def get_tensor_data(ndarray: NDArray):
213-
"""Takes an ndarray and returns the corresponding tensor data class instance
214-
e.g. FlatTensorDataInt8, FlatTensorDataUInt8 etc."""
215-
tensor_data_class = dtype_name_to_tensor_data_class[ndarray.dtype.name]
216-
data = ndarray.flatten()
217-
if tensor_data_class == FlatTensorDataInt8 or tensor_data_class == FlatTensorDataUInt8:
218-
data = data.tobytes() # as per the proto, int8 and uint8 are stored as bytes
219-
elif tensor_data_class == FlatTensorDataInt16 or tensor_data_class == FlatTensorDataUInt16:
220-
data = data.astype(np.uint32) # as per the proto, int16 and uint16 are stored as uint32
221-
tensor_data = tensor_data_class(data=data)
222-
return tensor_data
223-
224-
def get_tensor_data_type(ndarray: NDArray):
225-
"""Takes ndarray and returns a FlatTensor datatype property to be set
226-
e.g. "float_tensor", "uint32_tensor" etc."""
227-
if ndarray.dtype == np.float32:
228-
return "float_tensor"
229-
elif ndarray.dtype == np.float64:
230-
return "double_tensor"
231-
return f"{ndarray.dtype.name}_tensor"
232-
233-
tensors_mapping: Dict[str, FlatTensor] = dict()
234-
for name, ndarray in ndarrays.items():
235-
prop_name, prop_value = get_tensor_data_type(ndarray), get_tensor_data(ndarray)
236-
tensors_mapping[name] = FlatTensor(shape=ndarray.shape, **{prop_name: prop_value})
237-
return FlatTensors(tensors=tensors_mapping)
238-
239-
240154
def struct_to_dict(struct: Struct) -> Dict[str, ValueTypes]:
241155
return {key: value_to_primitive(value) for (key, value) in struct.fields.items()}
242156

tests/mocks/services.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,11 @@
227227
from viam.proto.service.vision import Classification, Detection
228228
from viam.app.data_client import DataClient
229229
from viam.services.mlmodel import File, LabelType, Metadata, MLModel, TensorInfo
230+
from viam.services.mlmodel.utils import ndarrays_to_flat_tensors, flat_tensors_to_ndarrays
230231
from viam.services.navigation import Navigation
231232
from viam.services.slam import SLAM
232233
from viam.services.vision import Vision
233-
from viam.utils import ValueTypes, datetime_to_timestamp, dict_to_struct, struct_to_dict, ndarrays_to_flat_tensors, flat_tensors_to_ndarrays
234+
from viam.utils import ValueTypes, datetime_to_timestamp, dict_to_struct, struct_to_dict
234235

235236

236237
class MockVision(Vision):

tests/test_mlmodel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def setup_class(cls):
6464
cls.manager = ResourceManager([cls.mlmodel])
6565
cls.service = MLModelRPCService(cls.manager)
6666

67+
# ignore warning about our out-of-bound int casting (i.e. uint32 -> int16) because we don't store uint32s for int16 & uint16 tensor
68+
# data > 2^16-1 in the first place (inherently they are int16, we just cast them to uint32 for the grpc message)
69+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
6770
@pytest.mark.asyncio
6871
async def test_infer(self):
6972
async with ChannelFor([self.service]) as channel:

0 commit comments

Comments
 (0)