Skip to content

Commit

Permalink
Refactors the controls module, adding custom errors (#17)
Browse files Browse the repository at this point in the history
* Changes "controlsClass" to factory function "set_controls".

* Removes "baseControls" class and fixes bug when initialising procedure field

* Adds error check to "set_controls"

* Adds formatted error for extra fields to "set_controls"

* Introduces logging for error reporting

* Substitutes logging for raising errors directly

* Adds routine "custom_pydantic_validation_error" to introduce custom error messages when raising a ValidationError
  • Loading branch information
DrPaulSharp authored Nov 1, 2023
1 parent 101365b commit 7f2ad53
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 340 deletions.
3 changes: 2 additions & 1 deletion RAT/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from RAT.classlist import ClassList
from RAT.controls import Controls
from RAT.project import Project
import RAT.controls
import RAT.models
99 changes: 49 additions & 50 deletions RAT/controls.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import prettytable
from pydantic import BaseModel, Field, field_validator
from typing import Union
from pydantic import BaseModel, Field, field_validator, ValidationError
from typing import Literal, Union

from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions
from RAT.utils.custom_errors import custom_pydantic_validation_error


class BaseProcedure(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the base class with properties used in all five procedures."""
class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""
procedure: Literal[Procedures.Calculate] = Procedures.Calculate
parallel: ParallelOptions = ParallelOptions.Single
calcSldDuringFit: bool = False
resamPars: list[float] = Field([0.9, 50], min_length=2, max_length=2)
Expand All @@ -21,15 +23,16 @@ def check_resamPars(cls, resamPars):
raise ValueError('resamPars[1] must be greater than or equal to 0')
return resamPars


class Calculate(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the calculate procedure."""
procedure: Procedures = Field(Procedures.Calculate, frozen=True)
def __repr__(self) -> str:
table = prettytable.PrettyTable()
table.field_names = ['Property', 'Value']
table.add_rows([[k, v] for k, v in self.__dict__.items()])
return table.get_string()


class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the simplex procedure."""
procedure: Procedures = Field(Procedures.Simplex, frozen=True)
class Simplex(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the simplex procedure."""
procedure: Literal[Procedures.Simplex] = Procedures.Simplex
tolX: float = Field(1.0e-6, gt=0.0)
tolFun: float = Field(1.0e-6, gt=0.0)
maxFunEvals: int = Field(10000, gt=0)
Expand All @@ -38,9 +41,9 @@ class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
updatePlotFreq: int = -1


class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Differential Evolution procedure."""
procedure: Procedures = Field(Procedures.DE, frozen=True)
class DE(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the Differential Evolution procedure."""
procedure: Literal[Procedures.DE] = Procedures.DE
populationSize: int = Field(20, ge=1)
fWeight: float = 0.5
crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0)
Expand All @@ -49,52 +52,48 @@ class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
numGenerations: int = Field(500, ge=1)


class NS(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Nested Sampler procedure."""
procedure: Procedures = Field(Procedures.NS, frozen=True)
class NS(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the Nested Sampler procedure."""
procedure: Literal[Procedures.NS] = Procedures.NS
Nlive: int = Field(150, ge=1)
Nmcmc: float = Field(0.0, ge=0.0)
propScale: float = Field(0.1, gt=0.0, lt=1.0)
nsTolerance: float = Field(0.1, ge=0.0)


class Dream(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Dream procedure."""
procedure: Procedures = Field(Procedures.Dream, frozen=True)
class Dream(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the Dream procedure."""
procedure: Literal[Procedures.Dream] = Procedures.Dream
nSamples: int = Field(50000, ge=0)
nChains: int = Field(10, gt=0)
jumpProb: float = Field(0.5, gt=0.0, lt=1.0)
pUnitGamma: float = Field(0.2, gt=0.0, lt=1.0)
boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold


class Controls:

def __init__(self,
procedure: Procedures = Procedures.Calculate,
**properties) -> None:

if procedure == Procedures.Calculate:
self.controls = Calculate(**properties)
elif procedure == Procedures.Simplex:
self.controls = Simplex(**properties)
elif procedure == Procedures.DE:
self.controls = DE(**properties)
elif procedure == Procedures.NS:
self.controls = NS(**properties)
elif procedure == Procedures.Dream:
self.controls = Dream(**properties)

@property
def controls(self) -> Union[Calculate, Simplex, DE, NS, Dream]:
return self._controls

@controls.setter
def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None:
self._controls = value

def __repr__(self) -> str:
table = prettytable.PrettyTable()
table.field_names = ['Property', 'Value']
table.add_rows([[k, v] for k, v in self._controls.__dict__.items()])
return table.get_string()
def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\
-> Union[Calculate, Simplex, DE, NS, Dream]:
"""Returns the appropriate controls model given the specified procedure."""
controls = {
Procedures.Calculate: Calculate,
Procedures.Simplex: Simplex,
Procedures.DE: DE,
Procedures.NS: NS,
Procedures.Dream: Dream
}

try:
model = controls[procedure](**properties)
except KeyError:
members = list(Procedures.__members__.values())
allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}'
raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None
except ValidationError as exc:
custom_error_msgs = {'extra_forbidden': f'Extra inputs are not permitted. The fields for the {procedure}'
f' controls procedure are:\n '
f'{", ".join(controls[procedure].model_fields.keys())}\n'
}
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None

return model
9 changes: 4 additions & 5 deletions RAT/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from RAT.classlist import ClassList
import RAT.models
from RAT.utils.custom_errors import formatted_pydantic_error
from RAT.utils.custom_errors import custom_pydantic_validation_error

try:
from enum import StrEnum
Expand Down Expand Up @@ -524,11 +524,10 @@ def wrapped_func(*args, **kwargs):
try:
return_value = func(*args, **kwargs)
Project.model_validate(self)
except ValidationError as e:
except ValidationError as exc:
setattr(class_list, 'data', previous_state)
error_string = formatted_pydantic_error(e)
# Use ANSI escape sequences to print error text in red
print('\033[31m' + error_string + '\033[0m')
custom_error_list = custom_pydantic_validation_error(exc.errors())
raise ValidationError.from_exception_data(exc.title, custom_error_list) from None
except (TypeError, ValueError):
setattr(class_list, 'data', previous_state)
raise
Expand Down
40 changes: 25 additions & 15 deletions RAT/utils/custom_errors.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
"""Defines routines for custom error handling in RAT."""
import pydantic_core

from pydantic import ValidationError

def custom_pydantic_validation_error(error_list: list[pydantic_core.ErrorDetails], custom_errors: dict[str, str] = None
) -> list[pydantic_core.ErrorDetails]:
"""Run through the list of errors generated from a pydantic ValidationError, substituting the standard error for a
PydanticCustomError for a given set of error types.
def formatted_pydantic_error(error: ValidationError) -> str:
"""Write a custom string format for pydantic validation errors.
For errors that do not have a custom error message defined, we redefine them using a PydanticCustomError to remove
the url from the error message.
Parameters
----------
error : pydantic.ValidationError
A ValidationError produced by a pydantic model
error_list : list[pydantic_core.ErrorDetails]
A list of errors produced by pydantic.ValidationError.errors().
custom_errors: dict[str, str], optional
A dict of custom error messages for given error types.
Returns
-------
error_str : str
A string giving details of the ValidationError in a custom format.
new_error : list[pydantic_core.ErrorDetails]
A list of errors including PydanticCustomErrors in place of the error types in custom_errors.
"""
num_errors = error.error_count()
error_str = f'{num_errors} validation error{"s"[:num_errors!=1]} for {error.title}'
for this_error in error.errors():
error_str += '\n'
if this_error['loc']:
error_str += ' '.join(this_error['loc']) + '\n'
error_str += ' ' + this_error['msg']
return error_str
if custom_errors is None:
custom_errors = {}
custom_error_list = []
for error in error_list:
if error['type'] in custom_errors:
RAT_custom_error = pydantic_core.PydanticCustomError(error['type'], custom_errors[error['type']])
else:
RAT_custom_error = pydantic_core.PydanticCustomError(error['type'], error['msg'])
error['type'] = RAT_custom_error
custom_error_list.append(error)

return custom_error_list
Loading

0 comments on commit 7f2ad53

Please sign in to comment.