diff --git a/bioimageio/spec/model/v0_4.py b/bioimageio/spec/model/v0_4.py index 78451de7a..500441fab 100644 --- a/bioimageio/spec/model/v0_4.py +++ b/bioimageio/spec/model/v0_4.py @@ -209,6 +209,37 @@ def check_one_entry(self) -> Self: return self + def __getitem__( + self, + key: Literal[ + "keras_hdf5", + "onnx", + "pytorch_state_dict", + "tensorflow_js", + "tensorflow_saved_model_bundle", + "torchscript", + ], + ): + if key == "keras_hdf5": + ret = self.keras_hdf5 + elif key == "onnx": + ret = self.onnx + elif key == "pytorch_state_dict": + ret = self.pytorch_state_dict + elif key == "tensorflow_js": + ret = self.tensorflow_js + elif key == "tensorflow_saved_model_bundle": + ret = self.tensorflow_saved_model_bundle + elif key == "torchscript": + ret = self.torchscript + else: + raise KeyError(key) + + if ret is None: + raise KeyError(key) + + return ret + class WeightsEntryDescrBase(FileDescr): type: ClassVar[WeightsFormat] diff --git a/bioimageio/spec/model/v0_5.py b/bioimageio/spec/model/v0_5.py index be31c62be..01b8a9f19 100644 --- a/bioimageio/spec/model/v0_5.py +++ b/bioimageio/spec/model/v0_5.py @@ -2026,6 +2026,37 @@ def check_entries(self) -> Self: return self + def __getitem__( + self, + key: Literal[ + "keras_hdf5", + "onnx", + "pytorch_state_dict", + "tensorflow_js", + "tensorflow_saved_model_bundle", + "torchscript", + ], + ): + if key == "keras_hdf5": + ret = self.keras_hdf5 + elif key == "onnx": + ret = self.onnx + elif key == "pytorch_state_dict": + ret = self.pytorch_state_dict + elif key == "tensorflow_js": + ret = self.tensorflow_js + elif key == "tensorflow_saved_model_bundle": + ret = self.tensorflow_saved_model_bundle + elif key == "torchscript": + ret = self.torchscript + else: + raise KeyError(key) + + if ret is None: + raise KeyError(key) + + return ret + class ModelId(ResourceId): pass