Skip to content

Commit

Permalink
Merge pull request #660 from bioimage-io/keras_env
Browse files Browse the repository at this point in the history
Add default Keras environments
  • Loading branch information
FynnBe authored Nov 18, 2024
2 parents 9dab46d + 3a23dfd commit a51a35a
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 12 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ To keep the bioimageio.spec Python package version in sync with the (model) desc

### bioimageio.spec Python package

#### bioimageio.spec 0.5.3.5

* fix loading tifffile in python 3.8 (pin tifffile)
* use default tensorflow environments for Keras H5 weights

#### bioimageio.spec 0.5.3.4

* support loading and saving from/to zipfile.ZipFile objects
Expand Down
2 changes: 1 addition & 1 deletion bioimageio/spec/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.5.3.4"
"version": "0.5.3.5"
}
13 changes: 13 additions & 0 deletions bioimageio/spec/conda_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ def _ensure_valid_conda_env_name(cls, value: Optional[str]) -> Optional[str]:
def wo_name(self):
return self.model_construct(**{k: v for k, v in self if k != "name"})

def _get_version(self, package: str):
"""Helper to return any verison pin for **package**
TODO: improve: interprete version pin and return structured information.
"""
for d in self.dependencies:
if isinstance(d, PipDeps):
for p in d.pip:
if p.startswith(package):
return p[len(package) :]
elif d.startswith(package):
return d[len(package) :]


