Exact chainrules derivatives for beta_inc and beta_inc_inv#506
Exact chainrules derivatives for beta_inc and beta_inc_inv#506lrnv wants to merge 16 commits intoJuliaMath:masterfrom
beta_inc and beta_inc_inv#506Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #506 +/- ##
==========================================
+ Coverage 94.03% 94.94% +0.90%
==========================================
Files 14 14
Lines 2902 3242 +340
==========================================
+ Hits 2729 3078 +349
+ Misses 173 164 -9
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@asinghvi17, @devmotion: I have switched implementations from the porting of the original code (which was probably not licensed correctly as we discussed yesterday) to a port of this repository, which is an independent implementation from @arzwa, unlicensed yet. I have better hopes that @arzwa will allow us to use his code, and have emailed him. By looking at the diff you'll see that this is clearly an independent implementation, and thus, if @arzwa agrees, we are good license-wise. I also have contacted the original authors, so if they come around faster we can still revert my last commit. |
|
I agree, you can use the code as you wish, I added an MIT license to the repo. |
|
I knew it'll be faster 🤣 |
Now that I'm reading mentioned discussion, I would like to add that I definitely implemented this using the approach described in the mentioned paper (Boik & Robinson-Cox 1998). As far as I remember, I did not 'translate' the code but implemented the described numerical method 'independently', based on their mathematical description (in fact I'm pretty sure I have never seen S-plus code in my life). However, I don't know what that would mean license-wise and whether it's OK to use an MIT license (I have never had to think about those kind of issues before). Any insights on this? |
|
I would test a much larger range of a,b values. Look at the code for calculation of the incomplete beta in this repo. They use different algorithms depending on the domain. |
devmotion
left a comment
There was a problem hiding this comment.
All the nested functions, multiple assignments on a single line and generally lack of comments make it quite challenging to follow the code. Maybe the implementation should be cleaned up a bit before it should be considered for inclusion in SpecialFunctions.
|
So, I have :
A slightly potentially-problematic outcome : the tests of the new chainrule alone last for ~40s on my machine, which is a large part of the total test time of the package :/ Ready for review round 2. |
|
Test time is probably not a concern. If this is good to go, we should merge. |
|
@devmotion Thanks a lot for the review and sorry for the delay ! All your suggestions were taken into account (I marked them as resolved), except 3: the first one is IMHO not a good idea for stability, the remaining two are points on which i agree with you but do not know how to do it myself, maybe you could help ? Tests passes, faillures on PRE seems unrelated. |
|
Thank you @lrnv for doing this. Pending its incorporation into SpecialFunctions, I have copied your code directly into a project I'm working on. I just wanted to check if this is a bug, in The asymmetry here looks funny, with |
|
The |
|
Sorry, I did not check the tests. Separately:
I assume the middle line is only needed because my copy of your code is outside of SpecialFunctions. I'm just using beta_inc() (actually Again, since I'm out of my depth, I leave it to you to judge if there's any practical upshot for you. |
|
Thanks @droodman for the hint on inlining, indeed it helped :) I am not sure the inclusion of ForwardDiff here is really in the scope of this PR. |
d7d26d4 to
500d94c
Compare
|
Hum... Apart ExplicitImports which yells about |
|
Hi @lrnv. Follow-up comments:
First the tolerance is set to eps(T)*T(1e4). Then it is lowered to 1e-14. I believe the original code uses 1e12. I tried changing to |
|
@droodman I removed the minapp/maxapp/eps weridness, thanks. I also removed the type restriction alltogether so that you can pass anything through. PS: I am not sure however that this is compliant with what SpecialFunctions.jl requires for merging ;) |
|
In the last couple of days, this code basically became the inner loop in my project and was consuming hours of time. So I was highly motivated to optimize it. I got another 2X improvement (at least on an M4 Mac), in addition (or multiplication) to the 2X improvement from adding @inline to functions. I forked @lrnv's fork of this repo and made a commit with the changes. However, I'm not that experienced with the complexities of GitHub--commits to different branches of different forks...--and I don't want to mess up anything that you are doing. Here is a link showing the new code block for ext/SpecialFunctionsChainRulesCoreExt.jl. As noted there, the main improvement is removing redundant calculations. I also deleted the first of the four return values of _beta_inc_grad() since it is never used. All tests pass. Timing code: Old timings: New timings: |
|
@droodman You created your branch out of my master branch, while you should have build it on top of my feature branch. Anyway sicne there is only one commit, I copy/pasted it here, thanks :) Ready again for review |
Summary
This PR adds ChainRules coverage for the regularized incomplete beta function and its inverse:
frule/rrulefor:beta_inc(a, b, x) -> (p, q)beta_inc(a, b, x, y)withy = 1 − xsemanticsbeta_inc_inv(a, b, p) -> (x, 1 − x)test/chainrules.jl.I did that by translating the S+ algorithm of Boik & Robinson-Cox (1998) for the partial derivatives w.r.t.
a,b,x.Motivation
beta_incandbeta_inc_inv.beta_invandbeta_inv_incgradients. #505, there a lot of places where these derivatives are missing:Distributions.MvTDist(), which was never fittable due to the lack of these derivatives.Copulas.jl, when usingDistributions.Beta()marginalsI definitely think these derivatives belong to
SpecialFunctions.jl'sChainRule's extension.Final note
The reference of the algo is :
Maybe we should have it as a proper reference in some documentation somewhere ? Dont know.
The algorithm is quite stable, and has good precision (see tests
rtol). It works forFloat64andFloat32. I do not think I broke anything so this should be patch-level change.Waiting for review :)