From aef5072a8f92eb677921410bd662118f2030d84a Mon Sep 17 00:00:00 2001 From: Ben Bales Date: Thu, 18 Apr 2019 11:22:55 -0700 Subject: [PATCH 1/3] Added in switching adaptation --- src/cmdstan/arguments/arg_metric.hpp | 2 ++ src/cmdstan/arguments/arg_switching_e.hpp | 17 ++++++++++ src/cmdstan/command.hpp | 39 +++++++++++++++++++++-- stan | 2 +- 4 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 src/cmdstan/arguments/arg_switching_e.hpp diff --git a/src/cmdstan/arguments/arg_metric.hpp b/src/cmdstan/arguments/arg_metric.hpp index 6b5386a588..e42645e136 100644 --- a/src/cmdstan/arguments/arg_metric.hpp +++ b/src/cmdstan/arguments/arg_metric.hpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace cmdstan { @@ -17,6 +18,7 @@ namespace cmdstan { _values.push_back(new arg_unit_e()); _values.push_back(new arg_diag_e()); _values.push_back(new arg_dense_e()); + _values.push_back(new arg_switching_e()); _default_cursor = 1; _cursor = _default_cursor; diff --git a/src/cmdstan/arguments/arg_switching_e.hpp b/src/cmdstan/arguments/arg_switching_e.hpp new file mode 100644 index 0000000000..bc3242b344 --- /dev/null +++ b/src/cmdstan/arguments/arg_switching_e.hpp @@ -0,0 +1,17 @@ +#ifndef CMDSTAN_ARGUMENTS_ARG_SWITCHING_E_HPP +#define CMDSTAN_ARGUMENTS_ARG_SWITCHING_E_HPP + +#include + +namespace cmdstan { + + class arg_switching_e: public unvalued_argument { + public: + arg_switching_e() { + _name = "switching_e"; + _description = "Euclidean manifold with sparsity of metric determined at warmup"; + } + }; + +} +#endif diff --git a/src/cmdstan/command.hpp b/src/cmdstan/command.hpp index d5b0789aa9..3db5389c83 100644 --- a/src/cmdstan/command.hpp +++ b/src/cmdstan/command.hpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -345,7 +346,7 @@ namespace cmdstan { if (adapt_engaged == true && num_warmup == 0) { info("The number of warmup samples (num_warmup) must be greater than zero if adaptation is enabled."); return_code = stan::services::error_codes::CONFIG; - } else if (engine->value() == "nuts" && metric->value() == "dense_e" && adapt_engaged == false && metric_supplied == false) { + } else if (engine->value() == "nuts" && (metric->value() == "dense_e" || metric->value() == "switching_e") && adapt_engaged == false && metric_supplied == false) { int max_depth = dynamic_cast(dynamic_cast(algo->arg("hmc")->arg("engine")->arg("nuts"))->arg("max_depth"))->value(); return_code = stan::services::sample::hmc_nuts_dense_e(model, *init_context, @@ -365,7 +366,7 @@ namespace cmdstan { init_writer, sample_writer, diagnostic_writer); - } else if (engine->value() == "nuts" && metric->value() == "dense_e" && adapt_engaged == false && metric_supplied == true) { + } else if (engine->value() == "nuts" && (metric->value() == "dense_e" || metric->value() == "switching_e") && adapt_engaged == false && metric_supplied == true) { int max_depth = dynamic_cast(dynamic_cast(algo->arg("hmc")->arg("engine")->arg("nuts"))->arg("max_depth"))->value(); return_code = stan::services::sample::hmc_nuts_dense_e(model, *init_context, @@ -420,6 +421,40 @@ namespace cmdstan { init_writer, sample_writer, diagnostic_writer); + } else if (engine->value() == "nuts" && metric->value() == "switching_e" && adapt_engaged == true && metric_supplied == false) { + int max_depth = dynamic_cast(dynamic_cast(algo->arg("hmc")->arg("engine")->arg("nuts"))->arg("max_depth"))->value(); + double delta = dynamic_cast(adapt->arg("delta"))->value(); + double gamma = dynamic_cast(adapt->arg("gamma"))->value(); + double kappa = dynamic_cast(adapt->arg("kappa"))->value(); + double t0 = dynamic_cast(adapt->arg("t0"))->value(); + unsigned int init_buffer = dynamic_cast(adapt->arg("init_buffer"))->value(); + unsigned int term_buffer = dynamic_cast(adapt->arg("term_buffer"))->value(); + unsigned int window = dynamic_cast(adapt->arg("window"))->value(); + return_code = stan::services::sample::hmc_nuts_switching_e_adapt(model, + *init_context, + random_seed, + id, + init_radius, + num_warmup, + num_samples, + num_thin, + save_warmup, + refresh, + stepsize, + stepsize_jitter, + max_depth, + delta, + gamma, + kappa, + t0, + init_buffer, + term_buffer, + window, + interrupt, + logger, + init_writer, + sample_writer, + diagnostic_writer); } else if (engine->value() == "nuts" && metric->value() == "dense_e" && adapt_engaged == true && metric_supplied == true) { int max_depth = dynamic_cast(dynamic_cast(algo->arg("hmc")->arg("engine")->arg("nuts"))->arg("max_depth"))->value(); double delta = dynamic_cast(adapt->arg("delta"))->value(); diff --git a/stan b/stan index 9273ff8fb9..1952caf941 160000 --- a/stan +++ b/stan @@ -1 +1 @@ -Subproject commit 9273ff8fb9d6d5b4ffc2bceec5c75ef064edaa5b +Subproject commit 1952caf94101c363d7d7bcd530ffa59652d6d7f5 From 0b652a083622004fa20f3dc4eaa3cb8b7ea2dfed Mon Sep 17 00:00:00 2001 From: Ben Bales Date: Sat, 20 Apr 2019 15:32:14 -0700 Subject: [PATCH 2/3] Changed 'Switching' to 'Auto'. Pointed stan submodule at modified version of Stan in personal fork that has the new adaptation stuff --- .gitmodules | 2 +- src/cmdstan/arguments/arg_auto_e.hpp | 17 +++++++ src/cmdstan/arguments/arg_metric.hpp | 4 +- src/cmdstan/arguments/arg_switching_e.hpp | 17 ------- src/cmdstan/command.hpp | 58 +++++++++++------------ stan | 2 +- 6 files changed, 50 insertions(+), 50 deletions(-) create mode 100644 src/cmdstan/arguments/arg_auto_e.hpp delete mode 100644 src/cmdstan/arguments/arg_switching_e.hpp diff --git a/.gitmodules b/.gitmodules index 97ffbf03be..024d6c5791 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "stan"] path = stan - url = https://github.com/stan-dev/stan + url = https://github.com/bbbales2/stan diff --git a/src/cmdstan/arguments/arg_auto_e.hpp b/src/cmdstan/arguments/arg_auto_e.hpp new file mode 100644 index 0000000000..5cd7700622 --- /dev/null +++ b/src/cmdstan/arguments/arg_auto_e.hpp @@ -0,0 +1,17 @@ +#ifndef CMDSTAN_ARGUMENTS_ARG_AUTO_E_HPP +#define CMDSTAN_ARGUMENTS_ARG_AUTO_E_HPP + +#include + +namespace cmdstan { + + class arg_auto_e: public unvalued_argument { + public: + arg_auto_e() { + _name = "auto_e"; + _description = "Euclidean manifold that chooses between dense/diagonal metric at warmup"; + } + }; + +} +#endif diff --git a/src/cmdstan/arguments/arg_metric.hpp b/src/cmdstan/arguments/arg_metric.hpp index e42645e136..e60ce9fe70 100644 --- a/src/cmdstan/arguments/arg_metric.hpp +++ b/src/cmdstan/arguments/arg_metric.hpp @@ -5,7 +5,7 @@ #include #include #include -#include +#include namespace cmdstan { @@ -18,7 +18,7 @@ namespace cmdstan { _values.push_back(new arg_unit_e()); _values.push_back(new arg_diag_e()); _values.push_back(new arg_dense_e()); - _values.push_back(new arg_switching_e()); + _values.push_back(new arg_auto_e()); _default_cursor = 1; _cursor = _default_cursor; diff --git a/src/cmdstan/arguments/arg_switching_e.hpp b/src/cmdstan/arguments/arg_switching_e.hpp deleted file mode 100644 index bc3242b344..0000000000 --- a/src/cmdstan/arguments/arg_switching_e.hpp +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef CMDSTAN_ARGUMENTS_ARG_SWITCHING_E_HPP -#define CMDSTAN_ARGUMENTS_ARG_SWITCHING_E_HPP - -#include - -namespace cmdstan { - - class arg_switching_e: public unvalued_argument { - public: - arg_switching_e() { - _name = "switching_e"; - _description = "Euclidean manifold with sparsity of metric determined at warmup"; - } - }; - -} -#endif diff --git a/src/cmdstan/command.hpp b/src/cmdstan/command.hpp index 3db5389c83..450008bbcc 100644 --- a/src/cmdstan/command.hpp +++ b/src/cmdstan/command.hpp @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include #include @@ -346,7 +346,7 @@ namespace cmdstan { if (adapt_engaged == true && num_warmup == 0) { info("The number of warmup samples (num_warmup) must be greater than zero if adaptation is enabled."); return_code = stan::services::error_codes::CONFIG; - } else if (engine->value() == "nuts" && (metric->value() == "dense_e" || metric->value() == "switching_e") && adapt_engaged == false && metric_supplied == false) { + } else if (engine->value() == "nuts" && (metric->value() == "dense_e" || metric->value() == "auto_e") && adapt_engaged == false && metric_supplied == false) { int max_depth = dynamic_cast(dynamic_cast(algo->arg("hmc")->arg("engine")->arg("nuts"))->arg("max_depth"))->value(); return_code = stan::services::sample::hmc_nuts_dense_e(model, *init_context, @@ -366,7 +366,7 @@ namespace cmdstan { init_writer, sample_writer, diagnostic_writer); - } else if (engine->value() == "nuts" && (metric->value() == "dense_e" || metric->value() == "switching_e") && adapt_engaged == false && metric_supplied == true) { + } else if (engine->value() == "nuts" && (metric->value() == "dense_e" || metric->value() == "auto_e") && adapt_engaged == false && metric_supplied == true) { int max_depth = dynamic_cast(dynamic_cast(algo->arg("hmc")->arg("engine")->arg("nuts"))->arg("max_depth"))->value(); return_code = stan::services::sample::hmc_nuts_dense_e(model, *init_context, @@ -421,7 +421,7 @@ namespace cmdstan { init_writer, sample_writer, diagnostic_writer); - } else if (engine->value() == "nuts" && metric->value() == "switching_e" && adapt_engaged == true && metric_supplied == false) { + } else if (engine->value() == "nuts" && metric->value() == "auto_e" && adapt_engaged == true && metric_supplied == false) { int max_depth = dynamic_cast(dynamic_cast(algo->arg("hmc")->arg("engine")->arg("nuts"))->arg("max_depth"))->value(); double delta = dynamic_cast(adapt->arg("delta"))->value(); double gamma = dynamic_cast(adapt->arg("gamma"))->value(); @@ -430,31 +430,31 @@ namespace cmdstan { unsigned int init_buffer = dynamic_cast(adapt->arg("init_buffer"))->value(); unsigned int term_buffer = dynamic_cast(adapt->arg("term_buffer"))->value(); unsigned int window = dynamic_cast(adapt->arg("window"))->value(); - return_code = stan::services::sample::hmc_nuts_switching_e_adapt(model, - *init_context, - random_seed, - id, - init_radius, - num_warmup, - num_samples, - num_thin, - save_warmup, - refresh, - stepsize, - stepsize_jitter, - max_depth, - delta, - gamma, - kappa, - t0, - init_buffer, - term_buffer, - window, - interrupt, - logger, - init_writer, - sample_writer, - diagnostic_writer); + return_code = stan::services::sample::hmc_nuts_auto_e_adapt(model, + *init_context, + random_seed, + id, + init_radius, + num_warmup, + num_samples, + num_thin, + save_warmup, + refresh, + stepsize, + stepsize_jitter, + max_depth, + delta, + gamma, + kappa, + t0, + init_buffer, + term_buffer, + window, + interrupt, + logger, + init_writer, + sample_writer, + diagnostic_writer); } else if (engine->value() == "nuts" && metric->value() == "dense_e" && adapt_engaged == true && metric_supplied == true) { int max_depth = dynamic_cast(dynamic_cast(algo->arg("hmc")->arg("engine")->arg("nuts"))->arg("max_depth"))->value(); double delta = dynamic_cast(adapt->arg("delta"))->value(); diff --git a/stan b/stan index 1952caf941..8000691480 160000 --- a/stan +++ b/stan @@ -1 +1 @@ -Subproject commit 1952caf94101c363d7d7bcd530ffa59652d6d7f5 +Subproject commit 800069148021c77f173248da14af273ecd3a11c9 From f1fc292801343ca3a8e0ac438f2d3bbe71e8963b Mon Sep 17 00:00:00 2001 From: Ben Date: Wed, 11 Sep 2019 16:10:00 -0400 Subject: [PATCH 3/3] Updated submodule --- stan | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stan b/stan index a6ae242485..88225d08e3 160000 --- a/stan +++ b/stan @@ -1 +1 @@ -Subproject commit a6ae2424858628874a03973230edfa23faf5414f +Subproject commit 88225d08e30636073a310e9cc4eefc97e5212dbf