Skip to content

Commit

Permalink
Merge pull request #2637 from andrjohns/feature/ibeta_inv
Browse files Browse the repository at this point in the history
Incomplete Beta Function Inverse
  • Loading branch information
andrjohns authored Mar 26, 2022
2 parents 359f742 + 7715e90 commit cd230dc
Show file tree
Hide file tree
Showing 11 changed files with 448 additions and 8 deletions.
1 change: 1 addition & 0 deletions stan/math/fwd/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include <stan/math/fwd/fun/inv_erfc.hpp>
#include <stan/math/fwd/fun/inv_Phi.hpp>
#include <stan/math/fwd/fun/inv_cloglog.hpp>
#include <stan/math/fwd/fun/inv_inc_beta.hpp>
#include <stan/math/fwd/fun/inv_logit.hpp>
#include <stan/math/fwd/fun/inv_sqrt.hpp>
#include <stan/math/fwd/fun/inv_square.hpp>
Expand Down
87 changes: 87 additions & 0 deletions stan/math/fwd/fun/inv_inc_beta.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#ifndef STAN_MATH_FWD_FUN_INV_INC_BETA_HPP
#define STAN_MATH_FWD_FUN_INV_INC_BETA_HPP

#include <stan/math/fwd/meta.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/prim/fun/inv_inc_beta.hpp>
#include <stan/math/prim/fun/inc_beta.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log_diff_exp.hpp>
#include <stan/math/prim/fun/lbeta.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/F32.hpp>

namespace stan {
namespace math {

/**
* The inverse of the normalized incomplete beta function of a, b, with
* probability p.
*
* Used to compute the inverse cumulative density function for the beta
* distribution.
*
* @param a Shape parameter a >= 0; a and b can't both be 0
* @param b Shape parameter b >= 0
* @param p Random variate. 0 <= p <= 1
* @throws if constraints are violated or if any argument is NaN
* @return The inverse of the normalized incomplete beta function.
*/
template <typename T1, typename T2, typename T3,
require_all_stan_scalar_t<T1, T2, T3>* = nullptr,
require_any_fvar_t<T1, T2, T3>* = nullptr>
inline fvar<partials_return_t<T1, T2, T3>> inv_inc_beta(const T1& a,
const T2& b,
const T3& p) {
using T_return = partials_return_t<T1, T2, T3>;
auto a_val = value_of(a);
auto b_val = value_of(b);
auto p_val = value_of(p);
T_return w = inv_inc_beta(a_val, b_val, p_val);
T_return log_w = log(w);
T_return log1m_w = log1m(w);
auto one_m_a = 1 - a_val;
auto one_m_b = 1 - b_val;
T_return one_m_w = 1 - w;
auto ap1 = a_val + 1;
auto bp1 = b_val + 1;
auto lbeta_ab = lbeta(a_val, b_val);
auto digamma_apb = digamma(a_val + b_val);

T_return inv_d_(0);

if (is_fvar<T1>::value) {
auto da1 = exp(one_m_b * log1m_w + one_m_a * log_w);
auto da2
= exp(a_val * log_w + 2 * lgamma(a_val)
+ log(F32(a_val, a_val, one_m_b, ap1, ap1, w)) - 2 * lgamma(ap1));
auto da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab)
* (log_w - digamma(a_val) + digamma_apb);
inv_d_ += forward_as<fvar<T_return>>(a).d_ * da1 * (da2 - da3);
}

if (is_fvar<T2>::value) {
auto db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w);
auto db2 = 2 * lgamma(b_val)
+ log(F32(b_val, b_val, one_m_a, bp1, bp1, one_m_w))
- 2 * lgamma(bp1) + b_val * log1m_w;

auto db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab)
* (log1m_w - digamma(b_val) + digamma_apb);

inv_d_ += forward_as<fvar<T_return>>(b).d_ * db1 * (exp(db2) - db3);
}

if (is_fvar<T3>::value) {
inv_d_ += forward_as<fvar<T_return>>(p).d_
* exp(one_m_b * log1m_w + one_m_a * log_w + lbeta_ab);
}

