-
-
Notifications
You must be signed in to change notification settings - Fork 189
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2637 from andrjohns/feature/ibeta_inv
Incomplete Beta Function Inverse
- Loading branch information
Showing
11 changed files
with
448 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
#ifndef STAN_MATH_FWD_FUN_INV_INC_BETA_HPP | ||
#define STAN_MATH_FWD_FUN_INV_INC_BETA_HPP | ||
|
||
#include <stan/math/fwd/meta.hpp> | ||
#include <stan/math/fwd/core.hpp> | ||
#include <stan/math/prim/fun/inv_inc_beta.hpp> | ||
#include <stan/math/prim/fun/inc_beta.hpp> | ||
#include <stan/math/prim/fun/exp.hpp> | ||
#include <stan/math/prim/fun/log.hpp> | ||
#include <stan/math/prim/fun/log_diff_exp.hpp> | ||
#include <stan/math/prim/fun/lbeta.hpp> | ||
#include <stan/math/prim/fun/lgamma.hpp> | ||
#include <stan/math/prim/fun/digamma.hpp> | ||
#include <stan/math/prim/fun/F32.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* The inverse of the normalized incomplete beta function of a, b, with | ||
* probability p. | ||
* | ||
* Used to compute the inverse cumulative density function for the beta | ||
* distribution. | ||
* | ||
* @param a Shape parameter a >= 0; a and b can't both be 0 | ||
* @param b Shape parameter b >= 0 | ||
* @param p Random variate. 0 <= p <= 1 | ||
* @throws if constraints are violated or if any argument is NaN | ||
* @return The inverse of the normalized incomplete beta function. | ||
*/ | ||
template <typename T1, typename T2, typename T3, | ||
require_all_stan_scalar_t<T1, T2, T3>* = nullptr, | ||
require_any_fvar_t<T1, T2, T3>* = nullptr> | ||
inline fvar<partials_return_t<T1, T2, T3>> inv_inc_beta(const T1& a, | ||
const T2& b, | ||
const T3& p) { | ||
using T_return = partials_return_t<T1, T2, T3>; | ||
auto a_val = value_of(a); | ||
auto b_val = value_of(b); | ||
auto p_val = value_of(p); | ||
T_return w = inv_inc_beta(a_val, b_val, p_val); | ||
T_return log_w = log(w); | ||
T_return log1m_w = log1m(w); | ||
auto one_m_a = 1 - a_val; | ||
auto one_m_b = 1 - b_val; | ||
T_return one_m_w = 1 - w; | ||
auto ap1 = a_val + 1; | ||
auto bp1 = b_val + 1; | ||
auto lbeta_ab = lbeta(a_val, b_val); | ||
auto digamma_apb = digamma(a_val + b_val); | ||
|
||
T_return inv_d_(0); | ||
|
||
if (is_fvar<T1>::value) { | ||
auto da1 = exp(one_m_b * log1m_w + one_m_a * log_w); | ||
auto da2 | ||
= exp(a_val * log_w + 2 * lgamma(a_val) | ||
+ log(F32(a_val, a_val, one_m_b, ap1, ap1, w)) - 2 * lgamma(ap1)); | ||
auto da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab) | ||
* (log_w - digamma(a_val) + digamma_apb); | ||
inv_d_ += forward_as<fvar<T_return>>(a).d_ * da1 * (da2 - da3); | ||
} | ||
|
||
if (is_fvar<T2>::value) { | ||
auto db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w); | ||
auto db2 = 2 * lgamma(b_val) | ||
+ log(F32(b_val, b_val, one_m_a, bp1, bp1, one_m_w)) | ||
- 2 * lgamma(bp1) + b_val * log1m_w; | ||
|
||
auto db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab) | ||
* (log1m_w - digamma(b_val) + digamma_apb); | ||
|
||
inv_d_ += forward_as<fvar<T_return>>(b).d_ * db1 * (exp(db2) - db3); | ||
} | ||
|
||
if (is_fvar<T3>::value) { | ||
inv_d_ += forward_as<fvar<T_return>>(p).d_ | ||
* exp(one_m_b * log1m_w + one_m_a * log_w + lbeta_ab); | ||
} | ||
|
||
return fvar<T_return>(w, inv_d_); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_INV_INC_BETA_HPP | ||
#define STAN_MATH_PRIM_FUN_INV_INC_BETA_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/err.hpp> | ||
#include <stan/math/prim/fun/boost_policy.hpp> | ||
#include <boost/math/special_functions/beta.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* The inverse of the normalized incomplete beta function of a, b, with | ||
* probability p. | ||
* | ||
* Used to compute the inverse cumulative density function for the beta | ||
* distribution. | ||
* | ||
* @param a Shape parameter a >= 0; a and b can't both be 0 | ||
* @param b Shape parameter b >= 0 | ||
* @param p Random variate. 0 <= p <= 1 | ||
* @throws if constraints are violated or if any argument is NaN | ||
* @return The inverse of the normalized incomplete beta function. | ||
*/ | ||
inline double inv_inc_beta(double a, double b, double p) { | ||
check_not_nan("inv_inc_beta", "a", a); | ||
check_not_nan("inv_inc_beta", "b", b); | ||
check_not_nan("inv_inc_beta", "p", p); | ||
check_positive("inv_inc_beta", "a", a); | ||
check_positive("inv_inc_beta", "b", b); | ||
check_bounded("inv_inc_beta", "p", p, 0, 1); | ||
return boost::math::ibeta_inv(a, b, p, boost_policy_t<>()); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
#ifndef STAN_MATH_REV_FUN_INV_INC_BETA_HPP | ||
#define STAN_MATH_REV_FUN_INV_INC_BETA_HPP | ||
|
||
#include <stan/math/rev/meta.hpp> | ||
#include <stan/math/rev/core.hpp> | ||
#include <stan/math/prim/fun/constants.hpp> | ||
#include <stan/math/prim/fun/inv_inc_beta.hpp> | ||
#include <stan/math/prim/fun/inc_beta.hpp> | ||
#include <stan/math/prim/fun/exp.hpp> | ||
#include <stan/math/prim/fun/log.hpp> | ||
#include <stan/math/prim/fun/log_diff_exp.hpp> | ||
#include <stan/math/prim/fun/lbeta.hpp> | ||
#include <stan/math/prim/fun/lgamma.hpp> | ||
#include <stan/math/prim/fun/digamma.hpp> | ||
#include <stan/math/prim/fun/F32.hpp> | ||
#include <stan/math/prim/fun/is_any_nan.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* The inverse of the normalized incomplete beta function of a, b, with | ||
* probability p. | ||
* | ||
* Used to compute the inverse cumulative density function for the beta | ||
* distribution. | ||
* | ||
\f[ | ||
\frac{\partial }{\partial a} = | ||
(1-w)^{1-b}w^{1-a} | ||
\left( | ||
w^a\Gamma(a)^2 {}_3\tilde{F}_2(a,a,1-b;a+1,a+1;w) | ||
- B(a,b)I_w(a,b)\left(\log(w)-\psi(a) + \psi(a+b)\right) | ||
\right)/;w=I_z^{-1}(a,b) | ||
\f] | ||
\f[ | ||
\frac{\partial }{\partial b} = | ||
(1-w)^{-b}w^{1-a}(w-1) | ||
\left( | ||
(1-w)^{b}\Gamma(b)^2 {}_3\tilde{F}_2(b,b,1-a;b+1,b+1;1-w) | ||
- B_{1-w}(b,a)\left(\log(1-w)-\psi(b) + \psi(a+b)\right) | ||
\right)/;w=I_z^{-1}(a,b) | ||
\f] | ||
\f[ | ||
\frac{\partial }{\partial z} = (1-w)^{1-b}w^{1-a}B(a,b)/;w=I_z^{-1}(a,b) | ||
\f] | ||
* | ||
* @param a Shape parameter a >= 0; a and b can't both be 0 | ||
* @param b Shape parameter b >= 0 | ||
* @param p Random variate. 0 <= p <= 1 | ||
* @throws if constraints are violated or if any argument is NaN | ||
* @return The inverse of the normalized incomplete beta function. | ||
*/ | ||
template <typename T1, typename T2, typename T3, | ||
require_all_stan_scalar_t<T1, T2, T3>* = nullptr, | ||
require_any_var_t<T1, T2, T3>* = nullptr> | ||
inline var inv_inc_beta(const T1& a, const T2& b, const T3& p) { | ||
double a_val = value_of(a); | ||
double b_val = value_of(b); | ||
double p_val = value_of(p); | ||
double w = inv_inc_beta(a_val, b_val, p_val); | ||
return make_callback_var(w, [a, b, p, a_val, b_val, p_val, w](auto& vi) { | ||
double log_w = log(w); | ||
double log1m_w = log1m(w); | ||
double one_m_a = 1 - a_val; | ||
double one_m_b = 1 - b_val; | ||
double one_m_w = 1 - w; | ||
double ap1 = a_val + 1; | ||
double bp1 = b_val + 1; | ||
double lbeta_ab = lbeta(a_val, b_val); | ||
double digamma_apb = digamma(a_val + b_val); | ||
|
||
if (!is_constant_all<T1>::value) { | ||
double da1 = exp(one_m_b * log1m_w + one_m_a * log_w); | ||
double da2 = a_val * log_w + 2 * lgamma(a_val) | ||
+ log(F32(a_val, a_val, one_m_b, ap1, ap1, w)) | ||
- 2 * lgamma(ap1); | ||
double da3 = inc_beta(a_val, b_val, w) * exp(lbeta_ab) | ||
* (log_w - digamma(a_val) + digamma_apb); | ||
|
||
forward_as<var>(a).adj() += vi.adj() * da1 * (exp(da2) - da3); | ||
} | ||
|
||
if (!is_constant_all<T2>::value) { | ||
double db1 = (w - 1) * exp(-b_val * log1m_w + one_m_a * log_w); | ||
double db2 = 2 * lgamma(b_val) | ||
+ log(F32(b_val, b_val, one_m_a, bp1, bp1, one_m_w)) | ||
- 2 * lgamma(bp1) + b_val * log1m_w; | ||
|
||
double db3 = inc_beta(b_val, a_val, one_m_w) * exp(lbeta_ab) | ||
* (log1m_w - digamma(b_val) + digamma_apb); | ||
|
||
forward_as<var>(b).adj() += vi.adj() * db1 * (exp(db2) - db3); | ||
} | ||
|
||
if (!is_constant_all<T3>::value) { | ||
forward_as<var>(p).adj() | ||
+= vi.adj() * exp(one_m_b * log1m_w + one_m_a * log_w + lbeta_ab); | ||
} | ||
}); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#include <stan/math/fwd.hpp> | ||
#include <gtest/gtest.h> | ||
|
||
TEST(AgradFwdMatrixIncBetaInv, fd_scalar) { | ||
using stan::math::fvar; | ||
using stan::math::inv_inc_beta; | ||
fvar<double> a = 6; | ||
fvar<double> b = 2; | ||
fvar<double> p = 0.9; | ||
a.d_ = 1.0; | ||
b.d_ = 1.0; | ||
p.d_ = 1.0; | ||
|
||
fvar<double> res = inv_inc_beta(a, b, p); | ||
|
||
EXPECT_FLOAT_EQ(res.d_, 0.0117172527399 - 0.0680999818473 + 0.455387298585); | ||
} | ||
|
||
TEST(AgradFwdMatrixIncBetaInv, ffd_scalar) { | ||
using stan::math::fvar; | ||
using stan::math::inv_inc_beta; | ||
fvar<fvar<double>> a = 7; | ||
fvar<fvar<double>> b = 4; | ||
fvar<fvar<double>> p = 0.15; | ||
a.val_.d_ = 1.0; | ||
b.val_.d_ = 1.0; | ||
p.val_.d_ = 1.0; | ||
|
||
fvar<fvar<double>> res = inv_inc_beta(a, b, p); | ||
|
||
EXPECT_FLOAT_EQ(res.val_.d_, | ||
0.0428905418857 - 0.0563420377808 + 0.664919819507); | ||
} |
Oops, something went wrong.