diff --git a/RATapi/controls.py b/RATapi/controls.py index b63193d4..8747be5b 100644 --- a/RATapi/controls.py +++ b/RATapi/controls.py @@ -6,7 +6,6 @@ Field, ValidationError, ValidatorFunctionWrapHandler, - field_validator, model_serializer, model_validator, ) @@ -14,7 +13,7 @@ from RATapi.utils.custom_errors import custom_pydantic_validation_error from RATapi.utils.enums import BoundHandling, Display, Parallel, Procedures, Strategies -common_fields = ["procedure", "parallel", "calcSldDuringFit", "resampleParams", "display"] +common_fields = ["procedure", "parallel", "calcSldDuringFit", "resampleMinAngle", "resampleNPoints", "display"] update_fields = ["updateFreq", "updatePlotFreq"] fields = { "calculate": common_fields, @@ -41,7 +40,8 @@ class Controls(BaseModel, validate_assignment=True, extra="forbid"): procedure: Procedures = Procedures.Calculate parallel: Parallel = Parallel.Single calcSldDuringFit: bool = False - resampleParams: list[float] = Field([0.9, 50], min_length=2, max_length=2) + resampleMinAngle: float = Field(0.9, le=1, gt=0) + resampleNPoints: int = Field(50, gt=0) display: Display = Display.Iter # Simplex xTolerance: float = Field(1.0e-6, gt=0.0) @@ -117,16 +117,6 @@ def warn_setting_incorrect_properties(self, handler: ValidatorFunctionWrapHandle return validated_self - @field_validator("resampleParams") - @classmethod - def check_resample_params(cls, values: list[float]) -> list[float]: - """Make sure each of the two values of resampleParams satisfy their conditions.""" - if not 0 < values[0] < 1: - raise ValueError("resampleParams[0] must be between 0 and 1") - if values[1] < 0: - raise ValueError("resampleParams[1] must be greater than or equal to 0") - return values - @model_serializer def serialize(self): """Filter fields so only those applying to the chosen procedure are serialized.""" diff --git a/RATapi/examples/absorption/absorption.py b/RATapi/examples/absorption/absorption.py index 1593f78a..0656d1dd 100644 --- a/RATapi/examples/absorption/absorption.py +++ b/RATapi/examples/absorption/absorption.py @@ -150,7 +150,7 @@ def absorption(): ) # Now make a controls block and run the code - controls = RAT.Controls(parallel="contrasts", resampleParams=[0.9, 150.0]) + controls = RAT.Controls(parallel="contrasts", resampleNPoints=150) problem, results = RAT.run(problem, controls) return problem, results diff --git a/RATapi/inputs.py b/RATapi/inputs.py index c8ebdefc..724fb472 100644 --- a/RATapi/inputs.py +++ b/RATapi/inputs.py @@ -436,7 +436,8 @@ def make_controls(input_controls: RATapi.Controls, checks: Checks) -> Control: controls.procedure = input_controls.procedure controls.parallel = input_controls.parallel controls.calcSldDuringFit = input_controls.calcSldDuringFit - controls.resampleParams = input_controls.resampleParams + controls.resampleMinAngle = input_controls.resampleMinAngle + controls.resampleNPoints = input_controls.resampleNPoints controls.display = input_controls.display # Simplex controls.xTolerance = input_controls.xTolerance diff --git a/cpp/RAT b/cpp/RAT index 1593003e..00ba077e 160000 --- a/cpp/RAT +++ b/cpp/RAT @@ -1 +1 @@ -Subproject commit 1593003e3c260ed95aac1923a11caaa6ab62386b +Subproject commit 00ba077e59cda3f49ca2f75086dc7d44eb969a3b diff --git a/cpp/rat.cpp b/cpp/rat.cpp index a569894c..59f2bab2 100644 --- a/cpp/rat.cpp +++ b/cpp/rat.cpp @@ -509,7 +509,8 @@ struct Control { real_T propScale {}; real_T nsTolerance {}; boolean_T calcSldDuringFit {}; - py::array_t resampleParams; + real_T resampleMinAngle {}; + real_T resampleNPoints {}; real_T updateFreq {}; real_T updatePlotFreq {}; real_T nSamples {}; @@ -914,8 +915,8 @@ RAT::struct2_T createStruct2T(const Control& control) stringToRatArray(control.procedure, control_struct.procedure.data, control_struct.procedure.size); stringToRatArray(control.display, control_struct.display.data, control_struct.display.size); control_struct.xTolerance = control.xTolerance; - control_struct.resampleParams[0] = control.resampleParams.at(0); - control_struct.resampleParams[1] = control.resampleParams.at(1); + control_struct.resampleMinAngle = control.resampleMinAngle; + control_struct.resampleNPoints = control.resampleNPoints; stringToRatArray(control.boundHandling, control_struct.boundHandling.data, control_struct.boundHandling.size); control_struct.adaptPCR = control.adaptPCR; control_struct.checks = createStruct3(control.checks); @@ -1616,7 +1617,8 @@ PYBIND11_MODULE(rat_core, m) { .def_readwrite("propScale", &Control::propScale) .def_readwrite("nsTolerance", &Control::nsTolerance) .def_readwrite("calcSldDuringFit", &Control::calcSldDuringFit) - .def_readwrite("resampleParams", &Control::resampleParams) + .def_readwrite("resampleMinAngle", &Control::resampleMinAngle) + .def_readwrite("resampleNPoints", &Control::resampleNPoints) .def_readwrite("updateFreq", &Control::updateFreq) .def_readwrite("updatePlotFreq", &Control::updatePlotFreq) .def_readwrite("nSamples", &Control::nSamples) @@ -1633,14 +1635,14 @@ PYBIND11_MODULE(rat_core, m) { return py::make_tuple(ctrl.parallel, ctrl.procedure, ctrl.display, ctrl.xTolerance, ctrl.funcTolerance, ctrl.maxFuncEvals, ctrl.maxIterations, ctrl.populationSize, ctrl.fWeight, ctrl.crossoverProbability, ctrl.targetValue, ctrl.numGenerations, ctrl.strategy, ctrl.nLive, ctrl.nMCMC, ctrl.propScale, - ctrl.nsTolerance, ctrl.calcSldDuringFit, ctrl.resampleParams, ctrl.updateFreq, ctrl.updatePlotFreq, - ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, ctrl.pUnitGamma, ctrl.boundHandling, ctrl.adaptPCR, - ctrl.IPCFilePath, ctrl.checks.fitParam, ctrl.checks.fitBackgroundParam, ctrl.checks.fitQzshift, - ctrl.checks.fitScalefactor, ctrl.checks.fitBulkIn, ctrl.checks.fitBulkOut, + ctrl.nsTolerance, ctrl.calcSldDuringFit, ctrl.resampleMinAngle, ctrl.resampleNPoints, + ctrl.updateFreq, ctrl.updatePlotFreq, ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, ctrl.pUnitGamma, + ctrl.boundHandling, ctrl.adaptPCR, ctrl.IPCFilePath, ctrl.checks.fitParam, ctrl.checks.fitBackgroundParam, + ctrl.checks.fitQzshift, ctrl.checks.fitScalefactor, ctrl.checks.fitBulkIn, ctrl.checks.fitBulkOut, ctrl.checks.fitResolutionParam, ctrl.checks.fitDomainRatio); }, [](py::tuple t) { // __setstate__ - if (t.size() != 36) + if (t.size() != 37) throw std::runtime_error("Encountered invalid state unpickling ProblemDefinition object!"); /* Create a new C++ instance */ @@ -1664,25 +1666,26 @@ PYBIND11_MODULE(rat_core, m) { ctrl.propScale = t[15].cast(); ctrl.nsTolerance = t[16].cast(); ctrl.calcSldDuringFit = t[17].cast(); - ctrl.resampleParams = t[18].cast>(); - ctrl.updateFreq = t[19].cast(); - ctrl.updatePlotFreq = t[20].cast(); - ctrl.nSamples = t[21].cast(); - ctrl.nChains = t[22].cast(); - ctrl.jumpProbability = t[23].cast(); - ctrl.pUnitGamma = t[24].cast(); - ctrl.boundHandling = t[25].cast(); - ctrl.adaptPCR = t[26].cast(); - ctrl.IPCFilePath = t[27].cast(); + ctrl.resampleMinAngle = t[18].cast(); + ctrl.resampleNPoints = t[19].cast(); + ctrl.updateFreq = t[20].cast(); + ctrl.updatePlotFreq = t[21].cast(); + ctrl.nSamples = t[22].cast(); + ctrl.nChains = t[23].cast(); + ctrl.jumpProbability = t[24].cast(); + ctrl.pUnitGamma = t[25].cast(); + ctrl.boundHandling = t[26].cast(); + ctrl.adaptPCR = t[27].cast(); + ctrl.IPCFilePath = t[28].cast(); - ctrl.checks.fitParam = t[28].cast>(); - ctrl.checks.fitBackgroundParam = t[29].cast>(); - ctrl.checks.fitQzshift = t[30].cast>(); - ctrl.checks.fitScalefactor = t[31].cast>(); - ctrl.checks.fitBulkIn = t[32].cast>(); - ctrl.checks.fitBulkOut = t[33].cast>(); - ctrl.checks.fitResolutionParam = t[34].cast>(); - ctrl.checks.fitDomainRatio = t[35].cast>(); + ctrl.checks.fitParam = t[29].cast>(); + ctrl.checks.fitBackgroundParam = t[30].cast>(); + ctrl.checks.fitQzshift = t[31].cast>(); + ctrl.checks.fitScalefactor = t[32].cast>(); + ctrl.checks.fitBulkIn = t[33].cast>(); + ctrl.checks.fitBulkOut = t[34].cast>(); + ctrl.checks.fitResolutionParam = t[35].cast>(); + ctrl.checks.fitDomainRatio = t[36].cast>(); return ctrl; })); diff --git a/setup.py b/setup.py index b38e84f2..5997af53 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ from setuptools.command.build_clib import build_clib from setuptools.command.build_ext import build_ext -__version__ = "0.0.0.dev1" +__version__ = "0.0.0.dev2" PACKAGE_NAME = "RATapi" with open("README.md") as f: diff --git a/tests/test_controls.py b/tests/test_controls.py index d50b8ac7..82964abc 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -47,7 +47,8 @@ def table_str(self): "| procedure | calculate |\n" "| parallel | single |\n" "| calcSldDuringFit | False |\n" - "| resampleParams | [0.9, 50] |\n" + "| resampleMinAngle | 0.9 |\n" + "| resampleNPoints | 50 |\n" "| display | iter |\n" "+------------------+-----------+" ) @@ -59,7 +60,8 @@ def table_str(self): [ ("parallel", Parallel.Single), ("calcSldDuringFit", False), - ("resampleParams", [0.9, 50]), + ("resampleMinAngle", 0.9), + ("resampleNPoints", 50), ("display", Display.Iter), ("procedure", Procedures.Calculate), ], @@ -73,7 +75,8 @@ def test_calculate_property_values(self, control_property: str, value: Any) -> N [ ("parallel", Parallel.Points), ("calcSldDuringFit", True), - ("resampleParams", [0.2, 1]), + ("resampleMinAngle", 0.2), + ("resampleNPoints", 1), ("display", Display.Notify), ], ) @@ -180,31 +183,6 @@ def test_calculate_display_validation(self, value: Any) -> None: with pytest.raises(pydantic.ValidationError, match="Input should be 'off', 'iter', 'notify' or 'final'"): self.calculate.display = value - @pytest.mark.parametrize( - "value, msg", - [ - ([5.0], "List should have at least 2 items after validation, not 1"), - ([12, 13, 14], "List should have at most 2 items after validation, not 3"), - ], - ) - def test_calculate_resampleParams_length_validation(self, value: list, msg: str) -> None: - """Tests the resampleParams setter length validation in Calculate class.""" - with pytest.raises(pydantic.ValidationError, match=msg): - self.calculate.resampleParams = value - - @pytest.mark.parametrize( - "value, msg", - [ - ([1.0, 2], "Value error, resampleParams[0] must be between 0 and 1"), - ([0.5, -0.1], "Value error, resampleParams[1] must be greater than or equal to 0"), - ], - ) - def test_calculate_resampleParams_value_validation(self, value: list, msg: str) -> None: - """Tests the resampleParams setter value validation in Calculate class.""" - with pytest.raises(pydantic.ValidationError) as exp: - self.calculate.resampleParams = value - assert exp.value.errors()[0]["msg"] == msg - def test_str(self, table_str) -> None: """Tests the Calculate model __str__.""" assert self.calculate.__str__() == table_str @@ -220,21 +198,22 @@ def setup_class(self): @pytest.fixture def table_str(self): table_str = ( - "+------------------+-----------+\n" - "| Property | Value |\n" - "+------------------+-----------+\n" - "| procedure | simplex |\n" - "| parallel | single |\n" - "| calcSldDuringFit | False |\n" - "| resampleParams | [0.9, 50] |\n" - "| display | iter |\n" - "| xTolerance | 1e-06 |\n" - "| funcTolerance | 1e-06 |\n" - "| maxFuncEvals | 10000 |\n" - "| maxIterations | 1000 |\n" - "| updateFreq | 1 |\n" - "| updatePlotFreq | 20 |\n" - "+------------------+-----------+" + "+------------------+---------+\n" + "| Property | Value |\n" + "+------------------+---------+\n" + "| procedure | simplex |\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resampleMinAngle | 0.9 |\n" + "| resampleNPoints | 50 |\n" + "| display | iter |\n" + "| xTolerance | 1e-06 |\n" + "| funcTolerance | 1e-06 |\n" + "| maxFuncEvals | 10000 |\n" + "| maxIterations | 1000 |\n" + "| updateFreq | 1 |\n" + "| updatePlotFreq | 20 |\n" + "+------------------+---------+" ) return table_str @@ -244,7 +223,8 @@ def table_str(self): [ ("parallel", Parallel.Single), ("calcSldDuringFit", False), - ("resampleParams", [0.9, 50]), + ("resampleMinAngle", 0.9), + ("resampleNPoints", 50), ("display", Display.Iter), ("procedure", Procedures.Simplex), ("xTolerance", 1e-6), @@ -264,7 +244,8 @@ def test_simplex_property_values(self, control_property: str, value: Any) -> Non [ ("parallel", Parallel.Points), ("calcSldDuringFit", True), - ("resampleParams", [0.2, 1]), + ("resampleMinAngle", 0.2), + ("resampleNPoints", 1), ("display", Display.Notify), ("xTolerance", 4e-6), ("funcTolerance", 3e-4), @@ -380,7 +361,8 @@ def table_str(self): "| procedure | de |\n" "| parallel | single |\n" "| calcSldDuringFit | False |\n" - "| resampleParams | [0.9, 50] |\n" + "| resampleMinAngle | 0.9 |\n" + "| resampleNPoints | 50 |\n" "| display | iter |\n" "| populationSize | 20 |\n" "| fWeight | 0.5 |\n" @@ -400,7 +382,8 @@ def table_str(self): [ ("parallel", Parallel.Single), ("calcSldDuringFit", False), - ("resampleParams", [0.9, 50]), + ("resampleMinAngle", 0.9), + ("resampleNPoints", 50), ("display", Display.Iter), ("procedure", Procedures.DE), ("populationSize", 20), @@ -420,7 +403,8 @@ def test_de_property_values(self, control_property: str, value: Any) -> None: [ ("parallel", Parallel.Points), ("calcSldDuringFit", True), - ("resampleParams", [0.2, 1]), + ("resampleMinAngle", 0.2), + ("resampleNPoints", 1), ("display", Display.Notify), ("populationSize", 20), ("fWeight", 0.3), @@ -544,19 +528,20 @@ def setup_class(self): @pytest.fixture def table_str(self): table_str = ( - "+------------------+-----------+\n" - "| Property | Value |\n" - "+------------------+-----------+\n" - "| procedure | ns |\n" - "| parallel | single |\n" - "| calcSldDuringFit | False |\n" - "| resampleParams | [0.9, 50] |\n" - "| display | iter |\n" - "| nLive | 150 |\n" - "| nMCMC | 0 |\n" - "| propScale | 0.1 |\n" - "| nsTolerance | 0.1 |\n" - "+------------------+-----------+" + "+------------------+--------+\n" + "| Property | Value |\n" + "+------------------+--------+\n" + "| procedure | ns |\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resampleMinAngle | 0.9 |\n" + "| resampleNPoints | 50 |\n" + "| display | iter |\n" + "| nLive | 150 |\n" + "| nMCMC | 0 |\n" + "| propScale | 0.1 |\n" + "| nsTolerance | 0.1 |\n" + "+------------------+--------+" ) return table_str @@ -566,7 +551,8 @@ def table_str(self): [ ("parallel", Parallel.Single), ("calcSldDuringFit", False), - ("resampleParams", [0.9, 50]), + ("resampleMinAngle", 0.9), + ("resampleNPoints", 50), ("display", Display.Iter), ("procedure", Procedures.NS), ("nLive", 150), @@ -584,7 +570,8 @@ def test_ns_property_values(self, control_property: str, value: Any) -> None: [ ("parallel", Parallel.Points), ("calcSldDuringFit", True), - ("resampleParams", [0.2, 1]), + ("resampleMinAngle", 0.2), + ("resampleNPoints", 1), ("display", Display.Notify), ("nLive", 1500), ("nMCMC", 1), @@ -707,21 +694,22 @@ def setup_class(self): @pytest.fixture def table_str(self): table_str = ( - "+------------------+-----------+\n" - "| Property | Value |\n" - "+------------------+-----------+\n" - "| procedure | dream |\n" - "| parallel | single |\n" - "| calcSldDuringFit | False |\n" - "| resampleParams | [0.9, 50] |\n" - "| display | iter |\n" - "| nSamples | 20000 |\n" - "| nChains | 10 |\n" - "| jumpProbability | 0.5 |\n" - "| pUnitGamma | 0.2 |\n" - "| boundHandling | reflect |\n" - "| adaptPCR | True |\n" - "+------------------+-----------+" + "+------------------+---------+\n" + "| Property | Value |\n" + "+------------------+---------+\n" + "| procedure | dream |\n" + "| parallel | single |\n" + "| calcSldDuringFit | False |\n" + "| resampleMinAngle | 0.9 |\n" + "| resampleNPoints | 50 |\n" + "| display | iter |\n" + "| nSamples | 20000 |\n" + "| nChains | 10 |\n" + "| jumpProbability | 0.5 |\n" + "| pUnitGamma | 0.2 |\n" + "| boundHandling | reflect |\n" + "| adaptPCR | True |\n" + "+------------------+---------+" ) return table_str @@ -731,7 +719,8 @@ def table_str(self): [ ("parallel", Parallel.Single), ("calcSldDuringFit", False), - ("resampleParams", [0.9, 50]), + ("resampleMinAngle", 0.9), + ("resampleNPoints", 50), ("display", Display.Iter), ("procedure", Procedures.DREAM), ("nSamples", 20000), @@ -751,7 +740,8 @@ def test_dream_property_values(self, control_property: str, value: Any) -> None: [ ("parallel", Parallel.Points), ("calcSldDuringFit", True), - ("resampleParams", [0.2, 1]), + ("resampleMinAngle", 0.2), + ("resampleNPoints", 1), ("display", Display.Notify), ("nSamples", 500), ("nChains", 1000), diff --git a/tests/test_inputs.py b/tests/test_inputs.py index ed77e883..e850cd28 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -474,7 +474,8 @@ def standard_layers_controls(): controls.procedure = Procedures.Calculate controls.parallel = Parallel.Single controls.calcSldDuringFit = False - controls.resampleParams = [0.9, 50.0] + controls.resampleMinAngle = 0.9 + controls.resampleNPoints = 50.0 controls.display = Display.Iter controls.xTolerance = 1.0e-6 controls.funcTolerance = 1.0e-6 @@ -519,7 +520,8 @@ def custom_xy_controls(): controls.procedure = Procedures.Calculate controls.parallel = Parallel.Single controls.calcSldDuringFit = True - controls.resampleParams = [0.9, 50.0] + controls.resampleMinAngle = 0.9 + controls.resampleNPoints = 50.0 controls.display = Display.Iter controls.xTolerance = 1.0e-6 controls.funcTolerance = 1.0e-6 @@ -874,6 +876,8 @@ def check_controls_equal(actual_controls, expected_controls) -> None: "procedure", "parallel", "calcSldDuringFit", + "resampleMinAngle", + "resampleNPoints", "display", "xTolerance", "funcTolerance", @@ -909,8 +913,6 @@ def check_controls_equal(actual_controls, expected_controls) -> None: "fitDomainRatio", ] - # Check "resampleParams" separately as it is an array - assert (actual_controls.resampleParams == expected_controls.resampleParams).all() for field in controls_fields: assert getattr(actual_controls, field) == getattr(expected_controls, field) for field in checks_fields: