Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/issue 2814 warmup auto #729

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "stan"]
path = stan
url = https://github.com/stan-dev/stan
url = https://github.com/bbbales2/stan
17 changes: 17 additions & 0 deletions src/cmdstan/arguments/arg_auto_e.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef CMDSTAN_ARGUMENTS_ARG_AUTO_E_HPP
#define CMDSTAN_ARGUMENTS_ARG_AUTO_E_HPP

#include <cmdstan/arguments/unvalued_argument.hpp>

namespace cmdstan {

class arg_auto_e: public unvalued_argument {
public:
arg_auto_e() {
_name = "auto_e";
_description = "Euclidean manifold that chooses between dense/diagonal metric at warmup";
}
};

}
#endif
2 changes: 2 additions & 0 deletions src/cmdstan/arguments/arg_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define CMDSTAN_ARGUMENTS_ARG_METRIC_HPP

#include <cmdstan/arguments/arg_dense_e.hpp>
#include <cmdstan/arguments/arg_auto_e.hpp>
#include <cmdstan/arguments/arg_diag_e.hpp>
#include <cmdstan/arguments/arg_unit_e.hpp>
#include <cmdstan/arguments/list_argument.hpp>
Expand All @@ -17,6 +18,7 @@ class arg_metric : public list_argument {
_values.push_back(new arg_unit_e());
_values.push_back(new arg_diag_e());
_values.push_back(new arg_dense_e());
_values.push_back(new arg_auto_e());

_default_cursor = 1;
_cursor = _default_cursor;
Expand Down
64 changes: 61 additions & 3 deletions src/cmdstan/command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <stan/services/optimize/lbfgs.hpp>
#include <stan/services/optimize/newton.hpp>
#include <stan/services/sample/fixed_param.hpp>
#include <stan/services/sample/hmc_nuts_auto_e_adapt.hpp>
#include <stan/services/sample/hmc_nuts_dense_e.hpp>
#include <stan/services/sample/hmc_nuts_dense_e_adapt.hpp>
#include <stan/services/sample/hmc_nuts_diag_e.hpp>
Expand Down Expand Up @@ -424,7 +425,8 @@ int command(int argc, const char *argv[]) {
"The number of warmup samples (num_warmup) must be greater than "
"zero if adaptation is enabled.");
return_code = stan::services::error_codes::CONFIG;
} else if (engine->value() == "nuts" && metric->value() == "dense_e"

} else if (engine->value() == "nuts" && (metric->value() == "dense_e" || metric->value() == "auto_e")
&& adapt_engaged == false && metric_supplied == false) {
int max_depth = dynamic_cast<int_argument *>(
dynamic_cast<categorical_argument *>(
Expand All @@ -436,7 +438,7 @@ int command(int argc, const char *argv[]) {
num_samples, num_thin, save_warmup, refresh, stepsize,
stepsize_jitter, max_depth, interrupt, logger, init_writer,
sample_writer, diagnostic_writer);
} else if (engine->value() == "nuts" && metric->value() == "dense_e"
} else if (engine->value() == "nuts" && (metric->value() == "dense_e" || metric->value() == "auto_e")
&& adapt_engaged == false && metric_supplied == true) {
int max_depth = dynamic_cast<int_argument *>(
dynamic_cast<categorical_argument *>(
Expand Down Expand Up @@ -504,7 +506,63 @@ int command(int argc, const char *argv[]) {
stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer,
term_buffer, window, interrupt, logger, init_writer, sample_writer,
diagnostic_writer);
} else if (engine->value() == "nuts" && metric->value() == "diag_e"
} else if (engine->value() == "nuts" && metric->value() == "auto_e"
&& adapt_engaged == true && metric_supplied == false) {
int max_depth = dynamic_cast<int_argument *>(
dynamic_cast<categorical_argument *>(
algo->arg("hmc")->arg("engine")->arg("nuts"))
->arg("max_depth"))
->value();
double delta
= dynamic_cast<real_argument *>(adapt->arg("delta"))->value();
double gamma
= dynamic_cast<real_argument *>(adapt->arg("gamma"))->value();
double kappa
= dynamic_cast<real_argument *>(adapt->arg("kappa"))->value();
double t0 = dynamic_cast<real_argument *>(adapt->arg("t0"))->value();
unsigned int init_buffer
= dynamic_cast<u_int_argument *>(adapt->arg("init_buffer"))
->value();
unsigned int term_buffer
= dynamic_cast<u_int_argument *>(adapt->arg("term_buffer"))
->value();
unsigned int window
= dynamic_cast<u_int_argument *>(adapt->arg("window"))->value();
return_code = stan::services::sample::hmc_nuts_auto_e_adapt(
model, *init_context, random_seed, id, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize,
stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer,
term_buffer, window, interrupt, logger, init_writer, sample_writer,
diagnostic_writer);
} else if (engine->value() == "nuts" && metric->value() == "auto_e"
&& adapt_engaged == true && metric_supplied == true) {
int max_depth = dynamic_cast<int_argument *>(
dynamic_cast<categorical_argument *>(
algo->arg("hmc")->arg("engine")->arg("nuts"))
->arg("max_depth"))
->value();
double delta
= dynamic_cast<real_argument *>(adapt->arg("delta"))->value();
double gamma
= dynamic_cast<real_argument *>(adapt->arg("gamma"))->value();
double kappa
= dynamic_cast<real_argument *>(adapt->arg("kappa"))->value();
double t0 = dynamic_cast<real_argument *>(adapt->arg("t0"))->value();
unsigned int init_buffer
= dynamic_cast<u_int_argument *>(adapt->arg("init_buffer"))
->value();
unsigned int term_buffer
= dynamic_cast<u_int_argument *>(adapt->arg("term_buffer"))
->value();
unsigned int window
= dynamic_cast<u_int_argument *>(adapt->arg("window"))->value();
return_code = stan::services::sample::hmc_nuts_auto_e_adapt(
model, *init_context, *metric_context, random_seed, id, init_radius,
num_warmup, num_samples, num_thin, save_warmup, refresh, stepsize,
stepsize_jitter, max_depth, delta, gamma, kappa, t0, init_buffer,
term_buffer, window, interrupt, logger, init_writer, sample_writer,
diagnostic_writer);
} else if (engine->value() == "nuts" && metric->value() == "diag_e"
&& adapt_engaged == false && metric_supplied == false) {
categorical_argument *base = dynamic_cast<categorical_argument *>(
algo->arg("hmc")->arg("engine")->arg("nuts"));
Expand Down
2 changes: 1 addition & 1 deletion stan
Submodule stan updated 32 files
+4 −3 src/stan/mcmc/hmc/hamiltonians/dense_e_metric.hpp
+27 −4 src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp
+4 −4 src/stan/mcmc/hmc/hamiltonians/diag_e_metric.hpp
+15 −2 src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp
+6 −2 src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp
+6 −2 src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
+4 −4 src/stan/mcmc/hmc/nuts/base_nuts.hpp
+6 −2 src/stan/mcmc/hmc/nuts_classic/adapt_dense_e_nuts_classic.hpp
+6 −2 src/stan/mcmc/hmc/nuts_classic/adapt_diag_e_nuts_classic.hpp
+4 −3 src/stan/mcmc/hmc/nuts_classic/dense_e_nuts_classic.hpp
+5 −3 src/stan/mcmc/hmc/nuts_classic/diag_e_nuts_classic.hpp
+6 −3 src/stan/mcmc/hmc/static/adapt_dense_e_static_hmc.hpp
+6 −3 src/stan/mcmc/hmc/static/adapt_diag_e_static_hmc.hpp
+4 −4 src/stan/mcmc/hmc/static/base_static_hmc.hpp
+7 −2 src/stan/mcmc/hmc/static_uniform/adapt_dense_e_static_uniform.hpp
+7 −2 src/stan/mcmc/hmc/static_uniform/adapt_diag_e_static_uniform.hpp
+6 −2 src/stan/mcmc/hmc/xhmc/adapt_dense_e_xhmc.hpp
+6 −2 src/stan/mcmc/hmc/xhmc/adapt_diag_e_xhmc.hpp
+1 −1 src/stan/services/sample/hmc_nuts_dense_e.hpp
+1 −1 src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp
+1 −1 src/stan/services/sample/hmc_nuts_diag_e.hpp
+1 −1 src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp
+1 −1 src/stan/services/sample/hmc_static_dense_e.hpp
+1 −1 src/stan/services/sample/hmc_static_dense_e_adapt.hpp
+1 −1 src/stan/services/sample/hmc_static_diag_e.hpp
+1 −1 src/stan/services/sample/hmc_static_diag_e_adapt.hpp
+1 −1 src/test/unit/mcmc/hmc/hamiltonians/dense_e_metric_test.cpp
+53 −0 src/test/unit/mcmc/hmc/hamiltonians/dense_e_point_test.cpp
+32 −0 src/test/unit/mcmc/hmc/hamiltonians/diag_e_point_test.cpp
+3 −1 src/test/unit/mcmc/hmc/integrators/expl_leapfrog_test.cpp
+3 −1 src/test/unit/mcmc/hmc/integrators/impl_leapfrog_test.cpp
+8 −8 src/test/unit/mcmc/hmc/nuts/derived_nuts_test.cpp