-
-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closures for ODEs #2094
Closures for ODEs #2094
Conversation
…4.1 (tags/RELEASE_600/final)
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
Thanks for doing this. Apologies for the radio silence. I'll look through this in the not-infinite future. Probably after the feature freeze? This will be really good to have. Has there been any discussion at the language level about how to do the functors? (and are there any big design decisions away from what you've coded here and we talked about previously?) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finally got around to looking at this! Apologies for the delayyy. Left a question about the todo list.
Now that you wrote it, do you think this design makes sense? I'm trying to figure out drawbacks and stuff like that.
The test at line 268 of test/unit/math/rev/functor/ode_rk45_rev_test.cpp is the thing that brings this all together: https://github.com/stan-dev/math/pull/2094/files#diff-818d2074acd865781bdb17e5b65c089d59c893e0c248853b14aab511145d778cR268
Edit: whoops, that last sentence ("The test at line...") you can ignore.
inline void check_finite(const char* function, const char* name, const T_y& y) { | ||
if (check_finite_screen(y)) { | ||
auto is_good = [](const auto& y) { return std::isfinite(y); }; | ||
elementwise_check(is_good, function, name, y, ", but must be finite!"); | ||
} | ||
} | ||
|
||
template <typename T, require_stan_closure_t<T>* = nullptr> | ||
inline void check_finite(const char* function, const char* name, const T& y) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess a closure has variables, so we'd want to check that the values of the variables in the closure are finite
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I guess. I was mostly fighting the compiler here.
I'm not sure why the ODEs try and check the inputs are finite. Infinite inputs don't necessarily cause infinite outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I would like to get rid of these checks too.
|
||
template <typename F, typename T> | ||
auto from_lambda(F f, T a) { | ||
return simple_closure<F, T>(f, a); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be the key to how all this works. This accepts a lambda and an argument. This binds a
to the first argument of the function f
, and this object is treated as a closure (which can be passed around wherever).
THe closure itself, because it can contain vars, is treated like one in all the ODE code (so there are specializations for save_varis, deep_copy_vars, etc.).
This sound right?
Presumably we'd need to expand from_lamba
to take a variable list of arguments and modify all the other higher order functions to work with this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's right.
Actually, from_lambda
isn't necessary because my plan was to leave implementing closure structs to the stanc codegen. require_stan_closure_t<...>
recognizes closure simply by the presence of ::captured_scalar_t__
; they are effectively duck-typed. simple_closure
is one example of how to make one.
Of course, an alternative (complementary?) path forward is expanding simple_closure
so that the template type T
is a parameter pack.
Other higher-order functions just need support for variadic arguments. Shortly after submitting this PR I realized you could revert all the changes to ODEs and instead have an adapter like this
struct closure_adapter_ode {
template<typename F, typename T0, typename T1, typename... Args>
auto operator()(const T0& t, const Eigen::Matrix<T1, Eigen::Dynamic, 1>& y,
std::ostream* msgs, F& f, Args... args) {
return f(msgs, t, y, args...);
}
}
template <typename F, typename T_y0, typename T_t0, typename T_ts,
typename... Args, require_stan_closure_t<F>* = nullptr>
std::vector<Eigen::Matrix<stan::return_type_t<T_y0, T_t0, T_ts, Args...>,
Eigen::Dynamic, 1>>
ode_rk45_tol_impl(const char* function_name, const F& f,
const Eigen::Matrix<T_y0, Eigen::Dynamic, 1>& y0_arg, T_t0 t0,
const std::vector<T_ts>& ts, double relative_tolerance,
double absolute_tolerance,
long int max_num_steps,
std::ostream* msgs, const Args&... args) {
closure_adapter_ode adapter;
return ode_rk45_tol_impl(function_name, adapter, y0_arg, t0, ts,
relative_tolerance, absolute_tolerance,
max_num_steps, msgs, f, args...);
}
Same goes for reduce_sum
and others.
std::vector<var> a1 = {0.75}; | ||
|
||
auto f = stan::math::from_lambda( | ||
[&](const auto& a, const auto& t, const auto& y, const auto& b, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this have no capture? If it captures by reference it might absorb something we don't want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's me not being very familiar with C++. I had assumed the capture default was just part of lambda syntax.
Cool beans. I guess we have a few options from here:
Which do you think is best? Or do you have some other idea? Feel free to allocate work to me. |
I can make something that works end-to-end. (The previous prototype was working end-to-end, though the stanc3 PR is way outdated by now.) You could help figuring out why this struct closure_adapter {
template<typename F, typename T_slice, typename... Args>
auto operator()(const T_slice& subslice, std::size_t start,
std::size_t end, std::ostream* msgs,
const F& f, Args... args) {
return f(msgs, subslice, start, end, args...);
}
};
TEST(StanMathRev_reduce_sum, grouped_gradient_closure) {
using stan::math::var;
using stan::math::from_lambda;
using stan::math::test::get_new_msg;
double lambda_d = 10.0;
const std::size_t groups = 10;
const std::size_t elems_per_group = 1000;
const std::size_t elems = groups * elems_per_group;
std::vector<int> data(elems);
std::vector<int> gidx(elems);
for (std::size_t i = 0; i != elems; ++i) {
data[i] = i;
gidx[i] = i / elems_per_group;
}
std::vector<var> vlambda_v;
for (std::size_t i = 0; i != groups; ++i)
vlambda_v.push_back(i + 0.2);
var lambda_v = vlambda_v[0];
auto functor = from_lambda(
[](auto& lambda, auto& slice, std::size_t start, std::size_t end, auto& gidx, std::ostream * msgs) {
const std::size_t num_terms = end - start + 1;
std::decay_t<decltype(lambda)> lambda_slice(num_terms);
for (std::size_t i = 0; i != num_terms; ++i)
lambda_slice[i] = lambda[gidx[start + i]];
return stan::math::poisson_lpmf(slice, lambda_slice);
}, vlambda_v);
var poisson_lpdf = stan::math::reduce_sum<closure_adapter>(
data, 5, get_new_msg(), functor, gidx);
std::vector<var> vref_lambda_v;
for (std::size_t i = 0; i != elems; ++i) {
vref_lambda_v.push_back(vlambda_v[gidx[i]]);
}
var lambda_ref = vlambda_v[0];
var poisson_lpdf_ref = stan::math::poisson_lpmf(data, vref_lambda_v);
EXPECT_FLOAT_EQ(value_of(poisson_lpdf), value_of(poisson_lpdf_ref));
stan::math::grad(poisson_lpdf_ref.vi_);
const double lambda_ref_adj = lambda_ref.adj();
stan::math::set_zero_all_adjoints();
stan::math::grad(poisson_lpdf.vi_);
const double lambda_adj = lambda_v.adj();
EXPECT_FLOAT_EQ(lambda_adj, lambda_ref_adj)
<< "ref value of poisson lpdf : " << poisson_lpdf_ref.val() << std::endl
<< "ref gradient wrt to lambda: " << lambda_ref_adj << std::endl
<< "value of poisson lpdf : " << poisson_lpdf.val() << std::endl
<< "gradient wrt to lambda: " << lambda_adj << std::endl;
var poisson_lpdf_static
= stan::math::reduce_sum_static<closure_adapter>(
data, 5, get_new_msg(), functor, gidx);
stan::math::set_zero_all_adjoints();
stan::math::grad(poisson_lpdf_static.vi_);
const double lambda_adj_static = lambda_v.adj();
EXPECT_FLOAT_EQ(lambda_adj_static, lambda_ref_adj);
stan::math::recover_memory();
} The input types are |
@nhuurre sounds good, I'll look at the reduce_sum thing tomorrow. I 'spose it's time to do the variadic thing for |
Looking at the ODE example how does |
Yes, |
You okay with me pushing code into this branch? I can do it separate with pull reqs if you want. I made some changes to reduce_sum. It wasn't too bad to code this up, though the address sanitizer is telling me there is a memory leak somewhere still. Before I went further I wanted to stop and ask about the |
Pushing here is fine. |
Alright I added a |
I was too lazy to resolve the merge conflicts so I opened a new PR with less code: #2384 |
Summary
Adds support for closures. A closure is a callable object that also captures some autodiff variables.
The autodiff API is an extension of the existing tools for variadic arguments. The following functions now support closures
These should be enough for any higher-order function that takes variadic arguments to (almost) automatically support closures. So far I've been working with the ODE solvers.
Tests
There's a couple of new tests for
ode_rk45
.Side Effects
None, I think.
Release notes
Add basic API for closure objects.
Checklist
Math issue Implement closures #2197
Copyright holder: Niko Huurre
The copyright holder is typically you or your assignee, such as a university or company. By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
- Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
- Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)
the basic tests are passing
./runTests.py test/unit
)make test-headers
)make test-math-dependencies
)make doxygen
)make cpplint
)the code is written in idiomatic C++ and changes are documented in the doxygen
the new changes are tested