return fvar<T_return>(w, inv_d_);
}

} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
#include <stan/math/prim/fun/int_step.hpp>
#include <stan/math/prim/fun/inv.hpp>
#include <stan/math/prim/fun/inv_Phi.hpp>
#include <stan/math/prim/fun/inv_inc_beta.hpp>
#include <stan/math/prim/fun/inv_cloglog.hpp>
#include <stan/math/prim/fun/inv_erfc.hpp>
#include <stan/math/prim/fun/inv_logit.hpp>
Expand Down
22 changes: 14 additions & 8 deletions stan/math/prim/fun/F32.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,29 +49,35 @@ namespace math {
* @param[in] precision precision of the infinite sum. defaults to 1e-6
* @param[in] max_steps number of steps to take. defaults to 1e5
*/
template <typename T>
T F32(const T& a1, const T& a2, const T& a3, const T& b1, const T& b2,
const T& z, double precision = 1e-6, int max_steps = 1e5) {
template <typename Ta1, typename Ta2, typename Ta3, typename Tb1, typename Tb2,
typename Tz>
return_type_t<Ta1, Ta2, Ta3, Tb1, Tb2, Tz> F32(const Ta1& a1, const Ta2& a2,
const Ta3& a3, const Tb1& b1,
const Tb2& b2, const Tz& z,
double precision = 1e-6,
int max_steps = 1e5) {
check_3F2_converges("F32", a1, a2, a3, b1, b2, z);

using T_return = return_type_t<Ta1, Ta2, Ta3, Tb1, Tb2, Tz>;
using std::exp;
using std::fabs;
using std::log;

T t_acc = 1.0;
T log_t = 0.0;
T log_z = log(z);
T_return t_acc = 1.0;
T_return log_t = 0.0;
Tz log_z = log(z);
double t_sign = 1.0;

for (int k = 0; k <= max_steps; ++k) {
T p = (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (k + 1));
T_return p
= (a1 + k) * (a2 + k) * (a3 + k) / ((b1 + k) * (b2 + k) * (k + 1));
if (p == 0.0) {
return t_acc;
}

log_t += log(fabs(p)) + log_z;
t_sign = p >= 0.0 ? t_sign : -t_sign;
T t_new = t_sign > 0.0 ? exp(log_t) : -exp(log_t);
T_return t_new = t_sign > 0.0 ? exp(log_t) : -exp(log_t);
t_acc += t_new;

if (fabs(t_new) <= precision) {
Expand Down
37 changes: 37 additions & 0 deletions stan/math/prim/fun/inv_inc_beta.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef STAN_MATH_PRIM_FUN_INV_INC_BETA_HPP
#define STAN_MATH_PRIM_FUN_INV_INC_BETA_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/boost_policy.hpp>
#include <boost/math/special_functions/beta.hpp>

namespace stan {
namespace math {

/**
* The inverse of the normalized incomplete beta function of a, b, with
* probability p.
*
* Used to compute the inverse cumulative density function for the beta
* distribution.
*
* @param a Shape parameter a >= 0; a and b can't both be 0
* @param b Shape parameter b >= 0
* @param p Random variate. 0 <= p <= 1
* @throws if constraints are violated or if any argument is NaN
* @return The inverse of the normalized incomplete beta function.
*/
inline double inv_inc_beta(double a, double b, double p) {
check_not_nan("inv_inc_beta", "a", a);
check_not_nan("inv_inc_beta", "b", b);
check_not_nan("inv_inc_beta", "p", p);
check_positive("inv_inc_beta", "a", a);
check_positive("inv_inc_beta", "b", b);
check_bounded("inv_inc_beta", "p", p, 0, 1);
return boost::math::ibeta_inv(a, b, p, boost_policy_t<>());
}

} // namespace math
} // namespace stan
#endif
1 change: 1 addition & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
#include <stan/math/rev/fun/identity_free.hpp>
#include <stan/math/rev/fun/if_else.hpp>
#include <stan/math/rev/fun/inc_beta.hpp>
#include <stan/math/rev/fun/inv_inc_beta.hpp>
#include <stan/math/rev/fun/initialize_fill.hpp>
#include <stan/math/rev/fun/initialize_variable.hpp>
#include <stan/math/rev/fun/inv.hpp>
Expand Down
105 changes: 105 additions & 0 deletions stan/math/rev/fun/inv_inc_beta.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#ifndef STAN_MATH_REV_FUN_INV_INC_BETA_HPP
#define STAN_MATH_REV_FUN_INV_INC_BETA_HPP

#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/inv_inc_beta.hpp>
#include <stan/math/prim/fun/inc_beta.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/log.hpp>
#include <stan/math/prim/fun/log_diff_exp.hpp>
#include <stan/math/prim/fun/lbeta.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/digamma.hpp>
#include <stan/math/prim/fun/F32.hpp>
#include <stan/math/prim/fun/is_any_nan.hpp>

namespace stan {
namespace math {

/**
* The inverse of the normalized incomplete beta function of a, b, with
* probability p.
*
* Used to compute the inverse cumulative density function for the beta
* distribution.
*
\f[
\frac{\partial }{\partial a} =
(1-w)^{1-b}w^{1-a}
\left(
w^a\Gamma(a)^2 {}_3\tilde{F}_2(a,a,1-b;a+1,a+1;w)
- B(a,b)I_w(a,b)\left(\log(w)-\psi(a) + \psi(a+b)\right)
\right)/;w=I_z^{-1}(a,b)
\f]
\f[
\frac{\partial }{\partial b} =
(1-w)^{-b}w^{1-a}(w-1)
\left(
(1-w)^{b}\Gamma(b)^2 {}_3\tilde{F}_2(b,b,1-a;b+1,b+1;1-w)
- B_{1-w}(b,a)\left(\log(1-w)-\psi(b) + \psi(a+b)\right)
\right)/;w=I_z^{-1}(a,b)
\f]
\f[
\frac{\partial }{\partial z} = (1-w)^{1-b}w^{1-a}B(a,b)/;w=I_z^{-1}(a,b)
\f]
*
* @param a Shape parameter a >= 0; a and b can't both be 0
* @param b Shape parameter b >= 0
* @param p Random variate. 0 <= p <= 1
* @throws if constraints are violated or if any argument is NaN
* @return The inverse of the normalized incomplete beta function.
*/
template <typename T1, typename T2, typename T3,
require_all_stan_scalar_t<T1, T2, T3>* = nullptr,
require_any_var_t<T1, T2, T3>* = nullptr>
inline var inv_inc_beta(const T1& a, const T2& b, const T3& p) {
double a_val = value_of(a);
double b_val = value_of(b);
double p_val = value_of(p);
double w = inv_inc_beta(a_val, b_val, p_val);
return make_callback_var(w, [a, b, p, a_val, b_val, p_val, w](auto& vi) {
double log_w = log(w);
double log1m_w = log1m(w);
double one_m_a = 1 - a_val;
double one_m_b = 1 - b_val;
double one_m_w = 1 - w;
double ap1 = a_val + 1;
double bp1 = b_val + 1;
double lbeta_ab = lbeta(a_val, b_val);
double digamma_apb = digamma(a_val + b_val);

if (!is_constant_all<T1>::value) {
double da1 = exp(one_m_b * log1m_w + one_m_a * log_w);
double da2 = a_val * log_w + 2 * lgamma(a_val)
+ log(F32(a_val, a_val, one_m_b, ap1, ap1, w))
- 2 * lgamma(ap1);
double da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab)
* (log_w - digamma(a_val) + digamma_apb);

forward_as<var>(a).adj() += vi.adj() * da1 * (exp(da2) - da3);
}

if (!is_constant_all<T2>::value) {
double db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w);
double db2 = 2 * lgamma(b_val)
+ log(F32(b_val, b_val, one_m_a, bp1, bp1, one_m_w))
- 2 * lgamma(bp1) + b_val * log1m_w;

double db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab)
* (log1m_w - digamma(b_val) + digamma_apb);

forward_as<var>(b).adj() += vi.adj() * db1 * (exp(db2) - db3);
}

if (!is_constant_all<T3>::value) {
forward_as<var>(p).adj()
+= vi.adj() * exp(one_m_b * log1m_w + one_m_a * log_w + lbeta_ab);
}
});
}

} // namespace math
} // namespace stan
#endif
33 changes: 33 additions & 0 deletions test/unit/math/fwd/fun/inv_inc_beta_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include <stan/math/fwd.hpp>
#include <gtest/gtest.h>

TEST(AgradFwdMatrixIncBetaInv, fd_scalar) {
using stan::math::fvar;
using stan::math::inv_inc_beta;
fvar<double> a = 6;
fvar<double> b = 2;
fvar<double> p = 0.9;
a.d_ = 1.0;
b.d_ = 1.0;
p.d_ = 1.0;

fvar<double> res = inv_inc_beta(a, b, p);

EXPECT_FLOAT_EQ(res.d_, 0.0117172527399 - 0.0680999818473 + 0.455387298585);
}

TEST(AgradFwdMatrixIncBetaInv, ffd_scalar) {
using stan::math::fvar;
using stan::math::inv_inc_beta;
fvar<fvar<double>> a = 7;
fvar<fvar<double>> b = 4;
fvar<fvar<double>> p = 0.15;
a.val_.d_ = 1.0;
b.val_.d_ = 1.0;
p.val_.d_ = 1.0;

fvar<fvar<double>> res = inv_inc_beta(a, b, p);

EXPECT_FLOAT_EQ(res.val_.d_,
0.0428905418857 - 0.0563420377808 + 0.664919819507);
}
Loading

0 comments on commit cd230dc

Please sign in to comment.