Skip to content

Exact chainrules derivatives for beta_inc and beta_inc_inv#506

Open
lrnv wants to merge 16 commits intoJuliaMath:masterfrom
lrnv:chainrules-for-beta_inc-and-beta_inc_inv
Open

Exact chainrules derivatives for beta_inc and beta_inc_inv#506
lrnv wants to merge 16 commits intoJuliaMath:masterfrom
lrnv:chainrules-for-beta_inc-and-beta_inc_inv

Conversation

@lrnv
Copy link

@lrnv lrnv commented Oct 4, 2025

Summary

This PR adds ChainRules coverage for the regularized incomplete beta function and its inverse:

  • Adds analytic frule/rrule for:
    • beta_inc(a, b, x) -> (p, q)
    • beta_inc(a, b, x, y) with y = 1 − x semantics
    • beta_inc_inv(a, b, p) -> (x, 1 − x)
  • Add a lot of tests in test/chainrules.jl.

I did that by translating the S+ algorithm of Boik & Robinson-Cox (1998) for the partial derivatives w.r.t. a, b, x.

Motivation

  • Allow AD to go through beta_inc and beta_inc_inv.
  • As hinted in beta_inv and beta_inv_inc gradients. #505, there a lot of places where these derivatives are missing:
    • Recently on this discourse thread
    • To fit a Distributions.MvTDist(), which was never fittable due to the lack of these derivatives.
    • In Copulas.jl, when using Distributions.Beta() marginals
    • In probably other issues that I do not know about.

I definitely think these derivatives belong to SpecialFunctions.jl's ChainRule's extension.

Final note

The reference of the algo is :

Boik, R. J., & Robinson-Cox, J. F. (1998). Derivatives of the incomplete beta function with respect to its parameters. Computational Statistics & Data Analysis, 27(1), 85–106.

Maybe we should have it as a proper reference in some documentation somewhere ? Dont know.

The algorithm is quite stable, and has good precision (see tests rtol). It works for Float64 and Float32. I do not think I broke anything so this should be patch-level change.

Waiting for review :)

@codecov
Copy link

codecov bot commented Oct 4, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 94.94%. Comparing base (1743a8b) to head (b962b26).
⚠️ Report is 30 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #506      +/-   ##
==========================================
+ Coverage   94.03%   94.94%   +0.90%     
==========================================
  Files          14       14              
  Lines        2902     3242     +340     
==========================================
+ Hits         2729     3078     +349     
+ Misses        173      164       -9     
Flag Coverage Δ
unittests 94.94% <100.00%> (+0.90%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@lrnv
Copy link
Author

lrnv commented Oct 5, 2025

@asinghvi17, @devmotion: I have switched implementations from the porting of the original code (which was probably not licensed correctly as we discussed yesterday) to a port of this repository, which is an independent implementation from @arzwa, unlicensed yet.

I have better hopes that @arzwa will allow us to use his code, and have emailed him. By looking at the diff you'll see that this is clearly an independent implementation, and thus, if @arzwa agrees, we are good license-wise.

I also have contacted the original authors, so if they come around faster we can still revert my last commit.

@arzwa
Copy link

arzwa commented Oct 5, 2025

I agree, you can use the code as you wish, I added an MIT license to the repo.

@lrnv
Copy link
Author

lrnv commented Oct 5, 2025

I knew it'll be faster 🤣

@arzwa
Copy link

arzwa commented Oct 5, 2025

@asinghvi17, @devmotion: I have switched implementations from the porting of the original code (which was probably not licensed correctly as we discussed yesterday) to a port of this repository, which is an independent implementation from @arzwa, unlicensed yet.

Now that I'm reading mentioned discussion, I would like to add that I definitely implemented this using the approach described in the mentioned paper (Boik & Robinson-Cox 1998). As far as I remember, I did not 'translate' the code but implemented the described numerical method 'independently', based on their mathematical description (in fact I'm pretty sure I have never seen S-plus code in my life). However, I don't know what that would mean license-wise and whether it's OK to use an MIT license (I have never had to think about those kind of issues before). Any insights on this?

@bdeonovic
Copy link

I would test a much larger range of a,b values. Look at the code for calculation of the incomplete beta in this repo. They use different algorithms depending on the domain.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the nested functions, multiple assignments on a single line and generally lack of comments make it quite challenging to follow the code. Maybe the implementation should be cleaned up a bit before it should be considered for inclusion in SpecialFunctions.

@lrnv
Copy link
Author

lrnv commented Oct 13, 2025

So, I have :

  1. Moved the internal methods out of the function for readability
  2. Added a lot of comments to help following the code (also look at the paper that will help)
  3. Increased largely the coverage of (a,b,x) in my tests, by looking at what was done in test/beta_inc.jl, I think I hit most of the branches as @bdeonovic mentioned.
  4. Fixed the behavior by "merging" two methods (not conform to what @arzwa implemented previously) to avoid cancellations.

A slightly potentially-problematic outcome : the tests of the new chainrule alone last for ~40s on my machine, which is a large part of the total test time of the package :/

Ready for review round 2.

@ViralBShah
Copy link
Member

Test time is probably not a concern. If this is good to go, we should merge.

@lrnv
Copy link
Author

lrnv commented Jan 7, 2026

@devmotion Thanks a lot for the review and sorry for the delay ! All your suggestions were taken into account (I marked them as resolved), except 3: the first one is IMHO not a good idea for stability, the remaining two are points on which i agree with you but do not know how to do it myself, maybe you could help ?

Tests passes, faillures on PRE seems unrelated.

@droodman
Copy link

droodman commented Jan 28, 2026

Thank you @lrnv for doing this. Pending its incorporation into SpecialFunctions, I have copied your code directly into a project I'm working on.

I just wanted to check if this is a bug, in _beta_inc_grad():

    maxapp = max(1000, maxapp)
    minapp = max(5, minapp)

The asymmetry here looks funny, with max() in both lines. The function's signature sets a default of maxapp=200 and here it is boosted to 1000 even if the user specifies, say, 250?

@lrnv
Copy link
Author

lrnv commented Jan 28, 2026

The maxapp and minapp arguments specify the maximum resp. minimum number of
approximants to use in the continued fraction evaluation. Letting the first be at least 1000 and the second at least 5 is coherent with what the original paper percognises, while @arzwa used defaults as minapp=3, maxapp=200 and does not force anything on them. You are right this is troubling. @droodman, did you investigate what appends to the tests if you simply remove these two lines ?

@droodman
Copy link

droodman commented Jan 28, 2026

Sorry, I did not check the tests.
I feel a bit out of my depth here. Based on pattern-matching intuition, it seems that one line should use min() and the other max().

Separately:

  • I found that decorating all the small helper functions with @inline halved run-time whether with Float64 or ForwardDiff.Dual.

  • On that note, I got this working with ForwardDiff (which in my project is working best for automatic differentiation) by making a couple of changes. I changed all instances of AbstractFloat to Real so it would accept ForwardDiff.Duals. And I added:

using ForwardDiff, ForwardDiffChainRules                                                  
import SpecialFunctions.beta_inc                                                          
@ForwardDiff_frule beta_inc(a::ForwardDiff.Dual, b::ForwardDiff.Dual, x::ForwardDiff.Dual)

I assume the middle line is only needed because my copy of your code is outside of SpecialFunctions. I'm just using beta_inc() (actually cdf(TDist(...)...)) so I only bothered creating a ForwardDiff chain rule for that function.

Again, since I'm out of my depth, I leave it to you to judge if there's any practical upshot for you.

@lrnv
Copy link
Author

lrnv commented Jan 29, 2026

Thanks @droodman for the hint on inlining, indeed it helped :)

I am not sure the inclusion of ForwardDiff here is really in the scope of this PR.

@lrnv lrnv force-pushed the chainrules-for-beta_inc-and-beta_inc_inv branch from d7d26d4 to 500d94c Compare January 29, 2026 09:48
@lrnv
Copy link
Author

lrnv commented Jan 29, 2026

Hum... Apart ExplicitImports which yells about @horner not being exported from Base.math, which is unrelated to this PR, everything looks great. I solved one more comment by moving to ChainRulesCore.muladd() to join the tangeants and partials are requested, and i am now ready for a new review round.

@droodman
Copy link

Hi @lrnv. Follow-up comments:

  • It might be out-of-scope to add the definitions of ForwardDiff-specific chain rules. But do you actually want to block users of ForwardDiff from using your code? I think delcaring AbstractFloat instead of Real does that. In my use case, I would continue to get error messages when using ForwardDiff for automatic differentiation of an objective function that calls beta_inc() (or the cdf of the t or F distributions). I might turn to Zygote or the like; previous explorations in my case have found the alternatives crash or run vastly slower.
  • This sequence looks peculiar to me:
function _beta_inc_grad(a::T, b::T, x::T; maxapp::Int=200, minapp::Int=3, err::T=eps(T)*T(1e4))
...
ϵ = min(err, T(1e-14))

First the tolerance is set to eps(T)*T(1e4). Then it is lowered to 1e-14. I believe the original code uses 1e12. I tried changing to ϵ = err and also commenting out the min/maxapp lines I mentioned before and it passed all tests (but you should check me on that). This has modest performance implications.

@lrnv
Copy link
Author

lrnv commented Jan 29, 2026

@droodman I removed the minapp/maxapp/eps weridness, thanks. I also removed the type restriction alltogether so that you can pass anything through.

PS: I am not sure however that this is compliant with what SpecialFunctions.jl requires for merging ;)

@lrnv lrnv requested review from asinghvi17 and devmotion January 29, 2026 15:54
@droodman
Copy link

droodman commented Jan 31, 2026

In the last couple of days, this code basically became the inner loop in my project and was consuming hours of time. So I was highly motivated to optimize it. I got another 2X improvement (at least on an M4 Mac), in addition (or multiplication) to the 2X improvement from adding @inline to functions.

I forked @lrnv's fork of this repo and made a commit with the changes. However, I'm not that experienced with the complexities of GitHub--commits to different branches of different forks...--and I don't want to mess up anything that you are doing. Here is a link showing the new code block for ext/SpecialFunctionsChainRulesCoreExt.jl.

As noted there, the main improvement is removing redundant calculations. I also deleted the first of the four return values of _beta_inc_grad() since it is never used.

All tests pass.

Timing code:

test_points = (
                0.05, 0.08, 0.10, 0.12, 0.14, 0.18, 0.20, 0.22, 0.26,
                0.28, 0.30, 0.32, 0.35, 0.38, 0.40, 0.42, 0.45,
                0.49, 0.50, 0.51, 0.55, 0.58, 0.60, 0.62, 0.65,
                0.68, 0.70, 0.72, 0.76, 0.80, 0.85, 0.90
            )
ab = (0.4, 0.6, 0.9, 1.1, 2.5, 5.0, 16.0, 45.0, 100.5, 150.0)

using TimerOutputs
# run twice and ignore results first time
const to = TimerOutput()
for a in ab, b in ab, x in test_points
    @timeit to "total" SpecialFunctionsChainRulesCoreExt._beta_inc_grad(a, b, x)
end
show(to)

Old timings:

────────────────────────────────────────────────────────────────────
                           Time                    Allocations
                  ───────────────────────   ────────────────────────
Tot / % measured:     22.1ms /   5.6%           1.08MiB /  13.6%

Section   ncalls     time    %tot     avg     alloc    %tot      avg
────────────────────────────────────────────────────────────────────
total      3.20k   **1.25ms**  100.0%   390ns    150KiB  100.0%    48.0B
────────────────────────────────────────────────────────────────────

New timings:

────────────────────────────────────────────────────────────────────
                           Time                    Allocations
                  ───────────────────────   ────────────────────────
Tot / % measured:     11.9ms /   5.0%           1.03MiB /   9.5%

Section   ncalls     time    %tot     avg     alloc    %tot      avg
────────────────────────────────────────────────────────────────────
total      3.20k    **596μs**  100.0%   186ns    100KiB  100.0%    32.0B
────────────────────────────────────────────────────────────────────

@lrnv
Copy link
Author

lrnv commented Jan 31, 2026

@droodman You created your branch out of my master branch, while you should have build it on top of my feature branch. Anyway sicne there is only one commit, I copy/pasted it here, thanks :)

Ready again for review

@lrnv lrnv requested a review from devmotion February 2, 2026 14:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants