Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use a plain serializer to enable pystac Classificiation in BaseModel #3

Closed
wants to merge 1 commit into from

Conversation

rbavery
Copy link
Owner

@rbavery rbavery commented Apr 1, 2024

Tests pass for serializing the pystac item and reading in the item with pystac when using the annotated Classification class from pystac in the output BaseModel. If this approach looks good I can do the same for Statistics and Band @fmigneault

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @rbavery, thank you for submitting a PR!

Comment on lines +35 to +36
MLMClassification = Annotated[Classification,
PlainSerializer(lambda x: x.to_dict(), return_type=Dict[str, Any], when_used='json')]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if a model_serializer decorator would not be more explicit and help users understand what is happening?

Would something like this work?

class MLMClassification(BaseModel, Classification):
    @model_serializer()
    def serialize_model(self) -> Dict[str, Any]:
        return self.to_dict()

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing the Annotated type with this class doesn't work unfortunately

/home/rave/work/dlm-extension/tests/test_schema.py::test_model_metadata_to_dict failed with error: Test failed with exception
tests/test_schema.py:8: in mlmodel_metadata_item
    model_metadata_stac_item = eurosat_resnet()
stac_model/examples.py:107: in eurosat_resnet
    class_objects = [
stac_model/examples.py:108: in <listcomp>
    MLMClassification.create(value=class_map[class_name], description=class_name, name=class_name)
../../.cache/pypoetry/virtualenvs/stac-model-YQXQDXJF-py3.10/lib/python3.10/site-packages/pystac/extensions/classification.py:112: in create
    c = cls({})
E   TypeError: BaseModel.__init__() takes 1 positional argument but 2 were given

I suggest for now we stick with this since it works, we can always change the underlying implementation in the future if it is confusing for contributors.

@@ -106,9 +105,10 @@ def eurosat_resnet():
"SeaLake": 9,
}
class_objects = [
ClassObject(value=class_map[class_name], name=class_name)
MLMClassification.create(value=class_map[class_name], description=class_name, name=class_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we defined class MLMClassification(BaseModel, Classification) (see other comment), would it be possible to use MLMClassification(value=...) and MLMClassification.create(value=...) interchangeably? Maybe a __init__ calling create could be required since the fields would not be defined.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the use of Annotated is clear enough that we can stick with it. We could override the init, but I think the use of Annotated is more succinct.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried using the Annotated approach, but I ended up getting multiple unhashable dict errors when using MLMClassification within another BaseModel definition.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to do the following to make it all work together:

class MLMClassification(BaseModel, Classification):
@model_serializer()
def model_dump(self, *_, **__) -> Dict[str, Any]:
return self.to_dict()
def __init__(
self,
value: int,
description: Optional[str] = None,
name: Optional[str] = None,
color_hint: Optional[str] = None
) -> None:
Classification.__init__(self, {})
if not name and not description:
raise ValueError("Class name or description is required!")
self.apply(
value=value,
name=name or description,
description=description or name,
color_hint=color_hint,
)
def __hash__(self) -> int:
return sum(map(hash, self.to_dict().items()))
def __setattr__(self, key: str, value: Any) -> None:
if key == "properties":
Classification.__setattr__(self, key, value)
else:
BaseModel.__setattr__(self, key, value)
model_config = ConfigDict(arbitrary_types_allowed=True)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to include MLMClassification in any other BaseModel definitions in this code base though I think? So it shouldn't be a problem? I haven't encountered that error when using Annotated when serializing or loading the data back in with pystac. The class above is more verbose and complex for contributors than the use of Annotated imo. Can we keep it as is since it works?

Copy link
Collaborator

@fmigneault fmigneault Apr 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It must be included in the model output. We employ the same definition as classification:classes, but they are not currently validated by that extension, because the field is nested under a specific MLM output rather than at the Item properties or Asset level (see stac-extensions/classification#48).

Using the above definition, serialize/load works both ways, as evaluated by this test:

def test_validate_model_against_schema(eurosat_resnet, mlm_validator):
mlm_item = pystac.read_dict(eurosat_resnet.item.to_dict())
validated = pystac.validation.validate(mlm_item, validator=mlm_validator)
assert SCHEMA_URI in validated

using this definition:
class_objects = [
MLMClassification(value=class_value, name=class_name)
for class_name, class_value in class_map.items()
]
output = ModelOutput(
name="classification",
tasks={"classification"},
classes=class_objects,
result=result_array,
post_processing_function=None,
)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok gotcha. I'm still a bit confused why the use of Annotated wasn't working out but trust this is best.

Since you have this working in your PR I'll close this @fmigneault

@rbavery rbavery closed this Apr 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants