Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Models and Project for RAT API #6

Merged
merged 20 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
c277ed0
Adds "models.py" with initial draft of pydantic models for API classes
DrPaulSharp Jul 31, 2023
98851d5
Adds validators for pydantic models
DrPaulSharp Jul 31, 2023
5ffb29b
Fixes parameter names in background model
DrPaulSharp Aug 1, 2023
75be6f6
Adds new model ProtectedParameter
DrPaulSharp Aug 1, 2023
4e66327
Adds "project.py" with initial draft of the high level "Project" class
DrPaulSharp Aug 3, 2023
b1f4574
Adds "model_post_init" routine for the "project" model
DrPaulSharp Aug 3, 2023
a05b35a
Adds "__repr__" routine for the "project" model
DrPaulSharp Aug 4, 2023
2ede5a1
Adds code to work with updated ClassList
DrPaulSharp Aug 7, 2023
760e686
Moves validators to cross-check project fields from "models.py" to "p…
DrPaulSharp Aug 11, 2023
d216c8c
Replaces annotated validators with single field validator in "project…
DrPaulSharp Aug 11, 2023
5e98ba0
Add contrasts to cross-checking model validator in "project.py"
DrPaulSharp Aug 11, 2023
035511c
Changes data model to accept numpy array
DrPaulSharp Aug 14, 2023
e39a252
Adds docs and modifies the Project class's "model_post_init" to ensur…
DrPaulSharp Aug 14, 2023
a5aeb48
Adds tests "test_models.py"
DrPaulSharp Aug 15, 2023
19051e6
Removes unused routine "get_all_names" from "project.py"
DrPaulSharp Aug 16, 2023
2ca3aa3
Adds tests "test_project.py"
DrPaulSharp Aug 16, 2023
4b69089
Tidies up model and project classes and tests
DrPaulSharp Aug 16, 2023
2257bbb
Adds code to fix enums for all python versions
DrPaulSharp Aug 17, 2023
c6ec305
Adds code to stop "test_repr" in "test_project.py" writing to console
DrPaulSharp Aug 17, 2023
55e07a1
Addresses review comments
DrPaulSharp Aug 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion RAT/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _check_classes(self, input_list: Iterable[object]) -> None:
Raised if the input list defines objects of different types.
"""
if not (all(isinstance(element, self._class_handle) for element in input_list)):
raise ValueError(f"Input list contains elements of type other than '{self._class_handle}'")
raise ValueError(f"Input list contains elements of type other than '{self._class_handle.__name__}'")

def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object, str]:
"""Return the object with the given value of the name_field attribute in the ClassList.
Expand Down
166 changes: 166 additions & 0 deletions RAT/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""The models module. Contains the pydantic models used by RAT to store project parameters."""

import numpy as np
from pydantic import BaseModel, Field, FieldValidationInfo, field_validator, model_validator

try:
from enum import StrEnum
except ImportError:
from strenum import StrEnum


def int_sequence():
"""Iterate through integers for use as model counters."""
num = 1
while True:
yield str(num)
num += 1


# Create a counter for each model
background_number = int_sequence()
contrast_number = int_sequence()
custom_file_number = int_sequence()
data_number = int_sequence()
domain_contrast_number = int_sequence()
layer_number = int_sequence()
parameter_number = int_sequence()
resolution_number = int_sequence()


class Hydration(StrEnum):
None_ = 'none'
BulkIn = 'bulk in'
BulkOut = 'bulk out'
Oil = 'oil'


class Languages(StrEnum):
Python = 'python'
Matlab = 'matlab'


class Priors(StrEnum):
Uniform = 'uniform'
Gaussian = 'gaussian'
Jeffreys = 'jeffreys'


class Types(StrEnum):
Constant = 'constant'
Data = 'data'
Function = 'function'


class Background(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the Backgrounds in RAT."""
name: str = Field(default_factory=lambda: 'New Background ' + next(background_number))
type: Types = Types.Constant
value_1: str = ''
value_2: str = ''
value_3: str = ''
value_4: str = ''
value_5: str = ''


class Contrast(BaseModel, validate_assignment=True, extra='forbid'):
"""Groups together all of the components of the model."""
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number))
data: str = ''
background: str = ''
nba: str = ''
nbs: str = ''
scalefactor: str = ''
resolution: str = ''
resample: bool = False
model: list[str] = [] # But how many strings? How to deal with this?


