Skip to content

Commit

Permalink
Splits resampleParams into two parameters and bumps version to 0.0.…
Browse files Browse the repository at this point in the history
…0.dev2 (#75)

* split resample params in source code

* split resample params in rat.cpp

* split resample params in tests

* made min angle consistent with MATLAB

* update submodule

* Update controls.py

* bumped version
  • Loading branch information
alexhroom authored Sep 6, 2024
1 parent fe0a45e commit 9a6c196
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 128 deletions.
16 changes: 3 additions & 13 deletions RATapi/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
Field,
ValidationError,
ValidatorFunctionWrapHandler,
field_validator,
model_serializer,
model_validator,
)

from RATapi.utils.custom_errors import custom_pydantic_validation_error
from RATapi.utils.enums import BoundHandling, Display, Parallel, Procedures, Strategies

common_fields = ["procedure", "parallel", "calcSldDuringFit", "resampleParams", "display"]
common_fields = ["procedure", "parallel", "calcSldDuringFit", "resampleMinAngle", "resampleNPoints", "display"]
update_fields = ["updateFreq", "updatePlotFreq"]
fields = {
"calculate": common_fields,
Expand All @@ -41,7 +40,8 @@ class Controls(BaseModel, validate_assignment=True, extra="forbid"):
procedure: Procedures = Procedures.Calculate
parallel: Parallel = Parallel.Single
calcSldDuringFit: bool = False
resampleParams: list[float] = Field([0.9, 50], min_length=2, max_length=2)
resampleMinAngle: float = Field(0.9, le=1, gt=0)
resampleNPoints: int = Field(50, gt=0)
display: Display = Display.Iter
# Simplex
xTolerance: float = Field(1.0e-6, gt=0.0)
Expand Down Expand Up @@ -117,16 +117,6 @@ def warn_setting_incorrect_properties(self, handler: ValidatorFunctionWrapHandle

return validated_self

@field_validator("resampleParams")
@classmethod
def check_resample_params(cls, values: list[float]) -> list[float]:
"""Make sure each of the two values of resampleParams satisfy their conditions."""
if not 0 < values[0] < 1:
raise ValueError("resampleParams[0] must be between 0 and 1")
if values[1] < 0:
raise ValueError("resampleParams[1] must be greater than or equal to 0")
return values

@model_serializer
def serialize(self):
"""Filter fields so only those applying to the chosen procedure are serialized."""
Expand Down
2 changes: 1 addition & 1 deletion RATapi/examples/absorption/absorption.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def absorption():
)

# Now make a controls block and run the code
controls = RAT.Controls(parallel="contrasts", resampleParams=[0.9, 150.0])
controls = RAT.Controls(parallel="contrasts", resampleNPoints=150)
problem, results = RAT.run(problem, controls)

return problem, results
Expand Down
3 changes: 2 additions & 1 deletion RATapi/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ def make_controls(input_controls: RATapi.Controls, checks: Checks) -> Control:
controls.procedure = input_controls.procedure
controls.parallel = input_controls.parallel
controls.calcSldDuringFit = input_controls.calcSldDuringFit
controls.resampleParams = input_controls.resampleParams
controls.resampleMinAngle = input_controls.resampleMinAngle
controls.resampleNPoints = input_controls.resampleNPoints
controls.display = input_controls.display
# Simplex
controls.xTolerance = input_controls.xTolerance
Expand Down
57 changes: 30 additions & 27 deletions cpp/rat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,8 @@ struct Control {
real_T propScale {};
real_T nsTolerance {};
boolean_T calcSldDuringFit {};
py::array_t<real_T> resampleParams;
real_T resampleMinAngle {};
real_T resampleNPoints {};
real_T updateFreq {};
real_T updatePlotFreq {};
real_T nSamples {};
Expand Down Expand Up @@ -914,8 +915,8 @@ RAT::struct2_T createStruct2T(const Control& control)
stringToRatArray(control.procedure, control_struct.procedure.data, control_struct.procedure.size);
stringToRatArray(control.display, control_struct.display.data, control_struct.display.size);
control_struct.xTolerance = control.xTolerance;
control_struct.resampleParams[0] = control.resampleParams.at(0);
control_struct.resampleParams[1] = control.resampleParams.at(1);
control_struct.resampleMinAngle = control.resampleMinAngle;
control_struct.resampleNPoints = control.resampleNPoints;
stringToRatArray(control.boundHandling, control_struct.boundHandling.data, control_struct.boundHandling.size);
control_struct.adaptPCR = control.adaptPCR;
control_struct.checks = createStruct3(control.checks);
Expand Down Expand Up @@ -1616,7 +1617,8 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("propScale", &Control::propScale)
.def_readwrite("nsTolerance", &Control::nsTolerance)
.def_readwrite("calcSldDuringFit", &Control::calcSldDuringFit)
.def_readwrite("resampleParams", &Control::resampleParams)
.def_readwrite("resampleMinAngle", &Control::resampleMinAngle)
.def_readwrite("resampleNPoints", &Control::resampleNPoints)
.def_readwrite("updateFreq", &Control::updateFreq)
.def_readwrite("updatePlotFreq", &Control::updatePlotFreq)
.def_readwrite("nSamples", &Control::nSamples)
Expand All @@ -1633,14 +1635,14 @@ PYBIND11_MODULE(rat_core, m) {
return py::make_tuple(ctrl.parallel, ctrl.procedure, ctrl.display, ctrl.xTolerance, ctrl.funcTolerance,
ctrl.maxFuncEvals, ctrl.maxIterations, ctrl.populationSize, ctrl.fWeight, ctrl.crossoverProbability,
ctrl.targetValue, ctrl.numGenerations, ctrl.strategy, ctrl.nLive, ctrl.nMCMC, ctrl.propScale,
ctrl.nsTolerance, ctrl.calcSldDuringFit, ctrl.resampleParams, ctrl.updateFreq, ctrl.updatePlotFreq,
ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, ctrl.pUnitGamma, ctrl.boundHandling, ctrl.adaptPCR,
ctrl.IPCFilePath, ctrl.checks.fitParam, ctrl.checks.fitBackgroundParam, ctrl.checks.fitQzshift,
ctrl.checks.fitScalefactor, ctrl.checks.fitBulkIn, ctrl.checks.fitBulkOut,
ctrl.nsTolerance, ctrl.calcSldDuringFit, ctrl.resampleMinAngle, ctrl.resampleNPoints,
ctrl.updateFreq, ctrl.updatePlotFreq, ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, ctrl.pUnitGamma,
ctrl.boundHandling, ctrl.adaptPCR, ctrl.IPCFilePath, ctrl.checks.fitParam, ctrl.checks.fitBackgroundParam,
ctrl.checks.fitQzshift, ctrl.checks.fitScalefactor, ctrl.checks.fitBulkIn, ctrl.checks.fitBulkOut,
ctrl.checks.fitResolutionParam, ctrl.checks.fitDomainRatio);
},
[](py::tuple t) { // __setstate__
if (t.size() != 36)
if (t.size() != 37)
throw std::runtime_error("Encountered invalid state unpickling ProblemDefinition object!");

/* Create a new C++ instance */
Expand All @@ -1664,25 +1666,26 @@ PYBIND11_MODULE(rat_core, m) {
ctrl.propScale = t[15].cast<real_T>();
ctrl.nsTolerance = t[16].cast<real_T>();
ctrl.calcSldDuringFit = t[17].cast<boolean_T>();
ctrl.resampleParams = t[18].cast<py::array_t<real_T>>();
ctrl.updateFreq = t[19].cast<real_T>();
ctrl.updatePlotFreq = t[20].cast<real_T>();
ctrl.nSamples = t[21].cast<real_T>();
ctrl.nChains = t[22].cast<real_T>();
ctrl.jumpProbability = t[23].cast<real_T>();
ctrl.pUnitGamma = t[24].cast<real_T>();
ctrl.boundHandling = t[25].cast<std::string>();
ctrl.adaptPCR = t[26].cast<boolean_T>();
ctrl.IPCFilePath = t[27].cast<std::string>();
ctrl.resampleMinAngle = t[18].cast<real_T>();
ctrl.resampleNPoints = t[19].cast<real_T>();
ctrl.updateFreq = t[20].cast<real_T>();
ctrl.updatePlotFreq = t[21].cast<real_T>();
ctrl.nSamples = t[22].cast<real_T>();
ctrl.nChains = t[23].cast<real_T>();
ctrl.jumpProbability = t[24].cast<real_T>();
ctrl.pUnitGamma = t[25].cast<real_T>();
ctrl.boundHandling = t[26].cast<std::string>();
ctrl.adaptPCR = t[27].cast<boolean_T>();
ctrl.IPCFilePath = t[28].cast<std::string>();

ctrl.checks.fitParam = t[28].cast<py::array_t<real_T>>();
ctrl.checks.fitBackgroundParam = t[29].cast<py::array_t<real_T>>();
ctrl.checks.fitQzshift = t[30].cast<py::array_t<real_T>>();
ctrl.checks.fitScalefactor = t[31].cast<py::array_t<real_T>>();
ctrl.checks.fitBulkIn = t[32].cast<py::array_t<real_T>>();
ctrl.checks.fitBulkOut = t[33].cast<py::array_t<real_T>>();
ctrl.checks.fitResolutionParam = t[34].cast<py::array_t<real_T>>();
ctrl.checks.fitDomainRatio = t[35].cast<py::array_t<real_T>>();
ctrl.checks.fitParam = t[29].cast<py::array_t<real_T>>();
ctrl.checks.fitBackgroundParam = t[30].cast<py::array_t<real_T>>();
ctrl.checks.fitQzshift = t[31].cast<py::array_t<real_T>>();
ctrl.checks.fitScalefactor = t[32].cast<py::array_t<real_T>>();
ctrl.checks.fitBulkIn = t[33].cast<py::array_t<real_T>>();
ctrl.checks.fitBulkOut = t[34].cast<py::array_t<real_T>>();
ctrl.checks.fitResolutionParam = t[35].cast<py::array_t<real_T>>();
ctrl.checks.fitDomainRatio = t[36].cast<py::array_t<real_T>>();

return ctrl;
}));
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from setuptools.command.build_clib import build_clib
from setuptools.command.build_ext import build_ext

__version__ = "0.0.0.dev1"
__version__ = "0.0.0.dev2"
PACKAGE_NAME = "RATapi"

with open("README.md") as f:
Expand Down
Loading

0 comments on commit 9a6c196

Please sign in to comment.