-
-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closures #2384
Closures #2384
Changes from 4 commits
01f0a03
7382327
ee21600
a595b43
29c165f
5dce92a
bbabc92
a4032be
76a8991
a55ddb3
990c070
cbf48fa
977f54d
c043511
e72e6f4
e120064
1245fa6
9c25817
a749f61
6990768
2e3180a
610af2d
880e270
4f1f6eb
f883e42
75f6d30
0a609db
e0f6145
9eca190
9436a18
2b2bee2
ecec96a
03ab504
9fc569f
9fb3740
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,6 +69,19 @@ inline auto value_of(EigMat&& M) { | |
std::forward<EigMat>(M)); | ||
} | ||
|
||
/** | ||
* Closures that capture non-arithmetic types have value_of__() method. | ||
* | ||
* @tparam F Input element type | ||
* @param[in] f Input closure | ||
* @return closure | ||
**/ | ||
template <typename F, require_stan_closure_t<F>* = nullptr, | ||
require_not_st_arithmetic<F>* = nullptr> | ||
inline auto value_of(const F& f) { | ||
return f.value_of__(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh wow I vaguely remember this. Is this used anywhere yet? |
||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,231 @@ | ||||||||
#ifndef STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP | ||||||||
#define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP | ||||||||
|
||||||||
#include <stan/math/prim/meta/is_stan_closure.hpp> | ||||||||
#include <stan/math/prim/meta/return_type.hpp> | ||||||||
#include <ostream> | ||||||||
|
||||||||
namespace stan { | ||||||||
namespace math { | ||||||||
|
||||||||
template <typename F> | ||||||||
struct empty_closure { | ||||||||
using captured_scalar_t__ = double; | ||||||||
using ValueOf__ = empty_closure<F>; | ||||||||
using CopyOf__ = empty_closure<F>; | ||||||||
F f_; | ||||||||
|
||||||||
explicit empty_closure(const F& f) : f_(f) {} | ||||||||
|
||||||||
template <typename... Args> | ||||||||
auto operator()(std::ostream* msgs, Args... args) const { | ||||||||
return f_(args..., msgs); | ||||||||
} | ||||||||
size_t count_vars__() const { return 0; } | ||||||||
auto value_of__() const { return ValueOf__(f_); } | ||||||||
auto copy_of__() const { return CopyOf__(f_); } | ||||||||
auto deep_copy_vars__() const { return CopyOf__(f_); } | ||||||||
void zero_adjoints__() const {} | ||||||||
double* accumulate_adjoints__(double* dest) const { return dest; } | ||||||||
template <typename Vari> | ||||||||
Vari** save_varis(Vari** dest) const { | ||||||||
return dest; | ||||||||
} | ||||||||
}; | ||||||||
|
||||||||
template <bool Ref, typename F, typename T> | ||||||||
struct one_arg_closure { | ||||||||
using captured_scalar_t__ = return_type_t<T>; | ||||||||
using ValueOf__ | ||||||||
= one_arg_closure<false, F, decltype(value_of(std::declval<T>()))>; | ||||||||
using CopyOf__ = one_arg_closure<false, F, T>; | ||||||||
F f_; | ||||||||
capture_type_t<T, Ref> s_; | ||||||||
|
||||||||
explicit one_arg_closure(const F& f, const T& s) : f_(f), s_(s) {} | ||||||||
|
||||||||
template <typename... Args> | ||||||||
auto operator()(std::ostream* msgs, Args... args) const { | ||||||||
return f_(s_, args..., msgs); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having functions where the variadic arguments aren't last makes writing the C++ painful (cause no type deduction). Could it go |
||||||||
} | ||||||||
size_t count_vars__() const { return count_vars(s_); } | ||||||||
auto value_of__() const { return ValueOf__(f_, value_of(s_)); } | ||||||||
auto copy_of__() const { return CopyOf__(f_, s_); } | ||||||||
auto deep_copy_vars__() const { return CopyOf__(f_, deep_copy_vars(s_)); } | ||||||||
void zero_adjoints__() { zero_adjoints(s_); } | ||||||||
double* accumulate_adjoints__(double* dest) const { | ||||||||
return accumulate_adjoints(dest, s_); | ||||||||
} | ||||||||
template <typename Vari> | ||||||||
Vari** save_varis__(Vari** dest) const { | ||||||||
return save_varis(dest, s_); | ||||||||
} | ||||||||
}; | ||||||||
|
||||||||
template <typename F> | ||||||||
struct empty_closure_rng { | ||||||||
using captured_scalar_t__ = double; | ||||||||
using ValueOf__ = empty_closure_rng<F>; | ||||||||
using CopyOf__ = empty_closure_rng<F>; | ||||||||
F f_; | ||||||||
|
||||||||
explicit empty_closure_rng(const F& f) : f_(f) {} | ||||||||
|
||||||||
template <typename Rng, typename... Args> | ||||||||
auto operator()(const Rng& rng, std::ostream* msgs, Args... args) const { | ||||||||
return f_(args..., rng, msgs); | ||||||||
} | ||||||||
size_t count_vars__() const { return 0; } | ||||||||
auto value_of__() const { return ValueOf__(f_); } | ||||||||
auto copy_of__() const { return CopyOf__(f_); } | ||||||||
auto deep_copy_vars__() const { return CopyOf__(f_); } | ||||||||
void zero_adjoints__() const {} | ||||||||
double* accumulate_adjoints__(double* dest) const { return dest; } | ||||||||
template <typename Vari> | ||||||||
Vari** save_varis(Vari** dest) const { | ||||||||
return dest; | ||||||||
} | ||||||||
}; | ||||||||
|
||||||||
template <typename F> | ||||||||
struct empty_closure_lpdf { | ||||||||
using captured_scalar_t__ = double; | ||||||||
using ValueOf__ = empty_closure_lpdf<F>; | ||||||||
using CopyOf__ = empty_closure_lpdf<F>; | ||||||||
F f_; | ||||||||
|
||||||||
explicit empty_closure_lpdf(const F& f) : f_(f) {} | ||||||||
|
||||||||
template <bool propto = false, typename... Args> | ||||||||
auto operator()(std::ostream* msgs, Args... args) const { | ||||||||
return f_.template operator()<propto>(args..., msgs); | ||||||||
} | ||||||||
size_t count_vars__() const { return 0; } | ||||||||
auto value_of__() const { return ValueOf__(f_); } | ||||||||
auto copy_of__() const { return CopyOf__(f_); } | ||||||||
auto deep_copy_vars__() const { return CopyOf__(f_); } | ||||||||
void zero_adjoints__() const {} | ||||||||
double* accumulate_adjoints__(double* dest) const { return dest; } | ||||||||
template <typename Vari> | ||||||||
Vari** save_varis(Vari** dest) const { | ||||||||
return dest; | ||||||||
} | ||||||||
}; | ||||||||
|
||||||||
template <typename F> | ||||||||
struct empty_closure_lp { | ||||||||
using captured_scalar_t__ = double; | ||||||||
using ValueOf__ = empty_closure_lp<F>; | ||||||||
using CopyOf__ = empty_closure_lp<F>; | ||||||||
static const size_t vars_count__ = 0; | ||||||||
F f_; | ||||||||
|
||||||||
explicit empty_closure_lp(const F& f) : f_(f) {} | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
template <typename T_lp, typename T_lp_accum, typename... Args> | ||||||||
auto operator()(T_lp_accum& lp, T_lp& lp_accum, std::ostream* msgs, | ||||||||
Args... args) const { | ||||||||
return f_(args..., lp, lp_accum, msgs); | ||||||||
} | ||||||||
size_t count_vars__() const { return 0; } | ||||||||
auto value_of__() const { return ValueOf__(f_); } | ||||||||
auto copy_of__() const { return CopyOf__(f_); } | ||||||||
auto deep_copy_vars__() const { return CopyOf__(f_); } | ||||||||
void zero_adjoints__() const {} | ||||||||
double* accumulate_adjoints__(double* dest) const { return dest; } | ||||||||
template <typename Vari> | ||||||||
Vari** save_varis(Vari** dest) const { | ||||||||
return dest; | ||||||||
} | ||||||||
}; | ||||||||
|
||||||||
/** | ||||||||
* Create a closure object from a callable. | ||||||||
*/ | ||||||||
template <typename F> | ||||||||
auto from_lambda(const F& f) { | ||||||||
return empty_closure<F>(f); | ||||||||
} | ||||||||
|
||||||||
/** | ||||||||
* Create a closure that captures a single argument. | ||||||||
*/ | ||||||||
template <typename F, typename T> | ||||||||
auto from_lambda(const F& f, const T& a) { | ||||||||
return one_arg_closure<true, F, T>(f, a); | ||||||||
} | ||||||||
|
||||||||
template <typename F> | ||||||||
auto rng_from_lambda(const F& f) { | ||||||||
return empty_closure_rng<F>(f); | ||||||||
} | ||||||||
|
||||||||
template <typename F> | ||||||||
auto lpdf_from_lambda(const F& f) { | ||||||||
return empty_closure_lpdf<F>(f); | ||||||||
} | ||||||||
|
||||||||
template <typename F> | ||||||||
auto lp_from_lambda(const F& f) { | ||||||||
return empty_closure_lp<F>(f); | ||||||||
} | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for each kind of function we might have in Stan, we can also have a closure version of that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Each function kind follows a different calling convention so each kind needs its own adapter closure. These aren't used in math library but stanc3 allows userdefined higher order functions that might need them. |
||||||||
|
||||||||
template <bool Propto, typename F, bool Ref> | ||||||||
struct lpdf_wrapper { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you know how one can do real func_lpdf(real[] slice, ...) {
...
}
target += reduce_sum(func_lpdf, ...);
target += reduce_sum(func_lupdf, ...); and both |
||||||||
using captured_scalar_t__ = return_type_t<F>; | ||||||||
using ValueOf__ | ||||||||
= lpdf_wrapper<Propto, decltype(std::declval<F>().value_of__()), false>; | ||||||||
using CopyOf__ | ||||||||
= lpdf_wrapper<Propto, decltype(std::declval<F>().copy_of__()), false>; | ||||||||
capture_type_t<F, Ref> f_; | ||||||||
|
||||||||
explicit lpdf_wrapper(const F& f) : f_(f) {} | ||||||||
|
||||||||
template <bool propto> | ||||||||
auto with_propto() { | ||||||||
return lpdf_wrapper < Propto && propto, F, true > (f_); | ||||||||
} | ||||||||
|
||||||||
template <bool propto = Propto, typename... Args> | ||||||||
auto operator()(Args... args) const { | ||||||||
return f_.template operator() < Propto && propto > (args...); | ||||||||
} | ||||||||
size_t count_vars__() const { return count_vars(f_); } | ||||||||
auto value_of__() const { return ValueOf__(value_of(f_)); } | ||||||||
auto deep_copy_vars__() const { return CopyOf__(deep_copy_vars(f_)); } | ||||||||
auto copy_of__() const { return CopyOf__(f_.copy_of__()); } | ||||||||
void zero_adjoints__() { zero_adjoints(f_); } | ||||||||
double* accumulate_adjoints__(double* dest) const { | ||||||||
return accumulate_adjoints(dest, f_); | ||||||||
} | ||||||||
template <typename Vari> | ||||||||
Vari** save_varis__(Vari** dest) const { | ||||||||
return save_varis(dest, f_); | ||||||||
} | ||||||||
}; | ||||||||
|
||||||||
struct reduce_sum_closure_adapter { | ||||||||
template <typename F, typename T, typename... Args> | ||||||||
auto operator()(const std::vector<T>& sub_slice, std::size_t start, | ||||||||
std::size_t end, std::ostream* msgs, const F& f, | ||||||||
Args... args) const { | ||||||||
return f(msgs, sub_slice, start, end, args...); | ||||||||
} | ||||||||
}; | ||||||||
|
||||||||
namespace internal { | ||||||||
|
||||||||
struct ode_closure_adapter { | ||||||||
template <typename F, typename T0, typename T1, typename... Args> | ||||||||
auto operator()(const T0& t, const Eigen::Matrix<T1, Eigen::Dynamic, 1>& y, | ||||||||
std::ostream* msgs, const F& f, Args... args) const { | ||||||||
return f(msgs, t, y, args...); | ||||||||
} | ||||||||
}; | ||||||||
|
||||||||
} // namespace internal | ||||||||
|
||||||||
} // namespace math | ||||||||
} // namespace stan | ||||||||
|
||||||||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ inline auto integrate_ode_rk45( | |
ts, relative_tolerance, absolute_tolerance, | ||
max_num_steps, msgs, theta, x, x_int); | ||
|
||
std::vector<std::vector<return_type_t<T_y0, T_param, T_t0, T_ts>>> | ||
std::vector<std::vector<fn_return_type_t<F, T_y0, T_param, T_t0, T_ts>>> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we just have the logic in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I did have some logic for handling closures in |
||
y_converted; | ||
y_converted.reserve(y.size()); | ||
for (size_t i = 0; i < y.size(); ++i) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. We'll need to implement these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, thinking more, I don't think there's a big advantage to implementing this.
We could expose the variables captured by a closure to checks, but the Math checks wouldn't know in what order its getting them, and then depending on which function was accepting closures it would need to decide which checks to do on which inputs.
I think instead in the ODE solves we check only the arguments passed in explicitly (which this is effectively doing) or we get rid of the infinity checks on the inputs to the ODE solves. I'll make an issue and see if getting rid of the checks altogether is an option. (Edit: Issue #2406)