Skip to content

Commit

Permalink
Makes the arguments from RAT_main and events data pickleable (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenNneji authored Sep 4, 2024
1 parent b3452f3 commit fe0a45e
Show file tree
Hide file tree
Showing 3 changed files with 314 additions and 10 deletions.
283 changes: 273 additions & 10 deletions cpp/rat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1295,12 +1295,54 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("resample", &PlotEventData::resample)
.def_readwrite("dataPresent", &PlotEventData::dataPresent)
.def_readwrite("modelType", &PlotEventData::modelType)
.def_readwrite("contrastNames", &PlotEventData::contrastNames);
.def_readwrite("contrastNames", &PlotEventData::contrastNames)
.def(py::pickle(
[](const PlotEventData &evt) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(evt.reflectivity, evt.shiftedData, evt.sldProfiles, evt.resampledLayers, evt.subRoughs, evt.resample,
evt.dataPresent, evt.modelType, evt.contrastNames);
},
[](py::tuple t) { // __setstate__
if (t.size() != 9)
throw std::runtime_error("Encountered invalid state unpickling PlotEventData object!");

/* Create a new C++ instance */
PlotEventData evt;

evt.reflectivity = t[0].cast<py::list>();
evt.shiftedData = t[1].cast<py::list>();
evt.sldProfiles = t[2].cast<py::list>();
evt.resampledLayers = t[3].cast<py::list>();
evt.subRoughs = t[4].cast<py::array_t<double>>();
evt.resample = t[5].cast<py::array_t<double>>();
evt.dataPresent = t[6].cast<py::array_t<double>>();
evt.modelType = t[7].cast<std::string>();
evt.contrastNames = t[8].cast<py::list>();

return evt;
}));

py::class_<ProgressEventData>(m, "ProgressEventData")
.def(py::init<>())
.def_readwrite("message", &ProgressEventData::message)
.def_readwrite("percent", &ProgressEventData::percent);
.def_readwrite("percent", &ProgressEventData::percent)
.def(py::pickle(
[](const ProgressEventData &evt) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(evt.message, evt.percent);
},
[](py::tuple t) { // __setstate__
if (t.size() != 2)
throw std::runtime_error("Encountered invalid state unpickling ProgressEventData object!");

/* Create a new C++ instance */
ProgressEventData evt;

evt.message = t[0].cast<std::string>();
evt.percent = t[1].cast<double>();

return evt;
}));

py::class_<ConfidenceIntervals>(m, "ConfidenceIntervals")
.def(py::init<>())
Expand Down Expand Up @@ -1393,7 +1435,31 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("fitBulkIn", &Checks::fitBulkIn)
.def_readwrite("fitBulkOut", &Checks::fitBulkOut)
.def_readwrite("fitResolutionParam", &Checks::fitResolutionParam)
.def_readwrite("fitDomainRatio", &Checks::fitDomainRatio);
.def_readwrite("fitDomainRatio", &Checks::fitDomainRatio)
.def(py::pickle(
[](const Checks &chk) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(chk.fitParam, chk.fitBackgroundParam, chk.fitQzshift, chk.fitScalefactor, chk.fitBulkIn, chk.fitBulkOut,
chk.fitResolutionParam, chk.fitDomainRatio);
},
[](py::tuple t) { // __setstate__
if (t.size() != 8)
throw std::runtime_error("Encountered invalid state unpickling Checks object!");

/* Create a new C++ instance */
Checks chk;

chk.fitParam = t[0].cast<py::array_t<real_T>>();
chk.fitBackgroundParam = t[1].cast<py::array_t<real_T>>();
chk.fitQzshift = t[2].cast<py::array_t<real_T>>();
chk.fitScalefactor = t[3].cast<py::array_t<real_T>>();
chk.fitBulkIn = t[4].cast<py::array_t<real_T>>();
chk.fitBulkOut = t[5].cast<py::array_t<real_T>>();
chk.fitResolutionParam = t[6].cast<py::array_t<real_T>>();
chk.fitDomainRatio = t[7].cast<py::array_t<real_T>>();

return chk;
}));

