Skip to content

Commit

Permalink
Adds additional validators and "write_script" routine (#15)
Browse files Browse the repository at this point in the history
* Adds code to remove layers for non-standard layers model types

* Adds code to ensure protected parameters cannot be removed

* Adds code to restore default domain ratios when switching calc_type

* Adds validators to "Data" model

* Adds "write_script" routine to "project.py"

* Addresses review comments

* Remove specific wrap data tests from "test_project", instead using new Data __eq__ method

* Adds test for Data.__eq__ to improve test coverage

* Adds fixture to setup and teardown temp directory
  • Loading branch information
DrPaulSharp authored Oct 24, 2023
1 parent c580435 commit bbfebd4
Show file tree
Hide file tree
Showing 5 changed files with 396 additions and 147 deletions.
2 changes: 2 additions & 0 deletions RAT/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from RAT.classlist import ClassList
from RAT.project import Project
107 changes: 88 additions & 19 deletions RAT/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
from typing import Any

try:
from enum import StrEnum
Expand Down Expand Up @@ -43,7 +44,6 @@ class Languages(StrEnum):
class Priors(StrEnum):
Uniform = 'uniform'
Gaussian = 'gaussian'
Jeffreys = 'jeffreys'


class Types(StrEnum):
Expand All @@ -52,7 +52,19 @@ class Types(StrEnum):
Function = 'function'


class Background(BaseModel, validate_assignment=True, extra='forbid'):
class RATModel(BaseModel):
"""A BaseModel where enums are represented by their value."""
def __repr__(self):
fields_repr = (', '.join(repr(v) if a is None else
f'{a}={v.value!r}' if isinstance(v, StrEnum) else
f'{a}={v!r}'
for a, v in self.__repr_args__()
)
)
return f'{self.__repr_name__()}({fields_repr})'


class Background(RATModel, validate_assignment=True, extra='forbid'):
"""Defines the Backgrounds in RAT."""
name: str = Field(default_factory=lambda: 'New Background ' + next(background_number), min_length=1)
type: Types = Types.Constant
Expand All @@ -63,7 +75,7 @@ class Background(BaseModel, validate_assignment=True, extra='forbid'):
value_5: str = ''


class Contrast(BaseModel, validate_assignment=True, extra='forbid'):
class Contrast(RATModel, validate_assignment=True, extra='forbid'):
"""Groups together all of the components of the model."""
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1)
data: str = ''
Expand All @@ -76,7 +88,7 @@ class Contrast(BaseModel, validate_assignment=True, extra='forbid'):
model: list[str] = []


class ContrastWithRatio(BaseModel, validate_assignment=True, extra='forbid'):
class ContrastWithRatio(RATModel, validate_assignment=True, extra='forbid'):
"""Groups together all of the components of the model including domain terms."""
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1)
data: str = ''
Expand All @@ -90,20 +102,20 @@ class ContrastWithRatio(BaseModel, validate_assignment=True, extra='forbid'):
model: list[str] = []


class CustomFile(BaseModel, validate_assignment=True, extra='forbid'):
class CustomFile(RATModel, validate_assignment=True, extra='forbid'):
"""Defines the files containing functions to run when using custom models."""
name: str = Field(default_factory=lambda: 'New Custom File ' + next(custom_file_number), min_length=1)
filename: str = ''
language: Languages = Languages.Python
path: str = 'pwd' # Should later expand to find current file path


