Skip to content

Commit

Permalink
Adds model validator to ensure consistent renaming (#9)
Browse files Browse the repository at this point in the history
* Adds model validator to ensure renamed models are updated throughout the project

* Adds tests for renaming code

* Uses project dicts to tidy up tests

* Adds routine "get_all_matches" to "classList.py"
  • Loading branch information
DrPaulSharp authored Sep 19, 2023
1 parent 1abd5ca commit 5355150
Show file tree
Hide file tree
Showing 4 changed files with 349 additions and 235 deletions.
16 changes: 16 additions & 0 deletions RAT/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,22 @@ def get_names(self) -> list[str]:
"""
return [getattr(model, self.name_field) for model in self.data if hasattr(model, self.name_field)]

def get_all_matches(self, value: Any) -> list[tuple]:
"""Return a list of all (index, field) tuples where the value of the field is equal to the given value.
Parameters
----------
value : str
The value we are searching for in the ClassList.
Returns
-------
: list [tuple]
A list of (index, field) tuples matching the given value.
"""
return [(index, field) for index, element in enumerate(self.data) for field in vars(element)
if getattr(element, field) == value]

def _validate_name_field(self, input_args: dict[str, Any]) -> None:
"""Raise a ValueError if the name_field attribute is passed as an object parameter, and its value is already
used within the ClassList.
Expand Down
53 changes: 47 additions & 6 deletions RAT/project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""The project module. Defines and stores all the input data required for reflectivity calculations in RAT."""

import collections
import contextlib
import copy
import functools
import numpy as np
Expand Down Expand Up @@ -46,7 +48,7 @@ class Geometries(StrEnum):
'custom_files': 'CustomFile',
'data': 'Data',
'layers': 'Layer',
'contrasts': 'Contrast'
'contrasts': 'Contrast',
}

values_defined_in = {'backgrounds.value_1': 'background_parameters',
Expand All @@ -70,6 +72,23 @@ class Geometries(StrEnum):
'contrasts.resolution': 'resolutions',
}

AllFields = collections.namedtuple('AllFields', ['attribute', 'fields'])
model_names_used_in = {'background_parameters': AllFields('backgrounds', ['value_1', 'value_2', 'value_3', 'value_4',
'value_5']),
'resolution_parameters': AllFields('resolutions', ['value_1', 'value_2', 'value_3', 'value_4',
'value_5']),
'parameters': AllFields('layers', ['thickness', 'SLD', 'roughness']),
'data': AllFields('contrasts', ['data']),
'backgrounds': AllFields('contrasts', ['background']),
'bulk_in': AllFields('contrasts', ['nba']),
'bulk_out': AllFields('contrasts', ['nbs']),
'scalefactors': AllFields('contrasts', ['scalefactor']),
'resolutions': AllFields('contrasts', ['resolution']),
}

class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters', 'backgrounds',
'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers', 'contrasts']


class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
"""Defines the input data for a reflectivity calculation in RAT.
Expand Down Expand Up @@ -120,6 +139,8 @@ class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_typ
layers: ClassList = ClassList()
contrasts: ClassList = ClassList()

_all_names: dict

@field_validator('parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
'contrasts')
Expand All @@ -133,8 +154,8 @@ def check_class(cls, value: ClassList, info: FieldValidationInfo) -> ClassList:
return value

def model_post_init(self, __context: Any) -> None:
"""Initialises the class in the ClassLists for empty data fields, sets protected parameters, and wraps
ClassList routines to control revalidation.
"""Initialises the class in the ClassLists for empty data fields, sets protected parameters, gets names of all
defined parameters and wraps ClassList routines to control revalidation.
"""
for field_name, model in model_in_classlist.items():
field = getattr(self, field_name)
Expand All @@ -145,17 +166,33 @@ def model_post_init(self, __context: Any) -> None:
fit=True, prior_type=RAT.models.Priors.Uniform, mu=0,
sigma=np.inf))

self._all_names = self.get_all_names()

# Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the
# model, handle errors and reset previous values if necessary.
class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
'contrasts']
methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend']
for class_list in class_lists:
attribute = getattr(self, class_list)
for methodName in methods_to_wrap:
setattr(attribute, methodName, self._classlist_wrapper(attribute, getattr(attribute, methodName)))

@model_validator(mode='after')
def update_renamed_models(self) -> 'Project':
"""When models defined in the ClassLists are renamed, we need to update that name elsewhere in the project."""
for class_list in class_lists:
old_names = self._all_names[class_list]
new_names = getattr(self, class_list).get_names()
if len(old_names) == len(new_names):
name_diff = [(old, new) for (old, new) in zip(old_names, new_names) if old != new]
for (old_name, new_name) in name_diff:
with contextlib.suppress(KeyError):
model_names_list = getattr(self, model_names_used_in[class_list].attribute)
all_matches = model_names_list.get_all_matches(old_name)
for (index, field) in all_matches:
setattr(model_names_list[index], field, new_name)
self._all_names = self.get_all_names()
return self

@model_validator(mode='after')
def cross_check_model_values(self) -> 'Project':
"""Certain model fields should contain values defined elsewhere in the project."""
Expand Down Expand Up @@ -185,6 +222,10 @@ def __repr__(self):
output += value.value + '\n\n'
return output

def get_all_names(self):
"""Record the names of all models defined in the project."""
return {class_list: getattr(self, class_list).get_names() for class_list in class_lists}

def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None:
"""Check the values of the given fields in the given model are in the supplied list of allowed values.
Expand Down
11 changes: 11 additions & 0 deletions tests/test_classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,17 @@ def test_get_names(class_list: 'ClassList', expected_names: list[str]) -> None:
assert class_list.get_names() == expected_names


@pytest.mark.parametrize(["class_list", "expected_matches"], [
(ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob')]), [(0, 'name')]),
(ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob', id='Alice')]), [(0, 'name'), (1, 'id')]),
(ClassList([InputAttributes(surname='Morgan'), InputAttributes(surname='Terwilliger')]), []),
(ClassList(InputAttributes()), []),
])
def test_get_all_matches(class_list: 'ClassList', expected_matches: list[tuple]) -> None:
"""We should get a list of (index, field) tuples matching the given value in the ClassList."""
assert class_list.get_all_matches("Alice") == expected_matches


@pytest.mark.parametrize("input_dict", [
({'name': 'Eve'}),
({'surname': 'Polastri'}),
Expand Down
Loading

0 comments on commit 5355150

Please sign in to comment.