Skip to content

Conversation

@kshyatt
Copy link
Member

@kshyatt kshyatt commented Oct 22, 2025

This adds CPU based reverse rules for all the factorizations (NB below) in MatrixAlgebraKit, using Mooncake or Enzyme. I tried my best to use existing pullbacks.

Some notes:

@kshyatt kshyatt requested review from Jutho and lkdvos October 22, 2025 08:48
@Jutho
Copy link
Member

Jutho commented Oct 22, 2025

I haven't reviewed anything yet, but would already like to suggest to use pushforward instead of "pull forward" 😄 . If you make this change, may I request to also name the folder "pushforwards"? (I am not too enthusiastic about random abbreviations like leaving out vowels.)

@Jutho
Copy link
Member

Jutho commented Oct 22, 2025

Ok I had a brief look already, only to conclude that this will be a though PR to review 😸 . If I have time this week, I will try to go through the generic pushforward definitions already. All the Moonzyme specific stuff, I will probably need a bit of an introduction first (or read the respective manuals).

@kshyatt
Copy link
Member Author

kshyatt commented Oct 22, 2025

@Jutho would it make more sense to remove the forward mode code entirely from this PR and shunt it off into a new one?

@Jutho
Copy link
Member

Jutho commented Oct 22, 2025

Maybe, as you wish. But then you cannot test it in that PR I assume. Which is not necessary for it to be reviewed, so it could still be helpful.

@kshyatt
Copy link
Member Author

kshyatt commented Oct 22, 2025

I do think moving the fwd mode stuff out might make this substantially easier to review, actually. We can always test the forward rules once the (working) reverse ones are in place.

@lkdvos
Copy link
Member

lkdvos commented Oct 22, 2025

I think I agree, it might be nice to split this up into several separated parts to get things moving easier. I would propose the following:

  • Since the GPU stuff isn't used yet, it might be nice to not have that code in the PR, just to avoid clutter
  • Can we consider Enzyme and Mooncake separately? Am I correct thinking that these don't actually interact, so we could review/test/implement them on their own?
  • I think there are a bunch of small improvements to the error and warning messages, I think these are all good and we can simply merge them already
  • It might actually be reasonable to add the remove_f_gaugedependence functions into the main package, not sure how yo feel about this? We kind of need that in the downstream package tests anyways, so it might be useful to just include them directly.

The rest definitely looks great, do you think it could be reasonable to schedule a meeting to just go over the general Mooncake and/or Enzyme approach as a whole? Obviously I don't want to put this on you if you aren't up for it, and I can try and read through them myself, but if you would be okay it might be nice to get the explanation from someone who already has some experience with it :)

@kshyatt
Copy link
Member Author

kshyatt commented Oct 22, 2025

Points 1-3 I 100% agree on. Would also be happy to chat/do a walkthrough of both packages (as much as I can...).

ltRtmp = view(Rtmp, MatrixAlgebraKit.lowertriangularind(Rtmp))
ltRtmp .= zero(eltype(Rtmp))
dR11 .= Rtmp * R11
dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dR11 * invR11 is Rtmp, so that can be simplified.

dQ, dR = dQR
dQ1 = view(dQ, 1:m, 1:m1)
dQ2 = view(dQ, 1:m, m1+1:m2+m1)
dQ3 = m1+m2+1 < size(dQ, 2) ? view(dQ, 1:m, m1+m2+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean m1+m2 < size(dQ, 2) or m1+m2+1 <= size(dQ, 2)?

Also, does just using view(dQ, 1:m, m1+m2+1:size(dQ,2)) always not just work? In the case that m1+m2 == size(dQ, 2), we automatically get a view with size(dQ3, 2)=0, no ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is a bit out of date (working on a new cleaner pushforwards PR 😉 ) but yes indeed I think these are correct observations

@github-actions
Copy link

github-actions bot commented Nov 7, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
index a80dce6..2d795ed 100644
--- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
+++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl
@@ -16,43 +16,47 @@ using LinearAlgebra
 
 
 # two-argument factorizations like LQ, QR, EIG
-for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pushforward!), 
-                    (qr_compact!, qr_pullback!, qr_pushforward!),
-                    (lq_full!, lq_pullback!, lq_pushforward!), 
-                    (lq_compact!, lq_pullback!, lq_pushforward!),
-                    (eig_full!, eig_pullback!, eig_pushforward!),
-                    (left_polar!, left_polar_pullback!, left_polar_pushforward!),
-                    (right_polar!, right_polar_pullback!, right_polar_pushforward!),
-                   )
+for (f, pb, pf) in (
+        (qr_full!, qr_pullback!, qr_pushforward!),
+        (qr_compact!, qr_pullback!, qr_pushforward!),
+        (lq_full!, lq_pullback!, lq_pushforward!),
+        (lq_compact!, lq_pullback!, lq_pushforward!),
+        (eig_full!, eig_pullback!, eig_pushforward!),
+        (left_polar!, left_polar_pullback!, left_polar_pushforward!),
+        (right_polar!, right_polar_pullback!, right_polar_pushforward!),
+    )
     @eval begin
-        function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
-                                              func::Const{typeof($f)},
-                                              ::Type{RT},
-                                              A::Annotation{<:AbstractMatrix},
-                                              arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
-                                              alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                                              kwargs...,
-                                             ) where {RT}
+        function EnzymeRules.augmented_primal(
+                config::EnzymeRules.RevConfigWidth{1},
+                func::Const{typeof($f)},
+                ::Type{RT},
+                A::Annotation{<:AbstractMatrix},
+                arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
+                alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+                kwargs...,
+            ) where {RT}
             cache_arg = nothing
             # form cache if needed
-            cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(arg) <: Const)) ? copy(A.val)  : nothing
+            cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(arg) <: Const)) ? copy(A.val) : nothing
             func.val(A.val, arg.val, alg.val; kwargs...)
-            primal = EnzymeRules.needs_primal(config) ? arg.val  : nothing
+            primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
             shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing
             return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg))
         end
-        function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
-                                     func::Const{typeof($f)},
-                                     dret::Type{RT},
-                                     cache,
-                                     A::Annotation{<:AbstractMatrix},
-                                     arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
-                                     alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                                     kwargs...) where {RT}
+        function EnzymeRules.reverse(
+                config::EnzymeRules.RevConfigWidth{1},
+                func::Const{typeof($f)},
+                dret::Type{RT},
+                cache,
+                A::Annotation{<:AbstractMatrix},
+                arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
+                alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+                kwargs...
+            ) where {RT}
             cache_A, cache_arg = cache
             argval = arg.val
-            Aval   = !isnothing(cache_A) ? cache_A  : A.val
-            ∂arg   = isa(arg, Const)     ? nothing : arg.dval
+            Aval = !isnothing(cache_A) ? cache_A : A.val
+            ∂arg = isa(arg, Const) ? nothing : arg.dval
             if !isa(A, Const) && !isa(arg, Const)
                 A.dval .= zero(eltype(Aval))
                 $pb(A.dval, A.val, argval, ∂arg; kwargs...)
@@ -60,24 +64,25 @@ for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pushforward!),
             !isa(arg, Const) && make_zero!(arg.dval)
             return (nothing, nothing, nothing)
         end