py::class_<Limits>(m, "Limits")
.def(py::init<>())
Expand All @@ -1404,7 +1470,31 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("bulkIn", &Limits::bulkIn)
.def_readwrite("bulkOut", &Limits::bulkOut)
.def_readwrite("resolutionParam", &Limits::resolutionParam)
.def_readwrite("domainRatio", &Limits::domainRatio);
.def_readwrite("domainRatio", &Limits::domainRatio)
.def(py::pickle(
[](const Limits &lim) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(lim.param, lim.backgroundParam, lim.qzshift, lim.scalefactor, lim.bulkIn, lim.bulkOut,
lim.resolutionParam, lim.domainRatio);
},
[](py::tuple t) { // __setstate__
if (t.size() != 8)
throw std::runtime_error("Encountered invalid state unpickling Limits object!");

/* Create a new C++ instance */
Limits lim;

lim.param = t[0].cast<py::array_t<real_T>>();
lim.backgroundParam = t[1].cast<py::array_t<real_T>>();
lim.qzshift = t[2].cast<py::array_t<real_T>>();
lim.scalefactor = t[3].cast<py::array_t<real_T>>();
lim.bulkIn = t[4].cast<py::array_t<real_T>>();
lim.bulkOut = t[5].cast<py::array_t<real_T>>();
lim.resolutionParam = t[6].cast<py::array_t<real_T>>();
lim.domainRatio = t[7].cast<py::array_t<real_T>>();

return lim;
}));

py::class_<Priors>(m, "Priors")
.def(py::init<>())
Expand All @@ -1417,8 +1507,34 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("resolutionParam", &Priors::resolutionParam)
.def_readwrite("domainRatio", &Priors::domainRatio)
.def_readwrite("priorNames", &Priors::priorNames)
.def_readwrite("priorValues", &Priors::priorValues);

.def_readwrite("priorValues", &Priors::priorValues)
.def(py::pickle(
[](const Priors &prior) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(prior.param, prior.backgroundParam, prior.qzshift, prior.scalefactor, prior.bulkIn,
prior.bulkOut, prior.resolutionParam, prior.domainRatio, prior.priorNames, prior.priorValues);
},
[](py::tuple t) { // __setstate__
if (t.size() != 10)
throw std::runtime_error("Encountered invalid state unpickling Limits object!");

/* Create a new C++ instance */
Priors prior;

prior.param = t[0].cast<py::list>();
prior.backgroundParam = t[1].cast<py::list>();
prior.qzshift = t[2].cast<py::list>();
prior.scalefactor = t[3].cast<py::list>();
prior.bulkIn = t[4].cast<py::list>();
prior.bulkOut = t[5].cast<py::list>();
prior.resolutionParam = t[6].cast<py::list>();
prior.domainRatio = t[7].cast<py::list>();
prior.priorNames = t[8].cast<py::list>();
prior.priorValues = t[9].cast<py::array_t<real_T>>();

return prior;
}));

py::class_<Cells>(m, "Cells")
.def(py::init<>())
.def_readwrite("f1", &Cells::f1)
Expand All @@ -1441,7 +1557,44 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("f18", &Cells::f18)
.def_readwrite("f19", &Cells::f19)
.def_readwrite("f20", &Cells::f20)
.def_readwrite("f21", &Cells::f21);
.def_readwrite("f21", &Cells::f21)
.def(py::pickle(
[](const Cells &cell) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(cell.f1, cell.f2, cell.f3, cell.f4, cell.f5, cell.f6, cell.f7, cell.f8, cell.f9, cell.f10, cell.f11,
cell.f12, cell.f13, cell.f14, cell.f15, cell.f16, cell.f17, cell.f18, cell.f19, cell.f20, cell.f21);
},
[](py::tuple t) { // __setstate__
if (t.size() != 21)
throw std::runtime_error("Encountered invalid state unpickling Cells object!");

/* Create a new C++ instance */
Cells cell;

cell.f1 = t[0].cast<py::list>();
cell.f2 = t[1].cast<py::list>();
cell.f3 = t[2].cast<py::list>();
cell.f4 = t[3].cast<py::list>();
cell.f5 = t[4].cast<py::list>();
cell.f6 = t[5].cast<py::list>();
cell.f7 = t[6].cast<py::list>();
cell.f8 = t[7].cast<py::list>();
cell.f9 = t[8].cast<py::list>();
cell.f10 = t[9].cast<py::list>();
cell.f11 = t[10].cast<py::list>();
cell.f12 = t[11].cast<py::list>();
cell.f13 = t[12].cast<py::list>();
cell.f14 = t[13].cast<py::list>();
cell.f15 = t[14].cast<py::list>();
cell.f16 = t[15].cast<py::list>();
cell.f17 = t[16].cast<py::list>();
cell.f18 = t[17].cast<py::list>();
cell.f19 = t[18].cast<py::list>();
cell.f20 = t[19].cast<py::list>();
cell.f21 = t[20].cast<py::list>();

return cell;
}));

py::class_<Control>(m, "Control")
.def(py::init<>())
Expand Down Expand Up @@ -1473,8 +1626,67 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("boundHandling", &Control::boundHandling)
.def_readwrite("adaptPCR", &Control::adaptPCR)
.def_readwrite("checks", &Control::checks)
.def_readwrite("IPCFilePath", &Control::IPCFilePath);

