Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-introduce sampling of the derivative into CLI #619

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions src/storm-pars-cli/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,76 @@ struct SampleInformation {
};

template<template<typename, typename> class ModelCheckerType, typename ModelType, typename ValueType, typename SolveValueType = double>
void verifyPropertiesAtSamplePointsDerivative(ModelType const& model, cli::SymbolicInput const& input, SampleInformation<ValueType> const& samples) {
// When samples are provided, we create an instantiation model checker.
ModelCheckerType<ValueType, SolveValueType> modelchecker(model);

for (auto const& property : input.properties) {
storm::cli::printModelCheckingProperty(property);

modelchecker.specifyFormula(Environment(), storm::api::createTask<ValueType>(property.getRawFormula(), true));

storm::utility::parametric::Valuation<ValueType> valuation;

std::vector<typename storm::utility::parametric::VariableType<ValueType>::type> parameters;
std::vector<typename std::vector<typename storm::utility::parametric::CoefficientType<ValueType>::type>::const_iterator> iterators;
std::vector<typename std::vector<typename storm::utility::parametric::CoefficientType<ValueType>::type>::const_iterator> iteratorEnds;

storm::utility::Stopwatch watch(true);
for (auto const& product : samples.cartesianProducts) {
parameters.clear();
iterators.clear();
iteratorEnds.clear();

for (auto const& entry : product) {
parameters.push_back(entry.first);
iterators.push_back(entry.second.cbegin());
iteratorEnds.push_back(entry.second.cend());
}

bool done = false;
while (!done) {
// Read off valuation.
for (uint64_t i = 0; i < parameters.size(); ++i) {
valuation[parameters[i]] = *iterators[i];
}

for (auto const& parameter : parameters) {
storm::utility::Stopwatch valuationWatch(true);
std::unique_ptr<storm::modelchecker::CheckResult> result = modelchecker.check(Environment(), valuation, parameter);
valuationWatch.stop();

if (result) {
result->filter(storm::modelchecker::ExplicitQualitativeCheckResult(model.getInitialStates()));
}
STORM_PRINT_AND_LOG("Derivative w.r.t. " << parameter << ":\n");
printInitialStatesResult<ValueType>(result, &valuationWatch, &valuation);
}

for (uint64_t i = 0; i < parameters.size(); ++i) {
++iterators[i];
if (iterators[i] == iteratorEnds[i]) {
// Reset iterator and proceed to move next iterator.
iterators[i] = product.at(parameters[i]).cbegin();

// If the last iterator was removed, we are done.
if (i == parameters.size() - 1) {
done = true;
}
} else {
// If an iterator was moved but not reset, we have another valuation to check.
break;
}
}
}
}

watch.stop();
STORM_PRINT_AND_LOG("Overall time for sampling all instances: " << watch << "\n\n");
}
}

template<template<typename, typename> class ModelCheckerType, typename ModelType, typename ValueType, typename SolveValueType = double, bool Derivative = false>
void verifyPropertiesAtSamplePoints(ModelType const& model, cli::SymbolicInput const& input, SampleInformation<ValueType> const& samples) {
// When samples are provided, we create an instantiation model checker.
ModelCheckerType<ModelType, SolveValueType> modelchecker(model);
Expand Down Expand Up @@ -142,6 +212,17 @@ void verifyPropertiesAtSamplePoints(ModelType const& model, cli::SymbolicInput c
}
}

template<typename ValueType, typename SolveValueType = double>
void verifyPropertiesAtSamplePointsWithSparseEngineDerivatives(std::shared_ptr<storm::models::sparse::Model<ValueType>> const& model,
cli::SymbolicInput const& input, SampleInformation<ValueType> const& samples) {
if (model->isOfType(storm::models::ModelType::Dtmc)) {
verifyPropertiesAtSamplePointsDerivative<storm::derivative::SparseDerivativeInstantiationModelChecker, storm::models::sparse::Dtmc<ValueType>,
ValueType, SolveValueType>(*model->template as<storm::models::sparse::Dtmc<ValueType>>(), input, samples);
} else {
STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Sampling the derivative is currently only supported for DTMCs.");
}
}