-        function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
-                                     func::Const{typeof($f)},
-                                     ::Type{RT},
-                                     A::Annotation{<:AbstractMatrix},
-                                     arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
-                                     alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                                     kwargs...,
-                                    ) where {RT}
-            ret  = func.val(A.val, arg.val, alg.val; kwargs...)
+        function EnzymeRules.forward(
+                config::EnzymeRules.FwdConfig,
+                func::Const{typeof($f)},
+                ::Type{RT},
+                A::Annotation{<:AbstractMatrix},
+                arg::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
+                alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+                kwargs...,
+            ) where {RT}
+            ret = func.val(A.val, arg.val, alg.val; kwargs...)
             arg1, arg2 = ret
             m, n = size(A.val)
 
             if isa(arg, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const)
                 darg1, darg2 = arg.dval
-                dA           = A.dval
+                dA = A.dval
                 darg1, darg2 = $pf(dA, A.val, ret, arg.dval)
-                dA          .= zero(eltype(A.val))
-                shadow       = (darg1, darg2)
+                dA .= zero(eltype(A.val))
+                shadow = (darg1, darg2)
             elseif isa(A, Const) && !!isa(arg, Union{Duplicated, DuplicatedNoNeed})
                 make_zero!(arg.dval)
                 shadow = arg.dval
@@ -96,54 +101,59 @@ for (f, pb, pf) in ((qr_full!, qr_pullback!, qr_pushforward!),
     end
 end
 
-for (f, pb, pf) in ((qr_null!, qr_null_pullback!, qr_null_pushforward!), 
-                    (lq_null!, lq_null_pullback!, lq_null_pushforward!), 
-                   )
+for (f, pb, pf) in (
+        (qr_null!, qr_null_pullback!, qr_null_pushforward!),
+        (lq_null!, lq_null_pullback!, lq_null_pushforward!),
+    )
     @eval begin
-        function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
-                                              func::Const{typeof($f)},
-                                              ::Type{RT},
-                                              A::Annotation{<:AbstractMatrix},
-                                              arg::Annotation{<:AbstractMatrix},
-                                              alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                                              kwargs...,
-                                             ) where {RT}
+        function EnzymeRules.augmented_primal(
+                config::EnzymeRules.RevConfigWidth{1},
+                func::Const{typeof($f)},
+                ::Type{RT},
+                A::Annotation{<:AbstractMatrix},
+                arg::Annotation{<:AbstractMatrix},
+                alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+                kwargs...,
+            ) where {RT}
             cache_arg = nothing
             # form cache if needed
             cache_A = nothing #copy(A.val)
             func.val(copy(A.val), arg.val, alg.val; kwargs...)
-            primal = EnzymeRules.needs_primal(config) ? arg.val  : nothing
+            primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
             shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing
             return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_arg))
         end
-                
-        function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
-                                     func::Const{typeof($f)},
-                                     dret::Type{RT},
-                                     cache,
-                                     A::Annotation{<:AbstractMatrix},
-                                     arg::Annotation{<:AbstractMatrix},
-                                     alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                                     tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(arg.val),
-                                     rank_atol::Real=tol,
-                                     gauge_atol::Real=tol,
-                                     kwargs...) where {RT}
+
+        function EnzymeRules.reverse(
+                config::EnzymeRules.RevConfigWidth{1},
+                func::Const{typeof($f)},
+                dret::Type{RT},
+                cache,
+                A::Annotation{<:AbstractMatrix},
+                arg::Annotation{<:AbstractMatrix},
+                alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+                tol::Real = MatrixAlgebraKit.default_pullback_gaugetol(arg.val),
+                rank_atol::Real = tol,
+                gauge_atol::Real = tol,
+                kwargs...
+            ) where {RT}
             cache_A, cache_arg = cache
-            Aval   = isnothing(cache_A) ? A.val  : cache_A
+            Aval = isnothing(cache_A) ? A.val : cache_A
             if !isa(A, Const) && !isa(arg, Const)
                 A.dval .= zero(eltype(A.val))
                 $pb(A.dval, A.val, arg.val, arg.dval; kwargs...)
             end
             return (nothing, nothing, nothing)
         end
-        function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
-                                     func::Const{typeof($f)},
-                                     ::Type{RT},
-                                     A::Annotation{<:AbstractMatrix},
-                                     arg::Annotation{<:AbstractMatrix},
-                                     alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                                     kwargs...,
-                                    ) where {RT}
+        function EnzymeRules.forward(
+                config::EnzymeRules.FwdConfig,
+                func::Const{typeof($f)},
+                ::Type{RT},
+                A::Annotation{<:AbstractMatrix},
+                arg::Annotation{<:AbstractMatrix},
+                alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+                kwargs...,
+            ) where {RT}
             ret = func.val(A.val, arg.val, alg.val; kwargs...)
 
             if isa(arg, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const)
@@ -170,15 +180,16 @@ for (f, pb, pf) in ((qr_null!, qr_null_pullback!, qr_null_pushforward!),
 end
 
 
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
-                             func::Const{typeof(svd_compact!)},
-                             ::Type{RT},
-                             A::Annotation{<:AbstractMatrix},
-                             USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
-                             alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                             kwargs...,
-                            ) where {RT}
-    ret    = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
+function EnzymeRules.forward(
+        config::EnzymeRules.FwdConfig,
+        func::Const{typeof(svd_compact!)},
+        ::Type{RT},
+        A::Annotation{<:AbstractMatrix},
+        USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+        alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+        kwargs...,
+    ) where {RT}
+    ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
     shadow = if EnzymeRules.needs_shadow(config)
         svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval)
     else
@@ -196,20 +207,21 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
 end
 
 # TODO
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
-                             func::Const{typeof(svd_full!)},
-                             ::Type{RT},
-                             A::Annotation{<:AbstractMatrix},
-                             USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
-                             alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                             kwargs...,
-                            ) where {RT}
-    ret        = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
+function EnzymeRules.forward(
+        config::EnzymeRules.FwdConfig,
+        func::Const{typeof(svd_full!)},
+        ::Type{RT},
+        A::Annotation{<:AbstractMatrix},
+        USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+        alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+        kwargs...,
+    ) where {RT}
+    ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
     shadow = if EnzymeRules.needs_shadow(config)
-            svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval)
-        else
-            nothing
-        end
+        svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval)
+    else
+        nothing
+    end
     if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
         return Duplicated(ret, shadow)
     elseif EnzymeRules.needs_shadow(config)
@@ -222,33 +234,36 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
 end
 for f in (:svd_compact!, :svd_full!)
     @eval begin
-        function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
-                                              func::Const{typeof($f)},
-                                              ::Type{RT},
-                                              A::Annotation{<:AbstractMatrix},
-                                              USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
-                                              alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                                              kwargs...,
-                                             ) where {RT}
+        function EnzymeRules.augmented_primal(
+                config::EnzymeRules.RevConfigWidth{1},
+                func::Const{typeof($f)},
+                ::Type{RT},
+                A::Annotation{<:AbstractMatrix},
+                USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+                alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+                kwargs...,
+            ) where {RT}
             # form cache if needed
