Skip to content

Commit

Permalink
Bit of cleanup and added adapter file (Issue #2197)
Browse files Browse the repository at this point in the history
  • Loading branch information
bbbales2 committed Feb 24, 2021
1 parent 767ebb6 commit f84f7eb
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 46 deletions.
26 changes: 19 additions & 7 deletions stan/math/prim/functor/integrate_1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,24 @@ inline double integrate(const F& f, double a, double b,
return Q;
}

/**
* Compute the integral of the single variable function f from a to b to within
* a specified relative tolerance. a and b can be finite or infinite.
*
* @tparam T Type of f
* @param f the function to be integrated
* @param a lower limit of integration
* @param b upper limit of integration
* @param relative_tolerance tolerance passed to Boost quadrature
* @param[in, out] msgs the print stream for warning messages
* @param args additional arguments passed to f
* @return numeric integral of function f
*/
template <typename F, typename... Args,
require_all_not_st_var<Args...>* = nullptr>
inline double integrate_1d_new(const F& f, double a, double b,
double relative_tolerance,
std::ostream* msgs, const Args&... args) {
//const double relative_tolerance = std::sqrt(EPSILON);
inline double integrate_1d_impl(const F& f, double a, double b,
double relative_tolerance,
std::ostream* msgs, const Args&... args) {
static const char* function = "integrate_1d";
check_less_or_equal(function, "lower limit", a, b);

Expand Down Expand Up @@ -207,9 +219,9 @@ inline double integrate_1d(const F& f, double a, double b,
const std::vector<int>& x_i, std::ostream* msgs,
const double relative_tolerance
= std::sqrt(EPSILON)) {
return integrate_1d_new(integrate_1d_adapter<F>(f), a, b,
relative_tolerance, msgs,
theta, x_r, x_i);
return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b,
relative_tolerance, msgs,
theta, x_r, x_i);
}

} // namespace math
Expand Down
29 changes: 29 additions & 0 deletions stan/math/prim/functor/integrate_1d_adapter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_ADAPTER_HPP
#define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_ADAPTER_HPP

#include <ostream>
#include <vector>

/**
* Adapt the non-variadic integrate_1d arguments to the variadic
* integrate_1d_impl interface
*
* @tparam F type of function to adapt
*/
template <typename F>
struct integrate_1d_adapter {
const F& f_;

explicit integrate_1d_adapter(const F& f) : f_(f) {}

template <typename T_a, typename T_b, typename T_theta>
auto operator()(const T_a& x, const T_b& xc,
std::ostream *msgs,
const std::vector<T_theta> &theta,
const std::vector<double> &x_r,
const std::vector<int> &x_i) const {
return f_(x, xc, theta, x_r, x_i, msgs);
}
};

#endif
69 changes: 30 additions & 39 deletions stan/math/rev/functor/integrate_1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,41 +28,17 @@ namespace math {
* NaN, a std::domain_error is thrown
*
* @tparam F type of f
* @tparam Args types of arguments to f
* @param f function to compute gradients of
* @param x location at which to evaluate gradients
* @param xc complement of location (if bounded domain of integration)
* @param n compute gradient with respect to nth parameter
* @param msgs stream for messages
* @param args other arguments to pass to f
*/
template <typename F>
inline double gradient_of_f(const F &f, const double &x, const double &xc,
const std::vector<double> &theta_vals,
const std::vector<double> &x_r,
const std::vector<int> &x_i, size_t n,
std::ostream *msgs) {
double gradient = 0.0;

// Run nested autodiff in this scope
nested_rev_autodiff nested;

std::vector<var> theta_var(theta_vals.size());
for (size_t i = 0; i < theta_vals.size(); i++) {
theta_var[i] = theta_vals[i];
}
var fx = f(x, xc, theta_var, x_r, x_i, msgs);
fx.grad();
gradient = theta_var[n].adj();
if (is_nan(gradient)) {
if (fx.val() == 0) {
gradient = 0;
} else {
throw_domain_error("gradient_of_f", "The gradient of f", n,
"is nan for parameter ", "");
}
}

return gradient;
}

template <typename F, typename... Args>
inline double gradient_of_f_new(const F &f, const double &x, const double &xc,
size_t n, std::ostream *msgs,
const Args&... args) {
inline double gradient_of_f(const F &f, const double &x, const double &xc,
size_t n, std::ostream *msgs, const Args&... args) {
double gradient = 0.0;

// Run nested autodiff in this scope
Expand Down Expand Up @@ -98,11 +74,26 @@ inline double gradient_of_f_new(const F &f, const double &x, const double &xc,
return gradient;
}


/**
* Return the integral of f from a to b to the given relative tolerance
*
* @tparam T_a type of first limit
* @tparam T_b type of second limit
* @tparam T_theta type of parameters
* @tparam T Type of f
*
* @param f the functor to integrate
* @param a lower limit of integration
* @param b upper limit of integration
* @param relative_tolerance relative tolerance passed to Boost quadrature
* @param[in, out] msgs the print stream for warning messages
* @param args additional arguments to pass to f
* @return numeric integral of function f
*/
template <typename F, typename T_a, typename T_b,
typename... Args,
require_any_st_var<T_a, T_b, Args...>* = nullptr>
inline return_type_t<T_a, T_b, Args...> integrate_1d_new(
inline return_type_t<T_a, T_b, Args...> integrate_1d_impl(
const F &f, const T_a &a, const T_b &b,
double relative_tolerance,
std::ostream *msgs, const Args&... args) {
Expand Down Expand Up @@ -160,7 +151,7 @@ inline return_type_t<T_a, T_b, Args...> integrate_1d_new(

for (size_t n = 0; n < num_vars_args; ++n) {
*partials_ptr = integrate(
std::bind<double>(gradient_of_f_new<F, Args...>, f,
std::bind<double>(gradient_of_f<F, Args...>, f,
std::placeholders::_1, std::placeholders::_2,
n, msgs, args...),
a_val, b_val, relative_tolerance);
Expand Down Expand Up @@ -234,9 +225,9 @@ inline return_type_t<T_a, T_b, T_theta> integrate_1d(
const F &f, const T_a &a, const T_b &b, const std::vector<T_theta> &theta,
const std::vector<double> &x_r, const std::vector<int> &x_i,
std::ostream *msgs, const double relative_tolerance = std::sqrt(EPSILON)) {
return integrate_1d_new(integrate_1d_adapter<F>(f), a, b,
relative_tolerance, msgs,
theta, x_r, x_i);
return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b,
relative_tolerance, msgs,
theta, x_r, x_i);
}

} // namespace math
Expand Down

0 comments on commit f84f7eb

Please sign in to comment.