Skip to content

Commit

Permalink
feat: make Map implement MutableMapping interface (#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 authored Mar 15, 2024
1 parent 80d6a40 commit 7f08c11
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ repos:
additional_dependencies:
- pydantic>=2
- pydantic-compat
- xsdata>=24
- xsdata==24.2.1
- Pint
- types-lxml; python_version > '3.8'
2 changes: 2 additions & 0 deletions src/ome_autogen/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
ALLOW_RESERVED_NAMES = {"type", "Type", "Union"}
OME_FORMAT = "OME"
MIXIN_MODULE = "ome_types._mixins"
# class_name, import_string, whether-to-prepend
MIXINS: list[tuple[str, str, bool]] = [
(".*", f"{MIXIN_MODULE}._base_type.OMEType", False), # base type on every class
("OME", f"{MIXIN_MODULE}._ome.OMEMixin", True),
("Instrument", f"{MIXIN_MODULE}._instrument.InstrumentMixin", False),
("Reference", f"{MIXIN_MODULE}._reference.ReferenceMixin", True),
("Map", f"{MIXIN_MODULE}._map_mixin.MapMixin", False),
("Union", f"{MIXIN_MODULE}._collections.ShapeUnionMixin", True),
(
"StructuredAnnotations",
Expand Down
19 changes: 16 additions & 3 deletions src/ome_autogen/_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import typing as typing_module
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Iterator, NamedTuple, cast
Expand Down Expand Up @@ -54,6 +55,12 @@
lambda c: c.name == "ROI",
"\n\n_v_shape_union = field_validator('union', mode='before')(validate_shape_union)", # noqa: E501
),
(
lambda c: c.name == "Map",
"\n\n_v_map = model_validator(mode='before')(validate_map_annotation)"
"\ndict: ClassVar = MapMixin._pydict"
"\n__iter__: ClassVar = MapMixin.__iter__",
),
]


Expand Down Expand Up @@ -99,15 +106,21 @@ class Override(NamedTuple):
"ome_types._mixins._validators": {
"any_elements_validator": ["any_elements_validator"],
"bin_data_root_validator": ["bin_data_root_validator"],
"pixels_root_validator": ["pixels_root_validator"],
"xml_value_validator": ["xml_value_validator"],
"pixel_type_to_numpy_dtype": ["pixel_type_to_numpy_dtype"],
"validate_structured_annotations": ["validate_structured_annotations"],
"pixels_root_validator": ["pixels_root_validator"],
"validate_map_annotation": ["validate_map_annotation"],
"validate_shape_union": ["validate_shape_union"],
"validate_structured_annotations": ["validate_structured_annotations"],
"xml_value_validator": ["xml_value_validator"],
},
}
)

# not all typing names appear to be added by xsdata
IMPORT_PATTERNS.setdefault("typing", {}).update(
{n: [f": {n}"] for n in dir(typing_module) if not n.startswith("_")}
)


class OmeGenerator(DataclassGenerator):
@classmethod
Expand Down
10 changes: 9 additions & 1 deletion src/ome_types/_mixins/_base_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class OMEType(BaseModel):
"arbitrary_types_allowed": True,
"validate_assignment": True,
"validate_default": True,
"coerce_numbers_to_str": True,
}

# allow use with weakref
Expand All @@ -93,16 +94,23 @@ class OMEType(BaseModel):

_vid = field_validator("id", mode="before", check_fields=False)(validate_id)

def __iter__(self) -> Any:
return super().__iter__()

def __init__(self, **data: Any) -> None:
warn_extra = data.pop("warn_extra", True)
field_names = set(self.model_fields)
_move_deprecated_fields(data, field_names)
super().__init__(**data)
if type(self).__name__ == "Map":
# special escape hack for Map subclass, which can convert any
# dict into appropriate key-value pairs
return
kwargs = set(data.keys())
extra = kwargs - field_names
if extra and warn_extra:
warnings.warn(
f"Unrecognized fields for type {type(self)}: {kwargs - field_names}",
f"Unrecognized fields for type {type(self)}: {extra}",
stacklevel=3,
)

Expand Down
54 changes: 54 additions & 0 deletions src/ome_types/_mixins/_map_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, MutableMapping, Optional