-            cache_USVᴴ = (EnzymeRules.overwritten(config)[3] && !(typeof(USVᴴ) <: Const)) ? copy(USVᴴ.val)  : nothing
-            cache_A    = (EnzymeRules.overwritten(config)[2] && !(typeof(A) <: Const)) ? copy(A.val)  : nothing
+            cache_USVᴴ = (EnzymeRules.overwritten(config)[3] && !(typeof(USVᴴ) <: Const)) ? copy(USVᴴ.val) : nothing
+            cache_A = (EnzymeRules.overwritten(config)[2] && !(typeof(A) <: Const)) ? copy(A.val) : nothing
             func.val(A.val, USVᴴ.val, alg.val; kwargs...)
-            primal = EnzymeRules.needs_primal(config) ? USVᴴ.val  : nothing
+            primal = EnzymeRules.needs_primal(config) ? USVᴴ.val : nothing
             shadow = EnzymeRules.needs_shadow(config) ? USVᴴ.dval : nothing
             return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ))
         end
-        function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
-                                     func::Const{typeof($f)},
-                                     dret::Type{RT},
-                                     cache,
-                                     A::Annotation{<:AbstractMatrix},
-                                     USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
-                                     alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                                     kwargs...) where {RT}
+        function EnzymeRules.reverse(
+                config::EnzymeRules.RevConfigWidth{1},
+                func::Const{typeof($f)},
+                dret::Type{RT},
+                cache,
+                A::Annotation{<:AbstractMatrix},
+                USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+                alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+                kwargs...
+            ) where {RT}
             cache_A, cache_USVᴴ = cache
             USVᴴval = !isnothing(cache_USVᴴ) ? cache_USVᴴ : USVᴴ.val
-            ∂USVᴴ   = isa(USVᴴ, Const) ? nothing : USVᴴ.dval
+            ∂USVᴴ = isa(USVᴴ, Const) ? nothing : USVᴴ.dval
             if !isa(A, Const) && !isa(USVᴴ, Const)
                 A.dval .= zero(eltype(A.dval))
                 MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴval, ∂USVᴴ; kwargs...)
@@ -261,26 +276,27 @@ for f in (:svd_compact!, :svd_full!)
     end
 end
 
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
-                                      func::Const{typeof(svd_trunc!)},
-                                      ::Type{RT},
-                                      A::Annotation{<:AbstractMatrix},
-                                      USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
-                                      ϵ::Annotation{Vector{T}},
-                                      alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
-                                      kwargs...,
-                                     ) where {RT, T<:Real}
+function EnzymeRules.augmented_primal(
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(svd_trunc!)},
+        ::Type{RT},
+        A::Annotation{<:AbstractMatrix},
+        USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+        ϵ::Annotation{Vector{T}},
+        alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+        kwargs...,
+    ) where {RT, T <: Real}
     # form cache if needed
-    cache_A     = copy(A.val)
+    cache_A = copy(A.val)
     svd_compact!(A.val, USVᴴ.val, alg.val.alg)
-    cache_USVᴴ  = copy.(USVᴴ.val)
-    USVᴴ′, ind  = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc)
-    ϵ.val[1]    = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
-    primal      = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val)  : nothing
+    cache_USVᴴ = copy.(USVᴴ.val)
+    USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc)
+    ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
+    primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing
     shadow_USVᴴ = if !isa(A, Const) && !isa(USVᴴ, Const)
         dU, dS, dVᴴ = USVᴴ.dval
-        dStrunc  = Diagonal(diagview(dS)[ind])
-        dUtrunc  = dU[:, ind]
+        dStrunc = Diagonal(diagview(dS)[ind])
+        dUtrunc = dU[:, ind]
         dVᴴtrunc = dVᴴ[ind, :]
         (dUtrunc, dStrunc, dVᴴtrunc)
     else
@@ -289,17 +305,19 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
     shadow = EnzymeRules.needs_shadow(config) ? (shadow_USVᴴ..., ϵ.dval) : nothing
     return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, shadow_USVᴴ, ind))
 end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
-                             func::Const{typeof(svd_trunc!)},
-                             dret::Type{RT},
-                             cache,
-                             A::Annotation{<:AbstractMatrix},
-                             USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
-                             ϵ::Annotation{Vector{T}},
-                             alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
-                             kwargs...) where {RT, T<:Real}
+function EnzymeRules.reverse(
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(svd_trunc!)},
+        dret::Type{RT},
+        cache,
+        A::Annotation{<:AbstractMatrix},
+        USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
+        ϵ::Annotation{Vector{T}},
+        alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+        kwargs...
+    ) where {RT, T <: Real}
     cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache
-    U, S, Vᴴ    = cache_USVᴴ
+    U, S, Vᴴ = cache_USVᴴ
     dU, dS, dVᴴ = shadow_USVᴴ
     if !isa(A, Const) && !isa(USVᴴ, Const)
         A.dval .= zero(eltype(A.val))
@@ -314,21 +332,22 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
     return (nothing, nothing, nothing, nothing)
 end
 
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
-                             func::Const{typeof(eigh_vals!)},
-                             ::Type{RT},
-                             A::Annotation{<:AbstractMatrix},
-                             D::Annotation{<:AbstractVector},
-                             alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                             kwargs...,
-                            ) where {RT}
+function EnzymeRules.forward(
+        config::EnzymeRules.FwdConfig,
+        func::Const{typeof(eigh_vals!)},
+        ::Type{RT},
+        A::Annotation{<:AbstractMatrix},
+        D::Annotation{<:AbstractVector},
+        alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+        kwargs...,
+    ) where {RT}
     Dmat, V = eigh_full(A.val; kwargs...)
     if isa(D, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const)
-        ∂K      = inv(V) * A.dval * V
-        ∂Kdiag  = diag(∂K)
+        ∂K = inv(V) * A.dval * V
+        ∂Kdiag = diag(∂K)
         D.dval .= real.(copy(∂Kdiag))
         A.dval .= zero(eltype(A.val))
-        shadow  = D.dval 
+        shadow = D.dval
     elseif isa(A, Const) && !!isa(D, Union{Duplicated, DuplicatedNoNeed})
         make_zero!(D.dval)
         shadow = D.dval
@@ -340,20 +359,21 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
     elseif EnzymeRules.needs_shadow(config)
         return shadow
     elseif EnzymeRules.needs_primal(config)
-        return Dmat.diag 
+        return Dmat.diag
     else
         return nothing
     end
 end
 
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
-                             func::Const{typeof(eigh_full!)},
-                             ::Type{RT},
-                             A::Annotation{<:AbstractMatrix},
-                             DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
-                             alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                             kwargs...,
-                            ) where {RT}
+function EnzymeRules.forward(
+        config::EnzymeRules.FwdConfig,
+        func::Const{typeof(eigh_full!)},
+        ::Type{RT},
+        A::Annotation{<:AbstractMatrix},
+        DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+        alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+        kwargs...,
+    ) where {RT}
     Dmat, V = func.val(A.val, DV.val; kwargs...)
     if isa(A, Const) || all(iszero, A.dval)
         make_zero!(DV.dval[1])
@@ -361,18 +381,18 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
         make_zero!(A.dval)
         shadow = (DV.dval[1], DV.dval[2])
     else
-        ∂K      = inv(V) * A.dval * V
-        ∂Kdiag  = diagview(∂K)
-        ∂Ddiag  = diagview(DV.dval[1])
+        ∂K = inv(V) * A.dval * V
+        ∂Kdiag = diagview(∂K)
+        ∂Ddiag = diagview(DV.dval[1])
         ∂Ddiag .= real.(∂Kdiag)
