From 53551509244e0db5b4230ec0fe6b7b7b858ac38c Mon Sep 17 00:00:00 2001 From: Paul Sharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Tue, 19 Sep 2023 15:46:39 +0100 Subject: [PATCH] Adds model validator to ensure consistent renaming (#9) * Adds model validator to ensure renamed models are updated throughout the project * Adds tests for renaming code * Uses project dicts to tidy up tests * Adds routine "get_all_matches" to "classList.py" --- RAT/classlist.py | 16 ++ RAT/project.py | 53 ++++- tests/test_classlist.py | 11 + tests/test_project.py | 504 ++++++++++++++++++++++------------------ 4 files changed, 349 insertions(+), 235 deletions(-) diff --git a/RAT/classlist.py b/RAT/classlist.py index ce1384c5..263f6e82 100644 --- a/RAT/classlist.py +++ b/RAT/classlist.py @@ -217,6 +217,22 @@ def get_names(self) -> list[str]: """ return [getattr(model, self.name_field) for model in self.data if hasattr(model, self.name_field)] + def get_all_matches(self, value: Any) -> list[tuple]: + """Return a list of all (index, field) tuples where the value of the field is equal to the given value. + + Parameters + ---------- + value : str + The value we are searching for in the ClassList. + + Returns + ------- + : list [tuple] + A list of (index, field) tuples matching the given value. + """ + return [(index, field) for index, element in enumerate(self.data) for field in vars(element) + if getattr(element, field) == value] + def _validate_name_field(self, input_args: dict[str, Any]) -> None: """Raise a ValueError if the name_field attribute is passed as an object parameter, and its value is already used within the ClassList. diff --git a/RAT/project.py b/RAT/project.py index 039614dd..7d36b02a 100644 --- a/RAT/project.py +++ b/RAT/project.py @@ -1,5 +1,7 @@ """The project module. Defines and stores all the input data required for reflectivity calculations in RAT.""" +import collections +import contextlib import copy import functools import numpy as np @@ -46,7 +48,7 @@ class Geometries(StrEnum): 'custom_files': 'CustomFile', 'data': 'Data', 'layers': 'Layer', - 'contrasts': 'Contrast' + 'contrasts': 'Contrast', } values_defined_in = {'backgrounds.value_1': 'background_parameters', @@ -70,6 +72,23 @@ class Geometries(StrEnum): 'contrasts.resolution': 'resolutions', } +AllFields = collections.namedtuple('AllFields', ['attribute', 'fields']) +model_names_used_in = {'background_parameters': AllFields('backgrounds', ['value_1', 'value_2', 'value_3', 'value_4', + 'value_5']), + 'resolution_parameters': AllFields('resolutions', ['value_1', 'value_2', 'value_3', 'value_4', + 'value_5']), + 'parameters': AllFields('layers', ['thickness', 'SLD', 'roughness']), + 'data': AllFields('contrasts', ['data']), + 'backgrounds': AllFields('contrasts', ['background']), + 'bulk_in': AllFields('contrasts', ['nba']), + 'bulk_out': AllFields('contrasts', ['nbs']), + 'scalefactors': AllFields('contrasts', ['scalefactor']), + 'resolutions': AllFields('contrasts', ['resolution']), + } + +class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters', 'backgrounds', + 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers', 'contrasts'] + class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True): """Defines the input data for a reflectivity calculation in RAT. @@ -120,6 +139,8 @@ class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_typ layers: ClassList = ClassList() contrasts: ClassList = ClassList() + _all_names: dict + @field_validator('parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters', 'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers', 'contrasts') @@ -133,8 +154,8 @@ def check_class(cls, value: ClassList, info: FieldValidationInfo) -> ClassList: return value def model_post_init(self, __context: Any) -> None: - """Initialises the class in the ClassLists for empty data fields, sets protected parameters, and wraps - ClassList routines to control revalidation. + """Initialises the class in the ClassLists for empty data fields, sets protected parameters, gets names of all + defined parameters and wraps ClassList routines to control revalidation. """ for field_name, model in model_in_classlist.items(): field = getattr(self, field_name) @@ -145,17 +166,33 @@ def model_post_init(self, __context: Any) -> None: fit=True, prior_type=RAT.models.Priors.Uniform, mu=0, sigma=np.inf)) + self._all_names = self.get_all_names() + # Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the # model, handle errors and reset previous values if necessary. - class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters', - 'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers', - 'contrasts'] methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend'] for class_list in class_lists: attribute = getattr(self, class_list) for methodName in methods_to_wrap: setattr(attribute, methodName, self._classlist_wrapper(attribute, getattr(attribute, methodName))) + @model_validator(mode='after') + def update_renamed_models(self) -> 'Project': + """When models defined in the ClassLists are renamed, we need to update that name elsewhere in the project.""" + for class_list in class_lists: + old_names = self._all_names[class_list] + new_names = getattr(self, class_list).get_names() + if len(old_names) == len(new_names): + name_diff = [(old, new) for (old, new) in zip(old_names, new_names) if old != new] + for (old_name, new_name) in name_diff: + with contextlib.suppress(KeyError): + model_names_list = getattr(self, model_names_used_in[class_list].attribute) + all_matches = model_names_list.get_all_matches(old_name) + for (index, field) in all_matches: + setattr(model_names_list[index], field, new_name) + self._all_names = self.get_all_names() + return self + @model_validator(mode='after') def cross_check_model_values(self) -> 'Project': """Certain model fields should contain values defined elsewhere in the project.""" @@ -185,6 +222,10 @@ def __repr__(self): output += value.value + '\n\n' return output + def get_all_names(self): + """Record the names of all models defined in the project.""" + return {class_list: getattr(self, class_list).get_names() for class_list in class_lists} + def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None: """Check the values of the given fields in the given model are in the supplied list of allowed values. diff --git a/tests/test_classlist.py b/tests/test_classlist.py index d9a80461..deee1d2f 100644 --- a/tests/test_classlist.py +++ b/tests/test_classlist.py @@ -476,6 +476,17 @@ def test_get_names(class_list: 'ClassList', expected_names: list[str]) -> None: assert class_list.get_names() == expected_names +@pytest.mark.parametrize(["class_list", "expected_matches"], [ + (ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob')]), [(0, 'name')]), + (ClassList([InputAttributes(name='Alice'), InputAttributes(name='Bob', id='Alice')]), [(0, 'name'), (1, 'id')]), + (ClassList([InputAttributes(surname='Morgan'), InputAttributes(surname='Terwilliger')]), []), + (ClassList(InputAttributes()), []), +]) +def test_get_all_matches(class_list: 'ClassList', expected_matches: list[tuple]) -> None: + """We should get a list of (index, field) tuples matching the given value in the ClassList.""" + assert class_list.get_all_matches("Alice") == expected_matches + + @pytest.mark.parametrize("input_dict", [ ({'name': 'Eve'}), ({'surname': 'Polastri'}), diff --git a/tests/test_project.py b/tests/test_project.py index f0715d68..78a96a36 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -19,10 +19,10 @@ def test_project(): test_project = RAT.project.Project() test_project.data[0] = {'data': np.array([[1, 1, 1]])} test_project.parameters.append(name='Test SLD') - test_project.custom_files.append() - test_project.layers.append(SLD='Test SLD') - test_project.contrasts.append(data='Simulation', background='Background 1', nba='SLD Air', nbs='SLD D2O', - scalefactor='Scalefactor 1', resolution='Resolution 1') + test_project.custom_files.append(name='Test Custom File') + test_project.layers.append(name='Test Layer', SLD='Test SLD') + test_project.contrasts.append(name='Test Contrast', data='Simulation', background='Background 1', nba='SLD Air', + nbs='SLD D2O', scalefactor='Scalefactor 1', resolution='Resolution 1') return test_project @@ -87,7 +87,7 @@ def test_classlists(test_project) -> None: assert class_list._class_handle.__name__ == value -@pytest.mark.parametrize("model", [ +@pytest.mark.parametrize("input_model", [ RAT.models.Background, RAT.models.Contrast, RAT.models.CustomFile, @@ -95,15 +95,15 @@ def test_classlists(test_project) -> None: RAT.models.Layer, RAT.models.Resolution, ]) -def test_initialise_wrong_classes(model: Callable) -> None: +def test_initialise_wrong_classes(input_model: Callable) -> None: """If the "Project" model is initialised with incorrect classes, we should raise a ValidationError.""" with pytest.raises(pydantic.ValidationError, match='1 validation error for Project\nparameters\n Assertion ' 'failed, "parameters" ClassList contains objects other than ' '"Parameter"'): - RAT.project.Project(parameters=ClassList(model())) + RAT.project.Project(parameters=ClassList(input_model())) -@pytest.mark.parametrize(["field", "input_model"], [ +@pytest.mark.parametrize(["field", "wrong_input_model"], [ ('backgrounds', RAT.models.Resolution), ('contrasts', RAT.models.Layer), ('custom_files', RAT.models.Data), @@ -112,44 +112,77 @@ def test_initialise_wrong_classes(model: Callable) -> None: ('parameters', RAT.models.CustomFile), ('resolutions', RAT.models.Background), ]) -def test_assign_wrong_classes(test_project, field: str, input_model: Callable) -> None: +def test_assign_wrong_classes(test_project, field: str, wrong_input_model: Callable) -> None: """If we assign incorrect classes to the "Project" model, we should raise a ValidationError.""" with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n{field}\n Assertion failed, ' f'"{field}" ClassList contains objects other than ' f'"{RAT.project.model_in_classlist[field]}"'): - setattr(test_project, field, ClassList(input_model())) + setattr(test_project, field, ClassList(wrong_input_model())) -@pytest.mark.parametrize(["field", "input_model"], [ - ('backgrounds', RAT.models.Background), - ('contrasts', RAT.models.Contrast), - ('custom_files', RAT.models.CustomFile), - ('data', RAT.models.Data), - ('layers', RAT.models.Layer), - ('parameters', RAT.models.Parameter), - ('resolutions', RAT.models.Resolution), +@pytest.mark.parametrize("field", [ + 'backgrounds', + 'contrasts', + 'custom_files', + 'data', + 'layers', + 'parameters', + 'resolutions', ]) -def test_assign_models(field: str, input_model: Callable) -> None: - """If the "Project" model is initialised with models rather than ClassLists, we should raise a ValidationError. - """ - empty_project = RAT.project.Project.model_construct() +def test_assign_models(test_project, field: str) -> None: + """If the "Project" model is initialised with models rather than ClassLists, we should raise a ValidationError.""" + input_model = getattr(RAT.models, RAT.project.model_in_classlist[field]) with pytest.raises(pydantic.ValidationError, match=f'1 validation error for Project\n{field}\n Input should be an ' f'instance of ClassList'): - setattr(empty_project, field, input_model()) + setattr(test_project, field, input_model()) def test_wrapped_routines(test_project) -> None: """When initialising a project, several ClassList routines should be wrapped.""" - class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters', - 'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers', - 'contrasts'] wrapped_methods = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend'] - for class_list in class_lists: + for class_list in RAT.project.class_lists: attribute = getattr(test_project, class_list) for methodName in wrapped_methods: assert hasattr(getattr(attribute, methodName), '__wrapped__') +@pytest.mark.parametrize(["model", "field"], [ + ('background_parameters', 'value_1'), + ('resolution_parameters', 'value_1'), + ('parameters', 'SLD'), + ('data', 'data'), + ('backgrounds', 'background'), + ('bulk_in', 'nba'), + ('bulk_out', 'nbs'), + ('scalefactors', 'scalefactor'), + ('resolutions', 'resolution'), +]) +def test_rename_models(test_project, model: str, field: str) -> None: + """When renaming a model in the project, the new name should be recorded when that model is referred to elsewhere + in the project. + """ + getattr(test_project, model)[-1] = {'name': 'New Name'} + attribute = RAT.project.model_names_used_in[model].attribute + assert getattr(getattr(test_project, attribute)[-1], field) == 'New Name' + + +def test_get_all_names(test_project) -> None: + """We should be able to get the names of all the models defined in the project.""" + assert test_project.get_all_names() == {'parameters': ['Substrate Roughness', 'Test SLD'], + 'bulk_in': ['SLD Air'], + 'bulk_out': ['SLD D2O'], + 'qz_shifts': ['Qz shift 1'], + 'scalefactors': ['Scalefactor 1'], + 'background_parameters': ['Background Param 1'], + 'backgrounds': ['Background 1'], + 'resolution_parameters': ['Resolution Param 1'], + 'resolutions': ['Resolution 1'], + 'custom_files': ['Test Custom File'], + 'data': ['Simulation'], + 'layers': ['Test Layer'], + 'contrasts': ['Test Contrast']} + + @pytest.mark.parametrize("field", [ 'value_1', 'value_2', @@ -247,28 +280,28 @@ def test_check_allowed_values_not_on_list(test_value: str) -> None: test_project.check_allowed_values("backgrounds", ["value_1"], ["Background Param 1"]) -@pytest.mark.parametrize(["class_list", "input_model", "field"], [ - ('backgrounds', RAT.models.Background, 'value_1'), - ('backgrounds', RAT.models.Background, 'value_2'), - ('backgrounds', RAT.models.Background, 'value_3'), - ('backgrounds', RAT.models.Background, 'value_4'), - ('backgrounds', RAT.models.Background, 'value_5'), - ('resolutions', RAT.models.Resolution, 'value_1'), - ('resolutions', RAT.models.Resolution, 'value_2'), - ('resolutions', RAT.models.Resolution, 'value_3'), - ('resolutions', RAT.models.Resolution, 'value_4'), - ('resolutions', RAT.models.Resolution, 'value_5'), - ('layers', RAT.models.Layer, 'thickness'), - ('layers', RAT.models.Layer, 'SLD'), - ('layers', RAT.models.Layer, 'roughness'), - ('contrasts', RAT.models.Contrast, 'data'), - ('contrasts', RAT.models.Contrast, 'background'), - ('contrasts', RAT.models.Contrast, 'nba'), - ('contrasts', RAT.models.Contrast, 'nbs'), - ('contrasts', RAT.models.Contrast, 'scalefactor'), - ('contrasts', RAT.models.Contrast, 'resolution'), +@pytest.mark.parametrize(["class_list", "field"], [ + ('backgrounds', 'value_1'), + ('backgrounds', 'value_2'), + ('backgrounds', 'value_3'), + ('backgrounds', 'value_4'), + ('backgrounds', 'value_5'), + ('resolutions', 'value_1'), + ('resolutions', 'value_2'), + ('resolutions', 'value_3'), + ('resolutions', 'value_4'), + ('resolutions', 'value_5'), + ('layers', 'thickness'), + ('layers', 'SLD'), + ('layers', 'roughness'), + ('contrasts', 'data'), + ('contrasts', 'background'), + ('contrasts', 'nba'), + ('contrasts', 'nbs'), + ('contrasts', 'scalefactor'), + ('contrasts', 'resolution'), ]) -def test_wrap_set(test_project, class_list: str, input_model: Callable, field: str) -> None: +def test_wrap_set(test_project, class_list: str, field: str) -> None: """If we set the field values of a model in a ClassList as undefined values, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) @@ -282,17 +315,17 @@ def test_wrap_set(test_project, class_list: str, input_model: Callable, field: s assert test_attribute == orig_class_list -@pytest.mark.parametrize(["class_list", "parameter", "parent_list", "field"], [ - ('background_parameters', 'Background Param 1', 'backgrounds', 'value_1'), - ('resolution_parameters', 'Resolution Param 1', 'resolutions', 'value_1'), - ('parameters', 'Test SLD', 'layers', 'SLD'), - ('backgrounds', 'Background 1', 'contrasts', 'background'), - ('bulk_in', 'SLD Air', 'contrasts', 'nba'), - ('bulk_out', 'SLD D2O', 'contrasts', 'nbs'), - ('scalefactors', 'Scalefactor 1', 'contrasts', 'scalefactor'), - ('resolutions', 'Resolution 1', 'contrasts', 'resolution'), +@pytest.mark.parametrize(["class_list", "parameter", "field"], [ + ('background_parameters', 'Background Param 1', 'value_1'), + ('resolution_parameters', 'Resolution Param 1', 'value_1'), + ('parameters', 'Test SLD', 'SLD'), + ('backgrounds', 'Background 1', 'background'), + ('bulk_in', 'SLD Air', 'nba'), + ('bulk_out', 'SLD D2O', 'nbs'), + ('scalefactors', 'Scalefactor 1', 'scalefactor'), + ('resolutions', 'Resolution 1', 'resolution'), ]) -def test_wrap_del(test_project, class_list: str, parameter: str, parent_list: str, field: str) -> None: +def test_wrap_del(test_project, class_list: str, parameter: str, field: str) -> None: """If we delete a model in a ClassList containing values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) @@ -300,17 +333,18 @@ def test_wrap_del(test_project, class_list: str, parameter: str, parent_list: st index = test_attribute.index(parameter) with contextlib.redirect_stdout(io.StringIO()) as print_str: del test_attribute[index] - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}"' - f' in the "{field}" field of "{parent_list}" must be defined in ' - f'"{class_list}".\033[0m\n') + assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' + f'in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".\033[0m\n') # Ensure model was not deleted assert test_attribute == orig_class_list -@pytest.mark.parametrize(["class_list", "parameter", "parent_list", "field"], [ - ('data', 'Simulation', 'contrasts', 'data'), +@pytest.mark.parametrize(["class_list", "parameter", "field"], [ + ('data', 'Simulation', 'data'), ]) -def test_wrap_del_data(test_project, class_list: str, parameter: str, parent_list: str, field: str) -> None: +def test_wrap_del_data(test_project, class_list: str, parameter: str, field: str) -> None: """If we delete a Data model in a ClassList containing values defined elsewhere, we should raise a ValidationError. """ test_attribute = getattr(test_project, class_list) @@ -319,9 +353,10 @@ def test_wrap_del_data(test_project, class_list: str, parameter: str, parent_lis index = test_attribute.index(parameter) with contextlib.redirect_stdout(io.StringIO()) as print_str: del test_attribute[index] - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}"' - f' in the "{field}" field of "{parent_list}" must be defined in ' - f'"{class_list}".\033[0m\n') + assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' + f'in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".\033[0m\n') # Ensure model was not deleted assert test_attribute[index].name == orig_class_list[index].name @@ -330,31 +365,32 @@ def test_wrap_del_data(test_project, class_list: str, parameter: str, parent_lis assert test_attribute[index].simulation_range == orig_class_list[index].simulation_range -@pytest.mark.parametrize(["class_list", "input_model", "field"], [ - ('backgrounds', RAT.models.Background, 'value_1'), - ('backgrounds', RAT.models.Background, 'value_2'), - ('backgrounds', RAT.models.Background, 'value_3'), - ('backgrounds', RAT.models.Background, 'value_4'), - ('backgrounds', RAT.models.Background, 'value_5'), - ('resolutions', RAT.models.Resolution, 'value_1'), - ('resolutions', RAT.models.Resolution, 'value_2'), - ('resolutions', RAT.models.Resolution, 'value_3'), - ('resolutions', RAT.models.Resolution, 'value_4'), - ('resolutions', RAT.models.Resolution, 'value_5'), - ('layers', RAT.models.Layer, 'thickness'), - ('layers', RAT.models.Layer, 'SLD'), - ('layers', RAT.models.Layer, 'roughness'), - ('contrasts', RAT.models.Contrast, 'data'), - ('contrasts', RAT.models.Contrast, 'background'), - ('contrasts', RAT.models.Contrast, 'nba'), - ('contrasts', RAT.models.Contrast, 'nbs'), - ('contrasts', RAT.models.Contrast, 'scalefactor'), - ('contrasts', RAT.models.Contrast, 'resolution'), +@pytest.mark.parametrize(["class_list", "field"], [ + ('backgrounds', 'value_1'), + ('backgrounds', 'value_2'), + ('backgrounds', 'value_3'), + ('backgrounds', 'value_4'), + ('backgrounds', 'value_5'), + ('resolutions', 'value_1'), + ('resolutions', 'value_2'), + ('resolutions', 'value_3'), + ('resolutions', 'value_4'), + ('resolutions', 'value_5'), + ('layers', 'thickness'), + ('layers', 'SLD'), + ('layers', 'roughness'), + ('contrasts', 'data'), + ('contrasts', 'background'), + ('contrasts', 'nba'), + ('contrasts', 'nbs'), + ('contrasts', 'scalefactor'), + ('contrasts', 'resolution'), ]) -def test_wrap_iadd(test_project, class_list: str, input_model: Callable, field: str) -> None: +def test_wrap_iadd(test_project, class_list: str, field: str) -> None: """If we add a model containing undefined values to a ClassList, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) + input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) with contextlib.redirect_stdout(io.StringIO()) as print_str: test_attribute += [input_model(**{field: 'undefined'})] @@ -364,31 +400,32 @@ def test_wrap_iadd(test_project, class_list: str, input_model: Callable, field: # Ensure invalid model was not added assert test_attribute == orig_class_list -@pytest.mark.parametrize(["class_list", "input_model", "field"], [ - ('backgrounds', RAT.models.Background, 'value_1'), - ('backgrounds', RAT.models.Background, 'value_2'), - ('backgrounds', RAT.models.Background, 'value_3'), - ('backgrounds', RAT.models.Background, 'value_4'), - ('backgrounds', RAT.models.Background, 'value_5'), - ('resolutions', RAT.models.Resolution, 'value_1'), - ('resolutions', RAT.models.Resolution, 'value_2'), - ('resolutions', RAT.models.Resolution, 'value_3'), - ('resolutions', RAT.models.Resolution, 'value_4'), - ('resolutions', RAT.models.Resolution, 'value_5'), - ('layers', RAT.models.Layer, 'thickness'), - ('layers', RAT.models.Layer, 'SLD'), - ('layers', RAT.models.Layer, 'roughness'), - ('contrasts', RAT.models.Contrast, 'data'), - ('contrasts', RAT.models.Contrast, 'background'), - ('contrasts', RAT.models.Contrast, 'nba'), - ('contrasts', RAT.models.Contrast, 'nbs'), - ('contrasts', RAT.models.Contrast, 'scalefactor'), - ('contrasts', RAT.models.Contrast, 'resolution'), +@pytest.mark.parametrize(["class_list", "field"], [ + ('backgrounds', 'value_1'), + ('backgrounds', 'value_2'), + ('backgrounds', 'value_3'), + ('backgrounds', 'value_4'), + ('backgrounds', 'value_5'), + ('resolutions', 'value_1'), + ('resolutions', 'value_2'), + ('resolutions', 'value_3'), + ('resolutions', 'value_4'), + ('resolutions', 'value_5'), + ('layers', 'thickness'), + ('layers', 'SLD'), + ('layers', 'roughness'), + ('contrasts', 'data'), + ('contrasts', 'background'), + ('contrasts', 'nba'), + ('contrasts', 'nbs'), + ('contrasts', 'scalefactor'), + ('contrasts', 'resolution'), ]) -def test_wrap_append(test_project, class_list: str, input_model: Callable, field: str) -> None: +def test_wrap_append(test_project, class_list: str, field: str) -> None: """If we append a model containing undefined values to a ClassList, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) + input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) with contextlib.redirect_stdout(io.StringIO()) as print_str: test_attribute.append(input_model(**{field: 'undefined'})) @@ -398,31 +435,32 @@ def test_wrap_append(test_project, class_list: str, input_model: Callable, field # Ensure invalid model was not appended assert test_attribute == orig_class_list -@pytest.mark.parametrize(["class_list", "input_model", "field"], [ - ('backgrounds', RAT.models.Background, 'value_1'), - ('backgrounds', RAT.models.Background, 'value_2'), - ('backgrounds', RAT.models.Background, 'value_3'), - ('backgrounds', RAT.models.Background, 'value_4'), - ('backgrounds', RAT.models.Background, 'value_5'), - ('resolutions', RAT.models.Resolution, 'value_1'), - ('resolutions', RAT.models.Resolution, 'value_2'), - ('resolutions', RAT.models.Resolution, 'value_3'), - ('resolutions', RAT.models.Resolution, 'value_4'), - ('resolutions', RAT.models.Resolution, 'value_5'), - ('layers', RAT.models.Layer, 'thickness'), - ('layers', RAT.models.Layer, 'SLD'), - ('layers', RAT.models.Layer, 'roughness'), - ('contrasts', RAT.models.Contrast, 'data'), - ('contrasts', RAT.models.Contrast, 'background'), - ('contrasts', RAT.models.Contrast, 'nba'), - ('contrasts', RAT.models.Contrast, 'nbs'), - ('contrasts', RAT.models.Contrast, 'scalefactor'), - ('contrasts', RAT.models.Contrast, 'resolution'), +@pytest.mark.parametrize(["class_list", "field"], [ + ('backgrounds', 'value_1'), + ('backgrounds', 'value_2'), + ('backgrounds', 'value_3'), + ('backgrounds', 'value_4'), + ('backgrounds', 'value_5'), + ('resolutions', 'value_1'), + ('resolutions', 'value_2'), + ('resolutions', 'value_3'), + ('resolutions', 'value_4'), + ('resolutions', 'value_5'), + ('layers', 'thickness'), + ('layers', 'SLD'), + ('layers', 'roughness'), + ('contrasts', 'data'), + ('contrasts', 'background'), + ('contrasts', 'nba'), + ('contrasts', 'nbs'), + ('contrasts', 'scalefactor'), + ('contrasts', 'resolution'), ]) -def test_wrap_insert(test_project, class_list: str, input_model: Callable, field: str) -> None: +def test_wrap_insert(test_project, class_list: str, field: str) -> None: """If we insert a model containing undefined values into a ClassList, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) + input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) with contextlib.redirect_stdout(io.StringIO()) as print_str: test_attribute.insert(0, input_model(**{field: 'undefined'})) @@ -433,31 +471,32 @@ def test_wrap_insert(test_project, class_list: str, input_model: Callable, field assert test_attribute == orig_class_list -@pytest.mark.parametrize(["class_list", "input_model", "field"], [ - ('backgrounds', RAT.models.Background, 'value_1'), - ('backgrounds', RAT.models.Background, 'value_2'), - ('backgrounds', RAT.models.Background, 'value_3'), - ('backgrounds', RAT.models.Background, 'value_4'), - ('backgrounds', RAT.models.Background, 'value_5'), - ('resolutions', RAT.models.Resolution, 'value_1'), - ('resolutions', RAT.models.Resolution, 'value_2'), - ('resolutions', RAT.models.Resolution, 'value_3'), - ('resolutions', RAT.models.Resolution, 'value_4'), - ('resolutions', RAT.models.Resolution, 'value_5'), - ('layers', RAT.models.Layer, 'thickness'), - ('layers', RAT.models.Layer, 'SLD'), - ('layers', RAT.models.Layer, 'roughness'), - ('contrasts', RAT.models.Contrast, 'data'), - ('contrasts', RAT.models.Contrast, 'background'), - ('contrasts', RAT.models.Contrast, 'nba'), - ('contrasts', RAT.models.Contrast, 'nbs'), - ('contrasts', RAT.models.Contrast, 'scalefactor'), - ('contrasts', RAT.models.Contrast, 'resolution'), +@pytest.mark.parametrize(["class_list", "field"], [ + ('backgrounds', 'value_1'), + ('backgrounds', 'value_2'), + ('backgrounds', 'value_3'), + ('backgrounds', 'value_4'), + ('backgrounds', 'value_5'), + ('resolutions', 'value_1'), + ('resolutions', 'value_2'), + ('resolutions', 'value_3'), + ('resolutions', 'value_4'), + ('resolutions', 'value_5'), + ('layers', 'thickness'), + ('layers', 'SLD'), + ('layers', 'roughness'), + ('contrasts', 'data'), + ('contrasts', 'background'), + ('contrasts', 'nba'), + ('contrasts', 'nbs'), + ('contrasts', 'scalefactor'), + ('contrasts', 'resolution'), ]) -def test_wrap_insert_type_error(test_project, class_list: str, input_model: Callable, field: str) -> None: +def test_wrap_insert_type_error(test_project, class_list: str, field: str) -> None: """If we raise a TypeError using the wrapped insert routine, we should re-raise the error.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) + input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) with pytest.raises(TypeError): test_attribute.insert(input_model(**{field: 'undefined'})) @@ -466,17 +505,17 @@ def test_wrap_insert_type_error(test_project, class_list: str, input_model: Call assert test_attribute == orig_class_list -@pytest.mark.parametrize(["class_list", "parameter", "parent_list", "field"], [ - ('background_parameters', 'Background Param 1', 'backgrounds', 'value_1'), - ('resolution_parameters', 'Resolution Param 1', 'resolutions', 'value_1'), - ('parameters', 'Test SLD', 'layers', 'SLD'), - ('backgrounds', 'Background 1', 'contrasts', 'background'), - ('bulk_in', 'SLD Air', 'contrasts', 'nba'), - ('bulk_out', 'SLD D2O', 'contrasts', 'nbs'), - ('scalefactors', 'Scalefactor 1', 'contrasts', 'scalefactor'), - ('resolutions', 'Resolution 1', 'contrasts', 'resolution'), +@pytest.mark.parametrize(["class_list", "parameter", "field"], [ + ('background_parameters', 'Background Param 1', 'value_1'), + ('resolution_parameters', 'Resolution Param 1', 'value_1'), + ('parameters', 'Test SLD', 'SLD'), + ('backgrounds', 'Background 1', 'background'), + ('bulk_in', 'SLD Air', 'nba'), + ('bulk_out', 'SLD D2O', 'nbs'), + ('scalefactors', 'Scalefactor 1', 'scalefactor'), + ('resolutions', 'Resolution 1', 'resolution'), ]) -def test_wrap_pop(test_project, class_list: str, parameter: str, parent_list: str, field: str) -> None: +def test_wrap_pop(test_project, class_list: str, parameter: str, field: str) -> None: """If we pop a model in a ClassList containing values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) @@ -484,17 +523,18 @@ def test_wrap_pop(test_project, class_list: str, parameter: str, parent_list: st index = test_attribute.index(parameter) with contextlib.redirect_stdout(io.StringIO()) as print_str: test_attribute.pop(index) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}"' - f' in the "{field}" field of "{parent_list}" must be defined in ' - f'"{class_list}".\033[0m\n') + assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' + f'in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".\033[0m\n') # Ensure model was not popped assert test_attribute == orig_class_list -@pytest.mark.parametrize(["class_list", "parameter", "parent_list", "field"], [ - ('data', 'Simulation', 'contrasts', 'data'), +@pytest.mark.parametrize(["class_list", "parameter", "field"], [ + ('data', 'Simulation', 'data'), ]) -def test_wrap_pop_data(test_project, class_list: str, parameter: str, parent_list: str, field: str) -> None: +def test_wrap_pop_data(test_project, class_list: str, parameter: str, field: str) -> None: """If we pop a Data model in a ClassList containing values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) @@ -502,9 +542,10 @@ def test_wrap_pop_data(test_project, class_list: str, parameter: str, parent_lis index = test_attribute.index(parameter) with contextlib.redirect_stdout(io.StringIO()) as print_str: test_attribute.pop(index) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}"' - f' in the "{field}" field of "{parent_list}" must be defined in ' - f'"{class_list}".\033[0m\n') + assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' + f'in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".\033[0m\n') # Ensure model was not popped assert test_attribute[index].name == orig_class_list[index].name @@ -513,34 +554,35 @@ def test_wrap_pop_data(test_project, class_list: str, parameter: str, parent_lis assert test_attribute[index].simulation_range == orig_class_list[index].simulation_range -@pytest.mark.parametrize(["class_list", "parameter", "parent_list", "field"], [ - ('background_parameters', 'Background Param 1', 'backgrounds', 'value_1'), - ('resolution_parameters', 'Resolution Param 1', 'resolutions', 'value_1'), - ('parameters', 'Test SLD', 'layers', 'SLD'), - ('backgrounds', 'Background 1', 'contrasts', 'background'), - ('bulk_in', 'SLD Air', 'contrasts', 'nba'), - ('bulk_out', 'SLD D2O', 'contrasts', 'nbs'), - ('scalefactors', 'Scalefactor 1', 'contrasts', 'scalefactor'), - ('resolutions', 'Resolution 1', 'contrasts', 'resolution'), +@pytest.mark.parametrize(["class_list", "parameter", "field"], [ + ('background_parameters', 'Background Param 1', 'value_1'), + ('resolution_parameters', 'Resolution Param 1', 'value_1'), + ('parameters', 'Test SLD', 'SLD'), + ('backgrounds', 'Background 1', 'background'), + ('bulk_in', 'SLD Air', 'nba'), + ('bulk_out', 'SLD D2O', 'nbs'), + ('scalefactors', 'Scalefactor 1', 'scalefactor'), + ('resolutions', 'Resolution 1', 'resolution'), ]) -def test_wrap_remove(test_project, class_list: str, parameter: str, parent_list: str, field: str) -> None: +def test_wrap_remove(test_project, class_list: str, parameter: str, field: str) -> None: """If we remove a model in a ClassList containing values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) with contextlib.redirect_stdout(io.StringIO()) as print_str: test_attribute.remove(parameter) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}"' - f' in the "{field}" field of "{parent_list}" must be defined in ' - f'"{class_list}".\033[0m\n') + assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' + f'in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".\033[0m\n') # Ensure model was not removed assert test_attribute == orig_class_list -@pytest.mark.parametrize(["class_list", "parameter", "parent_list", "field"], [ - ('data', 'Simulation', 'contrasts', 'data'), +@pytest.mark.parametrize(["class_list", "parameter", "field"], [ + ('data', 'Simulation', 'data'), ]) -def test_wrap_remove_data(test_project, class_list: str, parameter: str, parent_list: str, field: str) -> None: +def test_wrap_remove_data(test_project, class_list: str, parameter: str, field: str) -> None: """If we remove a Data model in a ClassList containing values defined elsewhere, we should raise a ValidationError. """ test_attribute = getattr(test_project, class_list) @@ -549,9 +591,10 @@ def test_wrap_remove_data(test_project, class_list: str, parameter: str, parent_ with contextlib.redirect_stdout(io.StringIO()) as print_str: test_attribute.remove(parameter) - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}"' - f' in the "{field}" field of "{parent_list}" must be defined in ' - f'"{class_list}".\033[0m\n') + assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' + f'in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".\033[0m\n') # Ensure model was not removed assert test_attribute[index].name == orig_class_list[index].name @@ -560,34 +603,35 @@ def test_wrap_remove_data(test_project, class_list: str, parameter: str, parent_ assert test_attribute[index].simulation_range == orig_class_list[index].simulation_range -@pytest.mark.parametrize(["class_list", "parameter", "parent_list", "field"], [ - ('background_parameters', 'Background Param 1', 'backgrounds', 'value_1'), - ('resolution_parameters', 'Resolution Param 1', 'resolutions', 'value_1'), - ('parameters', 'Test SLD', 'layers', 'SLD'), - ('backgrounds', 'Background 1', 'contrasts', 'background'), - ('bulk_in', 'SLD Air', 'contrasts', 'nba'), - ('bulk_out', 'SLD D2O', 'contrasts', 'nbs'), - ('scalefactors', 'Scalefactor 1', 'contrasts', 'scalefactor'), - ('resolutions', 'Resolution 1', 'contrasts', 'resolution'), +@pytest.mark.parametrize(["class_list", "parameter", "field"], [ + ('background_parameters', 'Background Param 1', 'value_1'), + ('resolution_parameters', 'Resolution Param 1', 'value_1'), + ('parameters', 'Test SLD', 'SLD'), + ('backgrounds', 'Background 1', 'background'), + ('bulk_in', 'SLD Air', 'nba'), + ('bulk_out', 'SLD D2O', 'nbs'), + ('scalefactors', 'Scalefactor 1', 'scalefactor'), + ('resolutions', 'Resolution 1', 'resolution'), ]) -def test_wrap_clear(test_project, class_list: str, parameter: str, parent_list: str, field: str) -> None: +def test_wrap_clear(test_project, class_list: str, parameter: str, field: str) -> None: """If we clear a ClassList containing models with values defined elsewhere, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) with contextlib.redirect_stdout(io.StringIO()) as print_str: test_attribute.clear() - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}"' - f' in the "{field}" field of "{parent_list}" must be defined in ' - f'"{class_list}".\033[0m\n') + assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' + f'in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".\033[0m\n') # Ensure list was not cleared assert test_attribute == orig_class_list -@pytest.mark.parametrize(["class_list", "parameter", "parent_list", "field"], [ - ('data', 'Simulation', 'contrasts', 'data'), +@pytest.mark.parametrize(["class_list", "parameter", "field"], [ + ('data', 'Simulation', 'data'), ]) -def test_wrap_clear_data(test_project, class_list: str, parameter: str, parent_list: str, field: str) -> None: +def test_wrap_clear_data(test_project, class_list: str, parameter: str, field: str) -> None: """If we clear a ClassList containing Data models with values defined elsewhere, we should raise a ValidationError. """ test_attribute = getattr(test_project, class_list) @@ -595,9 +639,10 @@ def test_wrap_clear_data(test_project, class_list: str, parameter: str, parent_l with contextlib.redirect_stdout(io.StringIO()) as print_str: test_attribute.clear() - assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}"' - f' in the "{field}" field of "{parent_list}" must be defined in ' - f'"{class_list}".\033[0m\n') + assert print_str.getvalue() == (f'\033[31m1 validation error for Project\n Value error, The value "{parameter}" ' + f'in the "{field}" field of ' + f'"{RAT.project.model_names_used_in[class_list].attribute}" ' + f'must be defined in "{class_list}".\033[0m\n') # Ensure list was not cleared for index in range(len(test_attribute)): assert test_attribute[index].name == orig_class_list[index].name @@ -606,31 +651,32 @@ def test_wrap_clear_data(test_project, class_list: str, parameter: str, parent_l assert test_attribute[index].simulation_range == orig_class_list[index].simulation_range -@pytest.mark.parametrize(["class_list", "input_model", "field"], [ - ('backgrounds', RAT.models.Background, 'value_1'), - ('backgrounds', RAT.models.Background, 'value_2'), - ('backgrounds', RAT.models.Background, 'value_3'), - ('backgrounds', RAT.models.Background, 'value_4'), - ('backgrounds', RAT.models.Background, 'value_5'), - ('resolutions', RAT.models.Resolution, 'value_1'), - ('resolutions', RAT.models.Resolution, 'value_2'), - ('resolutions', RAT.models.Resolution, 'value_3'), - ('resolutions', RAT.models.Resolution, 'value_4'), - ('resolutions', RAT.models.Resolution, 'value_5'), - ('layers', RAT.models.Layer, 'thickness'), - ('layers', RAT.models.Layer, 'SLD'), - ('layers', RAT.models.Layer, 'roughness'), - ('contrasts', RAT.models.Contrast, 'data'), - ('contrasts', RAT.models.Contrast, 'background'), - ('contrasts', RAT.models.Contrast, 'nba'), - ('contrasts', RAT.models.Contrast, 'nbs'), - ('contrasts', RAT.models.Contrast, 'scalefactor'), - ('contrasts', RAT.models.Contrast, 'resolution'), +@pytest.mark.parametrize(["class_list", "field"], [ + ('backgrounds', 'value_1'), + ('backgrounds', 'value_2'), + ('backgrounds', 'value_3'), + ('backgrounds', 'value_4'), + ('backgrounds', 'value_5'), + ('resolutions', 'value_1'), + ('resolutions', 'value_2'), + ('resolutions', 'value_3'), + ('resolutions', 'value_4'), + ('resolutions', 'value_5'), + ('layers', 'thickness'), + ('layers', 'SLD'), + ('layers', 'roughness'), + ('contrasts', 'data'), + ('contrasts', 'background'), + ('contrasts', 'nba'), + ('contrasts', 'nbs'), + ('contrasts', 'scalefactor'), + ('contrasts', 'resolution'), ]) -def test_wrap_extend(test_project, class_list: str, input_model: Callable, field: str) -> None: +def test_wrap_extend(test_project, class_list: str, field: str) -> None: """If we extend a ClassList with model containing undefined values, we should raise a ValidationError.""" test_attribute = getattr(test_project, class_list) orig_class_list = copy.deepcopy(test_attribute) + input_model = getattr(RAT.models, RAT.project.model_in_classlist[class_list]) with contextlib.redirect_stdout(io.StringIO()) as print_str: test_attribute.extend([input_model(**{field: 'undefined'})])