-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use a plain serializer to enable pystac Classificiation in BaseModel #3
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @rbavery, thank you for submitting a PR!
MLMClassification = Annotated[Classification, | ||
PlainSerializer(lambda x: x.to_dict(), return_type=Dict[str, Any], when_used='json')] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if a model_serializer decorator would not be more explicit and help users understand what is happening?
Would something like this work?
class MLMClassification(BaseModel, Classification):
@model_serializer()
def serialize_model(self) -> Dict[str, Any]:
return self.to_dict()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replacing the Annotated type with this class doesn't work unfortunately
/home/rave/work/dlm-extension/tests/test_schema.py::test_model_metadata_to_dict failed with error: Test failed with exception
tests/test_schema.py:8: in mlmodel_metadata_item
model_metadata_stac_item = eurosat_resnet()
stac_model/examples.py:107: in eurosat_resnet
class_objects = [
stac_model/examples.py:108: in <listcomp>
MLMClassification.create(value=class_map[class_name], description=class_name, name=class_name)
../../.cache/pypoetry/virtualenvs/stac-model-YQXQDXJF-py3.10/lib/python3.10/site-packages/pystac/extensions/classification.py:112: in create
c = cls({})
E TypeError: BaseModel.__init__() takes 1 positional argument but 2 were given
I suggest for now we stick with this since it works, we can always change the underlying implementation in the future if it is confusing for contributors.
@@ -106,9 +105,10 @@ def eurosat_resnet(): | |||
"SeaLake": 9, | |||
} | |||
class_objects = [ | |||
ClassObject(value=class_map[class_name], name=class_name) | |||
MLMClassification.create(value=class_map[class_name], description=class_name, name=class_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we defined class MLMClassification(BaseModel, Classification)
(see other comment), would it be possible to use MLMClassification(value=...)
and MLMClassification.create(value=...)
interchangeably? Maybe a __init__
calling create
could be required since the fields would not be defined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the use of Annotated is clear enough that we can stick with it. We could override the init, but I think the use of Annotated is more succinct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tried using the Annotated
approach, but I ended up getting multiple unhashable dict
errors when using MLMClassification
within another BaseModel
definition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to do the following to make it all work together:
dlm-extension/stac_model/output.py
Lines 34 to 65 in 2d6c70b
class MLMClassification(BaseModel, Classification): | |
@model_serializer() | |
def model_dump(self, *_, **__) -> Dict[str, Any]: | |
return self.to_dict() | |
def __init__( | |
self, | |
value: int, | |
description: Optional[str] = None, | |
name: Optional[str] = None, | |
color_hint: Optional[str] = None | |
) -> None: | |
Classification.__init__(self, {}) | |
if not name and not description: | |
raise ValueError("Class name or description is required!") | |
self.apply( | |
value=value, | |
name=name or description, | |
description=description or name, | |
color_hint=color_hint, | |
) | |
def __hash__(self) -> int: | |
return sum(map(hash, self.to_dict().items())) | |
def __setattr__(self, key: str, value: Any) -> None: | |
if key == "properties": | |
Classification.__setattr__(self, key, value) | |
else: | |
BaseModel.__setattr__(self, key, value) | |
model_config = ConfigDict(arbitrary_types_allowed=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to include MLMClassification in any other BaseModel definitions in this code base though I think? So it shouldn't be a problem? I haven't encountered that error when using Annotated when serializing or loading the data back in with pystac. The class above is more verbose and complex for contributors than the use of Annotated imo. Can we keep it as is since it works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It must be included in the model output. We employ the same definition as classification:classes
, but they are not currently validated by that extension, because the field is nested under a specific MLM output rather than at the Item properties or Asset level (see stac-extensions/classification#48).
Using the above definition, serialize/load works both ways, as evaluated by this test:
dlm-extension/tests/test_schema.py
Lines 36 to 39 in 4eb30da
def test_validate_model_against_schema(eurosat_resnet, mlm_validator): | |
mlm_item = pystac.read_dict(eurosat_resnet.item.to_dict()) | |
validated = pystac.validation.validate(mlm_item, validator=mlm_validator) | |
assert SCHEMA_URI in validated |
using this definition:
dlm-extension/stac_model/examples.py
Lines 106 to 116 in 4eb30da
class_objects = [ | |
MLMClassification(value=class_value, name=class_name) | |
for class_name, class_value in class_map.items() | |
] | |
output = ModelOutput( | |
name="classification", | |
tasks={"classification"}, | |
classes=class_objects, | |
result=result_array, | |
post_processing_function=None, | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok gotcha. I'm still a bit confused why the use of Annotated
wasn't working out but trust this is best.
Since you have this working in your PR I'll close this @fmigneault
Tests pass for serializing the pystac item and reading in the item with pystac when using the annotated Classification class from pystac in the output BaseModel. If this approach looks good I can do the same for Statistics and Band @fmigneault