-        D       = diagview(Dmat)
-        dDD     = transpose(D) .- D
-        ∂K    ./= dDD 
+        D = diagview(Dmat)
+        dDD = transpose(D) .- D
+        ∂K ./= dDD
         ∂Kdiag .= zero(eltype(V))
         mul!(DV.dval[2], V, ∂K, 1, 0)
-        shadow  = DV.dval[2]
+        shadow = DV.dval[2]
         A.dval .= zero(eltype(A.val))
-        shadow  = (Diagonal(∂Ddiag), DV.dval[2])
+        shadow = (Diagonal(∂Ddiag), DV.dval[2])
     end
     if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
         return Duplicated((Dmat, V), shadow)
@@ -385,24 +405,25 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
     end
 end
 
-function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
-                             func::Const{typeof(eig_vals!)},
-                             ::Type{RT},
-                             A::Annotation{<:AbstractMatrix},
-                             D::Annotation{<:AbstractVector},
-                             alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                             kwargs...,
-                            ) where {RT}
+function EnzymeRules.forward(
+        config::EnzymeRules.FwdConfig,
+        func::Const{typeof(eig_vals!)},
+        ::Type{RT},
+        A::Annotation{<:AbstractMatrix},
+        D::Annotation{<:AbstractVector},
+        alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+        kwargs...,
+    ) where {RT}
     Dval, V = eig_full(A.val, alg.val; kwargs...)
     if isa(D, Union{Duplicated, DuplicatedNoNeed}) && !isa(A, Const)
-        ∂K      = inv(V) * A.dval * V
-        ∂Kdiag  = diag(∂K)
+        ∂K = inv(V) * A.dval * V
+        ∂Kdiag = diag(∂K)
         D.dval .= copy(∂Kdiag)
         A.dval .= zero(eltype(A.val))
-        shadow  = D.dval
+        shadow = D.dval
     elseif isa(A, Const) && !!isa(D, Union{Duplicated, DuplicatedNoNeed})
         make_zero!(D.dval)
-        shadow  = D.dval
+        shadow = D.dval
     end
     eig_vals!(A.val, zeros(complex(eltype(A.val)), size(A.val, 1)))
     D.val .= diagview(Dval)
@@ -411,30 +432,31 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
     elseif EnzymeRules.needs_shadow(config)
         return shadow
     elseif EnzymeRules.needs_primal(config)
-        return Dmat.diag 
+        return Dmat.diag
     else
         return nothing
     end
 end
 
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
-                                      func::Const{typeof(eigh_trunc!)},
-                                      ::Type{RT},
-                                      A::Annotation{<:AbstractMatrix},
-                                      DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
-                                      ϵ::Annotation{Vector{T}},
-                                      alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
-                                      kwargs...,
-                                     ) where {RT, T}
+function EnzymeRules.augmented_primal(
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(eigh_trunc!)},
+        ::Type{RT},
+        A::Annotation{<:AbstractMatrix},
+        DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+        ϵ::Annotation{Vector{T}},
+        alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+        kwargs...,
+    ) where {RT, T}
     # form cache if needed
-    cache_A   = copy(A.val)
+    cache_A = copy(A.val)
     MatrixAlgebraKit.eigh_full!(A.val, DV.val, alg.val.alg)
-    cache_DV  = copy.(DV.val)
-    DV′, ind  = MatrixAlgebraKit.truncate(eigh_trunc!, DV.val, alg.val.trunc)
-    ϵ.val[1]  = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
-    primal    = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val)  : nothing
+    cache_DV = copy.(DV.val)
+    DV′, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV.val, alg.val.trunc)
+    ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
+    primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
     shadow_DV = if !isa(A, Const) && !isa(DV, Const)
-        dD, dV  = DV.dval
+        dD, dV = DV.dval
         dDtrunc = Diagonal(diagview(dD)[ind])
         dVtrunc = dV[:, ind]
         (dDtrunc, dVtrunc)
@@ -444,17 +466,19 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
     shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., [zero(T)]) : nothing
     return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind))
 end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
-                             func::Const{typeof(eigh_trunc!)},
-                             ::Type{RT},
-                             cache,
-                             A::Annotation{<:AbstractMatrix},
-                             DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
-                             ϵ::Annotation{Vector{T}},
-                             alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
-                             kwargs...) where {RT, T}
+function EnzymeRules.reverse(
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(eigh_trunc!)},
+        ::Type{RT},
+        cache,
+        A::Annotation{<:AbstractMatrix},
+        DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+        ϵ::Annotation{Vector{T}},
+        alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+        kwargs...
+    ) where {RT, T}
     cache_A, cache_DV, cache_dDVtrunc, ind = cache
-    D, V   = cache_DV
+    D, V = cache_DV
     dD, dV = cache_dDVtrunc
     if !isa(A, Const) && !isa(DV, Const)
         A.dval .= zero(eltype(A.val))
@@ -469,24 +493,25 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
     return (nothing, nothing, nothing, nothing)
 end
 
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
-                                      func::Const{typeof(eig_trunc!)},
-                                      ::Type{RT},
-                                      A::Annotation{<:AbstractMatrix},
-                                      DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
-                                      ϵ::Annotation{Vector{T}},
-                                      alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
-                                      kwargs...,
-                                     ) where {RT, T}
+function EnzymeRules.augmented_primal(
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(eig_trunc!)},
+        ::Type{RT},
+        A::Annotation{<:AbstractMatrix},
+        DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+        ϵ::Annotation{Vector{T}},
+        alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+        kwargs...,
+    ) where {RT, T}
     # form cache if needed
-    cache_A   = copy(A.val)
+    cache_A = copy(A.val)
     eig_full!(A.val, DV.val, alg.val.alg)
-    cache_DV  = copy.(DV.val)
-    DV′, ind  = MatrixAlgebraKit.truncate(eig_trunc!, DV.val, alg.val.trunc)
-    ϵ.val[1]  = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
-    primal    = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val)  : nothing
+    cache_DV = copy.(DV.val)
+    DV′, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV.val, alg.val.trunc)
+    ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
+    primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
     shadow_DV = if !isa(A, Const) && !isa(DV, Const)
-        dD, dV  = DV.dval
+        dD, dV = DV.dval
         dDtrunc = Diagonal(diagview(dD)[ind])
         dVtrunc = dV[:, ind]
         (dDtrunc, dVtrunc)
@@ -496,17 +521,19 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
     shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., [zero(T)]) : nothing
     return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind))
 end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
-                             func::Const{typeof(eig_trunc!)},
-                             ::Type{RT},
-                             cache,
-                             A::Annotation{<:AbstractMatrix},
-                             DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
-                             ϵ::Annotation{Vector{T}},
-                             alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
-                             kwargs...) where {RT, T}
+function EnzymeRules.reverse(
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(eig_trunc!)},
+        ::Type{RT},
+        cache,
+        A::Annotation{<:AbstractMatrix},
+        DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+        ϵ::Annotation{Vector{T}},
+        alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
+        kwargs...
+    ) where {RT, T}
     cache_A, cache_DV, cache_dDVtrunc, ind = cache
