-
-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closures #2384
Closures #2384
Changes from 26 commits
01f0a03
7382327
ee21600
a595b43
29c165f
5dce92a
bbabc92
a4032be
76a8991
a55ddb3
990c070
cbf48fa
977f54d
c043511
e72e6f4
e120064
1245fa6
9c25817
a749f61
6990768
2e3180a
610af2d
880e270
4f1f6eb
f883e42
75f6d30
0a609db
e0f6145
9eca190
9436a18
2b2bee2
ecec96a
03ab504
9fc569f
9fb3740
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
#ifndef STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP | ||
#define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP | ||
|
||
#include <stan/math/prim/meta/error_index.hpp> | ||
#include <stan/math/prim/meta/return_type.hpp> | ||
#include <stan/math/prim/meta/is_stan_closure.hpp> | ||
#include <stan/math/prim/functor/apply.hpp> | ||
#include <ostream> | ||
|
||
namespace stan { | ||
namespace math { | ||
namespace internal { | ||
|
||
/** | ||
* A closure that wraps a C++ lambda and captures values. | ||
*/ | ||
template <bool Ref, typename F, typename... Ts> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This all needs docs for template parameters etc. |
||
struct base_closure { | ||
using captured_scalar_t__ = return_type_t<Ts...>; | ||
using ValueOf__ | ||
= base_closure<false, F, decltype(eval(value_of(std::declval<Ts>())))...>; | ||
using CopyOf__ = base_closure<false, F, Ts...>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be good to have docs for what these are as their definitions also change across the different types of closures |
||
F f_; | ||
std::tuple<capture_type_t<Ts, Ref>...> captures_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whats the higher level logic for ref? Aka why can't these always just be references? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also could you add some tests and docs for all these? It would help me understand what your design goal is and how things should work |
||
|
||
explicit base_closure(const F& f, const Ts&... args) | ||
: f_(f), captures_(args...) {} | ||
|
||
template <typename... Args> | ||
auto operator()(std::ostream* msgs, const Args&... args) const { | ||
return apply([this, msgs, &args...]( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [optional] I like having things captured by reference at the front and then things copied after |
||
const auto&... s) { return f_(s..., args..., msgs); }, | ||
captures_); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [optional] When capturing |
||
} | ||
}; | ||
|
||
/** | ||
* A closure that takes rng argument. | ||
*/ | ||
template <bool Ref, typename F, typename... Ts> | ||
struct closure_rng { | ||
using captured_scalar_t__ = double; | ||
using ValueOf__ = closure_rng<false, F, Ts...>; | ||
using CopyOf__ = closure_rng<false, F, Ts...>; | ||
F f_; | ||
std::tuple<capture_type_t<Ts, Ref>...> captures_; | ||
|
||
explicit closure_rng(const F& f, const Ts&... args) | ||
: f_(f), captures_(args...) {} | ||
|
||
template <typename Rng, typename... Args> | ||
auto operator()(Rng& rng, std::ostream* msgs, const Args&... args) const { | ||
return apply([this, &rng, msgs, &args...]( | ||
const auto&... s) { return f_(s..., args..., rng, msgs); }, | ||
captures_); | ||
} | ||
}; | ||
|
||
/** | ||
* A closure that can be called with `propto` template argument. | ||
*/ | ||
template <bool Propto, bool Ref, typename F, typename... Ts> | ||
struct closure_lpdf { | ||
using captured_scalar_t__ = return_type_t<Ts...>; | ||
using ValueOf__ = closure_lpdf<Propto, false, F, Ts...>; | ||
using CopyOf__ = closure_lpdf<Propto, false, F, Ts...>; | ||
F f_; | ||
std::tuple<capture_type_t<Ts, Ref>...> captures_; | ||
|
||
explicit closure_lpdf(const F& f, const Ts&... args) | ||
: f_(f), captures_(args...) {} | ||
|
||
template <bool propto> | ||
auto with_propto() { | ||
Comment on lines
+91
to
+92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use camelcase for template parameters. How is this propto different from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this for like lupdf or something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it's an parameters {
real y[100];
}
model {
function
real higher_lpdf(real[] x, real(real[], int, int) f_lpdf) {
real lp = 0;
lp += reduce_sum(f_lpdf, x, 1); // <-- A
lp += reduce_sum(f_lupdf, x, 1); // <-- B
return lp;
}
function
real partial_lpdf(real[] x, int s, int e) {
return std_normal_lupdf(x|);
}
target += higher_lpdf( y| partial_lpdf); // <-- 1
target += higher_lupdf(y| partial_lupdf); // <-- 2
} Using The above compiles to C++ that looks something like auto higher_lpdf = from_lambda([&](auto f_lpdf) {
var lp = 0;
lp += reduce_sum(f_lpdf.with_propto<false>(), x, 1); // <-- A
lp += reduce_sum(f_lpdf.with_propto<true>(), x, 1); // <-- B
return lp;
});
auto partial_lpdf = from_lambda([]<bool propto>(auto x, int s, int e) {
return std_normal_lpdf<propto>(x);
});
lp_accum__.add(higher_lpdf(y, partial_lpdf.with_propto<false>()); // <-- 1
lp_accum__.add(higher_lpdf(y, partial_lpdf.with_propto<true>()); // <-- 2 Every time the closure object is passed to a higher-order function |
||
return apply( | ||
[this](const auto&... args) { | ||
return closure_lpdf < Propto && propto, true, F, | ||
Ts... > (f_, args...); | ||
}, | ||
captures_); | ||
} | ||
|
||
template <bool propto = false, typename... Args> | ||
auto operator()(std::ostream* msgs, const Args&... args) const { | ||
return apply( | ||
[this, msgs, &args...](const auto&... s) { | ||
return f_.template operator()<propto>(s..., args..., msgs); | ||
}, | ||
captures_); | ||
} | ||
}; | ||
|
||
/** | ||
* A closure that accesses logprob accumulator. | ||
*/ | ||
template <bool Propto, bool Ref, typename F, typename... Ts> | ||
struct closure_lp { | ||
using captured_scalar_t__ = return_type_t<Ts...>; | ||
using ValueOf__ = closure_lp<Propto, true, F, Ts...>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should have different names in the different classes as in the base it's the closure holding the partial type but in the others it's just the closure class with a Ref of true There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops, these were supposed to be the same as in the base. It only works because they're never used. |
||
using CopyOf__ = closure_lp<Propto, true, F, Ts...>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The C++ standard reserves double underscore for the compiler implementations (yes we do this at the stanc3 level but it's not good and we should not do it here) |
||
F f_; | ||
std::tuple<capture_type_t<Ts, Ref>...> captures_; | ||
|
||
explicit closure_lp(const F& f, const Ts&... args) | ||
: f_(f), captures_(args...) {} | ||
|
||
template <bool propto = false, typename T_lp, typename T_lp_accum, | ||
typename... Args> | ||
auto operator()(T_lp& lp, T_lp_accum& lp_accum, std::ostream* msgs, | ||
const Args&... args) const { | ||
return apply( | ||
[this, &lp, &lp_accum, msgs, &args...](const auto&... s) { | ||
return f_.template operator()<propto>(s..., args..., lp, lp_accum, | ||
msgs); | ||
}, | ||
captures_); | ||
} | ||
}; | ||
|
||
} // namespace internal | ||
|
||
/** | ||
* Higher-order functor suitable for calling a closure inside variadic ODE | ||
* solvers. | ||
*/ | ||
struct ode_closure_adapter { | ||
template <typename F, typename T0, typename T1, typename... Args> | ||
auto operator()(const T0& t, const T1& y, std::ostream* msgs, const F& f, | ||
Args... args) const { | ||
return f(msgs, t, y, args...); | ||
} | ||
}; | ||
|
||
struct integrate_ode_closure_adapter { | ||
template <typename F, typename T0, typename T1, typename... Args> | ||
auto operator()(const T0& t, const T1& y, std::ostream* msgs, const F& f, | ||
Args... args) const { | ||
return to_vector(f(msgs, t, to_array_1d(y), args...)); | ||
} | ||
}; | ||
|
||
/** | ||
* Create a closure from a C++ lambda and captures. | ||
*/ | ||
template <typename F, typename... Ts> | ||
auto from_lambda(const F& f, const Ts&... a) { | ||
return internal::base_closure<true, F, Ts...>(f, a...); | ||
} | ||
|
||
/** | ||
* Create a closure from an rng functor. | ||
*/ | ||
template <typename F, typename... Ts> | ||
auto rng_from_lambda(const F& f, const Ts&... a) { | ||
return internal::closure_rng<true, F, Ts...>(f, a...); | ||
} | ||
|
||
/** | ||
* Create a closure from an lpdf functor. | ||
*/ | ||
template <bool propto, typename F, typename... Ts> | ||
auto lpdf_from_lambda(const F& f, const Ts&... a) { | ||
return internal::closure_lpdf<propto, true, F, Ts...>(f, a...); | ||
} | ||
|
||
/** | ||
* Create a closure from a functor that needs access to logprob accumulator. | ||
*/ | ||
template <bool Propto, typename F, typename... Ts> | ||
auto lp_from_lambda(const F& f, const Ts&... args) { | ||
return internal::closure_lp<Propto, true, F, Ts...>(f, args...); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for each kind of function we might have in Stan, we can also have a closure version of that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Each function kind follows a different calling convention so each kind needs its own adapter closure. These aren't used in math library but stanc3 allows userdefined higher order functions that might need them. |
||
|
||
/** | ||
* Higher-order functor that invokes a closure inside a reduce_sum call. | ||
*/ | ||
struct reduce_sum_closure_adapter { | ||
template <typename F, typename T, typename... Args> | ||
auto operator()(const std::vector<T>& sub_slice, std::size_t start, | ||
std::size_t end, std::ostream* msgs, const F& f, | ||
Args... args) const { | ||
return f(msgs, sub_slice, start + error_index::value, | ||
end + error_index::value, args...); | ||
} | ||
}; | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -236,7 +236,7 @@ inline double integrate_1d_impl(const F& f, double a, double b, | |
* @param relative_tolerance tolerance passed to Boost quadrature | ||
* @return numeric integral of function f | ||
*/ | ||
template <typename F> | ||
template <typename F, require_not_stan_closure_t<F>* = nullptr> | ||
inline double integrate_1d(const F& f, double a, double b, | ||
const std::vector<double>& theta, | ||
const std::vector<double>& x_r, | ||
|
@@ -247,6 +247,18 @@ inline double integrate_1d(const F& f, double a, double b, | |
msgs, theta, x_r, x_i); | ||
} | ||
|
||
template <typename F, require_stan_closure_t<F>* = nullptr, | ||
require_arithmetic_t<return_type_t<F>>* = nullptr> | ||
inline double integrate_1d(const F& f, double a, double b, | ||
const std::vector<double>& theta, | ||
const std::vector<double>& x_r, | ||
const std::vector<int>& x_i, std::ostream* msgs, | ||
const double relative_tolerance | ||
= std::sqrt(EPSILON)) { | ||
return integrate_1d_impl(integrate_1d_closure_adapter(), a, b, | ||
relative_tolerance, msgs, f, theta, x_r, x_i); | ||
} | ||
Comment on lines
+250
to
+260
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Q] Think I'm just missing some context, why is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
#define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_ODE_RK45_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/functor/closure_adapter.hpp> | ||
#include <stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp> | ||
#include <stan/math/prim/functor/ode_rk45.hpp> | ||
#include <ostream> | ||
|
@@ -10,6 +11,38 @@ | |
namespace stan { | ||
namespace math { | ||
|
||
namespace internal { | ||
|
||
template <typename F, typename T_y0, typename T_param, typename T_t0, | ||
typename T_ts, require_not_stan_closure_t<F>* = nullptr> | ||
inline auto integrate_ode_rk45_impl( | ||
const F& f, const std::vector<T_y0>& y0, const T_t0& t0, | ||
const std::vector<T_ts>& ts, const std::vector<T_param>& theta, | ||
const std::vector<double>& x, const std::vector<int>& x_int, | ||
std::ostream* msgs, double relative_tolerance, double absolute_tolerance, | ||
int max_num_steps) { | ||
internal::integrate_ode_std_vector_interface_adapter<F> f_adapted(f); | ||
return ode_rk45_tol_impl("integrate_ode_rk45", f_adapted, to_vector(y0), t0, | ||
ts, relative_tolerance, absolute_tolerance, | ||
max_num_steps, msgs, theta, x, x_int); | ||
} | ||
|
||
template <typename F, typename T_y0, typename T_param, typename T_t0, | ||
typename T_ts, require_stan_closure_t<F>* = nullptr> | ||
inline auto integrate_ode_rk45_impl( | ||
const F& f, const std::vector<T_y0>& y0, const T_t0& t0, | ||
const std::vector<T_ts>& ts, const std::vector<T_param>& theta, | ||
const std::vector<double>& x, const std::vector<int>& x_int, | ||
std::ostream* msgs, double relative_tolerance, double absolute_tolerance, | ||
int max_num_steps) { | ||
return ode_rk45_tol_impl("integrate_ode_rk45", | ||
integrate_ode_closure_adapter(), to_vector(y0), t0, | ||
ts, relative_tolerance, absolute_tolerance, | ||
max_num_steps, msgs, f, theta, x, x_int); | ||
} | ||
|
||
} // namespace internal | ||
|
||
/** | ||
* @deprecated use <code>ode_rk45</code> | ||
*/ | ||
|
@@ -21,12 +54,11 @@ inline auto integrate_ode_rk45( | |
const std::vector<double>& x, const std::vector<int>& x_int, | ||
std::ostream* msgs = nullptr, double relative_tolerance = 1e-6, | ||
double absolute_tolerance = 1e-6, int max_num_steps = 1e6) { | ||
internal::integrate_ode_std_vector_interface_adapter<F> f_adapted(f); | ||
auto y = ode_rk45_tol_impl("integrate_ode_rk45", f_adapted, to_vector(y0), t0, | ||
ts, relative_tolerance, absolute_tolerance, | ||
max_num_steps, msgs, theta, x, x_int); | ||
auto y = internal::integrate_ode_rk45_impl(f, y0, t0, ts, theta, x, x_int, | ||
msgs, relative_tolerance, | ||
absolute_tolerance, max_num_steps); | ||
|
||
std::vector<std::vector<return_type_t<T_y0, T_param, T_t0, T_ts>>> | ||
std::vector<std::vector<fn_return_type_t<F, T_y0, T_param, T_t0, T_ts>>> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we just have the logic in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I did have some logic for handling closures in |
||
y_converted; | ||
y_converted.reserve(y.size()); | ||
for (size_t i = 0; i < y.size(); ++i) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. We'll need to implement these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, thinking more, I don't think there's a big advantage to implementing this.
We could expose the variables captured by a closure to checks, but the Math checks wouldn't know in what order its getting them, and then depending on which function was accepting closures it would need to decide which checks to do on which inputs.
I think instead in the ODE solves we check only the arguments passed in explicitly (which this is effectively doing) or we get rid of the infinity checks on the inputs to the ODE solves. I'll make an issue and see if getting rid of the checks altogether is an option. (Edit: Issue #2406)