-
Notifications
You must be signed in to change notification settings - Fork 5
WIP Enzyme and Mooncake rules #84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
I haven't reviewed anything yet, but would already like to suggest to use pushforward instead of "pull forward" 😄 . If you make this change, may I request to also name the folder "pushforwards"? (I am not too enthusiastic about random abbreviations like leaving out vowels.) |
|
Ok I had a brief look already, only to conclude that this will be a though PR to review 😸 . If I have time this week, I will try to go through the generic pushforward definitions already. All the Moonzyme specific stuff, I will probably need a bit of an introduction first (or read the respective manuals). |
|
@Jutho would it make more sense to remove the forward mode code entirely from this PR and shunt it off into a new one? |
|
Maybe, as you wish. But then you cannot test it in that PR I assume. Which is not necessary for it to be reviewed, so it could still be helpful. |
|
I do think moving the fwd mode stuff out might make this substantially easier to review, actually. We can always test the forward rules once the (working) reverse ones are in place. |
|
I think I agree, it might be nice to split this up into several separated parts to get things moving easier. I would propose the following:
The rest definitely looks great, do you think it could be reasonable to schedule a meeting to just go over the general Mooncake and/or Enzyme approach as a whole? Obviously I don't want to put this on you if you aren't up for it, and I can try and read through them myself, but if you would be okay it might be nice to get the explanation from someone who already has some experience with it :) |
|
Points 1-3 I 100% agree on. Would also be happy to chat/do a walkthrough of both packages (as much as I can...). |
| ltRtmp = view(Rtmp, MatrixAlgebraKit.lowertriangularind(Rtmp)) | ||
| ltRtmp .= zero(eltype(Rtmp)) | ||
| dR11 .= Rtmp * R11 | ||
| dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dR11 * invR11 is Rtmp, so that can be simplified.
src/pullfwds/qr.jl
Outdated
| dQ, dR = dQR | ||
| dQ1 = view(dQ, 1:m, 1:m1) | ||
| dQ2 = view(dQ, 1:m, m1+1:m2+m1) | ||
| dQ3 = m1+m2+1 < size(dQ, 2) ? view(dQ, 1:m, m1+m2+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean m1+m2 < size(dQ, 2) or m1+m2+1 <= size(dQ, 2)?
Also, does just using view(dQ, 1:m, m1+m2+1:size(dQ,2)) always not just work? In the case that m1+m2 == size(dQ, 2), we automatically get a view with size(dQ3, 2)=0, no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is a bit out of date (working on a new cleaner pushforwards PR 😉 ) but yes indeed I think these are correct observations
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
index a80dce6..2d795ed 100644
--- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
+++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
@@ -16,43 +16,47 @@ using LinearAlgebra
# two-argument factorizations like LQ, QR, EIG
-for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pushforward!),
- (qr_compact!, qr_pullback!, qr_pushforward!),
- (lq_full!, lq_pullback!, lq_pushforward!),
- (lq_compact!, lq_pullback!, lq_pushforward!),
- (eig_full!, eig_pullback!, eig_pushforward!),
- (left_polar!, left_polar_pullback!, left_polar_pushforward!),
- (right_polar!, right_polar_pullback!, right_polar_pushforward!),
- )
+for (f, pb, pf) in (
+ (qr_full!, qr_pullback!, qr_pushforward!),
+ (qr_compact!, qr_pullback!, qr_pushforward!),
+ (lq_full!, lq_pullback!, lq_pushforward!),
+ (lq_compact!, lq_pullback!, lq_pushforward!),
+ (eig_full!, eig_pullback!, eig_pushforward!),
+ (left_polar!, left_polar_pullback!, left_polar_pushforward!),
+ (right_polar!, right_polar_pullback!, right_polar_pushforward!),
+ )
@eval begin
- function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof($f)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+ function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof($f)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_arg = nothing
# form cache if needed
- cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(arg) <: Const)) ? copy(A.val) : nothing
+ cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(arg) <: Const)) ? copy(A.val) : nothing
func.val(A.val, arg.val, alg.val; kwargs...)
- primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
+ primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg))
end
- function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof($f)},
- dret::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...) where {RT}
+ function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof($f)},
+ dret::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...
+ ) where {RT}
cache_A, cache_arg = cache
argval = arg.val
- Aval = !isnothing(cache_A) ? cache_A : A.val
- ∂arg = isa(arg, Const) ? nothing : arg.dval
+ Aval = !isnothing(cache_A) ? cache_A : A.val
+ ∂arg = isa(arg, Const) ? nothing : arg.dval
if !isa(A, Const) && !isa(arg, Const)
A.dval .= zero(eltype(Aval))
$pb(A.dval, A.val, argval, ∂arg; kwargs...)
@@ -60,24 +64,25 @@ for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pushforward!),
!isa(arg, Const) && make_zero!(arg.dval)
return (nothing, nothing, nothing)
end
- function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof($f)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
- ret = func.val(A.val, arg.val, alg.val; kwargs...)
+ function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof($f)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
+ ret = func.val(A.val, arg.val, alg.val; kwargs...)
arg1, arg2 = ret
m, n = size(A.val)
if isa(arg, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const)
darg1, darg2 = arg.dval
- dA = A.dval
+ dA = A.dval
darg1, darg2 = $pf(dA, A.val, ret, arg.dval)
- dA .= zero(eltype(A.val))
- shadow = (darg1, darg2)
+ dA .= zero(eltype(A.val))
+ shadow = (darg1, darg2)
elseif isa(A, Const) && !!isa(arg, Union{Duplicated, DuplicatedNoNeed})
make_zero!(arg.dval)
shadow = arg.dval
@@ -96,54 +101,59 @@ for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pushforward!),
end
end
-for (f, pb, pf) in ((qr_null!, qr_null_pullback!, qr_null_pushforward!),
- (lq_null!, lq_null_pullback!, lq_null_pushforward!),
- )
+for (f, pb, pf) in (
+ (qr_null!, qr_null_pullback!, qr_null_pushforward!),
+ (lq_null!, lq_null_pullback!, lq_null_pushforward!),
+ )
@eval begin
- function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof($f)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- arg::Annotation{<:AbstractMatrix},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+ function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof($f)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ arg::Annotation{<:AbstractMatrix},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_arg = nothing
# form cache if needed
cache_A = nothing #copy(A.val)
func.val(copy(A.val), arg.val, alg.val; kwargs...)
- primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
+ primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg))
end
-
- function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof($f)},
- dret::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- arg::Annotation{<:AbstractMatrix},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(arg.val),
- rank_atol::Real=tol,
- gauge_atol::Real=tol,
- kwargs...) where {RT}
+
+ function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof($f)},
+ dret::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ arg::Annotation{<:AbstractMatrix},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ tol::Real = MatrixAlgebraKit.default_pullback_gaugetol(arg.val),
+ rank_atol::Real = tol,
+ gauge_atol::Real = tol,
+ kwargs...
+ ) where {RT}
cache_A, cache_arg = cache
- Aval = isnothing(cache_A) ? A.val : cache_A
+ Aval = isnothing(cache_A) ? A.val : cache_A
if !isa(A, Const) && !isa(arg, Const)
A.dval .= zero(eltype(A.val))
$pb(A.dval, A.val, arg.val, arg.dval; kwargs...)
end
return (nothing, nothing, nothing)
end
- function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof($f)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- arg::Annotation{<:AbstractMatrix},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+ function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof($f)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ arg::Annotation{<:AbstractMatrix},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
ret = func.val(A.val, arg.val, alg.val; kwargs...)
if isa(arg, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const)
@@ -170,15 +180,16 @@ for (f, pb, pf) in ((qr_null!, qr_null_pullback!, qr_null_pushforward!),
end
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof(svd_compact!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
- ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
+function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof(svd_compact!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
+ ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
shadow = if EnzymeRules.needs_shadow(config)
svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval)
else
@@ -196,20 +207,21 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
end
# TODO
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof(svd_full!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
- ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
+function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof(svd_full!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
+ ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
shadow = if EnzymeRules.needs_shadow(config)
- svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval)
- else
- nothing
- end
+ svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval)
+ else
+ nothing
+ end
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated(ret, shadow)
elseif EnzymeRules.needs_shadow(config)
@@ -222,33 +234,36 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
end
for f in (:svd_compact!, :svd_full!)
@eval begin
- function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof($f)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+ function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof($f)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
# form cache if needed
- cache_USVᴴ = (EnzymeRules.overwritten(config)[3] && !(typeof(USVᴴ) <: Const)) ? copy(USVᴴ.val) : nothing
- cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(A) <: Const)) ? copy(A.val) : nothing
+ cache_USVᴴ = (EnzymeRules.overwritten(config)[3] && !(typeof(USVᴴ) <: Const)) ? copy(USVᴴ.val) : nothing
+ cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(A) <: Const)) ? copy(A.val) : nothing
func.val(A.val, USVᴴ.val, alg.val; kwargs...)
- primal = EnzymeRules.needs_primal(config) ? USVᴴ.val : nothing
+ primal = EnzymeRules.needs_primal(config) ? USVᴴ.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? USVᴴ.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ))
end
- function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof($f)},
- dret::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...) where {RT}
+ function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof($f)},
+ dret::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...
+ ) where {RT}
cache_A, cache_USVᴴ = cache
USVᴴval = !isnothing(cache_USVᴴ) ? cache_USVᴴ : USVᴴ.val
- ∂USVᴴ = isa(USVᴴ, Const) ? nothing : USVᴴ.dval
+ ∂USVᴴ = isa(USVᴴ, Const) ? nothing : USVᴴ.dval
if !isa(A, Const) && !isa(USVᴴ, Const)
A.dval .= zero(eltype(A.dval))
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴval, ∂USVᴴ; kwargs...)
@@ -261,26 +276,27 @@ for f in (:svd_compact!, :svd_full!)
end
end
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(svd_trunc!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
- ϵ::Annotation{Vector{T}},
- alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
- kwargs...,
- ) where {RT, T<:Real}
+function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(svd_trunc!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+ ϵ::Annotation{Vector{T}},
+ alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+ kwargs...,
+ ) where {RT, T <: Real}
# form cache if needed
- cache_A = copy(A.val)
+ cache_A = copy(A.val)
svd_compact!(A.val, USVᴴ.val, alg.val.alg)
- cache_USVᴴ = copy.(USVᴴ.val)
- USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc)
- ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
- primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing
+ cache_USVᴴ = copy.(USVᴴ.val)
+ USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc)
+ ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
+ primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing
shadow_USVᴴ = if !isa(A, Const) && !isa(USVᴴ, Const)
dU, dS, dVᴴ = USVᴴ.dval
- dStrunc = Diagonal(diagview(dS)[ind])
- dUtrunc = dU[:, ind]
+ dStrunc = Diagonal(diagview(dS)[ind])
+ dUtrunc = dU[:, ind]
dVᴴtrunc = dVᴴ[ind, :]
(dUtrunc, dStrunc, dVᴴtrunc)
else
@@ -289,17 +305,19 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
shadow = EnzymeRules.needs_shadow(config) ? (shadow_USVᴴ..., ϵ.dval) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, shadow_USVᴴ, ind))
end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(svd_trunc!)},
- dret::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
- ϵ::Annotation{Vector{T}},
- alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
- kwargs...) where {RT, T<:Real}
+function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(svd_trunc!)},
+ dret::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+ ϵ::Annotation{Vector{T}},
+ alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+ kwargs...
+ ) where {RT, T <: Real}
cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache
- U, S, Vᴴ = cache_USVᴴ
+ U, S, Vᴴ = cache_USVᴴ
dU, dS, dVᴴ = shadow_USVᴴ
if !isa(A, Const) && !isa(USVᴴ, Const)
A.dval .= zero(eltype(A.val))
@@ -314,21 +332,22 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
return (nothing, nothing, nothing, nothing)
end
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof(eigh_vals!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- D::Annotation{<:AbstractVector},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof(eigh_vals!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ D::Annotation{<:AbstractVector},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
Dmat, V = eigh_full(A.val; kwargs...)
if isa(D, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const)
- ∂K = inv(V) * A.dval * V
- ∂Kdiag = diag(∂K)
+ ∂K = inv(V) * A.dval * V
+ ∂Kdiag = diag(∂K)
D.dval .= real.(copy(∂Kdiag))
A.dval .= zero(eltype(A.val))
- shadow = D.dval
+ shadow = D.dval
elseif isa(A, Const) && !!isa(D, Union{Duplicated, DuplicatedNoNeed})
make_zero!(D.dval)
shadow = D.dval
@@ -340,20 +359,21 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
elseif EnzymeRules.needs_shadow(config)
return shadow
elseif EnzymeRules.needs_primal(config)
- return Dmat.diag
+ return Dmat.diag
else
return nothing
end
end
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof(eigh_full!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof(eigh_full!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
Dmat, V = func.val(A.val, DV.val; kwargs...)
if isa(A, Const) || all(iszero, A.dval)
make_zero!(DV.dval[1])
@@ -361,18 +381,18 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
make_zero!(A.dval)
shadow = (DV.dval[1], DV.dval[2])
else
- ∂K = inv(V) * A.dval * V
- ∂Kdiag = diagview(∂K)
- ∂Ddiag = diagview(DV.dval[1])
+ ∂K = inv(V) * A.dval * V
+ ∂Kdiag = diagview(∂K)
+ ∂Ddiag = diagview(DV.dval[1])
∂Ddiag .= real.(∂Kdiag)
- D = diagview(Dmat)
- dDD = transpose(D) .- D
- ∂K ./= dDD
+ D = diagview(Dmat)
+ dDD = transpose(D) .- D
+ ∂K ./= dDD
∂Kdiag .= zero(eltype(V))
mul!(DV.dval[2], V, ∂K, 1, 0)
- shadow = DV.dval[2]
+ shadow = DV.dval[2]
A.dval .= zero(eltype(A.val))
- shadow = (Diagonal(∂Ddiag), DV.dval[2])
+ shadow = (Diagonal(∂Ddiag), DV.dval[2])
end
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return Duplicated((Dmat, V), shadow)
@@ -385,24 +405,25 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
end
end
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
- func::Const{typeof(eig_vals!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- D::Annotation{<:AbstractVector},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.forward(
+ config::EnzymeRules.FwdConfig,
+ func::Const{typeof(eig_vals!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ D::Annotation{<:AbstractVector},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
Dval, V = eig_full(A.val, alg.val; kwargs...)
if isa(D, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const)
- ∂K = inv(V) * A.dval * V
- ∂Kdiag = diag(∂K)
+ ∂K = inv(V) * A.dval * V
+ ∂Kdiag = diag(∂K)
D.dval .= copy(∂Kdiag)
A.dval .= zero(eltype(A.val))
- shadow = D.dval
+ shadow = D.dval
elseif isa(A, Const) && !!isa(D, Union{Duplicated, DuplicatedNoNeed})
make_zero!(D.dval)
- shadow = D.dval
+ shadow = D.dval
end
eig_vals!(A.val, zeros(complex(eltype(A.val)), size(A.val, 1)))
D.val .= diagview(Dval)
@@ -411,30 +432,31 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
elseif EnzymeRules.needs_shadow(config)
return shadow
elseif EnzymeRules.needs_primal(config)
- return Dmat.diag
+ return Dmat.diag
else
return nothing
end
end
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eigh_trunc!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
- ϵ::Annotation{Vector{T}},
- alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
- kwargs...,
- ) where {RT, T}
+function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eigh_trunc!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+ ϵ::Annotation{Vector{T}},
+ alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+ kwargs...,
+ ) where {RT, T}
# form cache if needed
- cache_A = copy(A.val)
+ cache_A = copy(A.val)
MatrixAlgebraKit.eigh_full!(A.val, DV.val, alg.val.alg)
- cache_DV = copy.(DV.val)
- DV′, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV.val, alg.val.trunc)
- ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
- primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
+ cache_DV = copy.(DV.val)
+ DV′, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV.val, alg.val.trunc)
+ ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
+ primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
shadow_DV = if !isa(A, Const) && !isa(DV, Const)
- dD, dV = DV.dval
+ dD, dV = DV.dval
dDtrunc = Diagonal(diagview(dD)[ind])
dVtrunc = dV[:, ind]
(dDtrunc, dVtrunc)
@@ -444,17 +466,19 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., [zero(T)]) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind))
end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eigh_trunc!)},
- ::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
- ϵ::Annotation{Vector{T}},
- alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
- kwargs...) where {RT, T}
+function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eigh_trunc!)},
+ ::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+ ϵ::Annotation{Vector{T}},
+ alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+ kwargs...
+ ) where {RT, T}
cache_A, cache_DV, cache_dDVtrunc, ind = cache
- D, V = cache_DV
+ D, V = cache_DV
dD, dV = cache_dDVtrunc
if !isa(A, Const) && !isa(DV, Const)
A.dval .= zero(eltype(A.val))
@@ -469,24 +493,25 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
return (nothing, nothing, nothing, nothing)
end
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eig_trunc!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
- ϵ::Annotation{Vector{T}},
- alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
- kwargs...,
- ) where {RT, T}
+function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eig_trunc!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+ ϵ::Annotation{Vector{T}},
+ alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+ kwargs...,
+ ) where {RT, T}
# form cache if needed
- cache_A = copy(A.val)
+ cache_A = copy(A.val)
eig_full!(A.val, DV.val, alg.val.alg)
- cache_DV = copy.(DV.val)
- DV′, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV.val, alg.val.trunc)
- ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
- primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
+ cache_DV = copy.(DV.val)
+ DV′, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV.val, alg.val.trunc)
+ ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
+ primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
shadow_DV = if !isa(A, Const) && !isa(DV, Const)
- dD, dV = DV.dval
+ dD, dV = DV.dval
dDtrunc = Diagonal(diagview(dD)[ind])
dVtrunc = dV[:, ind]
(dDtrunc, dVtrunc)
@@ -496,17 +521,19 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., [zero(T)]) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind))
end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eig_trunc!)},
- ::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
- ϵ::Annotation{Vector{T}},
- alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
- kwargs...) where {RT, T}
+function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eig_trunc!)},
+ ::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+ ϵ::Annotation{Vector{T}},
+ alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+ kwargs...
+ ) where {RT, T}
cache_A, cache_DV, cache_dDVtrunc, ind = cache
- D, V = cache_DV
+ D, V = cache_DV
dD, dV = cache_dDVtrunc
if !isa(A, Const) && !isa(DV, Const)
A.dval .= zero(eltype(A.val))
@@ -521,47 +548,49 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
return (nothing, nothing, nothing, nothing)
end
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eigh_full!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eigh_full!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
# form cache if needed
cache_DV = nothing
- cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
+ cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
func.val(A.val, DV.val, alg.val; kwargs...)
primal = EnzymeRules.needs_primal(config) ? DV.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? DV.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV))
end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eigh_full!)},
- ::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eigh_full!)},
+ ::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_A, cache_DV = cache
- DVval = !isnothing(cache_DV) ? cache_DV : DV.val
- Aval = !isnothing(cache_A) ? cache_A : A.val
- ∂DV = isa(DV, Const) ? nothing : DV.dval
+ DVval = !isnothing(cache_DV) ? cache_DV : DV.val
+ Aval = !isnothing(cache_A) ? cache_A : A.val
+ ∂DV = isa(DV, Const) ? nothing : DV.dval
if !isa(A, Const) && !isa(DV, Const)
- Dmat, V = DVval
+ Dmat, V = DVval
∂Dmat, ∂V = ∂DV
- A.dval .= zero(eltype(Aval))
+ A.dval .= zero(eltype(Aval))
MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, DVval, ∂DV; kwargs...)
A.dval .*= 2
diagview(A.dval) ./= 2
for i in 1:size(A.dval, 1), j in 1:size(A.dval, 2)
if i > j
- A.dval[i, j] = zero(eltype(A.dval))
+ A.dval[i, j] = zero(eltype(A.dval))
end
end
end
@@ -571,40 +600,42 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
return (nothing, nothing, nothing)
end
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eig_vals!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- D::Annotation{<:AbstractVector},
- alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eig_vals!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ D::Annotation{<:AbstractVector},
+ alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_D = nothing
- cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
+ cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
func.val(A.val, D.val, alg.val; kwargs...)
- primal = EnzymeRules.needs_primal(config) ? D.val : nothing
+ primal = EnzymeRules.needs_primal(config) ? D.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing
# form cache if needed
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D))
end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eig_vals!)},
- ::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- D::Annotation{<:AbstractVector},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eig_vals!)},
+ ::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ D::Annotation{<:AbstractVector},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_A, cache_D = cache
Dval = !isnothing(cache_D) ? cache_D : D.val
Aval = !isnothing(cache_A) ? cache_A : A.val
- ∂D = isa(D, Const) ? nothing : D.dval
+ ∂D = isa(D, Const) ? nothing : D.dval
if !isa(A, Const) && !isa(D, Const)
- _, V = eig_full(Aval, alg.val)
+ _, V = eig_full(Aval, alg.val)
A.dval .= zero(eltype(Aval))
- PΔV = V' \ Diagonal(D.dval)
+ PΔV = V' \ Diagonal(D.dval)
if eltype(A.dval) <: Real
ΔAc = PΔV * V'
A.dval .+= real.(ΔAc)
@@ -618,45 +649,47 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
return (nothing, nothing, nothing)
end
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eigh_vals!)},
- ::Type{RT},
- A::Annotation{<:AbstractMatrix},
- D::Annotation{<:AbstractVector},
- alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.augmented_primal(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eigh_vals!)},
+ ::Type{RT},
+ A::Annotation{<:AbstractMatrix},
+ D::Annotation{<:AbstractVector},
+ alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_D = nothing
- cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
+ cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
func.val(A.val, D.val, alg.val; kwargs...)
- primal = EnzymeRules.needs_primal(config) ? D.val : nothing
+ primal = EnzymeRules.needs_primal(config) ? D.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing
# form cache if needed
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D))
end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
- func::Const{typeof(eigh_vals!)},
- ::Type{RT},
- cache,
- A::Annotation{<:AbstractMatrix},
- D::Annotation{<:AbstractVector},
- alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
- kwargs...,
- ) where {RT}
+function EnzymeRules.reverse(
+ config::EnzymeRules.RevConfigWidth{1},
+ func::Const{typeof(eigh_vals!)},
+ ::Type{RT},
+ cache,
+ A::Annotation{<:AbstractMatrix},
+ D::Annotation{<:AbstractVector},
+ alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+ kwargs...,
+ ) where {RT}
cache_A, cache_D = cache
Dval = !isnothing(cache_D) ? cache_D : D.val
Aval = !isnothing(cache_A) ? cache_A : A.val
- ∂D = isa(D, Const) ? nothing : D.dval
+ ∂D = isa(D, Const) ? nothing : D.dval
if !isa(A, Const) && !isa(D, Const)
_, V = eigh_full(Aval, alg.val)
- A.dval .= zero(eltype(Aval))
+ A.dval .= zero(eltype(Aval))
mul!(A.dval, V * Diagonal(real(∂D)), V', 1, 0)
A.dval .*= 2
diagview(A.dval) ./= 2
for i in 1:size(A.dval, 1), j in 1:size(A.dval, 2)
if i > j
- A.dval[i, j] = zero(eltype(A.dval))
+ A.dval[i, j] = zero(eltype(A.dval))
end
end
end
diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
index 1797157..9bc5cae 100644
--- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
+++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
@@ -4,7 +4,7 @@ using Mooncake
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
using MatrixAlgebraKit
using MatrixAlgebraKit: inv_safe, diagview
-using MatrixAlgebraKit: svd_pushforward!
+using MatrixAlgebraKit: svd_pushforward!
using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pushforward!, lq_pushforward!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pushforward!, lq_null_pushforward!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pushforward!, eigh_pushforward!
@@ -12,74 +12,76 @@ using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_
using LinearAlgebra
# two-argument factorizations like LQ, QR, EIG
-for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
- (qr_compact!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
- (lq_full!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
- (lq_compact!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
- (eig_full!, eig_pullback!, eig_pushforward!, :deig_adjoint),
- (eigh_full!, eigh_pullback!, eigh_pushforward!, :deigh_adjoint),
- (left_polar!, left_polar_pullback!, left_polar_pushforward!, :dleft_polar_adjoint),
- (right_polar!, right_polar_pullback!, right_polar_pushforward!, :dright_polar_adjoint),
- )
+for (f, pb, pf, adj) in (
+ (qr_full!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
+ (qr_compact!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
+ (lq_full!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
+ (lq_compact!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
+ (eig_full!, eig_pullback!, eig_pushforward!, :deig_adjoint),
+ (eigh_full!, eigh_pullback!, eigh_pushforward!, :deigh_adjoint),
+ (left_polar!, left_polar_pullback!, left_polar_pushforward!, :dleft_polar_adjoint),
+ (right_polar!, right_polar_pullback!, right_polar_pushforward!, :dright_polar_adjoint),
+ )
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
- A, dA = arrayify(A_dA)
- dA .= zero(eltype(A))
- args = Mooncake.primal(args_dargs)
- dargs = Mooncake.tangent(args_dargs)
+ A, dA = arrayify(A_dA)
+ dA .= zero(eltype(A))
+ args = Mooncake.primal(args_dargs)
+ dargs = Mooncake.tangent(args_dargs)
arg1, darg1 = arrayify(args[1], dargs[1])
arg2, darg2 = arrayify(args[2], dargs[2])
function $adj(::Mooncake.NoRData)
dA = $pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...)
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
- args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
+ args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
darg1 .= zero(eltype(arg1))
darg2 .= zero(eltype(arg2))
return Mooncake.CoDual(args, dargs), $adj
end
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
- A, dA = arrayify(A_dA)
- args = Mooncake.primal(args_dargs)
- args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
- dargs = Mooncake.tangent(args_dargs)
- arg1, darg1 = arrayify(args[1], dargs[1])
- arg2, darg2 = arrayify(args[2], dargs[2])
+ A, dA = arrayify(A_dA)
+ args = Mooncake.primal(args_dargs)
+ args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
+ dargs = Mooncake.tangent(args_dargs)
+ arg1, darg1 = arrayify(args[1], dargs[1])
+ arg2, darg2 = arrayify(args[2], dargs[2])
darg1, darg2 = $pf(dA, A, (arg1, arg2), (darg1, darg2))
- dA .= zero(eltype(A))
+ dA .= zero(eltype(A))
return Mooncake.Dual(args, dargs)
end
end
end
-for (f, f_full, pb, pf, adj) in ((qr_null!, qr_full, qr_null_pullback!, qr_null_pushforward!, :dqr_null_adjoint),
- (lq_null!, lq_full, lq_null_pullback!, lq_null_pushforward!, :dlq_null_adjoint),
- )
+for (f, f_full, pb, pf, adj) in (
+ (qr_null!, qr_full, qr_null_pullback!, qr_null_pushforward!, :dqr_null_adjoint),
+ (lq_null!, lq_full, lq_null_pullback!, lq_null_pushforward!, :dlq_null_adjoint),
+ )
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, arg_darg::CoDual{<:AbstractMatrix}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
- A, dA = arrayify(A_dA)
- Ac = MatrixAlgebraKit.copy_input($f_full, A)
+ A, dA = arrayify(A_dA)
+ Ac = MatrixAlgebraKit.copy_input($f_full, A)
arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg))
- arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
+ arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
function $adj(::Mooncake.NoRData)
- dA .= zero(eltype(A))
+ dA .= zero(eltype(A))
$pb(dA, A, arg, darg; kwargs...)
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
- return arg_darg, $adj
+ return arg_darg, $adj
end
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(f_df::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, arg_darg::Dual{<:AbstractMatrix}, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
- A, dA = arrayify(A_dA)
- Ac = MatrixAlgebraKit.copy_input($f_full, A)
+ A, dA = arrayify(A_dA)
+ Ac = MatrixAlgebraKit.copy_input($f_full, A)
arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg))
- arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
+ arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
$pf(dA, A, arg, darg; kwargs...)
- dA .= zero(dA)
+ dA .= zero(dA)
return arg_darg
end
end
@@ -89,33 +91,33 @@ end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...)
# compute primal
- D_ = Mooncake.primal(D_dD)
- dD_ = Mooncake.tangent(D_dD)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
+ D_ = Mooncake.primal(D_dD)
+ dD_ = Mooncake.tangent(D_dD)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
D, dD = arrayify(D_, dD_)
nD, V = eig_full(A, alg_dalg.primal; kwargs...)
# update tangent
- tmp = V \ dA
- dD .= diagview(tmp * V)
- dA .= zero(eltype(dA))
+ tmp = V \ dA
+ dD .= diagview(tmp * V)
+ dA .= zero(eltype(dA))
return Mooncake.Dual(nD.diag, dD_)
end
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
# compute primal
- D_ = Mooncake.primal(D_dD)
- dD_ = Mooncake.tangent(D_dD)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
+ D_ = Mooncake.primal(D_dD)
+ dD_ = Mooncake.tangent(D_dD)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
D, dD = arrayify(D_, dD_)
- dA .= zero(eltype(dA))
- # update primal
- DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
- V = DV[2]
+ dA .= zero(eltype(dA))
+ # update primal
+ DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
+ V = DV[2]
dD .= zero(eltype(D))
function deig_vals_adjoint(::Mooncake.NoRData)
PΔV = V' \ Diagonal(dD)
@@ -163,30 +165,30 @@ end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...)
# compute primal
- D_ = Mooncake.primal(D_dD)
- dD_ = Mooncake.tangent(D_dD)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
+ D_ = Mooncake.primal(D_dD)
+ dD_ = Mooncake.tangent(D_dD)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
D, dD = arrayify(D_, dD_)
nD, V = eigh_full(A, alg_dalg.primal; kwargs...)
# update tangent
- tmp = inv(V) * dA * V
- dD .= real.(diagview(tmp))
- D .= nD.diag
- dA .= zero(eltype(dA))
+ tmp = inv(V) * dA * V
+ dD .= real.(diagview(tmp))
+ D .= nD.diag
+ dA .= zero(eltype(dA))
return D_dD
end
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
# compute primal
- D_ = Mooncake.primal(D_dD)
- dD_ = Mooncake.tangent(D_dD)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
+ D_ = Mooncake.primal(D_dD)
+ dD_ = Mooncake.tangent(D_dD)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
D, dD = arrayify(D_, dD_)
- DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
+ DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
function deigh_vals_adjoint(::Mooncake.NoRData)
mul!(dA, DV[2] * Diagonal(real(dD)), DV[2]', 1, 0)
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
@@ -199,60 +201,60 @@ for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal))
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual; kwargs...)
- A, dA = arrayify(A_dA)
- USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
- dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
- U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
- S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
+ A, dA = arrayify(A_dA)
+ USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
+ dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
+ U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
+ S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
- USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...)
+ USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...)
function dsvd_adjoint(::Mooncake.NoRData)
- dA .= zero(eltype(A))
+ dA .= zero(eltype(A))
minmn = min(size(A)...)
if size(U, 2) == size(Vᴴ, 1) == minmn # compact
- dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
+ dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
else # full
- vU = view(U, :, 1:minmn)
- vS = Diagonal(diagview(S)[1:minmn])
- vVᴴ = view(Vᴴ, 1:minmn, :)
- vdU = view(dU, :, 1:minmn)
- vdS = view(dS, 1:minmn, 1:minmn)
- vdVᴴ = view(dVᴴ, 1:minmn, :)
- dA = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
+ vU = view(U, :, 1:minmn)
+ vS = Diagonal(diagview(S)[1:minmn])
+ vVᴴ = view(Vᴴ, 1:minmn, :)
+ vdU = view(dU, :, 1:minmn)
+ vdS = view(dS, 1:minmn, 1:minmn)
+ vdVᴴ = view(dVᴴ, 1:minmn, :)
+ dA = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
end
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
- dU .= zero(dU)
- dS .= zero(dS)
+ dU .= zero(dU)
+ dS .= zero(dS)
dVᴴ .= zero(dVᴴ)
return Mooncake.CoDual(USVᴴ, dUSVᴴ), dsvd_adjoint
end
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{<:typeof($f)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual; kwargs...)
# compute primal
- USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
- dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
- A, dA = arrayify(A_, dA_)
+ USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
+ dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
+ A, dA = arrayify(A_, dA_)
$f(A, USVᴴ, alg_dalg.primal; kwargs...)
# update tangents
- U_, S_, Vᴴ_ = USVᴴ
+ U_, S_, Vᴴ_ = USVᴴ
dU_, dS_, dVᴴ_ = dUSVᴴ
- U, dU = arrayify(U_, dU_)
- S, dS = arrayify(S_, dS_)
- Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_)
- minmn = min(size(A)...)
+ U, dU = arrayify(U_, dU_)
+ S, dS = arrayify(S_, dS_)
+ Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_)
+ minmn = min(size(A)...)
if ($f == svd_compact!) # compact
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...)
else # full
- vU = view(U, :, 1:minmn)
- vS = view(S, 1:minmn, 1:minmn)
- vVᴴ = view(Vᴴ, 1:minmn, :)
- vdU = view(dU, :, 1:minmn)
- vdS = view(dS, 1:minmn, 1:minmn)
- vdVᴴ = view(dVᴴ, 1:minmn, :)
+ vU = view(U, :, 1:minmn)
+ vS = view(S, 1:minmn, 1:minmn)
+ vVᴴ = view(Vᴴ, 1:minmn, :)
+ vdU = view(dU, :, 1:minmn)
+ vdS = view(dS, 1:minmn, 1:minmn)
+ vdVᴴ = view(dVᴴ, 1:minmn, :)
svd_pushforward!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ); kwargs...)
end
return USVᴴ_dUSVᴴ
@@ -263,15 +265,15 @@ end
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual; kwargs...)
# compute primal
- S_ = Mooncake.primal(S_dS)
- dS_ = Mooncake.tangent(S_dS)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
+ S_ = Mooncake.primal(S_dS)
+ dS_ = Mooncake.tangent(S_dS)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
# update tangent
- S, dS = arrayify(S_, dS_)
+ S, dS = arrayify(S_, dS_)
copyto!(dS, diag(real.(Vᴴ * dA' * U)))
copyto!(S, diagview(nS))
dA .= zero(eltype(dA))
@@ -281,17 +283,17 @@ end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...)
# compute primal
- S_ = Mooncake.primal(S_dS)
- dS_ = Mooncake.tangent(S_dS)
- A_ = Mooncake.primal(A_dA)
- dA_ = Mooncake.tangent(A_dA)
+ S_ = Mooncake.primal(S_dS)
+ dS_ = Mooncake.tangent(S_dS)
+ A_ = Mooncake.primal(A_dA)
+ dA_ = Mooncake.tangent(A_dA)
A, dA = arrayify(A_, dA_)
S, dS = arrayify(S_, dS_)
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
- S .= diagview(nS)
- dS .= zero(eltype(S))
+ S .= diagview(nS)
+ dS .= zero(eltype(S))
function dsvd_vals_adjoint(::Mooncake.NoRData)
- dA .= U * Diagonal(dS) * Vᴴ
+ dA .= U * Diagonal(dS) * Vᴴ
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
end
return S_dS, dsvd_vals_adjoint
diff --git a/src/common/view.jl b/src/common/view.jl
index 0bc7b9e..c8ae1aa 100644
--- a/src/common/view.jl
+++ b/src/common/view.jl
@@ -1,5 +1,5 @@
# diagind: provided by LinearAlgebra.jl
-diagview(D::Diagonal) = D.diag
+diagview(D::Diagonal) = D.diag
diagview(D::AbstractMatrix) = view(D, diagind(D))
# triangularind
diff --git a/src/pushforwards/eig.jl b/src/pushforwards/eig.jl
index 19a43cb..36506f3 100644
--- a/src/pushforwards/eig.jl
+++ b/src/pushforwards/eig.jl
@@ -1,11 +1,11 @@
function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...)
- D, V = DV
- ΔD, ΔV = ΔDV
- iVΔAV = inv(V) * ΔA * V
+ D, V = DV
+ ΔD, ΔV = ΔDV
+ iVΔAV = inv(V) * ΔA * V
diagview(ΔD) .= diagview(iVΔAV)
- F = 1 ./ (transpose(diagview(D)) .- diagview(D))
+ F = 1 ./ (transpose(diagview(D)) .- diagview(D))
fill!(diagview(F), zero(eltype(F)))
- K̇ = F .* iVΔAV
+ K̇ = F .* iVΔAV
mul!(ΔV, V, K̇, 1, 0)
zero!(ΔA)
return ΔDV
diff --git a/src/pushforwards/eigh.jl b/src/pushforwards/eigh.jl
index 69685b1..5fdfdca 100644
--- a/src/pushforwards/eigh.jl
+++ b/src/pushforwards/eigh.jl
@@ -1,16 +1,16 @@
function eigh_pushforward!(dA, A, DV, dDV; kwargs...)
- D, V = DV
- dD, dV = dDV
- tmpV = V \ dA
- ∂K = tmpV * V
- ∂Kdiag = diag(∂K)
- dD.diag .= real.(∂Kdiag)
- dDD = transpose(diagview(D)) .- diagview(D)
- F = one(eltype(dDD)) ./ dDD
+ D, V = DV
+ dD, dV = dDV
+ tmpV = V \ dA
+ ∂K = tmpV * V
+ ∂Kdiag = diag(∂K)
+ dD.diag .= real.(∂Kdiag)
+ dDD = transpose(diagview(D)) .- diagview(D)
+ F = one(eltype(dDD)) ./ dDD
diagview(F) .= zero(eltype(F))
- ∂K .*= F
- ∂V = mul!(tmpV, V, ∂K)
+ ∂K .*= F
+ ∂V = mul!(tmpV, V, ∂K)
copyto!(dV, ∂V)
- dA .= zero(eltype(A))
+ dA .= zero(eltype(A))
return (dD, dV)
end
diff --git a/src/pushforwards/lq.jl b/src/pushforwards/lq.jl
index 2d390a5..ed5a72d 100644
--- a/src/pushforwards/lq.jl
+++ b/src/pushforwards/lq.jl
@@ -62,7 +62,7 @@
end=#
function lq_pushforward!(dA, A, LQ, dLQ; kwargs...)
- qr_pushforward!(dA, A, (adjoint(LQ[2]), adjoint(LQ[1])), (adjoint(dLQ[2]), adjoint(dLQ[1])); kwargs...)
+ return qr_pushforward!(dA, A, (adjoint(LQ[2]), adjoint(LQ[1])), (adjoint(dLQ[2]), adjoint(dLQ[1])); kwargs...)
end
-function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) end
+function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real = MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real = tol, gauge_atol::Real = tol) end
diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl
index 2001c41..78ab79d 100644
--- a/src/pushforwards/polar.jl
+++ b/src/pushforwards/polar.jl
@@ -1,23 +1,23 @@
function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...)
- W, P = WP
+ W, P = WP
ΔW, ΔP = ΔWP
- aWdA = adjoint(W) * ΔA
- K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA)))
- L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W*adjoint(W))*ΔA*inv(P)
- ΔW .= W * K̇ + L̇
- ΔP .= aWdA - K̇*P
+ aWdA = adjoint(W) * ΔA
+ K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA)))
+ L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W * adjoint(W)) * ΔA * inv(P)
+ ΔW .= W * K̇ + L̇
+ ΔP .= aWdA - K̇ * P
MatrixAlgebraKit.zero!(ΔA)
return (ΔW, ΔP)
end
function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...)
- P, Wᴴ = PWᴴ
+ P, Wᴴ = PWᴴ
ΔP, ΔWᴴ = ΔPWᴴ
- dAW = ΔA * adjoint(Wᴴ)
- K̇ = sylvester(P, P, -(dAW - adjoint(dAW)))
- L̇ = inv(P)*ΔA*(Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ)
- ΔWᴴ .= K̇ * Wᴴ + L̇
- ΔP .= dAW - P * K̇
+ dAW = ΔA * adjoint(Wᴴ)
+ K̇ = sylvester(P, P, -(dAW - adjoint(dAW)))
+ L̇ = inv(P) * ΔA * (Diagonal(ones(el...*[Comment body truncated]* |
This adds CPU based reverse rules for all the factorizations (NB below) in MatrixAlgebraKit, using Mooncake or Enzyme. I tried my best to use existing pullbacks.
Some notes:
eigh_trunc!,eig_trunc!, orsvd_trunc!rules due to Can't usetest_reversewithMixedDuplicatedEnzymeAD/Enzyme.jl#2677 andMixedDuplicatedfails inEnzyme.autodiffEnzymeAD/Enzyme.jl#2678. This is because of their mixed return of mutable and immutable objects. I'll work on creating a "bypass" signature for each that passes in a reference to epsilon that should allow us to test directly.