template<typename ValueType, typename SolveValueType = double>
void verifyPropertiesAtSamplePointsWithSparseEngine(std::shared_ptr<storm::models::sparse::Model<ValueType>> const& model, cli::SymbolicInput const& input,
SampleInformation<ValueType> const& samples) {
Expand Down
16 changes: 13 additions & 3 deletions src/storm-pars-cli/storm-pars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,11 +510,21 @@ void processInputWithValueTypeAndDdlib(cli::SymbolicInput& input, storm::cli::Mo
if (!samples.empty()) {
STORM_LOG_TRACE("Sampling the model at given points.");

if (samples.exact) {
verifyPropertiesAtSamplePointsWithSparseEngine<ValueType, storm::RationalNumber>(model->as<storm::models::sparse::Model<ValueType>>(), input,
if (sampleSettings.isSampleDerivativeSet()) {
if (samples.exact) {
verifyPropertiesAtSamplePointsWithSparseEngineDerivatives<ValueType, storm::RationalNumber>(
model->as<storm::models::sparse::Model<ValueType>>(), input, samples);
} else {
verifyPropertiesAtSamplePointsWithSparseEngineDerivatives<ValueType, double>(model->as<storm::models::sparse::Model<ValueType>>(), input,
samples);
}
} else {
verifyPropertiesAtSamplePointsWithSparseEngine<ValueType, double>(model->as<storm::models::sparse::Model<ValueType>>(), input, samples);
if (samples.exact) {
verifyPropertiesAtSamplePointsWithSparseEngine<ValueType, storm::RationalNumber>(model->as<storm::models::sparse::Model<ValueType>>(),
input, samples);
} else {
verifyPropertiesAtSamplePointsWithSparseEngine<ValueType, double>(model->as<storm::models::sparse::Model<ValueType>>(), input, samples);
}
}
}
} else {
Expand Down
14 changes: 0 additions & 14 deletions src/storm-pars/settings/modules/DerivativeSettings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ namespace modules {

const std::string DerivativeSettings::moduleName = "derivative";
const std::string DerivativeSettings::feasibleInstantiationSearch = "gradient-descent";
const std::string DerivativeSettings::derivativeAtInstantiation = "compute-derivative";
const std::string DerivativeSettings::learningRate = "learning-rate";
const std::string DerivativeSettings::miniBatchSize = "batch-size";
const std::string DerivativeSettings::adamParams = "adam-params";
Expand All @@ -32,11 +31,6 @@ DerivativeSettings::DerivativeSettings() : ModuleSettings(moduleName) {
this->addOption(storm::settings::OptionBuilder(moduleName, feasibleInstantiationSearch, false,
"Search for a feasible instantiation (restart with new instantiation while not feasible)")
.build());
this->addOption(storm::settings::OptionBuilder(moduleName, derivativeAtInstantiation, false, "Compute the derivative at an input instantiation")
.addArgument(storm::settings::ArgumentBuilder::createStringArgument(derivativeAtInstantiation,
"Instantiation at which the derivative should be computed")
.build())
.build());
this->addOption(storm::settings::OptionBuilder(moduleName, learningRate, false, "Sets the learning rate of gradient descent")
.addArgument(storm::settings::ArgumentBuilder::createDoubleArgument(learningRate, "The learning rate of the gradient descent")
.setDefaultValueDouble(0.1)
Expand Down Expand Up @@ -90,14 +84,6 @@ bool DerivativeSettings::isFeasibleInstantiationSearchSet() const {
return this->getOption(feasibleInstantiationSearch).getHasOptionBeenSet();
}

boost::optional<std::string> DerivativeSettings::getDerivativeAtInstantiation() const {
if (this->getOption(derivativeAtInstantiation).getHasOptionBeenSet()) {
return this->getOption(derivativeAtInstantiation).getArgumentByName(derivativeAtInstantiation).getValueAsString();
} else {
return boost::none;
}
}

double DerivativeSettings::getLearningRate() const {
return this->getOption(learningRate).getArgumentByName(learningRate).getValueAsDouble();
}
Expand Down
5 changes: 0 additions & 5 deletions src/storm-pars/settings/modules/DerivativeSettings.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ class DerivativeSettings : public ModuleSettings {
*/
bool isFeasibleInstantiationSearchSet() const;

/*!
* Retrieves whether an extremum should be found by Gradient Descent.
*/
boost::optional<std::string> getDerivativeAtInstantiation() const;

/*!
* Retrieves the learning rate for the gradient descent.
*/
Expand Down
6 changes: 6 additions & 0 deletions src/storm-pars/settings/modules/SamplingSettings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const std::string SamplingSettings::moduleName = "sampling";
const std::string samplesOptionName = "samples";
const std::string samplesGraphPreservingOptionName = "samples-graph-preserving";
const std::string sampleExactOptionName = "sample-exact";
const std::string sampleDerivativeOptionName = "sample-derivative";

SamplingSettings::SamplingSettings() : ModuleSettings(moduleName) {
this->addOption(
Expand All @@ -25,6 +26,7 @@ SamplingSettings::SamplingSettings() : ModuleSettings(moduleName) {
"Sets whether it can be assumed that the samples are graph-preserving.")
.build());
this->addOption(storm::settings::OptionBuilder(moduleName, sampleExactOptionName, false, "Sets whether to sample using exact arithmetic.").build());
this->addOption(storm::settings::OptionBuilder(moduleName, sampleDerivativeOptionName, false, "Sets whether to sample the derivatives instead..").build());
}

std::string SamplingSettings::getSamples() const {
Expand All @@ -38,4 +40,8 @@ bool SamplingSettings::isSamplesAreGraphPreservingSet() const {
bool SamplingSettings::isSampleExactSet() const {
return this->getOption(sampleExactOptionName).getHasOptionBeenSet();
}

bool SamplingSettings::isSampleDerivativeSet() const {
return this->getOption(sampleDerivativeOptionName).getHasOptionBeenSet();
}
} // namespace storm::settings::modules
5 changes: 5 additions & 0 deletions src/storm-pars/settings/modules/SamplingSettings.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ class SamplingSettings : public ModuleSettings {
*/
bool isSampleExactSet() const;

/*!
* Retrieves whether samples are to be from the derivative.
*/
bool isSampleDerivativeSet() const;

static const std::string moduleName;
};

Expand Down
Loading