class CustomFile(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the files containing functions to run when using custom models."""
name: str = Field(default_factory=lambda: 'New Custom File ' + next(custom_file_number))
filename: str = ''
language: Languages = Languages.Python
path: str = 'pwd' # Should later expand to find current file path


class Data(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
"""Defines the dataset required for each contrast."""
name: str = Field(default_factory=lambda: 'New Data ' + next(data_number))
data: np.ndarray[float] = np.empty([0, 3])
data_range: list[float] = []
simulation_range: list[float] = [0.005, 0.7]

@field_validator('data')
@classmethod
def check_data_dimension(cls, data: np.ndarray[float]) -> np.ndarray[float]:
"""The data must be a two-dimensional array containing at least three columns."""
try:
data.shape[1]
except IndexError:
raise ValueError('"data" must have at least two dimensions')
else:
if data.shape[1] < 3:
raise ValueError('"data" must have at least three columns')
return data

@field_validator('data_range', 'simulation_range')
@classmethod
def check_list_elements(cls, limits: list[float], info: FieldValidationInfo) -> list[float]:
"""The data range and simulation range must contain exactly two parameters."""
if len(limits) != 2:
raise ValueError(f'{info.field_name} must contain exactly two values')
return limits

# Also need model validators for data range compared to data etc -- need more details.


class DomainContrast(BaseModel, validate_assignment=True, extra='forbid'):
"""Groups together the layers required for each domain."""
name: str = Field(default_factory=lambda: 'New Domain Contrast ' + next(domain_contrast_number))
model: list[str] = []


class Layer(BaseModel, validate_assignment=True, extra='forbid'):
"""Combines parameters into defined layers."""
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number))
thickness: str = ''
SLD: str = ''
roughness: str = ''
hydration: str = ''
hydrate_with: Hydration = Hydration.BulkOut


class Parameter(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines parameters needed to specify the model"""
name: str = Field(default_factory=lambda: 'New Parameter ' + next(parameter_number))
min: float = 0.0
value: float = 0.0
max: float = 0.0
fit: bool = False
prior_type: Priors = Priors.Uniform
mu: float = 0.0
sigma: float = np.inf

@model_validator(mode='after')
def check_value_in_range(self) -> 'Parameter':
"""The value of a parameter must lie within its defined bounds."""
if self.value < self.min or self.value > self.max:
raise ValueError(f'value {self.value} is not within the defined range: {self.min} <= value <= {self.max}')
return self


class ProtectedParameter(Parameter, validate_assignment=True, extra='forbid'):
"""A Parameter with a fixed name."""
name: str = Field(frozen=True)


class Resolution(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines Resolutions in RAT."""
name: str = Field(default_factory=lambda: 'New Resolution ' + next(resolution_number))
type: Types = Types.Constant
value_1: str = ''
value_2: str = ''
value_3: str = ''
value_4: str = ''
value_5: str = ''
174 changes: 174 additions & 0 deletions RAT/project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""The project module. Defines and stores all the input data required for reflectivity calculations in RAT."""

import numpy as np
from pydantic import BaseModel, FieldValidationInfo, field_validator, model_validator
from typing import Any

from RAT.classlist import ClassList
import RAT.models

try:
from enum import StrEnum
except ImportError:
from strenum import StrEnum


class CalcTypes(StrEnum):
NonPolarised = 'non polarised'
Domains = 'domains'
OilWater = 'oil water'


class ModelTypes(StrEnum):
CustomLayers = 'custom layers'
CustomXY = 'custom xy'
StandardLayers = 'standard layers'


class Geometries(StrEnum):
AirSubstrate = 'air/substrate'
SubstrateLiquid = 'substrate/liquid'


# Map project fields to pydantic models
model_in_classlist = {'parameters': 'Parameter',
'bulk_in': 'Parameter',
'bulk_out': 'Parameter',
'qz_shifts': 'Parameter',
'scalefactors': 'Parameter',
'background_parameters': 'Parameter',
'resolution_parameters': 'Parameter',
'backgrounds': 'Background',
'resolutions': 'Resolution',
'custom_files': 'CustomFile',
'data': 'Data',
'layers': 'Layer',
'contrasts': 'Contrast'
}


class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
"""Defines the input data for a reflectivity calculation in RAT.

This class combines the data defined in each of the pydantic models included in "models.py" into the full set of
inputs required for a reflectivity calculation.
"""
name: str = ''
calc_type: CalcTypes = CalcTypes.NonPolarised
model: ModelTypes = ModelTypes.StandardLayers
geometry: Geometries = Geometries.AirSubstrate
absorption: bool = False

parameters: ClassList = ClassList()

bulk_in: ClassList = ClassList(RAT.models.Parameter(name='SLD Air', min=0, value=0, max=0, fit=False,
prior_type=RAT.models.Priors.Uniform, mu=0, sigma=np.inf))

bulk_out: ClassList = ClassList(RAT.models.Parameter(name='SLD D2O', min=6.2e-6, value=6.35e-6, max=6.35e-6,
fit=False, prior_type=RAT.models.Priors.Uniform, mu=0,
sigma=np.inf))

qz_shifts: ClassList = ClassList(RAT.models.Parameter(name='Qz shift 1', min=-1e-4, value=0, max=1e-4, fit=False,
prior_type=RAT.models.Priors.Uniform, mu=0, sigma=np.inf))

scalefactors: ClassList = ClassList(RAT.models.Parameter(name='Scalefactor 1', min=0.02, value=0.23, max=0.25,
fit=False, prior_type=RAT.models.Priors.Uniform, mu=0,
sigma=np.inf))

background_parameters: ClassList = ClassList(RAT.models.Parameter(name='Background Param 1', min=1e-7, value=1e-6,
max=1e-5, fit=False,
prior_type=RAT.models.Priors.Uniform, mu=0,
sigma=np.inf))

backgrounds: ClassList = ClassList(RAT.models.Background(name='Background 1', type=RAT.models.Types.Constant.value,
value_1='Background Param 1'))

resolution_parameters: ClassList = ClassList(RAT.models.Parameter(name='Resolution Param 1', min=0.01, value=0.03,
max=0.05, fit=False,
prior_type=RAT.models.Priors.Uniform, mu=0,
sigma=np.inf))

resolutions: ClassList = ClassList(RAT.models.Resolution(name='Resolution 1', type=RAT.models.Types.Constant.value,
value_1='Resolution Param 1'))

custom_files: ClassList = ClassList()
data: ClassList = ClassList(RAT.models.Data(name='Simulation'))
layers: ClassList = ClassList()
contrasts: ClassList = ClassList()

@field_validator('parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
'contrasts')
@classmethod
def check_class(cls, value: ClassList, info: FieldValidationInfo) -> ClassList:
"""Each of the data fields should be a ClassList of the appropriate model."""
model_name = model_in_classlist[info.field_name]
model = getattr(RAT.models, model_name)
assert all(isinstance(element, model) for element in value), \
f'"{info.field_name}" ClassList contains objects other than "{model_name}"'
return value

def model_post_init(self, __context: Any) -> None:
"""Initialises the class in the ClassLists for empty data fields, and sets protected parameters."""
for field_name, model in model_in_classlist.items():
field = getattr(self, field_name)
if not hasattr(field, "_class_handle"):
setattr(field, "_class_handle", getattr(RAT.models, model))

self.parameters.insert(0, RAT.models.ProtectedParameter(name='Substrate Roughness', min=1, value=3, max=5,
fit=True, prior_type=RAT.models.Priors.Uniform, mu=0,
sigma=np.inf))

@model_validator(mode='after')
def cross_check_model_values(self) -> 'Project':
"""Certain model fields should contain values defined elsewhere in the project."""
value_fields = ['value_1', 'value_2', 'value_3', 'value_4', 'value_5']
self.check_allowed_values('backgrounds', value_fields, self.background_parameters.get_names())
self.check_allowed_values('resolutions', value_fields, self.resolution_parameters.get_names())
self.check_allowed_values('layers', ['thickness', 'SLD', 'roughness'], self.parameters.get_names())

self.check_allowed_values('contrasts', ['data'], self.data.get_names())
self.check_allowed_values('contrasts', ['background'], self.backgrounds.get_names())
self.check_allowed_values('contrasts', ['nba'], self.bulk_in.get_names())
self.check_allowed_values('contrasts', ['nbs'], self.bulk_out.get_names())
self.check_allowed_values('contrasts', ['scalefactor'], self.scalefactors.get_names())
self.check_allowed_values('contrasts', ['resolution'], self.resolutions.get_names())
return self

def __repr__(self):
output = ''
for key, value in self.__dict__.items():
if value:
output += f'{key.replace("_", " ").title() + ": " :-<100}\n\n'
try:
value.value # For enums
except AttributeError:
output += repr(value) + '\n\n'
else:
output += value.value + '\n\n'
return output

def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None:
"""Check the values of the given fields in the given model are in the supplied list of allowed values.

Parameters
----------
attribute : str
The attribute of Project being validated.
field_list : list [str]
The fields of the attribute to be checked for valid values.
allowed_values : list [str]
The list of allowed values for the fields given in field_list.

Raises
------
ValueError
Raised if any field in field_list has a value not specified in allowed_values.
"""
class_list = getattr(self, attribute)
for model in class_list:
for field in field_list:
value = getattr(model, field)
if value and value not in allowed_values:
setattr(model, field, '')
raise ValueError(f'The parameter "{value}" has not been defined in the list of allowed values.')
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
numpy >= 1.20
pydantic >= 2.0.3
pytest >= 7.4.0
pytest-cov >= 4.1.0
StrEnum >= 0.4.15
tabulate >= 0.9.0
Loading