diff --git a/cpp/rat.cpp b/cpp/rat.cpp index 988d40e9..a569894c 100644 --- a/cpp/rat.cpp +++ b/cpp/rat.cpp @@ -1295,12 +1295,54 @@ PYBIND11_MODULE(rat_core, m) { .def_readwrite("resample", &PlotEventData::resample) .def_readwrite("dataPresent", &PlotEventData::dataPresent) .def_readwrite("modelType", &PlotEventData::modelType) - .def_readwrite("contrastNames", &PlotEventData::contrastNames); + .def_readwrite("contrastNames", &PlotEventData::contrastNames) + .def(py::pickle( + [](const PlotEventData &evt) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(evt.reflectivity, evt.shiftedData, evt.sldProfiles, evt.resampledLayers, evt.subRoughs, evt.resample, + evt.dataPresent, evt.modelType, evt.contrastNames); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 9) + throw std::runtime_error("Encountered invalid state unpickling PlotEventData object!"); + + /* Create a new C++ instance */ + PlotEventData evt; + + evt.reflectivity = t[0].cast(); + evt.shiftedData = t[1].cast(); + evt.sldProfiles = t[2].cast(); + evt.resampledLayers = t[3].cast(); + evt.subRoughs = t[4].cast>(); + evt.resample = t[5].cast>(); + evt.dataPresent = t[6].cast>(); + evt.modelType = t[7].cast(); + evt.contrastNames = t[8].cast(); + + return evt; + })); py::class_(m, "ProgressEventData") .def(py::init<>()) .def_readwrite("message", &ProgressEventData::message) - .def_readwrite("percent", &ProgressEventData::percent); + .def_readwrite("percent", &ProgressEventData::percent) + .def(py::pickle( + [](const ProgressEventData &evt) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(evt.message, evt.percent); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 2) + throw std::runtime_error("Encountered invalid state unpickling ProgressEventData object!"); + + /* Create a new C++ instance */ + ProgressEventData evt; + + evt.message = t[0].cast(); + evt.percent = t[1].cast(); + + return evt; + })); py::class_(m, "ConfidenceIntervals") .def(py::init<>()) @@ -1393,7 +1435,31 @@ PYBIND11_MODULE(rat_core, m) { .def_readwrite("fitBulkIn", &Checks::fitBulkIn) .def_readwrite("fitBulkOut", &Checks::fitBulkOut) .def_readwrite("fitResolutionParam", &Checks::fitResolutionParam) - .def_readwrite("fitDomainRatio", &Checks::fitDomainRatio); + .def_readwrite("fitDomainRatio", &Checks::fitDomainRatio) + .def(py::pickle( + [](const Checks &chk) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(chk.fitParam, chk.fitBackgroundParam, chk.fitQzshift, chk.fitScalefactor, chk.fitBulkIn, chk.fitBulkOut, + chk.fitResolutionParam, chk.fitDomainRatio); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 8) + throw std::runtime_error("Encountered invalid state unpickling Checks object!"); + + /* Create a new C++ instance */ + Checks chk; + + chk.fitParam = t[0].cast>(); + chk.fitBackgroundParam = t[1].cast>(); + chk.fitQzshift = t[2].cast>(); + chk.fitScalefactor = t[3].cast>(); + chk.fitBulkIn = t[4].cast>(); + chk.fitBulkOut = t[5].cast>(); + chk.fitResolutionParam = t[6].cast>(); + chk.fitDomainRatio = t[7].cast>(); + + return chk; + })); py::class_(m, "Limits") .def(py::init<>()) @@ -1404,7 +1470,31 @@ PYBIND11_MODULE(rat_core, m) { .def_readwrite("bulkIn", &Limits::bulkIn) .def_readwrite("bulkOut", &Limits::bulkOut) .def_readwrite("resolutionParam", &Limits::resolutionParam) - .def_readwrite("domainRatio", &Limits::domainRatio); + .def_readwrite("domainRatio", &Limits::domainRatio) + .def(py::pickle( + [](const Limits &lim) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(lim.param, lim.backgroundParam, lim.qzshift, lim.scalefactor, lim.bulkIn, lim.bulkOut, + lim.resolutionParam, lim.domainRatio); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 8) + throw std::runtime_error("Encountered invalid state unpickling Limits object!"); + + /* Create a new C++ instance */ + Limits lim; + + lim.param = t[0].cast>(); + lim.backgroundParam = t[1].cast>(); + lim.qzshift = t[2].cast>(); + lim.scalefactor = t[3].cast>(); + lim.bulkIn = t[4].cast>(); + lim.bulkOut = t[5].cast>(); + lim.resolutionParam = t[6].cast>(); + lim.domainRatio = t[7].cast>(); + + return lim; + })); py::class_(m, "Priors") .def(py::init<>()) @@ -1417,8 +1507,34 @@ PYBIND11_MODULE(rat_core, m) { .def_readwrite("resolutionParam", &Priors::resolutionParam) .def_readwrite("domainRatio", &Priors::domainRatio) .def_readwrite("priorNames", &Priors::priorNames) - .def_readwrite("priorValues", &Priors::priorValues); - + .def_readwrite("priorValues", &Priors::priorValues) + .def(py::pickle( + [](const Priors &prior) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(prior.param, prior.backgroundParam, prior.qzshift, prior.scalefactor, prior.bulkIn, + prior.bulkOut, prior.resolutionParam, prior.domainRatio, prior.priorNames, prior.priorValues); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 10) + throw std::runtime_error("Encountered invalid state unpickling Limits object!"); + + /* Create a new C++ instance */ + Priors prior; + + prior.param = t[0].cast(); + prior.backgroundParam = t[1].cast(); + prior.qzshift = t[2].cast(); + prior.scalefactor = t[3].cast(); + prior.bulkIn = t[4].cast(); + prior.bulkOut = t[5].cast(); + prior.resolutionParam = t[6].cast(); + prior.domainRatio = t[7].cast(); + prior.priorNames = t[8].cast(); + prior.priorValues = t[9].cast>(); + + return prior; + })); + py::class_(m, "Cells") .def(py::init<>()) .def_readwrite("f1", &Cells::f1) @@ -1441,7 +1557,44 @@ PYBIND11_MODULE(rat_core, m) { .def_readwrite("f18", &Cells::f18) .def_readwrite("f19", &Cells::f19) .def_readwrite("f20", &Cells::f20) - .def_readwrite("f21", &Cells::f21); + .def_readwrite("f21", &Cells::f21) + .def(py::pickle( + [](const Cells &cell) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(cell.f1, cell.f2, cell.f3, cell.f4, cell.f5, cell.f6, cell.f7, cell.f8, cell.f9, cell.f10, cell.f11, + cell.f12, cell.f13, cell.f14, cell.f15, cell.f16, cell.f17, cell.f18, cell.f19, cell.f20, cell.f21); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 21) + throw std::runtime_error("Encountered invalid state unpickling Cells object!"); + + /* Create a new C++ instance */ + Cells cell; + + cell.f1 = t[0].cast(); + cell.f2 = t[1].cast(); + cell.f3 = t[2].cast(); + cell.f4 = t[3].cast(); + cell.f5 = t[4].cast(); + cell.f6 = t[5].cast(); + cell.f7 = t[6].cast(); + cell.f8 = t[7].cast(); + cell.f9 = t[8].cast(); + cell.f10 = t[9].cast(); + cell.f11 = t[10].cast(); + cell.f12 = t[11].cast(); + cell.f13 = t[12].cast(); + cell.f14 = t[13].cast(); + cell.f15 = t[14].cast(); + cell.f16 = t[15].cast(); + cell.f17 = t[16].cast(); + cell.f18 = t[17].cast(); + cell.f19 = t[18].cast(); + cell.f20 = t[19].cast(); + cell.f21 = t[20].cast(); + + return cell; + })); py::class_(m, "Control") .def(py::init<>()) @@ -1473,8 +1626,67 @@ PYBIND11_MODULE(rat_core, m) { .def_readwrite("boundHandling", &Control::boundHandling) .def_readwrite("adaptPCR", &Control::adaptPCR) .def_readwrite("checks", &Control::checks) - .def_readwrite("IPCFilePath", &Control::IPCFilePath); - + .def_readwrite("IPCFilePath", &Control::IPCFilePath) + .def(py::pickle( + [](const Control &ctrl) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(ctrl.parallel, ctrl.procedure, ctrl.display, ctrl.xTolerance, ctrl.funcTolerance, + ctrl.maxFuncEvals, ctrl.maxIterations, ctrl.populationSize, ctrl.fWeight, ctrl.crossoverProbability, + ctrl.targetValue, ctrl.numGenerations, ctrl.strategy, ctrl.nLive, ctrl.nMCMC, ctrl.propScale, + ctrl.nsTolerance, ctrl.calcSldDuringFit, ctrl.resampleParams, ctrl.updateFreq, ctrl.updatePlotFreq, + ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, ctrl.pUnitGamma, ctrl.boundHandling, ctrl.adaptPCR, + ctrl.IPCFilePath, ctrl.checks.fitParam, ctrl.checks.fitBackgroundParam, ctrl.checks.fitQzshift, + ctrl.checks.fitScalefactor, ctrl.checks.fitBulkIn, ctrl.checks.fitBulkOut, + ctrl.checks.fitResolutionParam, ctrl.checks.fitDomainRatio); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 36) + throw std::runtime_error("Encountered invalid state unpickling ProblemDefinition object!"); + + /* Create a new C++ instance */ + Control ctrl; + + ctrl.parallel = t[0].cast(); + ctrl.procedure = t[1].cast(); + ctrl.display = t[2].cast(); + ctrl.xTolerance = t[3].cast(); + ctrl.funcTolerance = t[4].cast(); + ctrl.maxFuncEvals = t[5].cast(); + ctrl.maxIterations = t[6].cast(); + ctrl.populationSize = t[7].cast(); + ctrl.fWeight = t[8].cast(); + ctrl.crossoverProbability = t[9].cast(); + ctrl.targetValue = t[10].cast(); + ctrl.numGenerations = t[11].cast(); + ctrl.strategy = t[12].cast(); + ctrl.nLive = t[13].cast(); + ctrl.nMCMC = t[14].cast(); + ctrl.propScale = t[15].cast(); + ctrl.nsTolerance = t[16].cast(); + ctrl.calcSldDuringFit = t[17].cast(); + ctrl.resampleParams = t[18].cast>(); + ctrl.updateFreq = t[19].cast(); + ctrl.updatePlotFreq = t[20].cast(); + ctrl.nSamples = t[21].cast(); + ctrl.nChains = t[22].cast(); + ctrl.jumpProbability = t[23].cast(); + ctrl.pUnitGamma = t[24].cast(); + ctrl.boundHandling = t[25].cast(); + ctrl.adaptPCR = t[26].cast(); + ctrl.IPCFilePath = t[27].cast(); + + ctrl.checks.fitParam = t[28].cast>(); + ctrl.checks.fitBackgroundParam = t[29].cast>(); + ctrl.checks.fitQzshift = t[30].cast>(); + ctrl.checks.fitScalefactor = t[31].cast>(); + ctrl.checks.fitBulkIn = t[32].cast>(); + ctrl.checks.fitBulkOut = t[33].cast>(); + ctrl.checks.fitResolutionParam = t[34].cast>(); + ctrl.checks.fitDomainRatio = t[35].cast>(); + + return ctrl; + })); + py::class_(m, "ProblemDefinition") .def(py::init<>()) .def_readwrite("contrastBackgroundParams", &ProblemDefinition::contrastBackgroundParams) @@ -1507,7 +1719,58 @@ PYBIND11_MODULE(rat_core, m) { .def_readwrite("fitParams", &ProblemDefinition::fitParams) .def_readwrite("otherParams", &ProblemDefinition::otherParams) .def_readwrite("fitLimits", &ProblemDefinition::fitLimits) - .def_readwrite("otherLimits", &ProblemDefinition::otherLimits); + .def_readwrite("otherLimits", &ProblemDefinition::otherLimits) + .def(py::pickle( + [](const ProblemDefinition &p) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(p.contrastBackgroundParams, p.contrastBackgroundActions, p.TF, p.resample, p.dataPresent, p.oilChiDataPresent, + p.numberOfContrasts, p.geometry, p.useImaginary, p.contrastQzshifts, p.contrastScalefactors, + p.contrastBulkIns, p.contrastBulkOuts, p.contrastResolutionParams, p.backgroundParams, + p.qzshifts, p.scalefactors, p.bulkIn, p.bulkOut, p.resolutionParams, p.params, + p.numberOfLayers, p.modelType, p.contrastCustomFiles, p.contrastDomainRatios, + p.domainRatio, p.numberOfDomainContrasts, p.fitParams, p.otherParams, p.fitLimits, p.otherLimits); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 31) + throw std::runtime_error("Encountered invalid state unpickling ProblemDefinition object!"); + + /* Create a new C++ instance */ + ProblemDefinition p; + + p.contrastBackgroundParams = t[0].cast>(); + p.contrastBackgroundActions = t[1].cast>(); + p.TF = t[2].cast(); + p.resample = t[3].cast>(); + p.dataPresent = t[4].cast>(); + p.oilChiDataPresent = t[5].cast>(); + p.numberOfContrasts = t[6].cast(); + p.geometry = t[7].cast(); + p.useImaginary = t[8].cast(); + p.contrastQzshifts = t[9].cast>(); + p.contrastScalefactors = t[10].cast>(); + p.contrastBulkIns = t[11].cast>(); + p.contrastBulkOuts = t[12].cast>(); + p.contrastResolutionParams = t[13].cast>(); + p.backgroundParams = t[14].cast>(); + p.qzshifts = t[15].cast>(); + p.scalefactors = t[16].cast>(); + p.bulkIn= t[17].cast>(); + p.bulkOut= t[18].cast>(); + p.resolutionParams= t[19].cast>(); + p.params = t[20].cast>(), + p.numberOfLayers = t[21].cast(); + p.modelType = t[22].cast(); + p.contrastCustomFiles = t[23].cast>(); + p.contrastDomainRatios = t[24].cast>(), + p.domainRatio = t[25].cast>(); + p.numberOfDomainContrasts = t[26].cast(); + p.fitParams = t[27].cast>(); + p.otherParams = t[28].cast>(); + p.fitLimits = t[29].cast>(); + p.otherLimits = t[30].cast>(); + + return p; + })); m.def("RATMain", &RATMain, "Entry point for the main reflectivity computation."); diff --git a/tests/test_events.py b/tests/test_events.py index 185f7800..551b9d42 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1,5 +1,7 @@ +import pickle from unittest import mock +import numpy as np import pytest import RATapi.events @@ -64,3 +66,36 @@ def test_event_notify() -> None: assert first_callback.call_count == 1 assert second_callback.call_count == 1 assert third_callback.call_count == 1 + + +def test_event_data_pickle(): + data = RATapi.events.ProgressEventData() + data.message = "Hello" + data.percent = 0.5 + pickled_data = pickle.loads(pickle.dumps(data)) + assert pickled_data.message == data.message + assert pickled_data.percent == data.percent + + data = RATapi.events.PlotEventData() + data.modelType = "custom layers" + data.dataPresent = np.ones(2) + data.subRoughs = np.ones((20, 2)) + data.resample = np.ones(2) + data.resampledLayers = [np.ones((20, 2)), np.ones((20, 2))] + data.reflectivity = [np.ones((20, 2)), np.ones((20, 2))] + data.shiftedData = [np.ones((20, 2)), np.ones((20, 2))] + data.sldProfiles = [np.ones((20, 2)), np.ones((20, 2))] + data.contrastNames = ["D2O", "SMW"] + + pickled_data = pickle.loads(pickle.dumps(data)) + + assert pickled_data.modelType == data.modelType + assert (pickled_data.dataPresent == data.dataPresent).all() + assert (pickled_data.subRoughs == data.subRoughs).all() + assert (pickled_data.resample == data.resample).all() + for i in range(2): + assert (pickled_data.resampledLayers[i] == data.resampledLayers[i]).all() + assert (pickled_data.reflectivity[i] == data.reflectivity[i]).all() + assert (pickled_data.shiftedData[i] == data.shiftedData[i]).all() + assert (pickled_data.sldProfiles[i] == data.sldProfiles[i]).all() + assert pickled_data.contrastNames == data.contrastNames diff --git a/tests/test_inputs.py b/tests/test_inputs.py index a3ae2b7e..ed77e883 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -1,6 +1,7 @@ """Test the inputs module.""" import pathlib +import pickle from itertools import chain from unittest import mock @@ -640,18 +641,23 @@ def test_make_input(test_project, test_problem, test_cells, test_limits, test_pr ), mock.patch.object(RATapi.wrappers.DylibWrapper, "getHandle", mock.MagicMock(return_value=dummy_function)): problem, cells, limits, priors, controls = make_input(test_project, RATapi.Controls()) + problem = pickle.loads(pickle.dumps(problem)) check_problem_equal(problem, test_problem) + cells = pickle.loads(pickle.dumps(cells)) check_cells_equal(cells, test_cells) + limits = pickle.loads(pickle.dumps(limits)) for limit_field in parameter_fields: assert (getattr(limits, limit_field) == getattr(test_limits, limit_field)).all() + priors = pickle.loads(pickle.dumps(priors)) for prior_field in parameter_fields: assert getattr(priors, prior_field) == getattr(test_priors, prior_field) assert priors.priorNames == test_priors.priorNames assert (priors.priorValues == test_priors.priorValues).all() + controls = pickle.loads(pickle.dumps(controls)) check_controls_equal(controls, test_controls)