.def_readwrite("IPCFilePath", &Control::IPCFilePath)
.def(py::pickle(
[](const Control &ctrl) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(ctrl.parallel, ctrl.procedure, ctrl.display, ctrl.xTolerance, ctrl.funcTolerance,
ctrl.maxFuncEvals, ctrl.maxIterations, ctrl.populationSize, ctrl.fWeight, ctrl.crossoverProbability,
ctrl.targetValue, ctrl.numGenerations, ctrl.strategy, ctrl.nLive, ctrl.nMCMC, ctrl.propScale,
ctrl.nsTolerance, ctrl.calcSldDuringFit, ctrl.resampleParams, ctrl.updateFreq, ctrl.updatePlotFreq,
ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, ctrl.pUnitGamma, ctrl.boundHandling, ctrl.adaptPCR,
ctrl.IPCFilePath, ctrl.checks.fitParam, ctrl.checks.fitBackgroundParam, ctrl.checks.fitQzshift,
ctrl.checks.fitScalefactor, ctrl.checks.fitBulkIn, ctrl.checks.fitBulkOut,
ctrl.checks.fitResolutionParam, ctrl.checks.fitDomainRatio);
},
[](py::tuple t) { // __setstate__
if (t.size() != 36)
throw std::runtime_error("Encountered invalid state unpickling ProblemDefinition object!");

/* Create a new C++ instance */
Control ctrl;

ctrl.parallel = t[0].cast<std::string>();
ctrl.procedure = t[1].cast<std::string>();
ctrl.display = t[2].cast<std::string>();
ctrl.xTolerance = t[3].cast<real_T>();
ctrl.funcTolerance = t[4].cast<real_T>();
ctrl.maxFuncEvals = t[5].cast<real_T>();
ctrl.maxIterations = t[6].cast<real_T>();
ctrl.populationSize = t[7].cast<real_T>();
ctrl.fWeight = t[8].cast<real_T>();
ctrl.crossoverProbability = t[9].cast<real_T>();
ctrl.targetValue = t[10].cast<real_T>();
ctrl.numGenerations = t[11].cast<real_T>();
ctrl.strategy = t[12].cast<real_T>();
ctrl.nLive = t[13].cast<real_T>();
ctrl.nMCMC = t[14].cast<real_T>();
ctrl.propScale = t[15].cast<real_T>();
ctrl.nsTolerance = t[16].cast<real_T>();
ctrl.calcSldDuringFit = t[17].cast<boolean_T>();
ctrl.resampleParams = t[18].cast<py::array_t<real_T>>();
ctrl.updateFreq = t[19].cast<real_T>();
ctrl.updatePlotFreq = t[20].cast<real_T>();
ctrl.nSamples = t[21].cast<real_T>();
ctrl.nChains = t[22].cast<real_T>();
ctrl.jumpProbability = t[23].cast<real_T>();
ctrl.pUnitGamma = t[24].cast<real_T>();
ctrl.boundHandling = t[25].cast<std::string>();
ctrl.adaptPCR = t[26].cast<boolean_T>();
ctrl.IPCFilePath = t[27].cast<std::string>();

ctrl.checks.fitParam = t[28].cast<py::array_t<real_T>>();
ctrl.checks.fitBackgroundParam = t[29].cast<py::array_t<real_T>>();
ctrl.checks.fitQzshift = t[30].cast<py::array_t<real_T>>();
ctrl.checks.fitScalefactor = t[31].cast<py::array_t<real_T>>();
ctrl.checks.fitBulkIn = t[32].cast<py::array_t<real_T>>();
ctrl.checks.fitBulkOut = t[33].cast<py::array_t<real_T>>();
ctrl.checks.fitResolutionParam = t[34].cast<py::array_t<real_T>>();
ctrl.checks.fitDomainRatio = t[35].cast<py::array_t<real_T>>();

return ctrl;
}));