-    D, V   = cache_DV
+    D, V = cache_DV
     dD, dV = cache_dDVtrunc
     if !isa(A, Const) && !isa(DV, Const)
         A.dval .= zero(eltype(A.val))
@@ -521,47 +548,49 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
     return (nothing, nothing, nothing, nothing)
 end
 
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
-                                      func::Const{typeof(eigh_full!)},
-                                      ::Type{RT},
-                                      A::Annotation{<:AbstractMatrix},
-                                      DV::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
-                                      alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                                      kwargs...,
-                                     ) where {RT}
+function EnzymeRules.augmented_primal(
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(eigh_full!)},
+        ::Type{RT},
+        A::Annotation{<:AbstractMatrix},
+        DV::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix}},
+        alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+        kwargs...,
+    ) where {RT}
     # form cache if needed
     cache_DV = nothing
-    cache_A  = EnzymeRules.overwritten(config)[2] ? copy(A.val)  : nothing
+    cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
     func.val(A.val, DV.val, alg.val; kwargs...)
     primal = EnzymeRules.needs_primal(config) ? DV.val : nothing
     shadow = EnzymeRules.needs_shadow(config) ? DV.dval : nothing
     return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV))
 end
 
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
-                             func::Const{typeof(eigh_full!)},
-                             ::Type{RT},
-                             cache,
-                             A::Annotation{<:AbstractMatrix},
-                             DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
-                             alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                             kwargs...,
-                            ) where {RT}
+function EnzymeRules.reverse(
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(eigh_full!)},
+        ::Type{RT},
+        cache,
+        A::Annotation{<:AbstractMatrix},
+        DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
+        alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+        kwargs...,
+    ) where {RT}
 
     cache_A, cache_DV = cache
-    DVval   = !isnothing(cache_DV) ? cache_DV : DV.val
-    Aval    = !isnothing(cache_A)  ? cache_A  : A.val
-    ∂DV     = isa(DV, Const) ? nothing : DV.dval
+    DVval = !isnothing(cache_DV) ? cache_DV : DV.val
+    Aval = !isnothing(cache_A) ? cache_A : A.val
+    ∂DV = isa(DV, Const) ? nothing : DV.dval
     if !isa(A, Const) && !isa(DV, Const)
-        Dmat, V   = DVval
+        Dmat, V = DVval
         ∂Dmat, ∂V = ∂DV
-        A.dval   .= zero(eltype(Aval))
+        A.dval .= zero(eltype(Aval))
         MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, DVval, ∂DV; kwargs...)
         A.dval .*= 2
         diagview(A.dval) ./= 2
         for i in 1:size(A.dval, 1), j in 1:size(A.dval, 2)
             if i > j
-                A.dval[i, j] = zero(eltype(A.dval)) 
+                A.dval[i, j] = zero(eltype(A.dval))
             end
         end
     end
@@ -571,40 +600,42 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
     return (nothing, nothing, nothing)
 end
 
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
-                                      func::Const{typeof(eig_vals!)},
-                                      ::Type{RT},
-                                      A::Annotation{<:AbstractMatrix},
-                                      D::Annotation{<:AbstractVector},
-                                      alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm};
-                                      kwargs...,
-                                     ) where {RT}
+function EnzymeRules.augmented_primal(
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(eig_vals!)},
+        ::Type{RT},
+        A::Annotation{<:AbstractMatrix},
+        D::Annotation{<:AbstractVector},
+        alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm};
+        kwargs...,
+    ) where {RT}
     cache_D = nothing
-    cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val)  : nothing
+    cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
     func.val(A.val, D.val, alg.val; kwargs...)
-    primal = EnzymeRules.needs_primal(config) ? D.val  : nothing
+    primal = EnzymeRules.needs_primal(config) ? D.val : nothing
     shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing
     # form cache if needed
     return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D))
 end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
-                             func::Const{typeof(eig_vals!)},
-                             ::Type{RT},
-                             cache,
-                             A::Annotation{<:AbstractMatrix},
-                             D::Annotation{<:AbstractVector},
-                             alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                             kwargs...,
-                            ) where {RT}
+function EnzymeRules.reverse(
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(eig_vals!)},
+        ::Type{RT},
+        cache,
+        A::Annotation{<:AbstractMatrix},
+        D::Annotation{<:AbstractVector},
+        alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+        kwargs...,
+    ) where {RT}
 
     cache_A, cache_D = cache
     Dval = !isnothing(cache_D) ? cache_D : D.val
     Aval = !isnothing(cache_A) ? cache_A : A.val
-    ∂D   = isa(D, Const) ? nothing : D.dval
+    ∂D = isa(D, Const) ? nothing : D.dval
     if !isa(A, Const) && !isa(D, Const)
-        _, V    = eig_full(Aval, alg.val)
+        _, V = eig_full(Aval, alg.val)
         A.dval .= zero(eltype(Aval))
-        PΔV     = V' \ Diagonal(D.dval)
+        PΔV = V' \ Diagonal(D.dval)
         if eltype(A.dval) <: Real
             ΔAc = PΔV * V'
             A.dval .+= real.(ΔAc)
@@ -618,45 +649,47 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
     return (nothing, nothing, nothing)
 end
 
-function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
-                                      func::Const{typeof(eigh_vals!)},
-                                      ::Type{RT},
-                                      A::Annotation{<:AbstractMatrix},
-                                      D::Annotation{<:AbstractVector},
-                                      alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm};
-                                      kwargs...,
-                                     ) where {RT}
+function EnzymeRules.augmented_primal(
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(eigh_vals!)},
+        ::Type{RT},
+        A::Annotation{<:AbstractMatrix},
+        D::Annotation{<:AbstractVector},
+        alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm};
+        kwargs...,
+    ) where {RT}
     cache_D = nothing
-    cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val)  : nothing
+    cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : nothing
     func.val(A.val, D.val, alg.val; kwargs...)
-    primal = EnzymeRules.needs_primal(config) ? D.val  : nothing
+    primal = EnzymeRules.needs_primal(config) ? D.val : nothing
     shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing
     # form cache if needed
     return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D))
 end
-function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
-                             func::Const{typeof(eigh_vals!)},
-                             ::Type{RT},
-                             cache,
-                             A::Annotation{<:AbstractMatrix},
-                             D::Annotation{<:AbstractVector},
-                             alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
-                             kwargs...,
-                            ) where {RT}
+function EnzymeRules.reverse(
+        config::EnzymeRules.RevConfigWidth{1},
+        func::Const{typeof(eigh_vals!)},
+        ::Type{RT},
+        cache,
+        A::Annotation{<:AbstractMatrix},
+        D::Annotation{<:AbstractVector},
+        alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
+        kwargs...,
+    ) where {RT}
 
     cache_A, cache_D = cache
     Dval = !isnothing(cache_D) ? cache_D : D.val
     Aval = !isnothing(cache_A) ? cache_A : A.val
-    ∂D   = isa(D, Const) ? nothing : D.dval
+    ∂D = isa(D, Const) ? nothing : D.dval
     if !isa(A, Const) && !isa(D, Const)
         _, V = eigh_full(Aval, alg.val)
