diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index f47f8f63..a8dd1043 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -41,6 +41,7 @@ ) from ._settings import settings from .axis import Axis, AxisId +from .backends import create_model_adapter from .block_meta import BlockMeta from .common import MemberId from .prediction import predict, predict_many @@ -73,6 +74,7 @@ "commands", "common", "compute_dataset_measures", + "create_model_adapter", "create_prediction_pipeline", "digest_spec", "dump_description", @@ -104,7 +106,6 @@ "Stat", "tensor", "Tensor", - "test_description_in_conda_env", "test_description", "test_model", "test_resource", diff --git a/bioimageio/core/_create_model_adapter.py b/bioimageio/core/_create_model_adapter.py deleted file mode 100644 index ee79f260..00000000 --- a/bioimageio/core/_create_model_adapter.py +++ /dev/null @@ -1,127 +0,0 @@ -import warnings -from abc import abstractmethod -from typing import List, Optional, Sequence, Tuple, Union, final - -from bioimageio.spec.model import v0_4, v0_5 - -from ._model_adapter import ( - DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER, - ModelAdapter, - WeightsFormat, -) -from .tensor import Tensor - - -def create_model_adapter( - model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], - *, - devices: Optional[Sequence[str]] = None, - weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None, -): - """ - Creates model adapter based on the passed spec - Note: All specific adapters should happen inside this function to prevent different framework - initializations interfering with each other - """ - if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): - raise TypeError( - f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" - ) - - weights = model_description.weights - errors: List[Tuple[WeightsFormat, Exception]] = [] - weight_format_priority_order = ( - DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER - if weight_format_priority_order is None - else weight_format_priority_order - ) - # limit weight formats to the ones present - weight_format_priority_order = [ - w for w in weight_format_priority_order if getattr(weights, w) is not None - ] - - for wf in weight_format_priority_order: - if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None: - try: - from .model_adapters_old._pytorch_model_adapter import ( - PytorchModelAdapter, - ) - - return PytorchModelAdapter( - outputs=model_description.outputs, - weights=weights.pytorch_state_dict, - devices=devices, - ) - except Exception as e: - errors.append((wf, e)) - elif ( - wf == "tensorflow_saved_model_bundle" - and weights.tensorflow_saved_model_bundle is not None - ): - try: - from .model_adapters_old._tensorflow_model_adapter import ( - TensorflowModelAdapter, - ) - - return TensorflowModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append((wf, e)) - elif wf == "onnx" and weights.onnx is not None: - try: - from .model_adapters_old._onnx_model_adapter import ONNXModelAdapter - - return ONNXModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append((wf, e)) - elif wf == "torchscript" and weights.torchscript is not None: - try: - from .model_adapters_old._torchscript_model_adapter import ( - TorchscriptModelAdapter, - ) - - return TorchscriptModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append((wf, e)) - elif wf == "keras_hdf5" and weights.keras_hdf5 is not None: - # keras can either be installed as a separate package or used as part of tensorflow - # we try to first import the keras model adapter using the separate package and, - # if it is not available, try to load the one using tf - try: - from .backend.keras import ( - KerasModelAdapter, - keras, # type: ignore - ) - - if keras is None: - from .model_adapters_old._tensorflow_model_adapter import ( - KerasModelAdapter, - ) - - return KerasModelAdapter( - model_description=model_description, devices=devices - ) - except Exception as e: - errors.append((wf, e)) - - assert errors - if len(weight_format_priority_order) == 1: - assert len(errors) == 1 - raise ValueError( - f"The '{weight_format_priority_order[0]}' model adapter could not be created" - + f" in this environment:\n{errors[0][1].__class__.__name__}({errors[0][1]}).\n\n" - ) from errors[0][1] - - else: - error_list = "\n - ".join( - f"{wf}: {e.__class__.__name__}({e})" for wf, e in errors - ) - raise ValueError( - "None of the weight format specific model adapters could be created" - + f" in this environment. Errors are:\n\n{error_list}.\n\n" - ) diff --git a/bioimageio/core/_model_adapter.py b/bioimageio/core/_model_adapter.py deleted file mode 100644 index 0438d35e..00000000 --- a/bioimageio/core/_model_adapter.py +++ /dev/null @@ -1,93 +0,0 @@ -import warnings -from abc import ABC, abstractmethod -from typing import List, Optional, Sequence, Tuple, Union, final - -from bioimageio.spec.model import v0_4, v0_5 - -from .tensor import Tensor - -WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat] - -__all__ = [ - "ModelAdapter", - "create_model_adapter", - "get_weight_formats", -] - -# Known weight formats in order of priority -# First match wins -DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[WeightsFormat, ...] = ( - "pytorch_state_dict", - "tensorflow_saved_model_bundle", - "torchscript", - "onnx", - "keras_hdf5", -) - - -class ModelAdapter(ABC): - """ - Represents model *without* any preprocessing or postprocessing. - - ``` - from bioimageio.core import load_description - - model = load_description(...) - - # option 1: - adapter = ModelAdapter.create(model) - adapter.forward(...) - adapter.unload() - - # option 2: - with ModelAdapter.create(model) as adapter: - adapter.forward(...) - ``` - """ - - @final - @classmethod - def create( - cls, - model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], - *, - devices: Optional[Sequence[str]] = None, - weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None, - ): - """ - Creates model adapter based on the passed spec - Note: All specific adapters should happen inside this function to prevent different framework - initializations interfering with each other - """ - from ._create_model_adapter import create_model_adapter - - return create_model_adapter( - model_description, - devices=devices, - weight_format_priority_order=weight_format_priority_order, - ) - - @final - def load(self, *, devices: Optional[Sequence[str]] = None) -> None: - warnings.warn("Deprecated. ModelAdapter is loaded on initialization") - - @abstractmethod - def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: - """ - Run forward pass of model to get model predictions - """ - # TODO: handle tensor.transpose in here and make _forward_impl the abstract impl - - @abstractmethod - def unload(self): - """ - Unload model from any devices, freeing their memory. - The moder adapter should be considered unusable afterwards. - """ - - -def get_weight_formats() -> List[str]: - """ - Return list of supported weight types - """ - return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER) diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py index 8f24d363..e6675b73 100644 --- a/bioimageio/core/_resource_tests.py +++ b/bioimageio/core/_resource_tests.py @@ -37,6 +37,7 @@ from bioimageio.spec._internal.common_nodes import ResourceDescrBase from bioimageio.spec._internal.io import is_yaml_value from bioimageio.spec._internal.io_utils import read_yaml, write_yaml +from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256 from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.model.v0_5 import WeightsFormat from bioimageio.spec.summary import ( @@ -192,7 +193,7 @@ def test_description( decimal=decimal, determinism=determinism, expected_type=expected_type, - sha256=sha256, + sha256=sha256, ) return rd.validation_summary diff --git a/bioimageio/core/backend/__init__.py b/bioimageio/core/backend/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/bioimageio/core/backends/__init__.py b/bioimageio/core/backends/__init__.py new file mode 100644 index 00000000..c39b58b5 --- /dev/null +++ b/bioimageio/core/backends/__init__.py @@ -0,0 +1,3 @@ +from ._model_adapter import create_model_adapter + +__all__ = ["create_model_adapter"] diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/backends/_model_adapter.py similarity index 84% rename from bioimageio/core/model_adapters/_model_adapter.py rename to bioimageio/core/backends/_model_adapter.py index 3921f81b..66153f09 100644 --- a/bioimageio/core/model_adapters/_model_adapter.py +++ b/bioimageio/core/backends/_model_adapter.py @@ -73,7 +73,7 @@ def create( for wf in weight_format_priority_order: if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None: try: - from ._pytorch_model_adapter import PytorchModelAdapter + from .pytorch_backend import PytorchModelAdapter return PytorchModelAdapter( outputs=model_description.outputs, @@ -87,7 +87,7 @@ def create( and weights.tensorflow_saved_model_bundle is not None ): try: - from ._tensorflow_model_adapter import TensorflowModelAdapter + from .tensorflow_backend import TensorflowModelAdapter return TensorflowModelAdapter( model_description=model_description, devices=devices @@ -96,7 +96,7 @@ def create( errors.append((wf, e)) elif wf == "onnx" and weights.onnx is not None: try: - from ._onnx_model_adapter import ONNXModelAdapter + from .onnx_backend import ONNXModelAdapter return ONNXModelAdapter( model_description=model_description, devices=devices @@ -105,7 +105,7 @@ def create( errors.append((wf, e)) elif wf == "torchscript" and weights.torchscript is not None: try: - from ._torchscript_model_adapter import TorchscriptModelAdapter + from .torchscript_backend import TorchscriptModelAdapter return TorchscriptModelAdapter( model_description=model_description, devices=devices @@ -117,13 +117,10 @@ def create( # we try to first import the keras model adapter using the separate package and, # if it is not available, try to load the one using tf try: - from ._keras import ( - KerasModelAdapter, - keras, # type: ignore - ) - - if keras is None: - from ._tensorflow_model_adapter import KerasModelAdapter + try: + from .keras_backend import KerasModelAdapter + except Exception: + from .tensorflow_backend import KerasModelAdapter return KerasModelAdapter( model_description=model_description, devices=devices @@ -134,10 +131,11 @@ def create( assert errors if len(weight_format_priority_order) == 1: assert len(errors) == 1 + wf, e = errors[0] raise ValueError( - f"The '{weight_format_priority_order[0]}' model adapter could not be created" - + f" in this environment:\n{errors[0][1].__class__.__name__}({errors[0][1]}).\n\n" - ) from errors[0][1] + f"The '{wf}' model adapter could not be created" + + f" in this environment:\n{e.__class__.__name__}({e}).\n\n" + ) from e else: error_list = "\n - ".join( @@ -165,13 +163,3 @@ def unload(self): Unload model from any devices, freeing their memory. The moder adapter should be considered unusable afterwards. """ - - -def get_weight_formats() -> List[str]: - """ - Return list of supported weight types - """ - return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER) - - -create_model_adapter = ModelAdapter.create diff --git a/bioimageio/core/backend/keras.py b/bioimageio/core/backends/keras_backend.py similarity index 74% rename from bioimageio/core/backend/keras.py rename to bioimageio/core/backends/keras_backend.py index 1d273cfc..35ee79fe 100644 --- a/bioimageio/core/backend/keras.py +++ b/bioimageio/core/backends/keras_backend.py @@ -10,30 +10,22 @@ from .._settings import settings from ..digest_spec import get_axes_infos -from ..model_adapters import ModelAdapter from ..tensor import Tensor +from ._model_adapter import ModelAdapter os.environ["KERAS_BACKEND"] = settings.keras_backend # by default, we use the keras integrated with tensorflow +# TODO: check if we should prefer keras try: - import tensorflow as tf # pyright: ignore[reportMissingImports] - from tensorflow import ( # pyright: ignore[reportMissingImports] - keras, # pyright: ignore[reportUnknownVariableType] + import tensorflow as tf + from tensorflow import ( + keras, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] ) - tf_version = Version(tf.__version__) # pyright: ignore[reportUnknownArgumentType] + tf_version = Version(tf.__version__) except Exception: - try: - import keras # pyright: ignore[reportMissingImports] - except Exception as e: - keras = None - keras_error = str(e) - else: - keras_error = None - tf_version = None -else: - keras_error = None + import keras class KerasModelAdapter(ModelAdapter): @@ -43,9 +35,6 @@ def __init__( model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], devices: Optional[Sequence[str]] = None, ) -> None: - if keras is None: - raise ImportError(f"failed to import keras: {keras_error}") - super().__init__() if model_description.weights.keras_hdf5 is None: raise ValueError("model has not keras_hdf5 weights specified") @@ -86,18 +75,26 @@ def __init__( def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: _result: Union[Sequence[NDArray[Any]], NDArray[Any]] - _result = self._network.predict( # pyright: ignore[reportUnknownVariableType] + _result = self._network.predict( # type: ignore *[None if t is None else t.data.data for t in input_tensors] ) if isinstance(_result, (tuple, list)): - result: Sequence[NDArray[Any]] = _result + result = _result # pyright: ignore[reportUnknownVariableType] else: result = [_result] # type: ignore - assert len(result) == len(self._output_axes) + assert len(result) == len( # pyright: ignore[reportUnknownArgumentType] + self._output_axes + ) ret: List[Optional[Tensor]] = [] ret.extend( - [Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)] + [ + Tensor(r, dims=axes) # pyright: ignore[reportArgumentType] + for r, axes, in zip( # pyright: ignore[reportUnknownVariableType] + result, # pyright: ignore[reportUnknownArgumentType] + self._output_axes, + ) + ] ) return ret diff --git a/bioimageio/core/backends/onnx_backend.py b/bioimageio/core/backends/onnx_backend.py new file mode 100644 index 00000000..21bbcc09 --- /dev/null +++ b/bioimageio/core/backends/onnx_backend.py @@ -0,0 +1,60 @@ +import warnings +from typing import Any, List, Optional, Sequence, Union + +import onnxruntime as rt + +from bioimageio.spec._internal.type_guards import is_list, is_tuple +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + +from ..digest_spec import get_axes_infos +from ..model_adapters import ModelAdapter +from ..tensor import Tensor + + +class ONNXModelAdapter(ModelAdapter): + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ): + super().__init__() + self._internal_output_axes = [ + tuple(a.id for a in get_axes_infos(out)) + for out in model_description.outputs + ] + if model_description.weights.onnx is None: + raise ValueError("No ONNX weights specified for {model_description.name}") + + self._session = rt.InferenceSession( + str(download(model_description.weights.onnx.source).path) + ) + onnx_inputs = self._session.get_inputs() # type: ignore + self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] # type: ignore + + if devices is not None: + warnings.warn( + f"Device management is not implemented for onnx yet, ignoring the devices {devices}" + ) + + def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: + assert len(input_tensors) == len(self._input_names) + input_arrays = [None if ipt is None else ipt.data.data for ipt in input_tensors] + result: Any = self._session.run( + None, dict(zip(self._input_names, input_arrays)) + ) + if is_list(result) or is_tuple(result): + result_seq = result + else: + result_seq = [result] + + return [ + None if r is None else Tensor(r, dims=axes) + for r, axes in zip(result_seq, self._internal_output_axes) + ] + + def unload(self) -> None: + warnings.warn( + "Device management is not implemented for onnx yet, cannot unload model" + ) diff --git a/bioimageio/core/backend/pytorch.py b/bioimageio/core/backends/pytorch_backend.py similarity index 100% rename from bioimageio/core/backend/pytorch.py rename to bioimageio/core/backends/pytorch_backend.py diff --git a/bioimageio/core/backends/tensorflow_backend.py b/bioimageio/core/backends/tensorflow_backend.py new file mode 100644 index 00000000..3f9cee9d --- /dev/null +++ b/bioimageio/core/backends/tensorflow_backend.py @@ -0,0 +1,289 @@ +import zipfile +from io import TextIOWrapper +from pathlib import Path +from shutil import copyfileobj +from typing import List, Literal, Optional, Sequence, Union + +import numpy as np +import tensorflow as tf +from loguru import logger + +from bioimageio.spec.common import FileSource, ZipPath +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + +from ..digest_spec import get_axes_infos +from ..tensor import Tensor +from ._model_adapter import ModelAdapter + + +class TensorflowModelAdapterBase(ModelAdapter): + weight_format: Literal["keras_hdf5", "tensorflow_saved_model_bundle"] + + def __init__( + self, + *, + devices: Optional[Sequence[str]] = None, + weights: Union[ + v0_4.KerasHdf5WeightsDescr, + v0_4.TensorflowSavedModelBundleWeightsDescr, + v0_5.KerasHdf5WeightsDescr, + v0_5.TensorflowSavedModelBundleWeightsDescr, + ], + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + ): + super().__init__() + self.model_description = model_description + tf_version = v0_5.Version(tf.__version__) + model_tf_version = weights.tensorflow_version + if model_tf_version is None: + logger.warning( + "The model does not specify the tensorflow version." + + f"Cannot check if it is compatible with intalled tensorflow {tf_version}." + ) + elif model_tf_version > tf_version: + logger.warning( + f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." + ) + elif (model_tf_version.major, model_tf_version.minor) != ( + tf_version.major, + tf_version.minor, + ): + logger.warning( + "The tensorflow version specified by the model does not match the installed: " + + f"{model_tf_version} != {tf_version}." + ) + + self.use_keras_api = ( + tf_version.major > 1 + or self.weight_format == KerasModelAdapter.weight_format + ) + + # TODO tf device management + if devices is not None: + logger.warning( + f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" + ) + + weight_file = self.require_unzipped(weights.source) + self._network = self._get_network(weight_file) + self._internal_output_axes = [ + tuple(a.id for a in get_axes_infos(out)) + for out in model_description.outputs + ] + + # TODO: check how to load tf weights without unzipping + def require_unzipped(self, weight_file: FileSource): + local_weights_file = download(weight_file).path + if isinstance(local_weights_file, ZipPath): + # weights file is in a bioimageio zip package + out_path = ( + Path("bioimageio_unzipped_tf_weights") / local_weights_file.filename + ) + with local_weights_file.open("rb") as src, out_path.open("wb") as dst: + assert not isinstance(src, TextIOWrapper) + copyfileobj(src, dst) + + local_weights_file = out_path + + if zipfile.is_zipfile(local_weights_file): + # weights file itself is a zipfile + out_path = local_weights_file.with_suffix(".unzipped") + with zipfile.ZipFile(local_weights_file, "r") as f: + f.extractall(out_path) + + return out_path + else: + return local_weights_file + + def _get_network( # pyright: ignore[reportUnknownParameterType] + self, weight_file: FileSource + ): + weight_file = self.require_unzipped(weight_file) + assert tf is not None + if self.use_keras_api: + try: + return tf.keras.layers.TFSMLayer( # pyright: ignore[reportAttributeAccessIssue,reportUnknownVariableType] + weight_file, + call_endpoint="serve", + ) + except Exception as e: + try: + return tf.keras.layers.TFSMLayer( # pyright: ignore[reportAttributeAccessIssue,reportUnknownVariableType] + weight_file, call_endpoint="serving_default" + ) + except Exception as ee: + logger.opt(exception=ee).info( + "keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'" + ) + raise e + else: + # NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model + return str(weight_file) + + # TODO currently we relaod the model every time. it would be better to keep the graph and session + # alive in between of forward passes (but then the sessions need to be properly opened / closed) + def _forward_tf( # pyright: ignore[reportUnknownParameterType] + self, *input_tensors: Optional[Tensor] + ): + assert tf is not None + input_keys = [ + ipt.name if isinstance(ipt, v0_4.InputTensorDescr) else ipt.id + for ipt in self.model_description.inputs + ] + output_keys = [ + out.name if isinstance(out, v0_4.OutputTensorDescr) else out.id + for out in self.model_description.outputs + ] + # TODO read from spec + tag = ( # pyright: ignore[reportUnknownVariableType] + tf.saved_model.tag_constants.SERVING # pyright: ignore[reportAttributeAccessIssue] + ) + signature_key = ( # pyright: ignore[reportUnknownVariableType] + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pyright: ignore[reportAttributeAccessIssue] + ) + + graph = tf.Graph() + with graph.as_default(): + with tf.Session( # pyright: ignore[reportAttributeAccessIssue] + graph=graph + ) as sess: # pyright: ignore[reportUnknownVariableType] + # load the model and the signature + graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] + sess, [tag], self._network + ) + signature = ( # pyright: ignore[reportUnknownVariableType] + graph_def.signature_def + ) + + # get the tensors into the graph + in_names = [ # pyright: ignore[reportUnknownVariableType] + signature[signature_key].inputs[key].name for key in input_keys + ] + out_names = [ # pyright: ignore[reportUnknownVariableType] + signature[signature_key].outputs[key].name for key in output_keys + ] + in_tensors = [ + graph.get_tensor_by_name( + name # pyright: ignore[reportUnknownArgumentType] + ) + for name in in_names # pyright: ignore[reportUnknownVariableType] + ] + out_tensors = [ + graph.get_tensor_by_name( + name # pyright: ignore[reportUnknownArgumentType] + ) + for name in out_names # pyright: ignore[reportUnknownVariableType] + ] + + # run prediction + res = sess.run( # pyright: ignore[reportUnknownVariableType] + dict( + zip( + out_names, # pyright: ignore[reportUnknownArgumentType] + out_tensors, + ) + ), + dict( + zip( + in_tensors, + [None if t is None else t.data for t in input_tensors], + ) + ), + ) + # from dict to list of tensors + res = [ # pyright: ignore[reportUnknownVariableType] + res[out] + for out in out_names # pyright: ignore[reportUnknownVariableType] + ] + + return res # pyright: ignore[reportUnknownVariableType] + + def _forward_keras( # pyright: ignore[reportUnknownParameterType] + self, *input_tensors: Optional[Tensor] + ): + assert self.use_keras_api + assert not isinstance(self._network, str) + assert tf is not None + tf_tensor = [ + None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_tensors + ] + + result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType] + + assert isinstance(result, dict) + + # TODO: Use RDF's `outputs[i].id` here + result = list( # pyright: ignore[reportUnknownVariableType] + result.values() # pyright: ignore[reportUnknownArgumentType] + ) + + return [ # pyright: ignore[reportUnknownVariableType] + (None if r is None else r if isinstance(r, np.ndarray) else r.numpy()) + for r in result # pyright: ignore[reportUnknownVariableType] + ] + + def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: + if self.use_keras_api: + result = self._forward_keras( # pyright: ignore[reportUnknownVariableType] + *input_tensors + ) + else: + result = self._forward_tf( # pyright: ignore[reportUnknownVariableType] + *input_tensors + ) + + return [ + ( + None + if r is None + else Tensor(r, dims=axes) # pyright: ignore[reportUnknownArgumentType] + ) + for r, axes in zip( # pyright: ignore[reportUnknownVariableType] + result, # pyright: ignore[reportUnknownArgumentType] + self._internal_output_axes, + ) + ] + + def unload(self) -> None: + logger.warning( + "Device management is not implemented for keras yet, cannot unload model" + ) + + +class TensorflowModelAdapter(TensorflowModelAdapterBase): + weight_format = "tensorflow_saved_model_bundle" + + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ): + if model_description.weights.tensorflow_saved_model_bundle is None: + raise ValueError("missing tensorflow_saved_model_bundle weights") + + super().__init__( + devices=devices, + weights=model_description.weights.tensorflow_saved_model_bundle, + model_description=model_description, + ) + + +class KerasModelAdapter(TensorflowModelAdapterBase): + weight_format = "keras_hdf5" + + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ): + if model_description.weights.keras_hdf5 is None: + raise ValueError("missing keras_hdf5 weights") + + super().__init__( + model_description=model_description, + devices=devices, + weights=model_description.weights.keras_hdf5, + ) diff --git a/bioimageio/core/backends/torchscript_backend.py b/bioimageio/core/backends/torchscript_backend.py new file mode 100644 index 00000000..d1882180 --- /dev/null +++ b/bioimageio/core/backends/torchscript_backend.py @@ -0,0 +1,79 @@ +import gc +import warnings +from typing import Any, List, Optional, Sequence, Union + +import torch + +from bioimageio.spec._internal.type_guards import is_list, is_ndarray, is_tuple +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + +from ..digest_spec import get_axes_infos +from ..model_adapters import ModelAdapter +from ..tensor import Tensor + + +class TorchscriptModelAdapter(ModelAdapter): + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ): + super().__init__() + if model_description.weights.torchscript is None: + raise ValueError( + f"No torchscript weights found for model {model_description.name}" + ) + + weight_path = download(model_description.weights.torchscript.source).path + if devices is None: + self.devices = ["cuda" if torch.cuda.is_available() else "cpu"] + else: + self.devices = [torch.device(d) for d in devices] + + if len(self.devices) > 1: + warnings.warn( + "Multiple devices for single torchscript model not yet implemented" + ) + + self._model = torch.jit.load(weight_path) + self._model.to(self.devices[0]) + self._model = self._model.eval() + self._internal_output_axes = [ + tuple(a.id for a in get_axes_infos(out)) + for out in model_description.outputs + ] + + def forward(self, *batch: Optional[Tensor]) -> List[Optional[Tensor]]: + with torch.no_grad(): + torch_tensor = [ + None if b is None else torch.from_numpy(b.data.data).to(self.devices[0]) + for b in batch + ] + _result: Any = self._model.forward(*torch_tensor) + if is_list(_result) or is_tuple(_result): + result: Sequence[Any] = _result + else: + result = [_result] + + result = [ + ( + None + if r is None + else r.cpu().numpy() if isinstance(r, torch.Tensor) else r + ) + for r in result + ] + + assert len(result) == len(self._internal_output_axes) + return [ + None if r is None else Tensor(r, dims=axes) if is_ndarray(r) else r + for r, axes in zip(result, self._internal_output_axes) + ] + + def unload(self) -> None: + self._devices = None + del self._model + _ = gc.collect() # deallocate memory + torch.cuda.empty_cache() # release reserved memory diff --git a/bioimageio/core/model_adapters.py b/bioimageio/core/model_adapters.py index 86fcfe4b..db92d013 100644 --- a/bioimageio/core/model_adapters.py +++ b/bioimageio/core/model_adapters.py @@ -1,8 +1,22 @@ -from ._create_model_adapter import create_model_adapter -from ._model_adapter import ModelAdapter, get_weight_formats +"""DEPRECATED""" + +from typing import List + +from .backends._model_adapter import ( + DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER, + ModelAdapter, + create_model_adapter, +) __all__ = [ "ModelAdapter", "create_model_adapter", "get_weight_formats", ] + + +def get_weight_formats() -> List[str]: + """ + Return list of supported weight types + """ + return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER) diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py deleted file mode 100644 index e69de29b..00000000