Skip to content

Commit

Permalink
add __getitem__ for WeightsDescr
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Dec 19, 2024
1 parent 15ea0f7 commit 53b7558
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
31 changes: 31 additions & 0 deletions bioimageio/spec/model/v0_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,37 @@ def check_one_entry(self) -> Self:

return self

def __getitem__(
self,
key: Literal[
"keras_hdf5",
"onnx",
"pytorch_state_dict",
"tensorflow_js",
"tensorflow_saved_model_bundle",
"torchscript",
],
):
if key == "keras_hdf5":
ret = self.keras_hdf5
elif key == "onnx":
ret = self.onnx
elif key == "pytorch_state_dict":
ret = self.pytorch_state_dict
elif key == "tensorflow_js":
ret = self.tensorflow_js
elif key == "tensorflow_saved_model_bundle":
ret = self.tensorflow_saved_model_bundle
elif key == "torchscript":
ret = self.torchscript
else:
raise KeyError(key)

if ret is None:
raise KeyError(key)

return ret


class WeightsEntryDescrBase(FileDescr):
type: ClassVar[WeightsFormat]
Expand Down
31 changes: 31 additions & 0 deletions bioimageio/spec/model/v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,6 +2026,37 @@ def check_entries(self) -> Self:

return self

def __getitem__(
self,
key: Literal[
"keras_hdf5",
"onnx",
"pytorch_state_dict",
"tensorflow_js",
"tensorflow_saved_model_bundle",
"torchscript",
],
):
if key == "keras_hdf5":
ret = self.keras_hdf5
elif key == "onnx":
ret = self.onnx
elif key == "pytorch_state_dict":
ret = self.pytorch_state_dict
elif key == "tensorflow_js":
ret = self.tensorflow_js
elif key == "tensorflow_saved_model_bundle":
ret = self.tensorflow_saved_model_bundle
elif key == "torchscript":
ret = self.torchscript
else:
raise KeyError(key)

if ret is None:
raise KeyError(key)

return ret


class ModelId(ResourceId):
pass
Expand Down

0 comments on commit 53b7558

Please sign in to comment.