Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds model validator to ensure consistent renaming #9

Merged
merged 4 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions RAT/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,22 @@ def get_names(self) -> list[str]:
"""
return [getattr(model, self.name_field) for model in self.data if hasattr(model, self.name_field)]

def get_all_matches(self, value: Any) -> list[tuple]:
"""Return a list of all (index, field) tuples where the value of the field is equal to the given value.

Parameters
----------
value : str
The value we are searching for in the ClassList.

Returns
-------
: list [tuple]
A list of (index, field) tuples matching the given value.
"""
return [(index, field) for index, element in enumerate(self.data) for field in vars(element)
if getattr(element, field) == value]

def _validate_name_field(self, input_args: dict[str, Any]) -> None:
"""Raise a ValueError if the name_field attribute is passed as an object parameter, and its value is already
used within the ClassList.
Expand Down
53 changes: 47 additions & 6 deletions RAT/project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""The project module. Defines and stores all the input data required for reflectivity calculations in RAT."""

import collections
import contextlib
import copy
import functools
import numpy as np
Expand Down Expand Up @@ -46,7 +48,7 @@ class Geometries(StrEnum):
'custom_files': 'CustomFile',
'data': 'Data',
'layers': 'Layer',
'contrasts': 'Contrast'
'contrasts': 'Contrast',
}

values_defined_in = {'backgrounds.value_1': 'background_parameters',
Expand All @@ -70,6 +72,23 @@ class Geometries(StrEnum):
'contrasts.resolution': 'resolutions',
}

AllFields = collections.namedtuple('AllFields', ['attribute', 'fields'])
model_names_used_in = {'background_parameters': AllFields('backgrounds', ['value_1', 'value_2', 'value_3', 'value_4',
'value_5']),
'resolution_parameters': AllFields('resolutions', ['value_1', 'value_2', 'value_3', 'value_4',
'value_5']),
'parameters': AllFields('layers', ['thickness', 'SLD', 'roughness']),
'data': AllFields('contrasts', ['data']),
'backgrounds': AllFields('contrasts', ['background']),
'bulk_in': AllFields('contrasts', ['nba']),
'bulk_out': AllFields('contrasts', ['nbs']),
'scalefactors': AllFields('contrasts', ['scalefactor']),
'resolutions': AllFields('contrasts', ['resolution']),
}

class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters', 'backgrounds',
'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers', 'contrasts']


class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
"""Defines the input data for a reflectivity calculation in RAT.
Expand Down Expand Up @@ -120,6 +139,8 @@ class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_typ
layers: ClassList = ClassList()
contrasts: ClassList = ClassList()

_all_names: dict

@field_validator('parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
'contrasts')
Expand All @@ -133,8 +154,8 @@ def check_class(cls, value: ClassList, info: FieldValidationInfo) -> ClassList:
return value

def model_post_init(self, __context: Any) -> None:
"""Initialises the class in the ClassLists for empty data fields, sets protected parameters, and wraps
ClassList routines to control revalidation.
"""Initialises the class in the ClassLists for empty data fields, sets protected parameters, gets names of all
defined parameters and wraps ClassList routines to control revalidation.
"""
for field_name, model in model_in_classlist.items():
field = getattr(self, field_name)
Expand All @@ -145,17 +166,33 @@ def model_post_init(self, __context: Any) -> None:
fit=True, prior_type=RAT.models.Priors.Uniform, mu=0,
sigma=np.inf))

self._all_names = self.get_all_names()

# Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the
# model, handle errors and reset previous values if necessary.
class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
'contrasts']
methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend']
for class_list in class_lists:
attribute = getattr(self, class_list)
for methodName in methods_to_wrap:
setattr(attribute, methodName, self._classlist_wrapper(attribute, getattr(attribute, methodName)))

@model_validator(mode='after')
def update_renamed_models(self) -> 'Project':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function scares me with all the nested for loops. Is it possible to implement using signals/slot (observer pattern) instead?

"""When models defined in the ClassLists are renamed, we need to update that name elsewhere in the project."""
for class_list in class_lists:
old_names = self._all_names[class_list]
new_names = getattr(self, class_list).get_names()
if len(old_names) == len(new_names):
name_diff = [(old, new) for (old, new) in zip(old_names, new_names) if old != new]
for (old_name, new_name) in name_diff:
with contextlib.suppress(KeyError):
model_names_list = getattr(self, model_names_used_in[class_list].attribute)
all_matches = model_names_list.get_all_matches(old_name)
for (index, field) in all_matches:
setattr(model_names_list[index], field, new_name)
self._all_names = self.get_all_names()
return self

@model_validator(mode='after')
def cross_check_model_values(self) -> 'Project':
"""Certain model fields should contain values defined elsewhere in the project."""
Expand Down Expand Up @@ -185,6 +222,10 @@ def __repr__(self):
output += value.value + '\n\n'
return output

def get_all_names(self):
"""Record the names of all models defined in the project."""
return {class_list: getattr(self, class_list).get_names() for class_list in class_lists}

def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None:
"""Check the values of the given fields in the given model are in the supplied list of allowed values.

Expand Down
11 changes: 11 additions & 0 deletions tests/test_classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,17 @@ def test_get_names(class_list: 'ClassList', expected_names: list[str]) -> None:
assert class_list.get_names() == expected_names


@pytest.mark.parametrize(["class_list", "expected_matches"], [
(ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob')]), [(0, 'name')]),
(ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob', id='Alice')]), [(0, 'name'), (1, 'id')]),
(ClassList([InputAttributes(surname='Morgan'), InputAttributes(surname='Terwilliger')]), []),
(ClassList(InputAttributes()), []),
])
def test_get_all_matches(class_list: 'ClassList', expected_matches: list[tuple]) -> None:
"""We should get a list of (index, field) tuples matching the given value in the ClassList."""
assert class_list.get_all_matches("Alice") == expected_matches


@pytest.mark.parametrize("input_dict", [
({'name': 'Eve'}),
({'surname': 'Polastri'}),
Expand Down
Loading