Skip to content
Open
19 changes: 13 additions & 6 deletions src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,25 @@ namespace mcmc {
* with a Gaussian-Euclidean disintegration and adaptive
* diagonal metric and adaptive step size
*/
template <class Model, class BaseRNG>
class adapt_diag_e_nuts : public diag_e_nuts<Model, BaseRNG>,
template <class Model, class BaseRNG, bool ParallelBase = false>
class adapt_diag_e_nuts : public diag_e_nuts<Model, BaseRNG, ParallelBase>,
public stepsize_var_adapter {
public:
template <bool ParallelBase_ = ParallelBase,
std::enable_if_t<!ParallelBase_>* = nullptr>
adapt_diag_e_nuts(const Model& model, BaseRNG& rng)
: diag_e_nuts<Model, BaseRNG>(model, rng),
: diag_e_nuts<Model, BaseRNG, ParallelBase>(model, rng),
stepsize_var_adapter(model.num_params_r()) {}

~adapt_diag_e_nuts() {}
template <bool ParallelBase_ = ParallelBase,
std::enable_if_t<ParallelBase_>* = nullptr>
adapt_diag_e_nuts(const Model& model, std::vector<BaseRNG>& thread_rngs)
: diag_e_nuts<Model, BaseRNG, ParallelBase>(model, thread_rngs),
stepsize_var_adapter(model.num_params_r()) {}

sample transition(sample& init_sample, callbacks::logger& logger) {
sample s = diag_e_nuts<Model, BaseRNG>::transition(init_sample, logger);
inline sample transition(sample& init_sample, callbacks::logger& logger) {
sample s = diag_e_nuts<Model, BaseRNG, ParallelBase>::transition(
init_sample, logger);

if (this->adapt_flag_) {
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
Expand Down
Loading