Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds model validator to ensure consistent renaming #9

Merged
merged 4 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions RAT/project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""The project module. Defines and stores all the input data required for reflectivity calculations in RAT."""

import collections
import contextlib
import copy
import functools
import numpy as np
Expand Down Expand Up @@ -46,7 +48,7 @@ class Geometries(StrEnum):
'custom_files': 'CustomFile',
'data': 'Data',
'layers': 'Layer',
'contrasts': 'Contrast'
'contrasts': 'Contrast',
}

values_defined_in = {'backgrounds.value_1': 'background_parameters',
Expand All @@ -70,6 +72,23 @@ class Geometries(StrEnum):
'contrasts.resolution': 'resolutions',
}

AllFields = collections.namedtuple('AllFields', ['attribute', 'fields'])
model_names_used_in = {'background_parameters': AllFields('backgrounds', ['value_1', 'value_2', 'value_3', 'value_4',
'value_5']),
'resolution_parameters': AllFields('resolutions', ['value_1', 'value_2', 'value_3', 'value_4',
'value_5']),
'parameters': AllFields('layers', ['thickness', 'SLD', 'roughness']),
'data': AllFields('contrasts', ['data']),
'backgrounds': AllFields('contrasts', ['background']),
'bulk_in': AllFields('contrasts', ['nba']),
'bulk_out': AllFields('contrasts', ['nbs']),
'scalefactors': AllFields('contrasts', ['scalefactor']),
'resolutions': AllFields('contrasts', ['resolution']),
}

class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters', 'backgrounds',
'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers', 'contrasts']


class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
"""Defines the input data for a reflectivity calculation in RAT.
Expand Down Expand Up @@ -120,6 +139,8 @@ class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_typ
layers: ClassList = ClassList()
contrasts: ClassList = ClassList()

_all_names: dict

@field_validator('parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
'contrasts')
Expand All @@ -133,8 +154,8 @@ def check_class(cls, value: ClassList, info: FieldValidationInfo) -> ClassList:
return value

def model_post_init(self, __context: Any) -> None:
"""Initialises the class in the ClassLists for empty data fields, sets protected parameters, and wraps
ClassList routines to control revalidation.
"""Initialises the class in the ClassLists for empty data fields, sets protected parameters, gets names of all
defined parameters and wraps ClassList routines to control revalidation.
"""
for field_name, model in model_in_classlist.items():
field = getattr(self, field_name)
Expand All @@ -145,17 +166,33 @@ def model_post_init(self, __context: Any) -> None:
fit=True, prior_type=RAT.models.Priors.Uniform, mu=0,
sigma=np.inf))

self._all_names = self.get_all_names()

# Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the
# model, handle errors and reset previous values if necessary.
class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
'contrasts']
methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend']
for class_list in class_lists:
attribute = getattr(self, class_list)
for methodName in methods_to_wrap:
setattr(attribute, methodName, self._classlist_wrapper(attribute, getattr(attribute, methodName)))

@model_validator(mode='after')
def update_renamed_models(self) -> 'Project':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function scares me with all the nested for loops. Is it possible to implement using signals/slot (observer pattern) instead?

"""When models defined in the ClassLists are renamed, we need to update that name elsewhere in the project."""
for class_list in class_lists:
old_names = self._all_names[class_list]
new_names = getattr(self, class_list).get_names()
if len(old_names) == len(new_names):
name_diff = [(old, new) for (old, new) in zip(old_names, new_names) if old != new]
for (old_name, new_name) in name_diff:
with contextlib.suppress(KeyError):
for element in getattr(self, model_names_used_in[class_list].attribute):
for field in model_names_used_in[class_list].fields:
if getattr(element, field) == old_name:
setattr(element, field, new_name)
self._all_names = self.get_all_names()
return self

@model_validator(mode='after')
def cross_check_model_values(self) -> 'Project':
"""Certain model fields should contain values defined elsewhere in the project."""
Expand Down Expand Up @@ -185,6 +222,10 @@ def __repr__(self):
output += value.value + '\n\n'
return output

def get_all_names(self):
"""Record the names of all models defined in the project."""
return {class_list: getattr(self, class_list).get_names() for class_list in class_lists}

def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None:
"""Check the values of the given fields in the given model are in the supplied list of allowed values.

Expand Down
Loading