Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't do inverse metric decomposition every draw #2894

Open
wants to merge 26 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
45fff72
Made it so that the decomposition of the inverse dense metric is done…
bbbales2 Mar 6, 2020
7582cce
Applying suggested changes from review (Issue #2881)
bbbales2 Mar 15, 2020
548f992
Merge remote-tracking branch 'origin/develop' into feature/issue-2881…
bbbales2 Jul 15, 2020
83ef803
Added unit tests to diag_e_point/diag_e_metric interface (Issue #2881)
bbbales2 Jul 16, 2020
d4dd953
Merge branch 'develop' into feature/issue-2881-dense-metric-decomposi…
bbbales2 Oct 28, 2020
1da6121
Converted some other samplers over to use accessor functions on dense…
bbbales2 Oct 28, 2020
a182f78
Removed more direct metric accesses (Issue #2881)
bbbales2 Oct 29, 2020
6f165f9
Fixed lint issues (Issue #2881)
bbbales2 Oct 29, 2020
ed71201
Merge commit '84535f53ef5e75029273f063f753663aa74d85a9' into HEAD
yashikno Oct 29, 2020
7889b27
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Oct 29, 2020
f8b5be7
Fixed inv_metric access in test (Issue #2881)
bbbales2 Oct 30, 2020
a67e9e5
Fixed test (Issue #2881)
bbbales2 Oct 31, 2020
a8e7500
Fixed order of updated inverse metric with stepsize (Issue #2881)
bbbales2 Nov 2, 2020
c4f30f6
Merge remote-tracking branch 'origin/develop' into feature/issue-2881…
bbbales2 Nov 28, 2020
cfdb3a2
Merge remote-tracking branch 'origin/develop' into feature/issue-2881…
bbbales2 Mar 9, 2021
6925d77
Updated to use forwarding (Issue #2881)
bbbales2 Mar 9, 2021
13d139f
Only close over rng (Issue #2881)
bbbales2 Mar 9, 2021
83f551f
Fix universal references (Issue #2881)
bbbales2 Mar 15, 2021
6482ba4
Merge commit 'f915220425003cfab85185fc0b6bf45b485545a8' into HEAD
yashikno Mar 15, 2021
2977426
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 15, 2021
be91f7c
Merge remote-tracking branch 'origin/develop' into feature/issue-2881…
bbbales2 Mar 15, 2021
7072224
Merge branch 'feature/issue-2881-dense-metric-decomposition' of githu…
bbbales2 Mar 15, 2021
479783e
Merge remote-tracking branch 'origin/develop' into feature/issue-2881…
bbbales2 Mar 21, 2021
9093b0f
Merge remote-tracking branch 'origin/develop' into feature/issue-2881…
bbbales2 Apr 3, 2021
fd36a26
Store matrixL of llt instead of not matrixU
bbbales2 Apr 3, 2021
1c610a2
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/stan/mcmc/hmc/hamiltonians/dense_e_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class dense_e_metric : public base_hamiltonian<Model, dense_e_point, BaseRNG> {
: base_hamiltonian<Model, dense_e_point, BaseRNG>(model) {}

double T(dense_e_point& z) {
return 0.5 * z.p.transpose() * z.inv_e_metric_ * z.p;
return 0.5 * z.p.transpose() * z.get_inv_metric() * z.p;
}

double tau(dense_e_point& z) { return T(z); }
Expand All @@ -35,7 +35,7 @@ class dense_e_metric : public base_hamiltonian<Model, dense_e_point, BaseRNG> {
return Eigen::VectorXd::Zero(this->model_.num_params_r());
}

Eigen::VectorXd dtau_dp(dense_e_point& z) { return z.inv_e_metric_ * z.p; }
Eigen::VectorXd dtau_dp(dense_e_point& z) { return z.get_inv_metric() * z.p; }

Eigen::VectorXd dphi_dq(dense_e_point& z, callbacks::logger& logger) {
return z.g;
Expand All @@ -51,7 +51,7 @@ class dense_e_metric : public base_hamiltonian<Model, dense_e_point, BaseRNG> {
for (idx_t i = 0; i < u.size(); ++i)
u(i) = rand_dense_gaus();

z.p = z.inv_e_metric_.llt().matrixU().solve(u);
z.p = z.get_transpose_llt_inv_metric().triangularView<Eigen::Upper>().solve(u);
}
};

Expand Down
33 changes: 29 additions & 4 deletions src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,54 @@ namespace mcmc {
* Euclidean manifold with dense metric
*/
class dense_e_point : public ps_point {
public:
private:
/**
* Inverse mass matrix.
*/
Eigen::MatrixXd inv_e_metric_;
Eigen::MatrixXd inv_e_metric_llt_matrixU_;
public:

/**
* Construct a dense point in n-dimensional phase space
* with identity matrix as inverse mass matrix.
*
* @param n number of dimensions
*/
explicit dense_e_point(int n) : ps_point(n), inv_e_metric_(n, n) {
explicit dense_e_point(int n) : ps_point(n),
inv_e_metric_(n, n),
inv_e_metric_llt_matrixU_(n, n) {
inv_e_metric_.setIdentity();
inv_e_metric_llt_matrixU_.setIdentity();
}

/**
* Set elements of mass matrix
* Set inverse metric
*
* @param inv_e_metric initial mass matrix
*/
void set_metric(const Eigen::MatrixXd& inv_e_metric) {
void set_inv_metric(const Eigen::MatrixXd& inv_e_metric) {
SteveBronder marked this conversation as resolved.
Show resolved Hide resolved
inv_e_metric_ = inv_e_metric;
inv_e_metric_llt_matrixU_ = inv_e_metric_.llt().matrixU();
}
bbbales2 marked this conversation as resolved.
Show resolved Hide resolved

/**
* Get inverse metric
*
* @return reference to the inverse metric
*/
const Eigen::MatrixXd& get_inv_metric() const {
return inv_e_metric_;
}

/**
* Get the transpose of the lower Cholesky factor
* of the inverse metric
*
* @return reference to transpose of Cholesky factor
*/
const Eigen::MatrixXd& get_transpose_llt_inv_metric() const {
return inv_e_metric_llt_matrixU_;
}

/**
Expand Down
9 changes: 5 additions & 4 deletions src/stan/mcmc/hmc/hamiltonians/diag_e_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class diag_e_metric : public base_hamiltonian<Model, diag_e_point, BaseRNG> {
: base_hamiltonian<Model, diag_e_point, BaseRNG>(model) {}

double T(diag_e_point& z) {
return 0.5 * z.p.dot(z.inv_e_metric_.cwiseProduct(z.p));
return 0.5 * z.p.dot(z.get_inv_metric().cwiseProduct(z.p));
}

double tau(diag_e_point& z) { return T(z); }
Expand All @@ -34,7 +34,7 @@ class diag_e_metric : public base_hamiltonian<Model, diag_e_point, BaseRNG> {
}

Eigen::VectorXd dtau_dp(diag_e_point& z) {
return z.inv_e_metric_.cwiseProduct(z.p);
return z.get_inv_metric().cwiseProduct(z.p);
}

Eigen::VectorXd dphi_dq(diag_e_point& z, callbacks::logger& logger) {
Expand All @@ -45,8 +45,9 @@ class diag_e_metric : public base_hamiltonian<Model, diag_e_point, BaseRNG> {
boost::variate_generator<BaseRNG&, boost::normal_distribution<> >
rand_diag_gaus(rng, boost::normal_distribution<>());

for (int i = 0; i < z.p.size(); ++i)
z.p(i) = rand_diag_gaus() / sqrt(z.inv_e_metric_(i));
z.p = z.get_inv_metric().unaryExpr([&](auto&& x) {
return rand_diag_gaus() / sqrt(x);
});
}
};

Expand Down
14 changes: 12 additions & 2 deletions src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ namespace mcmc {
* Euclidean manifold with diagonal metric
*/
class diag_e_point : public ps_point {
public:
private:
/**
* Vector of diagonal elements of inverse mass matrix.
*/
Eigen::VectorXd inv_e_metric_;

public:
/**
* Construct a diag point in n-dimensional phase space
* with vector of ones for diagonal elements of inverse mass matrix.
Expand All @@ -32,10 +33,19 @@ class diag_e_point : public ps_point {
*
* @param inv_e_metric initial mass matrix
*/
void set_metric(const Eigen::VectorXd& inv_e_metric) {
void set_inv_metric(const Eigen::VectorXd& inv_e_metric) {
inv_e_metric_ = inv_e_metric;
}

/**
* Get inverse metric
*
* @return reference to the inverse metric
*/
const Eigen::VectorXd& get_inv_metric() const {
return inv_e_metric_;
}

/**
* Write elements of mass matrix to string and handoff to writer.
*
Expand Down
7 changes: 5 additions & 2 deletions src/stan/mcmc/hmc/nuts/adapt_dense_e_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ class adapt_dense_e_nuts : public dense_e_nuts<Model, BaseRNG>,
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
s.accept_stat());

bool update = this->covar_adaptation_.learn_covariance(
this->z_.inv_e_metric_, this->z_.q);
Eigen::MatrixXd inv_metric;

bool update = this->covar_adaptation_.learn_covariance(inv_metric, this->z_.q);
bbbales2 marked this conversation as resolved.
Show resolved Hide resolved

if (update) {
this->init_stepsize(logger);

this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
}
Expand Down
9 changes: 6 additions & 3 deletions src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ class adapt_diag_e_nuts : public diag_e_nuts<Model, BaseRNG>,
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
s.accept_stat());

bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_,
this->z_.q);

Eigen::VectorXd inv_metric;

bool update = this->var_adaptation_.learn_variance(inv_metric, this->z_.q);

if (update) {
this->init_stepsize(logger);

this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
}
Expand Down
8 changes: 4 additions & 4 deletions src/stan/mcmc/hmc/nuts/base_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ class base_nuts : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {

~base_nuts() {}

void set_metric(const Eigen::MatrixXd& inv_e_metric) {
this->z_.set_metric(inv_e_metric);
void set_inv_metric(const Eigen::MatrixXd& inv_e_metric) {
this->z_.set_inv_metric(inv_e_metric);
}

void set_metric(const Eigen::VectorXd& inv_e_metric) {
this->z_.set_metric(inv_e_metric);
void set_inv_metric(const Eigen::VectorXd& inv_e_metric) {
this->z_.set_inv_metric(inv_e_metric);
}

void set_max_depth(int d) {
Expand Down
6 changes: 4 additions & 2 deletions src/stan/mcmc/hmc/static/adapt_dense_e_static_hmc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ class adapt_dense_e_static_hmc : public dense_e_static_hmc<Model, BaseRNG>,
s.accept_stat());
this->update_L_();

bool update = this->covar_adaptation_.learn_covariance(
this->z_.inv_e_metric_, this->z_.q);
Eigen::MatrixXd inv_metric;

bool update = this->covar_adaptation_.learn_covariance(inv_metric, this->z_.q);

if (update) {
this->init_stepsize(logger);
this->update_L_();
this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
Expand Down
5 changes: 4 additions & 1 deletion src/stan/mcmc/hmc/static/adapt_diag_e_static_hmc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ class adapt_diag_e_static_hmc : public diag_e_static_hmc<Model, BaseRNG>,
s.accept_stat());
this->update_L_();

bool update = this->var_adaptation_.learn_variance(this->z_.inv_e_metric_,
Eigen::VectorXd inv_metric;

bool update = this->var_adaptation_.learn_variance(inv_metric,
this->z_.q);

if (update) {
this->init_stepsize(logger);
this->update_L_();
this->z_.set_inv_metric(inv_metric);

this->stepsize_adaptation_.set_mu(log(10 * this->nom_epsilon_));
this->stepsize_adaptation_.restart();
Expand Down
8 changes: 4 additions & 4 deletions src/stan/mcmc/hmc/static/base_static_hmc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class base_static_hmc

~base_static_hmc() {}

void set_metric(const Eigen::MatrixXd& inv_e_metric) {
this->z_.set_metric(inv_e_metric);
void set_inv_metric(const Eigen::MatrixXd& inv_e_metric) {
this->z_.set_inv_metric(inv_e_metric);
}

void set_metric(const Eigen::VectorXd& inv_e_metric) {
this->z_.set_metric(inv_e_metric);
void set_inv_metric(const Eigen::VectorXd& inv_e_metric) {
this->z_.set_inv_metric(inv_e_metric);
}

sample transition(sample& init_sample, callbacks::logger& logger) {
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_nuts_dense_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ int hmc_nuts_dense_e(Model& model, const stan::io::var_context& init,

stan::mcmc::dense_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);

sampler.set_nominal_stepsize(stepsize);
sampler.set_stepsize_jitter(stepsize_jitter);
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ int hmc_nuts_dense_e_adapt(

stan::mcmc::adapt_dense_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);

sampler.set_nominal_stepsize(stepsize);
sampler.set_stepsize_jitter(stepsize_jitter);
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_nuts_diag_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ int hmc_nuts_diag_e(Model& model, const stan::io::var_context& init,

stan::mcmc::diag_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);
sampler.set_nominal_stepsize(stepsize);
sampler.set_stepsize_jitter(stepsize_jitter);
sampler.set_max_depth(max_depth);
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ int hmc_nuts_diag_e_adapt(

stan::mcmc::adapt_diag_e_nuts<Model, boost::ecuyer1988> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);
sampler.set_nominal_stepsize(stepsize);
sampler.set_stepsize_jitter(stepsize_jitter);
sampler.set_max_depth(max_depth);
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_static_dense_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ int hmc_static_dense_e(

stan::mcmc::dense_e_static_hmc<Model, boost::ecuyer1988> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);
sampler.set_nominal_stepsize_and_T(stepsize, int_time);
sampler.set_stepsize_jitter(stepsize_jitter);

Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_static_dense_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ int hmc_static_dense_e_adapt(
stan::mcmc::adapt_dense_e_static_hmc<Model, boost::ecuyer1988> sampler(model,
rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);
sampler.set_nominal_stepsize_and_T(stepsize, int_time);
sampler.set_stepsize_jitter(stepsize_jitter);

Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_static_diag_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ int hmc_static_diag_e(Model& model, const stan::io::var_context& init,

stan::mcmc::diag_e_static_hmc<Model, boost::ecuyer1988> sampler(model, rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);
sampler.set_nominal_stepsize_and_T(stepsize, int_time);
sampler.set_stepsize_jitter(stepsize_jitter);

Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/sample/hmc_static_diag_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ int hmc_static_diag_e_adapt(
stan::mcmc::adapt_diag_e_static_hmc<Model, boost::ecuyer1988> sampler(model,
rng);

sampler.set_metric(inv_metric);
sampler.set_inv_metric(inv_metric);
sampler.set_nominal_stepsize_and_T(stepsize, int_time);
sampler.set_stepsize_jitter(stepsize_jitter);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ TEST(McmcDenseEMetric, sample_p) {

stan::mcmc::dense_e_metric<stan::mcmc::mock_model, rng_t> metric(model);
stan::mcmc::dense_e_point z(2);
z.set_metric(m_inv);
z.set_inv_metric(m_inv);

int n_samples = 1000;

Expand Down