-
-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adds boost root finders with reverse mode specializations #2720
base: develop
Are you sure you want to change the base?
Changes from 17 commits
c33ef74
b015dae
bd9fc69
9b2b477
d729001
f2f8c51
04dbda8
c0d4a81
b08f3ad
33e0b50
f468582
78f84c5
5bff38b
f5c762b
26ca3fc
4c6956f
91a748c
ee38928
344a450
ef8d4e0
4932068
19588e4
11c4a03
13a8886
8cd506b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef STAN_MATH_FWD_FUN_FREXP_HPP | ||
#define STAN_MATH_FWD_FUN_FREXP_HPP | ||
|
||
#include <stan/math/fwd/meta.hpp> | ||
#include <stan/math/fwd/core.hpp> | ||
#include <stan/math/prim/fun/value_of_rec.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
template <typename T> | ||
inline auto frexp(const fvar<T>& x, int* exponent) noexcept { | ||
return std::frexp(value_of_rec(x), exponent); | ||
} | ||
} // namespace math | ||
} // namespace stan | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#ifndef STAN_MATH_FWD_FUN_SIGN_HPP | ||
#define STAN_MATH_FWD_FUN_SIGN_HPP | ||
|
||
#include <stan/math/fwd/meta.hpp> | ||
#include <stan/math/fwd/core.hpp> | ||
#include <stan/math/prim/fun/value_of_rec.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
template <typename T> | ||
inline auto sign(const fvar<T>& x) { | ||
double x_val = value_of_rec(x); | ||
return (0. < x_val) - (x_val < 0.); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
#ifndef STAN_MATH_PRIM_FUNCTOR_ROOT_FINDER_HPP | ||
#define STAN_MATH_PRIM_FUNCTOR_ROOT_FINDER_HPP | ||
|
||
#include <stan/math/prim/fun/Eigen.hpp> | ||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/err/check_bounded.hpp> | ||
#include <stan/math/prim/err/check_positive.hpp> | ||
#include <stan/math/prim/functor/apply.hpp> | ||
#include <boost/math/tools/roots.hpp> | ||
#include <tuple> | ||
#include <utility> | ||
|
||
namespace stan { | ||
namespace math { | ||
namespace internal { | ||
template <typename Tuple, typename... Args> | ||
inline auto func_with_derivs(Tuple&& f_tuple, Args&&... args) { | ||
return stan::math::apply( | ||
[&args...](auto&&... funcs) { | ||
return [&args..., &funcs...](auto&& g) { | ||
return std::make_tuple(funcs(g, args...)...); | ||
}; | ||
}, | ||
f_tuple); | ||
} | ||
} // namespace internal | ||
|
||
/** | ||
* Solve for root using Boost's Halley method | ||
* @tparam FTuple A tuple holding functors whose signatures all match | ||
* `(GuessScalar g, Types&&... Args)`. | ||
* @tparam GuessScalar Scalar type | ||
* @tparam MinScalar Scalar type | ||
* @tparam MaxScalar Scalar type | ||
* @tparam Types Arg types to pass to functors in `f_tuple` | ||
* @param f_tuple A tuple of functors to calculate the value and any derivates | ||
* needed. | ||
* @param guess An initial guess at the root value | ||
* @param min The minimum possible value for the result, this is used as an | ||
* initial lower bracket | ||
* @param max The maximum possible value for the result, this is used as an | ||
* initial upper bracket | ||
* @param digits The desired number of binary digits precision | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indicate that digits cannot exceed the precision of |
||
* @param max_iter An optional maximum number of iterations to perform. On exit, | ||
* this is updated to the actual number of iterations performed | ||
* @param args Parameter pack of arguments to pass the the functors in `f_tuple` | ||
*/ | ||
template <typename SolverFun, typename FTuple, typename GuessScalar, | ||
typename MinScalar, typename MaxScalar, typename... Types, | ||
require_all_not_st_var<GuessScalar, MinScalar, MaxScalar, | ||
Types...>* = nullptr> | ||
auto root_finder_tol(SolverFun&& f_solver, FTuple&& f_tuple, | ||
const GuessScalar guess, const MinScalar min, | ||
const MaxScalar max, const int digits, | ||
std::uintmax_t& max_iter, Types&&... args) { | ||
check_bounded("root_finder", "initial guess", guess, min, max); | ||
check_positive("root_finder", "digits", digits); | ||
check_positive("root_finder", "max_iter", max_iter); | ||
using ret_t = return_type_t<GuessScalar, MinScalar, MaxScalar, Types...>; | ||
ret_t ret = 0; | ||
auto f_plus_div = internal::func_with_derivs(f_tuple, args...); | ||
try { | ||
ret = f_solver(f_plus_div, ret_t(guess), ret_t(min), ret_t(max), digits, | ||
max_iter); | ||
} catch (const std::exception& e) { | ||
std::cout << "err: \n" << e.what() << "\n"; | ||
throw e; | ||
} | ||
return ret; | ||
} | ||
|
||
template <typename FTuple, typename GuessScalar, typename MinScalar, | ||
typename MaxScalar, typename... Types> | ||
auto root_finder_halley_tol(FTuple&& f_tuple, const GuessScalar guess, | ||
const MinScalar min, const MaxScalar max, | ||
const int digits, std::uintmax_t& max_iter, | ||
Types&&... args) { | ||
return root_finder_tol( | ||
[](auto&&... args) { | ||
return boost::math::tools::halley_iterate(args...); | ||
}, | ||
std::forward<FTuple>(f_tuple), guess, min, max, digits, max_iter, | ||
std::forward<Types>(args)...); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This my lacking C++ skills speaking but could you describe how the call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. return root_finder_tol(
[](auto&&... args) {
return boost::math::tools::halley_iterate(args...);
},
std::forward<FTuple>(f_tuple), guess, min, max, digits, max_iter,
std::forward<Types>(args)...); The lambda is in this call to tell struct halley_iterator {
template <typename Types>
operator(Types&&... args) {
return boost::math::tools::halley_iterate(args...);
}
}; |
||
} | ||
|
||
template <typename FTuple, typename GuessScalar, typename MinScalar, | ||
typename MaxScalar, typename... Types> | ||
auto root_finder_newton_raphson_tol(FTuple&& f_tuple, const GuessScalar guess, | ||
const MinScalar min, const MaxScalar max, | ||
const int digits, std::uintmax_t& max_iter, | ||
Types&&... args) { | ||
return root_finder_tol( | ||
[](auto&&... args) { | ||
return boost::math::tools::newton_raphson_iterate(args...); | ||
}, | ||
std::forward<FTuple>(f_tuple), guess, min, max, digits, max_iter, | ||
std::forward<Types>(args)...); | ||
} | ||
|
||
template <typename FTuple, typename GuessScalar, typename MinScalar, | ||
typename MaxScalar, typename... Types> | ||
auto root_finder_schroder_tol(FTuple&& f_tuple, const GuessScalar guess, | ||
const MinScalar min, const MaxScalar max, | ||
const int digits, std::uintmax_t& max_iter, | ||
Types&&... args) { | ||
return root_finder_tol( | ||
[](auto&&... args) { | ||
return boost::math::tools::schroder_iterate(args...); | ||
}, | ||
std::forward<FTuple>(f_tuple), guess, min, max, digits, max_iter, | ||
std::forward<Types>(args)...); | ||
} | ||
|
||
/** | ||
* Solve for root using Boost Halley method with default values for the | ||
* tolerances | ||
* @tparam FTuple A tuple holding functors whose signatures all match | ||
* `(GuessScalar g, Types&&... Args)`. | ||
* @tparam GuessScalar Scalar type | ||
* @tparam MinScalar Scalar type | ||
* @tparam MaxScalar Scalar type | ||
* @tparam Types Arg types to pass to functors in `f_tuple` | ||
* @param f_tuple A tuple of functors to calculate the value and any derivates | ||
* needed. | ||
* @param guess An initial guess at the root value | ||
* @param min The minimum possible value for the result, this is used as an | ||
* initial lower bracket | ||
* @param max The maximum possible value for the result, this is used as an | ||
* initial upper bracket | ||
* @param args Parameter pack of arguments to pass the the functors in `f_tuple` | ||
*/ | ||
template <typename SolverFun, typename FTuple, typename GuessScalar, | ||
typename MinScalar, typename MaxScalar, typename... Types> | ||
auto root_finder(SolverFun&& f_solver, FTuple&& f_tuple, | ||
const GuessScalar guess, const MinScalar min, | ||
const MaxScalar max, Types&&... args) { | ||
constexpr int digits = 16; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how was this default chosen? Maybe add one sentence about this choice in the doxygen doc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to update this to be in line with what boost's docs say
https://www.boost.org/doc/libs/1_62_0/libs/math/doc/html/math_toolkit/roots/roots_deriv.html |
||
std::uintmax_t max_iter = std::numeric_limits<std::uintmax_t>::max(); | ||
return root_finder_tol(std::forward<SolverFun>(f_solver), | ||
std::forward<FTuple>(f_tuple), guess, min, max, digits, | ||
max_iter, std::forward<Types>(args)...); | ||
} | ||
|
||
template <typename FTuple, typename GuessScalar, typename MinScalar, | ||
typename MaxScalar, typename... Types> | ||
auto root_finder_hailey(FTuple&& f_tuple, const GuessScalar guess, | ||
const MinScalar min, const MaxScalar max, | ||
Types&&... args) { | ||
constexpr int digits = 16; | ||
std::uintmax_t max_iter = std::numeric_limits<std::uintmax_t>::max(); | ||
return root_finder_halley_tol(std::forward<FTuple>(f_tuple), guess, min, max, | ||
digits, max_iter, std::forward<Types>(args)...); | ||
} | ||
|
||
template <typename FTuple, typename GuessScalar, typename MinScalar, | ||
typename MaxScalar, typename... Types> | ||
auto root_finder_newton_raphson(FTuple&& f_tuple, const GuessScalar guess, | ||
const MinScalar min, const MaxScalar max, | ||
Types&&... args) { | ||
constexpr int digits = 16; | ||
std::uintmax_t max_iter = std::numeric_limits<std::uintmax_t>::max(); | ||
return root_finder_newton_raphson_tol(std::forward<FTuple>(f_tuple), guess, | ||
min, max, digits, max_iter, | ||
std::forward<Types>(args)...); | ||
} | ||
|
||
template <typename FTuple, typename GuessScalar, typename MinScalar, | ||
typename MaxScalar, typename... Types> | ||
auto root_finder_schroder(FTuple&& f_tuple, const GuessScalar guess, | ||
const MinScalar min, const MaxScalar max, | ||
Types&&... args) { | ||
constexpr int digits = 16; | ||
std::uintmax_t max_iter = std::numeric_limits<std::uintmax_t>::max(); | ||
return root_finder_schroder_tol(std::forward<FTuple>(f_tuple), guess, min, | ||
max, digits, max_iter, | ||
std::forward<Types>(args)...); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#ifndef STAN_MATH_REV_FUN_FREXP_HPP | ||
#define STAN_MATH_REV_FUN_FREXP_HPP | ||
|
||
#include <stan/math/rev/meta.hpp> | ||
#include <stan/math/rev/core.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
inline auto frexp(stan::math::var x, int* exponent) noexcept { | ||
return std::frexp(x.val(), exponent); | ||
} | ||
} // namespace math | ||
} // namespace stan | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#ifndef STAN_MATH_REV_FUN_SIGN_HPP | ||
#define STAN_MATH_REV_FUN_SIGN_HPP | ||
|
||
#include <stan/math/rev/meta.hpp> | ||
#include <stan/math/rev/core.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
inline int sign(stan::math::var z) { return (z == 0) ? 0 : z < 0 ? -1 : 1; } | ||
} // namespace math | ||
} // namespace stan | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs some doxygen doc. It's not clear what this function does.