diff --git a/README.md b/README.md
index 4cb751f88..53bbaf8f8 100644
--- a/README.md
+++ b/README.md
@@ -122,6 +122,11 @@ To keep the bioimageio.spec Python package version in sync with the (model) desc
### bioimageio.spec Python package
+#### bioimageio.spec 0.5.3.5
+
+* fix loading tifffile in python 3.8 (pin tifffile)
+* use default tensorflow environments for Keras H5 weights
+
#### bioimageio.spec 0.5.3.4
* support loading and saving from/to zipfile.ZipFile objects
diff --git a/bioimageio/spec/VERSION b/bioimageio/spec/VERSION
index b4960cf9f..aa7631d22 100644
--- a/bioimageio/spec/VERSION
+++ b/bioimageio/spec/VERSION
@@ -1,3 +1,3 @@
{
- "version": "0.5.3.4"
+ "version": "0.5.3.5"
}
diff --git a/bioimageio/spec/conda_env.py b/bioimageio/spec/conda_env.py
index 9cb9d4793..90bd6dfe1 100644
--- a/bioimageio/spec/conda_env.py
+++ b/bioimageio/spec/conda_env.py
@@ -43,6 +43,19 @@ def _ensure_valid_conda_env_name(cls, value: Optional[str]) -> Optional[str]:
def wo_name(self):
return self.model_construct(**{k: v for k, v in self if k != "name"})
+ def _get_version(self, package: str):
+ """Helper to return any verison pin for **package**
+
+ TODO: improve: interprete version pin and return structured information.
+ """
+ for d in self.dependencies:
+ if isinstance(d, PipDeps):
+ for p in d.pip:
+ if p.startswith(package):
+ return p[len(package) :]
+ elif d.startswith(package):
+ return d[len(package) :]
+
class BioimageioCondaEnv(CondaEnv):
"""A special `CondaEnv` that
diff --git a/bioimageio/spec/get_conda_env.py b/bioimageio/spec/get_conda_env.py
index 7b6430795..675c2a16e 100644
--- a/bioimageio/spec/get_conda_env.py
+++ b/bioimageio/spec/get_conda_env.py
@@ -11,10 +11,12 @@
from .utils import download
SupportedWeightsEntry = Union[
+ v0_4.KerasHdf5WeightsDescr,
v0_4.OnnxWeightsDescr,
v0_4.PytorchStateDictWeightsDescr,
v0_4.TensorflowSavedModelBundleWeightsDescr,
v0_4.TorchscriptWeightsDescr,
+ v0_5.KerasHdf5WeightsDescr,
v0_5.OnnxWeightsDescr,
v0_5.PytorchStateDictWeightsDescr,
v0_5.TensorflowSavedModelBundleWeightsDescr,
@@ -58,6 +60,11 @@ def get_conda_env(
conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
else:
conda_env = _get_env_from_deps(entry.dependencies)
+ elif isinstance(
+ entry,
+ (v0_4.KerasHdf5WeightsDescr, v0_5.KerasHdf5WeightsDescr),
+ ):
+ conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
else:
assert_never(entry)
@@ -133,7 +140,7 @@ def _get_default_pytorch_env(
deps = [f"pytorch=={v}", "torchvision==0.20.0", "torchaudio==2.5.0"]
else:
set_github_warning(
- "UPDATE NEEDED", "Specify pins for additional pytorch dependencies!"
+ "UPDATE NEEDED", f"Specify pins for additional pytorch=={v} dependencies!"
)
deps = [f"pytorch=={v}", "torchvision", "torchaudio"]
diff --git a/bioimageio/spec/summary.py b/bioimageio/spec/summary.py
index 8d50ada0e..adb670b31 100644
--- a/bioimageio/spec/summary.py
+++ b/bioimageio/spec/summary.py
@@ -338,21 +338,29 @@ def format_loc(loc: Loc):
if d.recommended_env is not None:
rec_env = StringIO()
- json_env = d.recommended_env.model_dump(mode="json")
+ json_env = d.recommended_env.model_dump(
+ mode="json", exclude_defaults=True
+ )
assert is_yaml_value(json_env)
write_yaml(json_env, rec_env)
rec_env_code = rec_env.getvalue().replace("\n", "
")
details.append(
[
"🐍",
- "recommended conda env",
- f"{rec_env_code}
",
+ format_loc(d.loc),
+ f"recommended conda env ({d.name})
"
+ + f"{rec_env_code}
",
]
)
if d.conda_compare:
details.append(
- ["🐍", "conda compare", d.conda_compare.replace("\n", "
")]
+ [
+ "🐍",
+ format_loc(d.loc),
+ "conda compare ({d.name}):
"
+ + d.conda_compare.replace("\n", "
"),
+ ]
)
for entry in d.errors:
diff --git a/setup.py b/setup.py
index 529d8f644..6c7a82415 100644
--- a/setup.py
+++ b/setup.py
@@ -42,7 +42,7 @@
"requests",
"rich",
"ruyaml",
- "tifffile",
+ "tifffile >=2020.7.4",
"tqdm",
"typing-extensions",
"zipp",
diff --git a/tests/test_get_conda_env.py b/tests/test_get_conda_env.py
index c35b28e4f..7a2b102f5 100644
--- a/tests/test_get_conda_env.py
+++ b/tests/test_get_conda_env.py
@@ -1,17 +1,139 @@
from pathlib import Path
-from typing import Any, Mapping
+from typing import Any, Dict, List, Mapping, Optional, Union
+
+import pytest
from bioimageio.spec._internal.validation_context import ValidationContext
+from bioimageio.spec.model.v0_5 import (
+ OnnxWeightsDescr,
+ PytorchStateDictWeightsDescr,
+ TorchscriptWeightsDescr,
+)
-def test_get_conda_env(unet2d_data: Mapping[str, Any], unet2d_path: Path):
+@pytest.mark.parametrize(
+ "descr_class,w",
+ [
+ (
+ PytorchStateDictWeightsDescr,
+ dict(
+ authors=[
+ dict(
+ name="Constantin Pape;@bioimage-io",
+ affiliation="EMBL Heidelberg",
+ orcid="0000-0001-6562-7187",
+ )
+ ],
+ sha256="e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2",
+ source="https://zenodo.org/records/3446812/files/unet2d_weights.torch",
+ architecture=dict(
+ callable="UNet2d",
+ source="unet2d.py",
+ sha256="7cdd8332dc3e3735e71c328f81b63a9ac86c028f80522312484ca9a4027d4ce1",
+ kwargs=dict(input_channels=1, output_channels=1),
+ ),
+ dependencies=dict(
+ source="environment.yaml",
+ sha256="129d589d2ec801398719b1a6d1bf20ea36b3632f14ccb56a24700df7d719fd10",
+ ),
+ pytorch_version="1.5.1",
+ ),
+ ),
+ (
+ OnnxWeightsDescr,
+ dict(
+ sha256="f1f086d5e340f9d4d7001a1b62a2b835f9b87a2fb5452c4fe7d8cc821bdf539c",
+ source="weights.onnx",
+ opset_version=12,
+ parent="pytorch_state_dict",
+ ),
+ ),
+ (
+ TorchscriptWeightsDescr,
+ dict(
+ sha256="62fa1c39923bee7d58a192277e0dd58f2da9ee810662addadd0f44a3784d9210",
+ source="weights.pt",
+ parent="pytorch_state_dict",
+ pytorch_version="1.5.1",
+ ),
+ ),
+ ],
+)
+def test_get_conda_env(
+ descr_class: Union[
+ PytorchStateDictWeightsDescr, OnnxWeightsDescr, TorchscriptWeightsDescr
+ ],
+ w: Mapping[str, Any],
+ unet2d_path: Path,
+):
from bioimageio.spec.get_conda_env import get_conda_env
- from bioimageio.spec.model.v0_5 import PytorchStateDictWeightsDescr
with ValidationContext(perform_io_checks=False, root=unet2d_path.parent):
- w = PytorchStateDictWeightsDescr(**unet2d_data["weights"]["pytorch_state_dict"])
+ w_descr = descr_class.model_validate(w)
- conda_env = get_conda_env(entry=w)
+ conda_env = get_conda_env(entry=w_descr)
assert conda_env.channels
assert conda_env.dependencies
+
+
+def test_get_default_pytorch_env():
+ from bioimageio.spec._internal.version_type import Version
+ from bioimageio.spec.get_conda_env import (
+ _get_default_pytorch_env, # pyright: ignore[reportPrivateUsage]
+ )
+
+ versions: Dict[str, List[Optional[str]]] = {
+ "pytorch": [
+ "1.6.0",
+ "1.7.0",
+ "1.7.1",
+ "1.8.0",
+ "1.8.1",
+ "1.9.0",
+ "1.9.1",
+ "1.10.0",
+ "1.10.1",
+ "1.11.0",
+ "1.12.0",
+ "1.12.1",
+ "1.13.0",
+ "1.13.1",
+ "2.0.0",
+ "2.0.1",
+ "2.1.0",
+ "2.1.1",
+ "2.1.2",
+ "2.2.0",
+ "2.2.1",
+ "2.2.2",
+ "2.3.0",
+ "2.3.1",
+ "2.4.0",
+ "2.4.1",
+ "2.5.0",
+ ]
+ }
+ envs = [
+ _get_default_pytorch_env(pytorch_version=Version.model_validate(v))
+ for v in versions["pytorch"]
+ ]
+ for p in ["torchvision", "torchaudio"]:
+ versions[p] = [
+ env._get_version(p) for env in envs # pyright: ignore[reportPrivateUsage]
+ ]
+
+ def assert_lt(p: str, i: int):
+ vs = versions[p]
+ a, b = vs[i], vs[i + 1]
+ assert a is not None, (vs[i], vs[i + 1])
+ assert b is not None, (vs[i], vs[i + 1])
+ av = Version(a.strip("="))
+ bv = Version(b.strip("="))
+ assert av < bv, (vs[i], vs[i + 1])
+
+ for i in range(len(versions["pytorch"]) - 1):
+ assert_lt("pytorch", i)
+ assert_lt("torchvision", i)
+ if i > 0:
+ assert_lt("torchaudio", i)