-        A.dval   .= zero(eltype(Aval))
+        A.dval .= zero(eltype(Aval))
         mul!(A.dval, V * Diagonal(real(∂D)), V', 1, 0)
         A.dval .*= 2
         diagview(A.dval) ./= 2
         for i in 1:size(A.dval, 1), j in 1:size(A.dval, 2)
             if i > j
-                A.dval[i, j] = zero(eltype(A.dval)) 
+                A.dval[i, j] = zero(eltype(A.dval))
             end
         end
     end
diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
index 1797157..9bc5cae 100644
--- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
+++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
@@ -4,7 +4,7 @@ using Mooncake
 using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
 using MatrixAlgebraKit
 using MatrixAlgebraKit: inv_safe, diagview
-using MatrixAlgebraKit: svd_pushforward! 
+using MatrixAlgebraKit: svd_pushforward!
 using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pushforward!, lq_pushforward!
 using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pushforward!, lq_null_pushforward!
 using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pushforward!, eigh_pushforward!
@@ -12,74 +12,76 @@ using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_
 using LinearAlgebra
 
 # two-argument factorizations like LQ, QR, EIG
-for (f, pb, pf, adj) in ((qr_full!,    qr_pullback!, qr_pushforward!, :dqr_adjoint), 
-                         (qr_compact!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
-                         (lq_full!,    lq_pullback!, lq_pushforward!, :dlq_adjoint), 
-                         (lq_compact!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
-                         (eig_full!,   eig_pullback!, eig_pushforward!, :deig_adjoint),
-                         (eigh_full!,  eigh_pullback!, eigh_pushforward!, :deigh_adjoint),
-                         (left_polar!, left_polar_pullback!, left_polar_pushforward!, :dleft_polar_adjoint),
-                         (right_polar!, right_polar_pullback!, right_polar_pushforward!, :dright_polar_adjoint),
-                        )
+for (f, pb, pf, adj) in (
+        (qr_full!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
+        (qr_compact!, qr_pullback!, qr_pushforward!, :dqr_adjoint),
+        (lq_full!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
+        (lq_compact!, lq_pullback!, lq_pushforward!, :dlq_adjoint),
+        (eig_full!, eig_pullback!, eig_pushforward!, :deig_adjoint),
+        (eigh_full!, eigh_pullback!, eigh_pushforward!, :deigh_adjoint),
+        (left_polar!, left_polar_pullback!, left_polar_pushforward!, :dleft_polar_adjoint),
+        (right_polar!, right_polar_pullback!, right_polar_pushforward!, :dright_polar_adjoint),
+    )
 
     @eval begin
         @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
         function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
-            A, dA       = arrayify(A_dA)
-            dA         .= zero(eltype(A))
-            args        = Mooncake.primal(args_dargs)
-            dargs       = Mooncake.tangent(args_dargs)
+            A, dA = arrayify(A_dA)
+            dA .= zero(eltype(A))
+            args = Mooncake.primal(args_dargs)
+            dargs = Mooncake.tangent(args_dargs)
             arg1, darg1 = arrayify(args[1], dargs[1])
             arg2, darg2 = arrayify(args[2], dargs[2])
             function $adj(::Mooncake.NoRData)
                 dA = $pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...)
                 return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
             end
-            args   = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
+            args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
             darg1 .= zero(eltype(arg1))
             darg2 .= zero(eltype(arg2))
             return Mooncake.CoDual(args, dargs), $adj
         end
         @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
         function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
-            A, dA        = arrayify(A_dA)
-            args         = Mooncake.primal(args_dargs)
-            args         = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
-            dargs        = Mooncake.tangent(args_dargs)
-            arg1, darg1  = arrayify(args[1], dargs[1])
-            arg2, darg2  = arrayify(args[2], dargs[2])
+            A, dA = arrayify(A_dA)
+            args = Mooncake.primal(args_dargs)
+            args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
+            dargs = Mooncake.tangent(args_dargs)
+            arg1, darg1 = arrayify(args[1], dargs[1])
+            arg2, darg2 = arrayify(args[2], dargs[2])
             darg1, darg2 = $pf(dA, A, (arg1, arg2), (darg1, darg2))
-            dA          .= zero(eltype(A))
+            dA .= zero(eltype(A))
             return Mooncake.Dual(args, dargs)
         end
     end
 end
 
-for (f, f_full, pb, pf, adj) in ((qr_null!, qr_full, qr_null_pullback!, qr_null_pushforward!, :dqr_null_adjoint), 
-                                 (lq_null!, lq_full, lq_null_pullback!, lq_null_pushforward!, :dlq_null_adjoint), 
-                                )
+for (f, f_full, pb, pf, adj) in (
+        (qr_null!, qr_full, qr_null_pullback!, qr_null_pushforward!, :dqr_null_adjoint),
+        (lq_null!, lq_full, lq_null_pullback!, lq_null_pushforward!, :dlq_null_adjoint),
+    )
     @eval begin
         @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
         function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, arg_darg::CoDual{<:AbstractMatrix}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
-            A, dA     = arrayify(A_dA)
-            Ac        = MatrixAlgebraKit.copy_input($f_full, A)
+            A, dA = arrayify(A_dA)
+            Ac = MatrixAlgebraKit.copy_input($f_full, A)
             arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg))
-            arg       = $f(Ac, arg, Mooncake.primal(alg_dalg))
+            arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
             function $adj(::Mooncake.NoRData)
-                dA   .= zero(eltype(A))
+                dA .= zero(eltype(A))
                 $pb(dA, A, arg, darg; kwargs...)
                 return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
             end
-            return arg_darg, $adj 
+            return arg_darg, $adj
         end
         @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
         function Mooncake.frule!!(f_df::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, arg_darg::Dual{<:AbstractMatrix}, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
-            A, dA     = arrayify(A_dA)
-            Ac        = MatrixAlgebraKit.copy_input($f_full, A)
+            A, dA = arrayify(A_dA)
+            Ac = MatrixAlgebraKit.copy_input($f_full, A)
             arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg))
-            arg       = $f(Ac, arg, Mooncake.primal(alg_dalg))
+            arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
             $pf(dA, A, arg, darg; kwargs...)
-            dA       .= zero(dA)
+            dA .= zero(dA)
             return arg_darg
         end
     end
@@ -89,33 +91,33 @@ end
 @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
 function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...)
     # compute primal
-    D_    = Mooncake.primal(D_dD)
-    dD_   = Mooncake.tangent(D_dD)
-    A_    = Mooncake.primal(A_dA)
-    dA_   = Mooncake.tangent(A_dA)
+    D_ = Mooncake.primal(D_dD)
+    dD_ = Mooncake.tangent(D_dD)
+    A_ = Mooncake.primal(A_dA)
+    dA_ = Mooncake.tangent(A_dA)
     A, dA = arrayify(A_, dA_)
     D, dD = arrayify(D_, dD_)
     nD, V = eig_full(A, alg_dalg.primal; kwargs...)
 
     # update tangent
-    tmp   = V \ dA
-    dD   .= diagview(tmp * V)
-    dA   .= zero(eltype(dA))
+    tmp = V \ dA
+    dD .= diagview(tmp * V)
+    dA .= zero(eltype(dA))
     return Mooncake.Dual(nD.diag, dD_)
 end
 
 function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
     # compute primal
-    D_    = Mooncake.primal(D_dD)
-    dD_   = Mooncake.tangent(D_dD)
-    A_    = Mooncake.primal(A_dA)
-    dA_   = Mooncake.tangent(A_dA)
+    D_ = Mooncake.primal(D_dD)
+    dD_ = Mooncake.tangent(D_dD)
+    A_ = Mooncake.primal(A_dA)
+    dA_ = Mooncake.tangent(A_dA)
     A, dA = arrayify(A_, dA_)
     D, dD = arrayify(D_, dD_)
-    dA   .= zero(eltype(dA))
-    # update primal 
-    DV  = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
-    V   = DV[2]
+    dA .= zero(eltype(dA))
+    # update primal
+    DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
+    V = DV[2]
     dD .= zero(eltype(D))
     function deig_vals_adjoint(::Mooncake.NoRData)
         PΔV = V' \ Diagonal(dD)
@@ -163,30 +165,30 @@ end
 @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
 function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...)
     # compute primal
-    D_    = Mooncake.primal(D_dD)
-    dD_   = Mooncake.tangent(D_dD)
-    A_    = Mooncake.primal(A_dA)
-    dA_   = Mooncake.tangent(A_dA)
+    D_ = Mooncake.primal(D_dD)
+    dD_ = Mooncake.tangent(D_dD)
+    A_ = Mooncake.primal(A_dA)
+    dA_ = Mooncake.tangent(A_dA)
     A, dA = arrayify(A_, dA_)
     D, dD = arrayify(D_, dD_)
     nD, V = eigh_full(A, alg_dalg.primal; kwargs...)
     # update tangent
-    tmp   = inv(V) * dA * V
-    dD   .= real.(diagview(tmp))
-    D    .= nD.diag
-    dA   .= zero(eltype(dA))
+    tmp = inv(V) * dA * V
+    dD .= real.(diagview(tmp))
+    D .= nD.diag
+    dA .= zero(eltype(dA))
     return D_dD
 end
 
 function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
     # compute primal
-    D_    = Mooncake.primal(D_dD)
-    dD_   = Mooncake.tangent(D_dD)
-    A_    = Mooncake.primal(A_dA)
-    dA_   = Mooncake.tangent(A_dA)
+    D_ = Mooncake.primal(D_dD)
+    dD_ = Mooncake.tangent(D_dD)
+    A_ = Mooncake.primal(A_dA)
+    dA_ = Mooncake.tangent(A_dA)
     A, dA = arrayify(A_, dA_)
     D, dD = arrayify(D_, dD_)
-    DV    = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
+    DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
     function deigh_vals_adjoint(::Mooncake.NoRData)
         mul!(dA, DV[2] * Diagonal(real(dD)), DV[2]', 1, 0)
         return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
@@ -199,60 +201,60 @@ for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal))
     @eval begin
         @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
         function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual; kwargs...)
-            A, dA   = arrayify(A_dA)
-            USVᴴ    = Mooncake.primal(USVᴴ_dUSVᴴ)
-            dUSVᴴ   = Mooncake.tangent(USVᴴ_dUSVᴴ)
-            U, dU   = arrayify(USVᴴ[1], dUSVᴴ[1])
-            S, dS   = arrayify(USVᴴ[2], dUSVᴴ[2])
+            A, dA = arrayify(A_dA)
+            USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
+            dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
+            U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
+            S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
             Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
-            USVᴴ    = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...)
+            USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...)
             function dsvd_adjoint(::Mooncake.NoRData)
-                dA   .= zero(eltype(A))
+                dA .= zero(eltype(A))
                 minmn = min(size(A)...)
                 if size(U, 2) == size(Vᴴ, 1) == minmn # compact
-                    dA    = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
+                    dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
                 else # full
-                    vU    = view(U, :, 1:minmn)
-                    vS    = Diagonal(diagview(S)[1:minmn])
-                    vVᴴ   = view(Vᴴ, 1:minmn, :)
-                    vdU   = view(dU, :, 1:minmn)
-                    vdS   = view(dS, 1:minmn, 1:minmn)
-                    vdVᴴ  = view(dVᴴ, 1:minmn, :)
-                    dA    = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
+                    vU = view(U, :, 1:minmn)
+                    vS = Diagonal(diagview(S)[1:minmn])
+                    vVᴴ = view(Vᴴ, 1:minmn, :)
+                    vdU = view(dU, :, 1:minmn)
+                    vdS = view(dS, 1:minmn, 1:minmn)
+                    vdVᴴ = view(dVᴴ, 1:minmn, :)
+                    dA = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
                 end
                 return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
             end
-            dU  .= zero(dU)
-            dS  .= zero(dS)
+            dU .= zero(dU)
+            dS .= zero(dS)
             dVᴴ .= zero(dVᴴ)
             return Mooncake.CoDual(USVᴴ, dUSVᴴ), dsvd_adjoint
         end
         @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
         function Mooncake.frule!!(::Dual{<:typeof($f)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual; kwargs...)
             # compute primal
-            USVᴴ   = Mooncake.primal(USVᴴ_dUSVᴴ)
-            dUSVᴴ  = Mooncake.tangent(USVᴴ_dUSVᴴ)
-            A_     = Mooncake.primal(A_dA)
-            dA_    = Mooncake.tangent(A_dA)
-            A, dA  = arrayify(A_, dA_)
+            USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
+            dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
+            A_ = Mooncake.primal(A_dA)
+            dA_ = Mooncake.tangent(A_dA)
+            A, dA = arrayify(A_, dA_)
             $f(A, USVᴴ, alg_dalg.primal; kwargs...)
 
             # update tangents
-            U_, S_, Vᴴ_    = USVᴴ
+            U_, S_, Vᴴ_ = USVᴴ
             dU_, dS_, dVᴴ_ = dUSVᴴ
-            U, dU          = arrayify(U_, dU_) 
-            S, dS          = arrayify(S_, dS_) 
-            Vᴴ, dVᴴ        = arrayify(Vᴴ_, dVᴴ_) 
-            minmn          = min(size(A)...)
+            U, dU = arrayify(U_, dU_)
+            S, dS = arrayify(S_, dS_)
+            Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_)
+            minmn = min(size(A)...)
             if ($f == svd_compact!) # compact
                 svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...)
             else # full
-                vU    = view(U, :, 1:minmn)
-                vS    = view(S, 1:minmn, 1:minmn)
-                vVᴴ   = view(Vᴴ, 1:minmn, :)
-                vdU   = view(dU, :, 1:minmn)
-                vdS   = view(dS, 1:minmn, 1:minmn)
-                vdVᴴ  = view(dVᴴ, 1:minmn, :)
+                vU = view(U, :, 1:minmn)
+                vS = view(S, 1:minmn, 1:minmn)
+                vVᴴ = view(Vᴴ, 1:minmn, :)
+                vdU = view(dU, :, 1:minmn)
+                vdS = view(dS, 1:minmn, 1:minmn)
+                vdVᴴ = view(dVᴴ, 1:minmn, :)
                 svd_pushforward!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ); kwargs...)
             end
             return USVᴴ_dUSVᴴ
@@ -263,15 +265,15 @@ end
 @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
 function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual; kwargs...)
     # compute primal
-    S_    = Mooncake.primal(S_dS)
-    dS_   = Mooncake.tangent(S_dS)
-    A_    = Mooncake.primal(A_dA)
-    dA_   = Mooncake.tangent(A_dA)
+    S_ = Mooncake.primal(S_dS)
+    dS_ = Mooncake.tangent(S_dS)
+    A_ = Mooncake.primal(A_dA)
+    dA_ = Mooncake.tangent(A_dA)
     A, dA = arrayify(A_, dA_)
     U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
 
     # update tangent
-    S, dS   = arrayify(S_, dS_) 
+    S, dS = arrayify(S_, dS_)
     copyto!(dS, diag(real.(Vᴴ * dA' * U)))
     copyto!(S, diagview(nS))
     dA .= zero(eltype(dA))
@@ -281,17 +283,17 @@ end
 @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
 function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...)
     # compute primal
-    S_    = Mooncake.primal(S_dS)
-    dS_   = Mooncake.tangent(S_dS)
-    A_    = Mooncake.primal(A_dA)
-    dA_   = Mooncake.tangent(A_dA)
+    S_ = Mooncake.primal(S_dS)
+    dS_ = Mooncake.tangent(S_dS)
+    A_ = Mooncake.primal(A_dA)
+    dA_ = Mooncake.tangent(A_dA)
     A, dA = arrayify(A_, dA_)
     S, dS = arrayify(S_, dS_)
     U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
-    S    .= diagview(nS)
-    dS   .= zero(eltype(S))
+    S .= diagview(nS)
+    dS .= zero(eltype(S))
     function dsvd_vals_adjoint(::Mooncake.NoRData)
-        dA   .= U * Diagonal(dS) * Vᴴ
+        dA .= U * Diagonal(dS) * Vᴴ
         return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
     end
     return S_dS, dsvd_vals_adjoint
diff --git a/src/common/view.jl b/src/common/view.jl
index 0bc7b9e..c8ae1aa 100644
--- a/src/common/view.jl
+++ b/src/common/view.jl
@@ -1,5 +1,5 @@
 # diagind: provided by LinearAlgebra.jl
-diagview(D::Diagonal)       = D.diag
+diagview(D::Diagonal) = D.diag
 diagview(D::AbstractMatrix) = view(D, diagind(D))
 
 # triangularind
diff --git a/src/pushforwards/eig.jl b/src/pushforwards/eig.jl
index 19a43cb..36506f3 100644
--- a/src/pushforwards/eig.jl
+++ b/src/pushforwards/eig.jl
@@ -1,11 +1,11 @@
 function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...)
-    D, V     = DV
-    ΔD, ΔV   = ΔDV
-    iVΔAV    = inv(V) * ΔA * V
+    D, V = DV
+    ΔD, ΔV = ΔDV
+    iVΔAV = inv(V) * ΔA * V
     diagview(ΔD) .= diagview(iVΔAV)
-    F        = 1 ./ (transpose(diagview(D)) .- diagview(D))
+    F = 1 ./ (transpose(diagview(D)) .- diagview(D))
     fill!(diagview(F), zero(eltype(F)))
-    K̇        = F .* iVΔAV
+    K̇ = F .* iVΔAV
     mul!(ΔV, V, K̇, 1, 0)
     zero!(ΔA)
     return ΔDV
diff --git a/src/pushforwards/eigh.jl b/src/pushforwards/eigh.jl
index 69685b1..5fdfdca 100644
--- a/src/pushforwards/eigh.jl
+++ b/src/pushforwards/eigh.jl
@@ -1,16 +1,16 @@
 function eigh_pushforward!(dA, A, DV, dDV; kwargs...)
-    D, V         = DV
-    dD, dV       = dDV
-    tmpV         = V \ dA
-    ∂K           = tmpV * V
-    ∂Kdiag       = diag(∂K)
-    dD.diag     .= real.(∂Kdiag)
-    dDD          = transpose(diagview(D)) .- diagview(D)
-    F            = one(eltype(dDD)) ./ dDD
+    D, V = DV
+    dD, dV = dDV
+    tmpV = V \ dA
+    ∂K = tmpV * V
+    ∂Kdiag = diag(∂K)
+    dD.diag .= real.(∂Kdiag)
+    dDD = transpose(diagview(D)) .- diagview(D)
+    F = one(eltype(dDD)) ./ dDD
     diagview(F) .= zero(eltype(F))
-    ∂K         .*= F
-    ∂V           = mul!(tmpV, V, ∂K) 
+    ∂K .*= F
+    ∂V = mul!(tmpV, V, ∂K)
     copyto!(dV, ∂V)
-    dA          .= zero(eltype(A))
+    dA .= zero(eltype(A))
     return (dD, dV)
 end
diff --git a/src/pushforwards/lq.jl b/src/pushforwards/lq.jl
index 2d390a5..ed5a72d 100644
--- a/src/pushforwards/lq.jl
+++ b/src/pushforwards/lq.jl
@@ -62,7 +62,7 @@
 end=#
 
 function lq_pushforward!(dA, A, LQ, dLQ; kwargs...)
-    qr_pushforward!(dA, A, (adjoint(LQ[2]), adjoint(LQ[1])), (adjoint(dLQ[2]), adjoint(dLQ[1])); kwargs...)
+    return qr_pushforward!(dA, A, (adjoint(LQ[2]), adjoint(LQ[1])), (adjoint(dLQ[2]), adjoint(dLQ[1])); kwargs...)
 end
 
-function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) end
+function lq_null_pushforward!(dA, A, LQ, dLQ; tol::Real = MatrixAlgebraKit.default_pullback_gaugetol(LQ[1]), rank_atol::Real = tol, gauge_atol::Real = tol) end
diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl
index 2001c41..78ab79d 100644
--- a/src/pushforwards/polar.jl
+++ b/src/pushforwards/polar.jl
@@ -1,23 +1,23 @@
 function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...)
-    W, P   = WP
+    W, P = WP
     ΔW, ΔP = ΔWP
-    aWdA   = adjoint(W) * ΔA
-    K̇      = sylvester(P, P, -(aWdA - adjoint(aWdA)))
-    L̇      = (Diagonal(ones(eltype(W), size(W, 1))) - W*adjoint(W))*ΔA*inv(P)
-    ΔW     .= W * K̇ + L̇
-    ΔP     .= aWdA - K̇*P
+    aWdA = adjoint(W) * ΔA
+    K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA)))
+    L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W * adjoint(W)) * ΔA * inv(P)
+    ΔW .= W * K̇ + L̇
+    ΔP .= aWdA - K̇ * P
     MatrixAlgebraKit.zero!(ΔA)
     return (ΔW, ΔP)
 end
 
 function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...)
-    P, Wᴴ   = PWᴴ
+    P, Wᴴ = PWᴴ
     ΔP, ΔWᴴ = ΔPWᴴ
-    dAW     = ΔA * adjoint(Wᴴ)
-    K̇       = sylvester(P, P, -(dAW - adjoint(dAW)))
-    L̇       = inv(P)*ΔA*(Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ)
-    ΔWᴴ    .= K̇ * Wᴴ + L̇
-    ΔP     .= dAW - P * K̇
+    dAW = ΔA * adjoint(Wᴴ)
+    K̇ = sylvester(P, P, -(dAW - adjoint(dAW)))
+    L̇ = inv(P) * ΔA * (Diagonal(ones(el...*[Comment body truncated]*

@kshyatt kshyatt marked this pull request as draft November 12, 2025 12:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants