@@ -41,11 +41,11 @@ for (f!, f, pb, adj) in (
4141 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Tuple{<: Any , <: Any }, MatrixAlgebraKit. AbstractAlgorithm}
4242 function Mooncake. rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , args_dargs:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} ; kwargs... )
4343 A, dA = arrayify (A_dA)
44- args = Mooncake. primal (args_dargs)
44+ args = Mooncake. primal (args_dargs)
4545 dargs = Mooncake. tangent (args_dargs)
4646 arg1, darg1 = arrayify (args[1 ], dargs[1 ])
4747 arg2, darg2 = arrayify (args[2 ], dargs[2 ])
48- Ac = copy (A)
48+ Ac = copy (A)
4949 arg1c = copy (arg1)
5050 arg2c = copy (arg2)
5151 $ f! (A, args, Mooncake. primal (alg_dalg); kwargs... )
@@ -68,8 +68,8 @@ for (f!, f, pb, adj) in (
6868 function $adj (:: Mooncake.NoRData )
6969 arg1, arg2 = Mooncake. primal (output_codual)
7070 darg1_, darg2_ = Mooncake. tangent (output_codual)
71- arg1, darg1 = Mooncake. arrayify (arg1, darg1_)
72- arg2, darg2 = Mooncake. arrayify (arg2, darg2_)
71+ arg1, darg1 = Mooncake. arrayify (arg1, darg1_)
72+ arg2, darg2 = Mooncake. arrayify (arg2, darg2_)
7373 $ pb (dA, A, (arg1, arg2), (darg1, darg2); kwargs... )
7474 MatrixAlgebraKit. zero! (darg1)
7575 MatrixAlgebraKit. zero! (darg2)
@@ -89,14 +89,14 @@ for (f!, f, pb, adj) in (
8989 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
9090 function Mooncake. rrule!! (f_df:: CoDual{typeof($f!)} , A_dA:: CoDual , arg_darg:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} ; kwargs... )
9191 A, dA = arrayify (A_dA)
92- Ac = copy (A)
92+ Ac = copy (A)
9393 arg, darg = arrayify (arg_darg)
94- argc = copy (arg)
94+ argc = copy (arg)
9595 $ f! (A, arg, Mooncake. primal (alg_dalg))
9696 function $adj (:: Mooncake.NoRData )
9797 A .= Ac
9898 $ pb (dA, A, arg, darg; kwargs... )
99- arg .= argc
99+ arg .= argc
100100 MatrixAlgebraKit. zero! (darg)
101101 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
102102 end
@@ -129,9 +129,9 @@ for (f!, f, full_f, pb, adj) in (
129129 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
130130 function Mooncake. rrule!! (:: CoDual{<:typeof($f!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual ; kwargs... )
131131 # compute primal
132- D_ = Mooncake. primal (D_dD)
132+ D_ = Mooncake. primal (D_dD)
133133 dD_ = Mooncake. tangent (D_dD)
134- A_ = Mooncake. primal (A_dA)
134+ A_ = Mooncake. primal (A_dA)
135135 dA_ = Mooncake. tangent (A_dA)
136136 A, dA = arrayify (A_, dA_)
137137 D, dD = arrayify (D_, dD_)
@@ -160,7 +160,7 @@ for (f!, f, full_f, pb, adj) in (
160160 function $adj (:: Mooncake.NoRData )
161161 D = Mooncake. primal (output_codual)
162162 dD_ = Mooncake. tangent (output_codual)
163- D, dD = Mooncake. arrayify (D, dD_)
163+ D, dD = Mooncake. arrayify (D, dD_)
164164 $ pb (dA, A, (D, V), (dD, nothing ); kwargs... )
165165 MatrixAlgebraKit. zero! (dD)
166166 return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
@@ -179,17 +179,17 @@ for (f, pb, adj) in (
179179 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit. TruncatedAlgorithm}
180180 function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual ; kwargs... )
181181 # compute primal
182- A_ = Mooncake. primal (A_dA)
183- dA_ = Mooncake. tangent (A_dA)
184- A, dA = arrayify (A_, dA_)
185- alg = Mooncake. primal (alg_dalg)
182+ A_ = Mooncake. primal (A_dA)
183+ dA_ = Mooncake. tangent (A_dA)
184+ A, dA = arrayify (A_, dA_)
185+ alg = Mooncake. primal (alg_dalg)
186186 output = $ f (A, alg)
187187 output_codual = Mooncake. CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
188- function $adj (dy:: Tuple{Mooncake.NoRData, Mooncake.NoRData, T} ) where {T<: Real }
188+ function $adj (dy:: Tuple{Mooncake.NoRData, Mooncake.NoRData, T} ) where {T <: Real }
189189 Dtrunc, Vtrunc, ϵ = Mooncake. primal (output_codual)
190190 dDtrunc_, dVtrunc_, dϵ = Mooncake. tangent (output_codual)
191- D, dD = Mooncake. arrayify (Dtrunc, dDtrunc_)
192- V, dV = Mooncake. arrayify (Vtrunc, dVtrunc_)
191+ D, dD = Mooncake. arrayify (Dtrunc, dDtrunc_)
192+ V, dV = Mooncake. arrayify (Vtrunc, dVtrunc_)
193193 $ pb (dA, A, (D, V), (dD, dV))
194194 MatrixAlgebraKit. zero! (dD)
195195 MatrixAlgebraKit. zero! (dV)
@@ -201,19 +201,20 @@ for (f, pb, adj) in (
201201 end
202202end
203203
204- for (f!, f) in ((svd_full!, svd_full),
205- (svd_compact!, svd_compact),
206- )
204+ for (f!, f) in (
205+ (svd_full!, svd_full),
206+ (svd_compact!, svd_compact),
207+ )
207208 @eval begin
208209 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f!), AbstractMatrix, Tuple{<: Any , <: Any , <: Any }, MatrixAlgebraKit. AbstractAlgorithm}
209210 function Mooncake. rrule!! (:: CoDual{typeof($f!)} , A_dA:: CoDual , USVᴴ_dUSVᴴ:: CoDual , alg_dalg:: CoDual ; kwargs... )
210- A, dA = arrayify (A_dA)
211- USVᴴ = Mooncake. primal (USVᴴ_dUSVᴴ)
212- dUSVᴴ = Mooncake. tangent (USVᴴ_dUSVᴴ)
213- U, dU = arrayify (USVᴴ[1 ], dUSVᴴ[1 ])
214- S, dS = arrayify (USVᴴ[2 ], dUSVᴴ[2 ])
211+ A, dA = arrayify (A_dA)
212+ USVᴴ = Mooncake. primal (USVᴴ_dUSVᴴ)
213+ dUSVᴴ = Mooncake. tangent (USVᴴ_dUSVᴴ)
214+ U, dU = arrayify (USVᴴ[1 ], dUSVᴴ[1 ])
215+ S, dS = arrayify (USVᴴ[2 ], dUSVᴴ[2 ])
215216 Vᴴ, dVᴴ = arrayify (USVᴴ[3 ], dUSVᴴ[3 ])
216- output = $ f (A, Mooncake. primal (alg_dalg); kwargs... )
217+ output = $ f (A, Mooncake. primal (alg_dalg); kwargs... )
217218 function dsvd_adjoint (:: Mooncake.NoRData )
218219 if ($ f == svd_compact!)
219220 svd_pullback! (dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
@@ -237,7 +238,7 @@ for (f!, f) in ((svd_full!, svd_full),
237238 @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), AbstractMatrix, MatrixAlgebraKit. AbstractAlgorithm}
238239 function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual ; kwargs... )
239240 A, dA = arrayify (A_dA)
240- USVᴴ = $ f (A, Mooncake. primal (alg_dalg); kwargs... )
241+ USVᴴ = $ f (A, Mooncake. primal (alg_dalg); kwargs... )
241242 USVᴴ_codual = Mooncake. CoDual (USVᴴ, Mooncake. fdata (Mooncake. zero_tangent (USVᴴ)))
242243 function dsvd_adjoint (:: Mooncake.NoRData )
243244 U, S, Vᴴ = Mooncake. primal (USVᴴ_codual)
@@ -271,10 +272,10 @@ end
271272@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_vals!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
272273function Mooncake. rrule!! (:: CoDual{typeof(MatrixAlgebraKit.svd_vals!)} , A_dA:: CoDual , S_dS:: CoDual , alg_dalg:: CoDual ; kwargs... )
273274 # compute primal
274- S_ = Mooncake. primal (S_dS)
275- dS_ = Mooncake. tangent (S_dS)
276- A_ = Mooncake. primal (A_dA)
277- dA_ = Mooncake. tangent (A_dA)
275+ S_ = Mooncake. primal (S_dS)
276+ dS_ = Mooncake. tangent (S_dS)
277+ A_ = Mooncake. primal (A_dA)
278+ dA_ = Mooncake. tangent (A_dA)
278279 A, dA = arrayify (A_, dA_)
279280 S, dS = arrayify (S_, dS_)
280281 U, nS, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg); kwargs... )
@@ -290,15 +291,15 @@ end
290291@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_vals), Any, MatrixAlgebraKit. AbstractAlgorithm}
291292function Mooncake. rrule!! (:: CoDual{typeof(MatrixAlgebraKit.svd_vals)} , A_dA:: CoDual , alg_dalg:: CoDual ; kwargs... )
292293 # compute primal
293- A = Mooncake. primal (A_dA)
294- dA_ = Mooncake. tangent (A_dA)
295- A, dA = arrayify (A, dA_)
296- S = svd_vals (A, Mooncake. primal (alg_dalg); kwargs... )
294+ A = Mooncake. primal (A_dA)
295+ dA_ = Mooncake. tangent (A_dA)
296+ A, dA = arrayify (A, dA_)
297+ S = svd_vals (A, Mooncake. primal (alg_dalg); kwargs... )
297298 U, _, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg); kwargs... )
298299 S_codual = Mooncake. CoDual (S, Mooncake. fdata (Mooncake. zero_tangent (S)))
299300 function dsvd_vals_adjoint (:: Mooncake.NoRData )
300- S = Mooncake. primal (S_codual)
301- dS_ = Mooncake. tangent (S_codual)
301+ S = Mooncake. primal (S_codual)
302+ dS_ = Mooncake. tangent (S_codual)
302303 S, dS = Mooncake. arrayify (S, dS_)
303304 svd_pullback! (dA, A, (U, S, Vᴴ), (nothing , dS, nothing ))
304305 MatrixAlgebraKit. zero! (dS)
@@ -311,17 +312,17 @@ end
311312@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_trunc), Any, MatrixAlgebraKit. TruncatedAlgorithm}
312313function Mooncake. rrule!! (:: CoDual{typeof(MatrixAlgebraKit.svd_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual ; kwargs... )
313314 # compute primal
314- A_ = Mooncake. primal (A_dA)
315- dA_ = Mooncake. tangent (A_dA)
316- A, dA = arrayify (A_, dA_)
317- alg = Mooncake. primal (alg_dalg)
315+ A_ = Mooncake. primal (A_dA)
316+ dA_ = Mooncake. tangent (A_dA)
317+ A, dA = arrayify (A_, dA_)
318+ alg = Mooncake. primal (alg_dalg)
318319 output = svd_trunc (A, alg)
319320 output_codual = Mooncake. CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
320- function dsvd_trunc_adjoint (dy:: Tuple{Mooncake.NoRData, Mooncake.NoRData, Mooncake.NoRData, T} ) where {T<: Real }
321+ function dsvd_trunc_adjoint (dy:: Tuple{Mooncake.NoRData, Mooncake.NoRData, Mooncake.NoRData, T} ) where {T <: Real }
321322 Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake. primal (output_codual)
322323 dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake. tangent (output_codual)
323- U, dU = Mooncake. arrayify (Utrunc, dUtrunc_)
324- S, dS = Mooncake. arrayify (Strunc, dStrunc_)
324+ U, dU = Mooncake. arrayify (Utrunc, dUtrunc_)
325+ S, dS = Mooncake. arrayify (Strunc, dStrunc_)
325326 Vᴴ, dVᴴ = Mooncake. arrayify (Vᴴtrunc, dVᴴtrunc_)
326327 svd_trunc_pullback! (dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
327328 MatrixAlgebraKit. zero! (dU)
0 commit comments