-
Notifications
You must be signed in to change notification settings - Fork 5
Mooncake reverse rules #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Your PR no longer requires formatting changes. Thank you for your contribution! |
3af6855 to
c665fc4
Compare
743950b to
f82d3b7
Compare
lkdvos
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments throughout, I haven't looked at the tests yet though.
One thing that I am still struggling with is understanding why the adjoint functions return NoRData() for all of the input arguments. Am I correct in assuming that this only works for mutable data? Should we maybe reconsider the strictness of our rules in that case?
Yes, afaik it's set up this way because all the reverse-data you need should be "packaged in" to the tangents of the |
|
OK, everything has been addressed except for the |
lkdvos
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for so quickly going into the comments and resolving them!
I tried going over this again, and I have another round of questions (apologies for that, I think pieces are falling together as I start understanding what is going on):
In the rrule!s, I see we are consistently copying the state of the provided memory for the output of the primal call. (arg1c = copy(arg1) etc). I think I understand that this is what is required by Mooncake, as we need to restore the state before the function call to ensure rules that need to happen earlier in the function call get the correct data.
- If we have to make a copy of the current state anyways, and we aren't required by our interface to ensure that our factorizations are in-place, would it maybe be simpler to call the non-inplace version of the factorization function to begin with, and not mutate?
For example, something like this:
function Mooncake.rrule!!(::CoDual{typeof(eig_full!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
# argument sanitizing with arrayify etc
DV = eig_full(A, alg)
function adjoint(::Mooncake.NoRData)
eig_pullback!(dA, A, DV..., dD, dV; kwargs...)
MatrixAlgebraKit.zero!.(dDV)
return NoRData(), ...
end
...
endwhere we effectively say that eig_full!(A, DV, alg) = eig_full(A, alg) for AD purposes.
Along the same lines of reasoning, a different implementation style + optimization might be to actually implement rrules for the non-mutating functions first, i.e. eig_full(A, alg), and then add a rrule for eig_full!(A, DV, alg) that simply calls the eig_full rrule.
Here, the nice part is that we are not keeping copies of state that we have just allocated, and know we wont be needing later on, and additionally that we might be simplifying some of the logic.
I still have to actually look at the test functionality to see what is going on there, but one thing at a time 🙃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that I have to copy this implementation in TensorKit, and failed horrendously:
https://github.com/QuantumKitHub/TensorKit.jl/blob/bbeb8e510727fd2ce77c579ae9833f5056c85cc8/test/autodiff/ad.jl#L105-L120
I think it would be nice to have this as part of the main package, and if not, to turn this into a "test module" that can be included from within downstream packages (probably mostly by adding some comments to this file to indicate that its location and contents are considered "public" and cannot be moved/removed without breaking changes)
Yes, exactly
Yes, we can do that also. However generally the preference is to implement the rule for the "lowest" method which is why I chose the inplace ones. I can do the ones that copy instead and see if we like that better? |
|
FWIW for Enzyme, there is no requirement to restore the state of he variables on the reverse pass which is the other reason I did things this way -- to be able to compare the two "like to like". |
|
BTW for the in place methods, I think there is no way to avoid copying |
lkdvos
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we for now comment out the 1.12 tests for mooncake so we have an idea of the tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, let me summarize what I think for this PR specifically, also keep in mind that this is my opinion and you are allowed to just overrule that:
-
kwargs...in therrule!!
I don't actually have a great solution for the fact that we don't really support keyword arguments on top of the factorizations.
I don't think the current way of just adding these as kwargs to therrule!!works (although it doesn't hurt either), since the primal call can't actually handle any keyword arguments (e.g.).function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm)
The only precedences for this we currently have are either KrylovKit, where we ended up adding aalg_rrulekeyword argument to all of the primal computations that is simply ignored, which could be an emptyNamedTuplein our case.
The other is PEPSKit, where I designed ahook_pullbackfunction that collaborates with the ChainRules.jl system to insert keyword arguments into a function that are passed on to the pullback: https://github.com/QuantumKitHub/PEPSKit.jl/blob/master/src/utility/hook_pullback.jl
Importantly, we don't need to solve that for this PR, since we are more or less everywhere just using the defaults anyways, so we can tackle that later. I would therefore maybe suggest to remove the keywords for now, since adding that would be non-breaking so that's easier to do in a follow up. (I'm fine leaving them as is too, it shouldn't matter since these shouldnt be able to ever be generated anyways) -
Did we decide what to do with the
@is_primitivedirections?
I remember talking about them being more or less specific, but can't remember what we decided on. What is worrying me slightly is that we are clearly assuming everywhere that the inputs and outputs are mutable, which I guess doesn't always have to be the case. In the interest of moving things forwards, I guess the easiest solution here is to just restrict these definitions more aggressively, claiming that the AD engine should then figure out how to handle other cases and start adding cases back in, rather than the other direction? -
There's a little bit of leftover commented out code in the tests, which could do with either removing or perhaps some explanatory comments about why they are there/why they are commented out
-
Overall it might be slightly nicer to have some additional comments to explain what is going on, as it really takes a bit of time to understand all of this. I'm not saying this has to be done neither by you nor now, just a general remark that this could actually benefit the overall progress here a bit.
Codecov Report❌ Patch coverage is
... and 1 file with indirect coverage changes 🚀 New features to boost your workflow:
|
fcb67e7 to
e90e842
Compare
0fc8e63 to
21cfa85
Compare
|
Ok, I think this is the last set of comments and questions from my side. |
580555b to
da00887
Compare
|
Ok, my approval still stands, so if @lkdvos agrees this can be merged. |
lkdvos
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot for the hard work (and the hard patience!).
I only have some cosmetic comments left, you can choose to ignore this if you like but I think it keeps everything somewhat in the same style.
These should otherwise not make any difference though:
- I think you imported
NoRDatabut then qualifyMooncake.NoRDataeverywhere. Since we are already usingCoDualeverywhere and this is a Mooncake extension, it might be fine to simply useNoRData()everywhere to lighten the notation a bit. - In the
@evalfor loops there is the combination ofSymbols andFunctionobjects. I have absolutely no clue about any performance differences, or burdens on the compiler, but I think in most other parts of the ecosystem we consistently useSymbols, and I think here it wouldn't make a difference either? Correct me if I'm wrong though, no need to spend a whole lot of effort here. - For the adjoint function names, I don't think the extra
d*at the front is totally necessary,qr_adjointlooks maybe a little cleaner? I checked the Mooncake rrules to see if they have any consistent conventions, and from what I can tell here there seems to be a combination off_adjointandf_pb!!. I don't have any strong preference, so maybef_adjoint?
Otherwise definitely a great PR that we will benefit from greatly :)
|
Thanks! There are a few other small cosmetic improvements I'll make with imports as well. |
* Bump v0.6 * rename `gaugefix` -> `fixgauge` * reduce unnecessary warnings * fix `copy_input` signatures in Mooncake tests * Add changelog to docs
No description provided.