-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Aliases now also work for nested fields; Only retrieve data required for constructing a response from the database. #1304
base: main
Are you sure you want to change the base?
Changes from all commits
b949945
7adbf4a
8c05579
6c8294b
0ebc724
4388c9d
c29cdbb
368792a
38fcde7
5cec6d6
cbba008
67f2dea
775009d
1358b9d
bb228ee
8245884
7b2320b
d87bc1b
6b99822
841520d
c22089f
1146273
4c46f55
3a908bc
1e49fb8
596bb73
703a9df
f174850
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import re | ||
import warnings | ||
from abc import ABC, abstractmethod | ||
from functools import lru_cache | ||
from typing import Any, Dict, Iterable, List, Set, Tuple, Type, Union | ||
|
||
from lark import Transformer | ||
|
@@ -11,6 +12,7 @@ | |
from optimade.server.config import CONFIG, SupportedBackend | ||
from optimade.server.mappers import BaseResourceMapper | ||
from optimade.server.query_params import EntryListingQueryParams, SingleEntryQueryParams | ||
from optimade.utils import set_field_to_none_if_missing_in_dict | ||
from optimade.warnings import ( | ||
FieldValueNotRecognized, | ||
QueryParamNotUsed, | ||
|
@@ -121,13 +123,7 @@ def count(self, **kwargs: Any) -> int: | |
|
||
def find( | ||
self, params: Union[EntryListingQueryParams, SingleEntryQueryParams] | ||
) -> Tuple[ | ||
Union[List[EntryResource], EntryResource], | ||
int, | ||
bool, | ||
Set[str], | ||
Set[str], | ||
]: | ||
) -> Tuple[Union[List[EntryResource], EntryResource], int, bool, Set[str]]: | ||
""" | ||
Fetches results and indicates if more data is available. | ||
|
||
|
@@ -146,23 +142,49 @@ def find( | |
criteria = self.handle_query_params(params) | ||
single_entry = isinstance(params, SingleEntryQueryParams) | ||
response_fields = criteria.pop("fields") | ||
response_fields_set = criteria.pop("response_fields_set", False) | ||
|
||
raw_results, data_returned, more_data_available = self._run_db_query( | ||
criteria, single_entry | ||
) | ||
|
||
exclude_fields = self.all_fields - response_fields | ||
|
||
results: List = [self.resource_mapper.map_back(doc) for doc in raw_results] | ||
|
||
self.check_and_add_missing_fields(results, response_fields, response_fields_set) | ||
|
||
if results: | ||
results = self.resource_mapper.deserialize(results) | ||
|
||
if single_entry: | ||
raw_results = raw_results[0] if raw_results else None # type: ignore[assignment] | ||
results = results[0] if results else None # type: ignore[assignment] | ||
|
||
if data_returned > 1: | ||
raise NotFound( | ||
detail=f"Instead of a single entry, {data_returned} entries were found", | ||
) | ||
|
||
exclude_fields = self.all_fields - response_fields | ||
return results, data_returned, more_data_available, exclude_fields | ||
|
||
def check_and_add_missing_fields( | ||
self, results: List[dict], response_fields: set, response_fields_set: bool | ||
): | ||
"""Checks whether the response_fields and mandatory fields are present. | ||
If they are not present the values are set to None, so the deserialization works correctly. | ||
It also checks whether all fields in the response have been defined either in the model or in the config file. | ||
If not it raises an appropriate error or warning.""" | ||
include_fields = ( | ||
response_fields - self.resource_mapper.TOP_LEVEL_NON_ATTRIBUTES_FIELDS | ||
) | ||
# Include missing fields | ||
for result in results: | ||
for field in include_fields: | ||
set_field_to_none_if_missing_in_dict(result["attributes"], field) | ||
Comment on lines
+181
to
+183
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is going to incur a significant performance overhead, but I guess you want to do it so that you don't have to pull e.g., entire trajectories from the database each time, yet you still want to deserialize the JSON into your classes? I think I would suggest we instead have a per-collection deserialization flag, as presumably you only want to deserialize trajectories once (on database insertion) anyway. Does that make sense? If you want to retain this approach, it might be cleaner to do it at the pydantic level, e.g., a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not think this is particularly heavy compared to all the other things we do in the code. For biomolecular data, a structure can easily have 10,000 atoms, so retrieving them from the database and putting them in the model would take some time. This way we can avoid that if the species_at_sites and cartesian_site_positions are not in the response_fields. (I also made a patch in the code for Barcelona that allowed them to specify the default response fields, so they can choose to not have these fields in the response by default.) I did not want to make the change even bigger by bypassing the rest of the validator (as in your second suggestion). I tried the root validator idea, but it seems I already get an error before the root_validator is executed, so I do not think this solution will work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm, fair enough, just looks a bit scarier as a double for loop alongside the recursive descent into dictionaries to get the nested aliases. It's quite hard to reason about this, so I might set up a separate repo for measuring performance in the extreme limits (1 structure of 10000 atoms vs 10000 structures of a few atoms -- i.e., what we have now, ignoring pagination of course).
Did you use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just noticed I made a mistake in my test script. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added a root_validator to the attributes class that, if a flag is set, checks whether all required fields are present and if not adds them and sets them to 0. I'll try to put the handling of the other include fields back to the place where it happened originally so the code changes less. |
||
|
||
if response_fields_set: | ||
for result in results: | ||
result["attributes"]["set_missing_to_none"] = True | ||
|
||
bad_optimade_fields = set() | ||
bad_provider_fields = set() | ||
|
@@ -189,19 +211,6 @@ def find( | |
detail=f"Unrecognised OPTIMADE field(s) in requested `response_fields`: {bad_optimade_fields}." | ||
) | ||
|
||
if raw_results is not None: | ||
results = self.resource_mapper.deserialize(raw_results) | ||
else: | ||
results = None | ||
|
||
return ( | ||
results, | ||
data_returned, | ||
more_data_available, | ||
exclude_fields, | ||
include_fields, | ||
) | ||
|
||
@abstractmethod | ||
def _run_db_query( | ||
self, criteria: Dict[str, Any], single_entry: bool = False | ||
|
@@ -244,6 +253,7 @@ def all_fields(self) -> Set[str]: | |
|
||
return self._all_fields | ||
|
||
@lru_cache(maxsize=4) | ||
def get_attribute_fields(self) -> Set[str]: | ||
"""Get the set of attribute fields | ||
|
||
|
@@ -327,16 +337,16 @@ def handle_query_params( | |
cursor_kwargs["limit"] = CONFIG.page_limit | ||
|
||
# response_fields | ||
cursor_kwargs["projection"] = { | ||
f"{self.resource_mapper.get_backend_field(f)}": True | ||
for f in self.all_fields | ||
} | ||
|
||
if getattr(params, "response_fields", False): | ||
cursor_kwargs["response_fields_set"] = True | ||
response_fields = set(params.response_fields.split(",")) | ||
response_fields |= self.resource_mapper.get_required_fields() | ||
else: | ||
response_fields = self.all_fields.copy() | ||
cursor_kwargs["projection"] = { | ||
f"{self.resource_mapper.get_backend_field(f)}": True | ||
for f in response_fields | ||
} | ||
|
||
cursor_kwargs["fields"] = response_fields | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks on the right track! I've just played around with something too after spotting something in the pydantic docs about default factories. Would the snippet also solve this issue?
We can then patch the underlying
OptimadeField
andStrictField
wrappers to default to having adefault_factory
that returns null in cases where there is no default value for the field to fall back on, and we can do this without modifying the schema or the models.The only concern is that this functionality might get removed from pydantic:
Though it has already lasted a few versions.