Skip to content

Commit

Permalink
Adds a stop event that works via a temp IPC file (#255)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenNneji authored Aug 12, 2024
1 parent 704a5e6 commit feb0c6d
Show file tree
Hide file tree
Showing 45 changed files with 253 additions and 252 deletions.
4 changes: 2 additions & 2 deletions 3rdParty/mcmcstat/%mcmcrun_compile.m
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@

rej=0; reju=0; ii=1; rejl = 0;
%% setup waitbar
if (wbarupd && display); triggerEvent(coderEnums.eventTypes.Progress, 'init', 0); end
if (wbarupd && display); triggerEvent(coderEnums.eventTypes.Progress, 'Bayes', 0); end

% covariance update uses these to store previous values
covchain = []; meanchain = []; wsum = initqcovn; lasti = 0;
Expand Down Expand Up @@ -914,7 +914,7 @@
end
end

triggerEvent(coderEnums.eventTypes.Progress, 'end', 1);
triggerEvent(coderEnums.eventTypes.Progress, 'Bayes', 1);



Expand Down
4 changes: 2 additions & 2 deletions 3rdParty/mcmcstat/mcmcrun.m
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@

rej=0; reju=0; ii=1; rejl = 0;
%% setup waitbar
if wbarupd; triggerEvent(coderEnums.eventTypes.Progress, 'init', 0); end
if wbarupd; triggerEvent(coderEnums.eventTypes.Progress, 'Bayes', 0); end

% covariance update uses these to store previous values
covchain = []; meanchain = []; wsum = initqcovn; lasti = 0;
Expand Down Expand Up @@ -841,7 +841,7 @@
end
end

triggerEvent(coderEnums.eventTypes.Progress, 'end', 1);
triggerEvent(coderEnums.eventTypes.Progress, 'Bayes', 1);



Expand Down
4 changes: 2 additions & 2 deletions 3rdParty/mcmcstat/mcmcrun_compile.m
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@

rej=0; reju=0; ii=1; rejl = 0;
%% setup waitbar
if wbarupd; triggerEvent(coderEnums.eventTypes.Progress, 'init', 0); end
if wbarupd; triggerEvent(coderEnums.eventTypes.Progress, 'Bayes', 0); end

% covariance update uses these to store previous values
covchain = []; meanchain = []; wsum = initqcovn; lasti = 0;
Expand Down Expand Up @@ -887,7 +887,7 @@
end
end

triggerEvent(coderEnums.eventTypes.Progress, 'end', 1);
triggerEvent(coderEnums.eventTypes.Progress, 'Bayes', 1);



Expand Down
4 changes: 2 additions & 2 deletions 3rdParty/mcmcstat/mcmcrun_compile_scaled.m
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@

rej=0; reju=0; ii=1; rejl = 0;
%% setup waitbar
if wbarupd; triggerEvent(coderEnums.eventTypes.Progress, 'init', 0); end
if wbarupd; triggerEvent(coderEnums.eventTypes.Progress, 'Bayes', 0); end

% covariance update uses these to store previous values
covchain = []; meanchain = []; wsum = initqcovn; lasti = 0;
Expand Down Expand Up @@ -887,7 +887,7 @@
end
end

triggerEvent(coderEnums.eventTypes.Progress, 'end', 1);
triggerEvent(coderEnums.eventTypes.Progress, 'Bayes', 1);



Expand Down
9 changes: 5 additions & 4 deletions API/RAT.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@
end

% Call the main RAT routine...

display = ~strcmpi(controls.display, displayOptions.Off.value);
textProgressBar(0, 0, display);
% If display is not silent print a line confirming RAT is starting
if ~strcmpi(controls.display, displayOptions.Off.value)
if display
fprintf('Starting RAT ________________________________________________________________________________________________\n\n');
end

tic
[problemStruct,result,bayesResults] = RATMain_mex(problemStruct,problemCells,problemLimits,controls,priors);

if ~strcmpi(controls.display, displayOptions.Off.value)
if display
toc
end

textProgressBar(0, 0, true);
if any(strcmpi(controls.procedure, {procedures.NS.value, procedures.Dream.value}))
result = mergeStructs(result, bayesResults);
end
Expand Down
2 changes: 1 addition & 1 deletion API/RATMain.m
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function [problemStruct,result,bayesResults] = RATMain(problemStruct,problemCells,problemLimits,controls,priors)
coderEnums.initialize()
coderEnums.initialise()

if strcmpi(problemStruct.TF, coderEnums.calculationTypes.Domains)
domains = true;
Expand Down
93 changes: 65 additions & 28 deletions API/controlsClass.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
classdef controlsClass < matlab.mixin.CustomDisplay
classdef controlsClass < handle & matlab.mixin.CustomDisplay

properties
% Parallelisation Option (Default: parallelOptions.Single)
Expand Down Expand Up @@ -41,7 +41,7 @@
% Target stopping tolerance for Nested Sampler (Default: 0.1)
nsTolerance = 0.1

% Total number of samples for DREAM (Default: 50000)
% Total number of samples for DREAM (Default: 20000)
nSamples = 20000;
% Number of MCMC chains (Default: 10)
nChains = 10
Expand All @@ -53,40 +53,45 @@
adaptPCR = true;
end


properties (SetAccess = private, Hidden = true)
IPCFilePath = ''
end

%------------------------- Set and Get ------------------------------
methods
function obj = set.parallel(obj,val)
function set.parallel(obj,val)
message = sprintf('parallel must be a parallelOptions enum or one of the following strings (%s)', ...
strjoin(parallelOptions.values(), ', '));
obj.parallel = validateOption(val, 'parallelOptions', message).value;
end

function obj = set.procedure(obj,val)
function set.procedure(obj,val)
message = sprintf('procedure must be a procedures enum or one of the following strings (%s)', ...
strjoin(procedures.values(), ', '));
obj.procedure = validateOption(val, 'procedures', message).value;
end

function obj = set.calcSldDuringFit(obj,val)
function set.calcSldDuringFit(obj,val)
validateLogical(val, 'calcSldDuringFit must be logical ''true'' or ''false''');
obj.calcSldDuringFit = val;
end

function obj = set.display(obj,val)
function set.display(obj,val)
message = sprintf('display must be a displayOptions enum or one of the following strings (%s)', ...
strjoin(displayOptions.values(), ', '));
obj.display = validateOption(val, 'displayOptions', message).value;
end

function obj = set.updatePlotFreq(obj, val)
function set.updatePlotFreq(obj, val)
validateNumber(val, 'updatePlotFreq must be a number');
if val < 1
throw(exceptions.invalidValue('updatePlotFreq must be greater or equal to 1'));
end
obj.updatePlotFreq = val;
end

function obj = set.resampleParams(obj,val)
function set.resampleParams(obj,val)
if length(val) ~= 2
throw(exceptions.invalidValue('resampleParams must have length of 2'));
end
Expand All @@ -103,59 +108,59 @@
end

% Simplex control methods
function obj = set.xTolerance(obj, val)
function set.xTolerance(obj, val)
obj.xTolerance = validateNumber(val, 'xTolerance must be a number');
end

function obj = set.funcTolerance(obj, val)
function set.funcTolerance(obj, val)
obj.funcTolerance = validateNumber(val, 'funcTolerance must be a number');
end

function obj = set.maxFuncEvals(obj, val)
function set.maxFuncEvals(obj, val)
obj.maxFuncEvals = validateNumber(val, 'maxFuncEvals must be a number');
end

function obj = set.maxIterations(obj, val)
function set.maxIterations(obj, val)
obj.maxIterations = validateNumber(val, 'maxIterations must be a number');
end

% DE controls methods
function obj = set.populationSize(obj, val)
function set.populationSize(obj, val)
validateNumber(val, 'populationSize must be a number');
if val < 1
throw(exceptions.invalidValue('populationSize must be greater or equal to 1'));
end
obj.populationSize = val;
end

function obj = set.fWeight(obj,val)
function set.fWeight(obj,val)
obj.fWeight = validateNumber(val,'fWeight must be a number');
end

function obj = set.crossoverProbability(obj,val)
function set.crossoverProbability(obj,val)
validateNumber(val, 'crossoverProbability must be a number');
if (val < 0 || val > 1)
throw(exceptions.invalidValue('crossoverProbability must be between 0 and 1'));
end
obj.crossoverProbability = val;
end

function obj = set.strategy(obj,val)
function set.strategy(obj,val)
message = sprintf('strategy must be a searchStrategy enum or one of the following integers (%s)', ...
strjoin(string(searchStrategy.values()), ', '));

obj.strategy = validateOption(val, 'searchStrategy', message).value;
end

function obj = set.targetValue(obj,val)
function set.targetValue(obj,val)
validateNumber(val, 'targetValue must be a number');
if val < 1
throw(exceptions.invalidValue('targetValue must be greater or equal to 1'));
end
obj.targetValue = val;
end

function obj = set.numGenerations(obj, val)
function set.numGenerations(obj, val)
validateNumber(val, 'numGenerations value must be a number');
if val < 1
throw(exceptions.invalidValue('numGenerations must be greater or equal to 1'));
Expand All @@ -164,31 +169,31 @@
end

% NS control methods
function obj = set.nLive(obj, val)
function set.nLive(obj, val)
validateNumber(val, 'nLive must be a number');
if val < 1
throw(exceptions.invalidValue('nLive must be greater or equal to 1'));
end
obj.nLive = val;
end

function obj = set.nMCMC(obj, val)
function set.nMCMC(obj, val)
validateNumber(val, 'nMCMC must be a number');
if val < 0
throw(exceptions.invalidValue('nMCMC must be greater or equal than 0'));
end
obj.nMCMC = val;
end

function obj = set.propScale(obj, val)
function set.propScale(obj, val)
validateNumber(val, 'propScale must be a number');
if (val < 0 || val > 1)
throw(exceptions.invalidValue('propScale must be between 0 and 1'));
end
obj.propScale = val;
end

function obj = set.nsTolerance(obj,val)
function set.nsTolerance(obj,val)
validateNumber(val, 'nsTolerance must be a number ');
if val < 0
throw(exceptions.invalidValue('nsTolerance must be greater or equal to 0'));
Expand All @@ -197,45 +202,45 @@
end

% DREAM methods
function obj = set.nSamples(obj,val)
function set.nSamples(obj,val)
validateNumber(val, 'nSample must be a number ');
if val < 0
throw(exceptions.invalidValue('nSample must be greater or equal to 0'));
end
obj.nSamples = val;
end

function obj = set.nChains(obj,val)
function set.nChains(obj,val)
validateNumber(val, 'nChains must be a number ');
if (~(round(val) == val) || val <= 0 || isnan(val) || isinf(val))
throw(exceptions.invalidValue('nChains must be a finite integer greater than 0'));
end
obj.nChains = val;
end

function obj = set.jumpProbability(obj,val)
function set.jumpProbability(obj,val)
validateNumber(val, 'jumpProbability must be a number');
if (val < 0 || val > 1)
throw(exceptions.invalidValue('JumpProbability must be a fraction between 0 and 1'));
end
obj.jumpProbability = val;
end

function obj = set.pUnitGamma(obj,val)
function set.pUnitGamma(obj,val)
validateNumber(val, 'pUnitGamma must be a number');
if (val < 0 || val > 1)
throw(exceptions.invalidValue('pUnitGamma must be a fraction between 0 and 1'));
end
obj.pUnitGamma = val;
end

function obj = set.boundHandling(obj,val)
function set.boundHandling(obj,val)
message = sprintf('boundHandling must be a boundHandlingOptions enum or one of the following strings (%s)', ...
strjoin(boundHandlingOptions.values(), ', '));
obj.boundHandling = validateOption(val, 'boundHandlingOptions', message).value;
end

function obj = set.adaptPCR(obj,val)
function set.adaptPCR(obj,val)
validateLogical(val, 'adaptPCR must be logical ''true'' or ''false''');
obj.adaptPCR = val;
end
Expand Down Expand Up @@ -299,6 +304,38 @@
end

end

function obj = initialiseIPC(obj)
% Method setup the inter-process communication file.
%
% USAGE:
% obj.initialiseIPC()
obj.IPCFilePath = tempname();
fileID = fopen(obj.IPCFilePath, 'w');
fwrite(fileID, false, 'uchar');
fclose(fileID);
end

function path = getIPCFilePath(obj)
% Returns the path of the IPC file.
%
% USAGE:
% path = obj.getIPCFilePath()
path = obj.IPCFilePath;
end

function obj = sendStopEvent(obj)
% Sends the stop event via IPC file.
%
% USAGE:
% obj.sendStopEvent()
if isempty(obj.IPCFilePath)
return
end
fileID = fopen(obj.IPCFilePath, 'w');
fwrite(fileID, true, 'uchar');
fclose(fileID);
end
end

%------------------------- Display Methods --------------------------
Expand Down
2 changes: 1 addition & 1 deletion API/enums/coderEnums.m
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
supportedLanguages = supportedLanguages.toStruct()
end
methods (Static)
function initialize()
function initialise()
% initializes enum struct field as a coder const like so
% coder.const(coderEnums.procedures.Dream);
props = properties(coderEnums);
Expand Down
15 changes: 15 additions & 0 deletions API/events/isRATStopped.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
function state = isRATStopped(IPCFilePath)
% Checks if the stop event was set via the IPC file. The expected input
% is the path to the inter-process communication file
%
% stopped = isRATStopped(filePath);
if isempty(IPCFilePath)
state = false;
return
end

fileID = fopen(IPCFilePath);
state = logical(fread(fileID, 1, '*uchar'));
fclose(fileID);
end

Loading

0 comments on commit feb0c6d

Please sign in to comment.