py::class_<ProblemDefinition>(m, "ProblemDefinition")
.def(py::init<>())
.def_readwrite("contrastBackgroundParams", &ProblemDefinition::contrastBackgroundParams)
Expand Down Expand Up @@ -1507,7 +1719,58 @@ PYBIND11_MODULE(rat_core, m) {
.def_readwrite("fitParams", &ProblemDefinition::fitParams)
.def_readwrite("otherParams", &ProblemDefinition::otherParams)
.def_readwrite("fitLimits", &ProblemDefinition::fitLimits)
.def_readwrite("otherLimits", &ProblemDefinition::otherLimits);
.def_readwrite("otherLimits", &ProblemDefinition::otherLimits)
.def(py::pickle(
[](const ProblemDefinition &p) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(p.contrastBackgroundParams, p.contrastBackgroundActions, p.TF, p.resample, p.dataPresent, p.oilChiDataPresent,
p.numberOfContrasts, p.geometry, p.useImaginary, p.contrastQzshifts, p.contrastScalefactors,
p.contrastBulkIns, p.contrastBulkOuts, p.contrastResolutionParams, p.backgroundParams,
p.qzshifts, p.scalefactors, p.bulkIn, p.bulkOut, p.resolutionParams, p.params,
p.numberOfLayers, p.modelType, p.contrastCustomFiles, p.contrastDomainRatios,
p.domainRatio, p.numberOfDomainContrasts, p.fitParams, p.otherParams, p.fitLimits, p.otherLimits);
},
[](py::tuple t) { // __setstate__
if (t.size() != 31)
throw std::runtime_error("Encountered invalid state unpickling ProblemDefinition object!");

/* Create a new C++ instance */
ProblemDefinition p;

p.contrastBackgroundParams = t[0].cast<py::array_t<real_T>>();
p.contrastBackgroundActions = t[1].cast<py::array_t<real_T>>();
p.TF = t[2].cast<std::string>();
p.resample = t[3].cast<py::array_t<real_T>>();
p.dataPresent = t[4].cast<py::array_t<real_T>>();
p.oilChiDataPresent = t[5].cast<py::array_t<real_T>>();
p.numberOfContrasts = t[6].cast<real_T>();
p.geometry = t[7].cast<std::string>();
p.useImaginary = t[8].cast<bool>();
p.contrastQzshifts = t[9].cast<py::array_t<real_T>>();
p.contrastScalefactors = t[10].cast<py::array_t<real_T>>();
p.contrastBulkIns = t[11].cast<py::array_t<real_T>>();
p.contrastBulkOuts = t[12].cast<py::array_t<real_T>>();
p.contrastResolutionParams = t[13].cast<py::array_t<real_T>>();
p.backgroundParams = t[14].cast<py::array_t<real_T>>();
p.qzshifts = t[15].cast<py::array_t<real_T>>();
p.scalefactors = t[16].cast<py::array_t<real_T>>();
p.bulkIn= t[17].cast<py::array_t<real_T>>();
p.bulkOut= t[18].cast<py::array_t<real_T>>();
p.resolutionParams= t[19].cast<py::array_t<real_T>>();
p.params = t[20].cast<py::array_t<real_T>>(),
p.numberOfLayers = t[21].cast<real_T>();
p.modelType = t[22].cast<std::string>();
p.contrastCustomFiles = t[23].cast<py::array_t<real_T>>();
p.contrastDomainRatios = t[24].cast<py::array_t<real_T>>(),
p.domainRatio = t[25].cast<py::array_t<real_T>>();
p.numberOfDomainContrasts = t[26].cast<real_T>();
p.fitParams = t[27].cast<py::array_t<real_T>>();
p.otherParams = t[28].cast<py::array_t<real_T>>();
p.fitLimits = t[29].cast<py::array_t<real_T>>();
p.otherLimits = t[30].cast<py::array_t<real_T>>();

return p;
}));

m.def("RATMain", &RATMain, "Entry point for the main reflectivity computation.");

Expand Down
35 changes: 35 additions & 0 deletions tests/test_events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pickle
from unittest import mock

import numpy as np
import pytest

import RATapi.events
Expand Down Expand Up @@ -64,3 +66,36 @@ def test_event_notify() -> None:
assert first_callback.call_count == 1
assert second_callback.call_count == 1
assert third_callback.call_count == 1


def test_event_data_pickle():
data = RATapi.events.ProgressEventData()
data.message = "Hello"
data.percent = 0.5
pickled_data = pickle.loads(pickle.dumps(data))
assert pickled_data.message == data.message
assert pickled_data.percent == data.percent

data = RATapi.events.PlotEventData()
data.modelType = "custom layers"
data.dataPresent = np.ones(2)
data.subRoughs = np.ones((20, 2))
data.resample = np.ones(2)
data.resampledLayers = [np.ones((20, 2)), np.ones((20, 2))]
data.reflectivity = [np.ones((20, 2)), np.ones((20, 2))]
data.shiftedData = [np.ones((20, 2)), np.ones((20, 2))]
data.sldProfiles = [np.ones((20, 2)), np.ones((20, 2))]
data.contrastNames = ["D2O", "SMW"]

pickled_data = pickle.loads(pickle.dumps(data))

assert pickled_data.modelType == data.modelType
assert (pickled_data.dataPresent == data.dataPresent).all()
assert (pickled_data.subRoughs == data.subRoughs).all()
assert (pickled_data.resample == data.resample).all()
for i in range(2):
assert (pickled_data.resampledLayers[i] == data.resampledLayers[i]).all()
assert (pickled_data.reflectivity[i] == data.reflectivity[i]).all()
assert (pickled_data.shiftedData[i] == data.shiftedData[i]).all()
assert (pickled_data.sldProfiles[i] == data.sldProfiles[i]).all()
assert pickled_data.contrastNames == data.contrastNames
Loading

0 comments on commit fe0a45e

Please sign in to comment.