class Data(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
class Data(RATModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
"""Defines the dataset required for each contrast."""
name: str = Field(default_factory=lambda: 'New Data ' + next(data_number), min_length=1)
data: np.ndarray[float] = np.empty([0, 3])
data_range: list[float] = []
simulation_range: list[float] = [0.005, 0.7]
data: np.ndarray[np.float64] = np.empty([0, 3])
data_range: list[float] = Field(default=[], min_length=2, max_length=2)
simulation_range: list[float] = Field(default=[], min_length=2, max_length=2)

@field_validator('data')
@classmethod
Expand All @@ -120,22 +132,79 @@ def check_data_dimension(cls, data: np.ndarray[float]) -> np.ndarray[float]:

@field_validator('data_range', 'simulation_range')
@classmethod
def check_list_elements(cls, limits: list[float], info: ValidationInfo) -> list[float]:
"""The data range and simulation range must contain exactly two parameters."""
if len(limits) != 2:
raise ValueError(f'{info.field_name} must contain exactly two values')
def check_min_max(cls, limits: list[float], info: ValidationInfo) -> list[float]:
"""The data range and simulation range maximum must be greater than the minimum."""
if limits[0] > limits[1]:
raise ValueError(f'{info.field_name} "min" value is greater than the "max" value')
return limits

# Also need model validators for data range compared to data etc -- need more details.
def model_post_init(self, __context: Any) -> None:
"""If the "data_range" and "simulation_range" fields are not set, but "data" is supplied, the ranges should be
set to the min and max values of the first column (assumed to be q) of the supplied data.
"""
if len(self.data[:, 0]) > 0:
data_min = np.min(self.data[:, 0])
data_max = np.max(self.data[:, 0])
for field in ["data_range", "simulation_range"]:
if field not in self.model_fields_set:
getattr(self, field).extend([data_min, data_max])

@model_validator(mode='after')
def check_ranges(self) -> 'Data':
"""The limits of the "data_range" field must lie within the range of the supplied data, whilst the limits
of the "simulation_range" field must lie outside of the range of the supplied data.
"""
if len(self.data[:, 0]) > 0:
data_min = np.min(self.data[:, 0])
data_max = np.max(self.data[:, 0])
if "data_range" in self.model_fields_set and (self.data_range[0] < data_min or
self.data_range[1] > data_max):
raise ValueError(f'The data_range value of: {self.data_range} must lie within the min/max values of '
f'the data: [{data_min}, {data_max}]')
if "simulation_range" in self.model_fields_set and (self.simulation_range[0] > data_min or
self.simulation_range[1] < data_max):
raise ValueError(f'The simulation_range value of: {self.simulation_range} must lie outside of the '
f'min/max values of the data: [{data_min}, {data_max}]')
return self

def __eq__(self, other: Any) -> bool:
if isinstance(other, BaseModel):
# When comparing instances of generic types for equality, as long as all field values are equal,
# only require their generic origin types to be equal, rather than exact type equality.
# This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1).
self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__
other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__

return (
self_type == other_type
and self.name == other.name
and (self.data == other.data).all()
and self.data_range == other.data_range
and self.simulation_range == other.simulation_range
and self.__pydantic_private__ == other.__pydantic_private__
and self.__pydantic_extra__ == other.__pydantic_extra__
)
else:
return NotImplemented # delegate to the other item in the comparison

def __repr__(self):
"""Only include the name if the data is empty."""
fields_repr = (f"name={self.name!r}" if self.data.size == 0 else
", ".join(repr(v) if a is None else
f"{a}={v!r}"
for a, v in self.__repr_args__()
)
)
return f'{self.__repr_name__()}({fields_repr})'


class DomainContrast(BaseModel, validate_assignment=True, extra='forbid'):
class DomainContrast(RATModel, validate_assignment=True, extra='forbid'):
"""Groups together the layers required for each domain."""
name: str = Field(default_factory=lambda: 'New Domain Contrast ' + next(domain_contrast_number), min_length=1)
model: list[str] = []


class Layer(BaseModel, validate_assignment=True, extra='forbid', populate_by_name=True):
class Layer(RATModel, validate_assignment=True, extra='forbid', populate_by_name=True):
"""Combines parameters into defined layers."""
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number), min_length=1)
thickness: str = ''
Expand All @@ -145,7 +214,7 @@ class Layer(BaseModel, validate_assignment=True, extra='forbid', populate_by_nam
hydrate_with: Hydration = Hydration.BulkOut


class AbsorptionLayer(BaseModel, validate_assignment=True, extra='forbid', populate_by_name=True):
class AbsorptionLayer(RATModel, validate_assignment=True, extra='forbid', populate_by_name=True):
"""Combines parameters into defined layers including absorption terms."""
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number), min_length=1)
thickness: str = ''
Expand All @@ -156,7 +225,7 @@ class AbsorptionLayer(BaseModel, validate_assignment=True, extra='forbid', popul
hydrate_with: Hydration = Hydration.BulkOut


class Parameter(BaseModel, validate_assignment=True, extra='forbid'):
class Parameter(RATModel, validate_assignment=True, extra='forbid'):
"""Defines parameters needed to specify the model."""
name: str = Field(default_factory=lambda: 'New Parameter ' + next(parameter_number), min_length=1)
min: float = 0.0
Expand All @@ -180,7 +249,7 @@ class ProtectedParameter(Parameter, validate_assignment=True, extra='forbid'):
name: str = Field(frozen=True, min_length=1)


class Resolution(BaseModel, validate_assignment=True, extra='forbid'):
class Resolution(RATModel, validate_assignment=True, extra='forbid'):
"""Defines Resolutions in RAT."""
name: str = Field(default_factory=lambda: 'New Resolution ' + next(resolution_number), min_length=1)
type: Types = Types.Constant
Expand Down
Loading

0 comments on commit bbfebd4

Please sign in to comment.