Skip to content

Commit

Permalink
improve increase_available_weight_formats
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Jan 24, 2025
1 parent e35735d commit 34d5478
Showing 1 changed file with 78 additions and 23 deletions.
101 changes: 78 additions & 23 deletions bioimageio/core/weight_converters/_add_weights.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,92 @@
from copy import deepcopy
from pathlib import Path
from typing import List, Optional, Sequence, Union
from typing import Optional, Sequence

from bioimageio.spec.model import v0_4, v0_5
from loguru import logger
from pydantic import DirectoryPath

from bioimageio.core._resource_tests import test_model
from bioimageio.spec import load_model_description, save_bioimageio_package_as_folder
from bioimageio.spec._internal.types import AbsoluteTolerance, RelativeTolerance
from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat


def increase_available_weight_formats(
model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr],
model_descr: ModelDescr,
*,
source_format: Optional[v0_5.WeightsFormat] = None,
target_format: Optional[v0_5.WeightsFormat] = None,
output_path: Path,
devices: Optional[Sequence[str]] = None,
) -> Union[v0_4.ModelDescr, v0_5.ModelDescr]:
"""Convert neural network weights to other formats and add them to the model description"""
if not isinstance(model_descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
raise TypeError(
f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_descr)}"
)
output_path: DirectoryPath,
source_format: Optional[WeightsFormat] = None,
target_format: Optional[WeightsFormat] = None,
devices: Sequence[str] = ("cpu",),
) -> ModelDescr:
"""Convert model weights to other formats and add them to the model description
Args:
output_path: Path to save updated model package to.
source_format: convert from a specific weights format.
Default: choose automatically from any available.
target_format: convert to a specific weights format.
Default: attempt to convert to any missing format.
devices: Devices that may be used during conversion.
"""
if not isinstance(model_descr, ModelDescr):
raise TypeError(type(model_descr))

# save model to local folder
output_path = save_bioimageio_package_as_folder(
model_descr, output_path=output_path
)
# reload from local folder to make sure we do not edit the given model
_model_descr = load_model_description(output_path)
assert isinstance(_model_descr, ModelDescr)
model_descr = _model_descr
del _model_descr

if source_format is None:
available = [wf for wf, w in model_descr.weights if w is not None]
missing = [wf for wf, w in model_descr.weights if w is None]
available = set(model_descr.weights.available_formats)
else:
available = {source_format}

if target_format is None:
missing = set(model_descr.weights.missing_formats)
else:
available = [source_format]
missing = [target_format]
missing = {target_format}

if "pytorch_state_dict" in available and "onnx" in missing:
from .pytorch_to_onnx import convert

onnx = convert(model_descr)
try:
model_descr.weights.onnx = convert(
model_descr,
output_path=output_path,
use_tracing=False,
)
except Exception as e:
logger.error(e)
else:
available.add("onnx")
missing.discard("onnx")

else:
raise NotImplementedError(
f"Converting from '{source_format}' to '{target_format}' is not yet implemented. Please create an issue at https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose"
if "pytorch_state_dict" in available and "torchscript" in missing:
from .pytorch_to_torchscript import convert

try:
model_descr.weights.torchscript = convert(
model_descr,
output_path=output_path,
use_tracing=False,
)
except Exception as e:
logger.error(e)
else:
available.add("torchscript")
missing.discard("torchscript")

if missing:
logger.warning(
f"Converting from any of the available weights formats {available} to any"
+ f" of {missing} is not yet implemented. Please create an issue at"
+ " https://github.com/bioimage-io/core-bioimage-io-python/issues/new/choose"
+ " if you would like bioimageio.core to support a particular conversion."
)

test_model(model_descr).display()
return model_descr

0 comments on commit 34d5478

Please sign in to comment.