try:
from pydantic import model_serializer
except ImportError:
model_serializer = None # type: ignore


if TYPE_CHECKING:
from typing import Protocol

from ome_types._autogenerated.ome_2016_06.map import Map

class HasMsProtocol(Protocol):
@property
def ms(self) -> List["Map.M"]: ...


class MapMixin(MutableMapping[str, Optional[str]]):
def __delitem__(self: "HasMsProtocol", key: str) -> None:
for m in self.ms:
if m.k == key:
self.ms.remove(m)
return

def __len__(self: "HasMsProtocol") -> int:
return len(self.ms)

def __iter__(self: "HasMsProtocol") -> Iterator[str]:
yield from (m.k for m in self.ms if m.k is not None)

def __getitem__(self: "HasMsProtocol", key: str) -> Optional[str]:
return next((m.value for m in self.ms if m.k == key), None)

def __setitem__(self: "HasMsProtocol", key: str, value: Optional[str]) -> None:
for m in self.ms:
if m.k == key:
m.value = value or ""
return
from ome_types.model import Map

self.ms.append(Map.M(k=key, value=value))

def _pydict(self: "HasMsProtocol", **kwargs: Any) -> Dict[str, str]:
return {m.k: m.value for m in self.ms if m.k is not None}

def dict(self, **kwargs: Any) -> Dict[str, Any]:
return self._pydict() # type: ignore

if model_serializer is not None:

@model_serializer(mode="wrap")
def serialize_root(self, handler, _info) -> dict: # type: ignore
return self._pydict() # type: ignore
16 changes: 14 additions & 2 deletions src/ome_types/_mixins/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
OME,
ROI,
BinData,
Map,
Pixels,
PixelType,
StructuredAnnotations,
Expand All @@ -19,7 +20,7 @@
from xsdata_pydantic_basemodel.compat import AnyElement


# @root_validator(pre=True)
# @model_validator(mode='before')
def bin_data_root_validator(cls: "BinData", values: dict) -> Dict[str, Any]:
# This catches the case of <BinData Length="0"/>, where the parser may have
# omitted value from the dict, and sets value to b""
Expand All @@ -34,7 +35,7 @@ def bin_data_root_validator(cls: "BinData", values: dict) -> Dict[str, Any]:
return values


# @root_validator(pre=True)
# @model_validator(mode='before')
def pixels_root_validator(cls: "Pixels", value: dict) -> dict:
if "metadata_only" in value:
if isinstance(value["metadata_only"], bool):
Expand Down Expand Up @@ -124,3 +125,14 @@ def validate_shape_union(cls: "ROI", v: Any) -> "ROI.Union":
_values.setdefault(ROI.Union._field_name(item), []).append(item)
v = _values
return v


# @model_validator(mode="before")
def validate_map_annotation(cls: "Map", v: Any) -> "Map | dict":
from ome_types.model import Map

if isinstance(v, dict):
if len(v) == 1 and "ms" in v:
return v
return {"ms": [Map.M(k=k, value=v) for k, v in v.items()]}
return v
32 changes: 32 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,35 @@ def test_transformations() -> None:
# SHOULD warn
with pytest.warns(match="Casting invalid"):
from_xml(DATA / "MMStack.ome.xml")


def test_map_annotations() -> None:
from ome_types.model import Map, MapAnnotation

data = {"a": "string", "b": 2}

# can be created from a dict
map_annotation = MapAnnotation(value=data)
map_val = map_annotation.value
assert isinstance(map_annotation, MapAnnotation)
assert isinstance(map_val, Map)

out = map_annotation.value.model_dump()
assert out == {k: str(v) for k, v in data.items()} # all values cast to str

# it's a mutable mapping
map_val["c"] = "new"
assert map_val.get("c") == "new"
assert map_val.get("X") is None
assert len(map_val) == 3
assert set(map_val) == {"a", "b", "c"}
assert dict(map_val) == map_val.model_dump() == {**out, "c": "new"}
del map_val["c"]
assert len(map_val) == 2

_ = map_annotation.to_xml() # shouldn't fail

# only strings are allowed as values
data["nested"] = [1, 2]
with pytest.raises(ValidationError):
MapAnnotation(value=data)

0 comments on commit 7f08c11

Please sign in to comment.