class BioimageioCondaEnv(CondaEnv):
"""A special `CondaEnv` that
Expand Down
9 changes: 8 additions & 1 deletion bioimageio/spec/get_conda_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from .utils import download

SupportedWeightsEntry = Union[
v0_4.KerasHdf5WeightsDescr,
v0_4.OnnxWeightsDescr,
v0_4.PytorchStateDictWeightsDescr,
v0_4.TensorflowSavedModelBundleWeightsDescr,
v0_4.TorchscriptWeightsDescr,
v0_5.KerasHdf5WeightsDescr,
v0_5.OnnxWeightsDescr,
v0_5.PytorchStateDictWeightsDescr,
v0_5.TensorflowSavedModelBundleWeightsDescr,
Expand Down Expand Up @@ -58,6 +60,11 @@ def get_conda_env(
conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
else:
conda_env = _get_env_from_deps(entry.dependencies)
elif isinstance(
entry,
(v0_4.KerasHdf5WeightsDescr, v0_5.KerasHdf5WeightsDescr),
):
conda_env = _get_default_tf_env(tensorflow_version=entry.tensorflow_version)
else:
assert_never(entry)

Expand Down Expand Up @@ -133,7 +140,7 @@ def _get_default_pytorch_env(
deps = [f"pytorch=={v}", "torchvision==0.20.0", "torchaudio==2.5.0"]
else:
set_github_warning(
"UPDATE NEEDED", "Specify pins for additional pytorch dependencies!"
"UPDATE NEEDED", f"Specify pins for additional pytorch=={v} dependencies!"
)
deps = [f"pytorch=={v}", "torchvision", "torchaudio"]

Expand Down
16 changes: 12 additions & 4 deletions bioimageio/spec/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,21 +338,29 @@ def format_loc(loc: Loc):

if d.recommended_env is not None:
rec_env = StringIO()
json_env = d.recommended_env.model_dump(mode="json")
json_env = d.recommended_env.model_dump(
mode="json", exclude_defaults=True
)
assert is_yaml_value(json_env)
write_yaml(json_env, rec_env)
rec_env_code = rec_env.getvalue().replace("\n", "</code><br><code>")
details.append(
[
"🐍",
"recommended conda env",
f"<pre><code>{rec_env_code}</code></pre>",
format_loc(d.loc),
f"recommended conda env ({d.name})<br>"
+ f"<pre><code>{rec_env_code}</code></pre>",
]
)

if d.conda_compare:
details.append(
["🐍", "conda compare", d.conda_compare.replace("\n", "<br>")]
[
"🐍",
format_loc(d.loc),
"conda compare ({d.name}):<br>"
+ d.conda_compare.replace("\n", "<br>"),
]
)

for entry in d.errors:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"requests",
"rich",
"ruyaml",
"tifffile",
"tifffile >=2020.7.4",
"tqdm",
"typing-extensions",
"zipp",
Expand Down
132 changes: 127 additions & 5 deletions tests/test_get_conda_env.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,139 @@
from pathlib import Path
from typing import Any, Mapping
from typing import Any, Dict, List, Mapping, Optional, Union

import pytest

from bioimageio.spec._internal.validation_context import ValidationContext
from bioimageio.spec.model.v0_5 import (
OnnxWeightsDescr,
PytorchStateDictWeightsDescr,
TorchscriptWeightsDescr,
)


def test_get_conda_env(unet2d_data: Mapping[str, Any], unet2d_path: Path):
@pytest.mark.parametrize(
"descr_class,w",
[
(
PytorchStateDictWeightsDescr,
dict(
authors=[
dict(
name="Constantin Pape;@bioimage-io",
affiliation="EMBL Heidelberg",
orcid="0000-0001-6562-7187",
)
],
sha256="e4d3885bccbe41cbf6c1d825f3cd2b707c7021ead5593156007e407a16b27cf2",
source="https://zenodo.org/records/3446812/files/unet2d_weights.torch",
architecture=dict(
callable="UNet2d",
source="unet2d.py",
sha256="7cdd8332dc3e3735e71c328f81b63a9ac86c028f80522312484ca9a4027d4ce1",
kwargs=dict(input_channels=1, output_channels=1),
),
dependencies=dict(
source="environment.yaml",
sha256="129d589d2ec801398719b1a6d1bf20ea36b3632f14ccb56a24700df7d719fd10",
),
pytorch_version="1.5.1",
),
),
(
OnnxWeightsDescr,
dict(
sha256="f1f086d5e340f9d4d7001a1b62a2b835f9b87a2fb5452c4fe7d8cc821bdf539c",
source="weights.onnx",
opset_version=12,
parent="pytorch_state_dict",
),
),
(
TorchscriptWeightsDescr,
dict(
sha256="62fa1c39923bee7d58a192277e0dd58f2da9ee810662addadd0f44a3784d9210",
source="weights.pt",
parent="pytorch_state_dict",
pytorch_version="1.5.1",
),
),
],
)
def test_get_conda_env(
descr_class: Union[
PytorchStateDictWeightsDescr, OnnxWeightsDescr, TorchscriptWeightsDescr
],
w: Mapping[str, Any],
unet2d_path: Path,
):
from bioimageio.spec.get_conda_env import get_conda_env
from bioimageio.spec.model.v0_5 import PytorchStateDictWeightsDescr

with ValidationContext(perform_io_checks=False, root=unet2d_path.parent):
w = PytorchStateDictWeightsDescr(**unet2d_data["weights"]["pytorch_state_dict"])
w_descr = descr_class.model_validate(w)

conda_env = get_conda_env(entry=w)
conda_env = get_conda_env(entry=w_descr)

assert conda_env.channels
assert conda_env.dependencies


def test_get_default_pytorch_env():
from bioimageio.spec._internal.version_type import Version
from bioimageio.spec.get_conda_env import (
_get_default_pytorch_env, # pyright: ignore[reportPrivateUsage]
)

versions: Dict[str, List[Optional[str]]] = {
"pytorch": [
"1.6.0",
"1.7.0",
"1.7.1",
"1.8.0",
"1.8.1",
"1.9.0",
"1.9.1",
"1.10.0",
"1.10.1",
"1.11.0",
"1.12.0",
"1.12.1",
"1.13.0",
"1.13.1",
"2.0.0",
"2.0.1",
"2.1.0",
"2.1.1",
"2.1.2",
"2.2.0",
"2.2.1",
"2.2.2",
"2.3.0",
"2.3.1",
"2.4.0",
"2.4.1",
"2.5.0",
]
}
envs = [
_get_default_pytorch_env(pytorch_version=Version.model_validate(v))
for v in versions["pytorch"]
]
for p in ["torchvision", "torchaudio"]:
versions[p] = [
env._get_version(p) for env in envs # pyright: ignore[reportPrivateUsage]
]

def assert_lt(p: str, i: int):
vs = versions[p]
a, b = vs[i], vs[i + 1]
assert a is not None, (vs[i], vs[i + 1])
assert b is not None, (vs[i], vs[i + 1])
av = Version(a.strip("="))
bv = Version(b.strip("="))
assert av < bv, (vs[i], vs[i + 1])

for i in range(len(versions["pytorch"]) - 1):
assert_lt("pytorch", i)
assert_lt("torchvision", i)
if i > 0:
assert_lt("torchaudio", i)

0 comments on commit a51a35a

Please sign in to comment.