@@ -23,7 +23,7 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu
2323 return CoDual (Ac, dAc), copy_input_pb
2424end
2525
26- # two-argument factorizations like LQ, QR, EIG
26+ # two-argument in-place factorizations like LQ, QR, EIG
2727for (f, pb, adj) in (
2828 (qr_full!, qr_pullback!, :dqr_adjoint ),
2929 (qr_compact!, qr_pullback!, :dqr_adjoint ),
@@ -61,6 +61,39 @@ for (f, pb, adj) in (
6161 end
6262end
6363
64+ # two-argument copying factorizations like LQ, QR, EIG
65+ for (f, pb, adj) in (
66+ (qr_full, qr_pullback!, :dqr_adjoint ),
67+ (lq_full, lq_pullback!, :dlq_adjoint ),
68+ (qr_compact, qr_pullback!, :dqr_adjoint ),
69+ (lq_compact, lq_pullback!, :dlq_adjoint ),
70+ (eig_full, eig_pullback!, :deig_adjoint ),
71+ (eigh_full, eigh_pullback!, :deigh_adjoint ),
72+ (left_polar, left_polar_pullback!, :dleft_polar_adjoint ),
73+ (right_polar, right_polar_pullback!, :dright_polar_adjoint ),
74+ )
75+
76+ @eval begin
77+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit. AbstractAlgorithm}
78+ function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} ; kwargs... )
79+ A, dA = arrayify (A_dA)
80+ output = $ f (A, Mooncake. primal (alg_dalg); kwargs... )
81+ output_codual = Mooncake. CoDual (output, Mooncake. fdata (Mooncake. zero_tangent (output)))
82+ function $adj (:: Mooncake.NoRData )
83+ arg1, arg2 = Mooncake. primal (output_codual)
84+ darg1_, darg2_ = Mooncake. tangent (output_codual)
85+ arg1, darg1 = Mooncake. arrayify (arg1, darg1_)
86+ arg2, darg2 = Mooncake. arrayify (arg2, darg2_)
87+ $ pb (dA, A, (arg1, arg2), (darg1, darg2); kwargs... )
88+ MatrixAlgebraKit. zero! (darg1)
89+ MatrixAlgebraKit. zero! (darg2)
90+ return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
91+ end
92+ return output_codual, $ adj
93+ end
94+ end
95+ end
96+
6497for (f, f_full, pb, adj) in (
6598 (qr_null!, qr_full, qr_null_pullback!, :dqr_null_adjoint ),
6699 (lq_null!, lq_full, lq_null_pullback!, :dlq_null_adjoint ),
@@ -86,6 +119,29 @@ for (f, f_full, pb, adj) in (
86119 end
87120end
88121
122+ for (f, pb, adj) in (
123+ (qr_null, qr_null_pullback!, :dqr_null_adjoint ),
124+ (lq_null, lq_null_pullback!, :dlq_null_adjoint ),
125+ )
126+ @eval begin
127+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit. AbstractAlgorithm}
128+ function Mooncake. rrule!! (f_df:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual{<:MatrixAlgebraKit.AbstractAlgorithm} ; kwargs... )
129+ A, dA = arrayify (A_dA)
130+ output = $ f (A, Mooncake. primal (alg_dalg); kwargs... )
131+ output_codual = Mooncake. CoDual (output, Mooncake. zero_tangent (output))
132+ function $adj (:: Mooncake.NoRData )
133+ arg = Mooncake. primal (output_codual)
134+ darg_ = Mooncake. tangent (output_codual)
135+ arg, darg = Mooncake. arrayify (arg, darg_)
136+ $ pb (dA, A, arg, darg; kwargs... )
137+ MatrixAlgebraKit. zero! (darg)
138+ return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
139+ end
140+ return output_codual, $ adj
141+ end
142+ end
143+ end
144+
89145@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. eig_vals!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
90146function Mooncake. rrule!! (:: CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)} , A_dA:: CoDual , D_dD:: CoDual , alg_dalg:: CoDual ; kwargs... )
91147 # compute primal
@@ -135,6 +191,35 @@ function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA:
135191 return Mooncake. CoDual (DV[1 ]. diag, dD_), deigh_vals_adjoint
136192end
137193
194+ for (f, full_f, pb, adj) in (
195+ (eig_vals, eig_full, eig_pullback!, :deig_vals_adjoint ),
196+ (eigh_vals, eigh_full, eigh_pullback!, :deigh_vals_adjoint ),
197+ )
198+ @eval begin
199+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), Any, MatrixAlgebraKit. AbstractAlgorithm}
200+ function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual ; kwargs... )
201+ # compute primal
202+ A_ = Mooncake. primal (A_dA)
203+ dA_ = Mooncake. tangent (A_dA)
204+ A, dA = arrayify (A_, dA_)
205+ output = $ f (A, Mooncake. primal (alg_dalg); kwargs... )
206+ output_codual = Mooncake. CoDual (output, Mooncake. zero_tangent (output))
207+ # update primal
208+ DV = $ full_f (A, Mooncake. primal (alg_dalg); kwargs... )
209+ V = DV[2 ]
210+ function $adj (:: Mooncake.NoRData )
211+ D = Mooncake. primal (output_codual)
212+ dD_ = Mooncake. tangent (output_codual)
213+ D, dD = Mooncake. arrayify (D, dD_)
214+ $ pb (dA, A, (D, V), (dD, nothing ); kwargs... )
215+ MatrixAlgebraKit. zero! (dD)
216+ return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
217+ end
218+ return output_codual, $ adj
219+ end
220+ end
221+ end
222+
138223
139224for f in (svd_full!, svd_compact!)
140225 @eval begin
@@ -178,6 +263,41 @@ for f in (svd_full!, svd_compact!)
178263 end
179264end
180265
266+ for f in (svd_full, svd_compact)
267+ @eval begin
268+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f), AbstractMatrix, MatrixAlgebraKit. AbstractAlgorithm}
269+ function Mooncake. rrule!! (:: CoDual{typeof($f)} , A_dA:: CoDual , alg_dalg:: CoDual ; kwargs... )
270+ A, dA = arrayify (A_dA)
271+ USVᴴ = $ f (A, Mooncake. primal (alg_dalg); kwargs... )
272+ USVᴴ_codual = Mooncake. CoDual (USVᴴ, Mooncake. fdata (Mooncake. zero_tangent (USVᴴ)))
273+ function dsvd_adjoint (:: Mooncake.NoRData )
274+ U, S, Vᴴ = Mooncake. primal (USVᴴ_codual)
275+ dU_, dS_, dVᴴ_ = Mooncake. tangent (USVᴴ_codual)
276+ U, dU = Mooncake. arrayify (U, dU_)
277+ S, dS = Mooncake. arrayify (S, dS_)
278+ Vᴴ, dVᴴ = Mooncake. arrayify (Vᴴ, dVᴴ_)
279+ if ($ f == svd_compact)
280+ svd_pullback! (dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
281+ else # full
282+ minmn = min (size (A)... )
283+ vU = view (U, :, 1 : minmn)
284+ vS = Diagonal (diagview (S)[1 : minmn])
285+ vVᴴ = view (Vᴴ, 1 : minmn, :)
286+ vdU = view (dU, :, 1 : minmn)
287+ vdS = Diagonal (diagview (dS)[1 : minmn])
288+ vdVᴴ = view (dVᴴ, 1 : minmn, :)
289+ svd_pullback! (dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
290+ end
291+ MatrixAlgebraKit. zero! (dU)
292+ MatrixAlgebraKit. zero! (dS)
293+ MatrixAlgebraKit. zero! (dVᴴ)
294+ return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
295+ end
296+ return USVᴴ_codual, dsvd_adjoint
297+ end
298+ end
299+ end
300+
181301@is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_vals!), Any, Any, MatrixAlgebraKit. AbstractAlgorithm}
182302function Mooncake. rrule!! (:: CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)} , A_dA:: CoDual , S_dS:: CoDual , alg_dalg:: CoDual ; kwargs... )
183303 # compute primal
@@ -199,4 +319,24 @@ function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::
199319 return S_dS, dsvd_vals_adjoint
200320end
201321
322+ @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof (MatrixAlgebraKit. svd_vals), Any, MatrixAlgebraKit. AbstractAlgorithm}
323+ function Mooncake. rrule!! (:: CoDual{<:typeof(MatrixAlgebraKit.svd_vals)} , A_dA:: CoDual , alg_dalg:: CoDual ; kwargs... )
324+ # compute primal
325+ A = Mooncake. primal (A_dA)
326+ dA_ = Mooncake. tangent (A_dA)
327+ A, dA = arrayify (A, dA_)
328+ S = svd_vals (A, Mooncake. primal (alg_dalg); kwargs... )
329+ U, _, Vᴴ = svd_compact (A, Mooncake. primal (alg_dalg); kwargs... )
330+ S_codual = Mooncake. CoDual (S, Mooncake. fdata (Mooncake. zero_tangent (S)))
331+ function dsvd_vals_adjoint (:: Mooncake.NoRData )
332+ S = Mooncake. primal (S_codual)
333+ dS_ = Mooncake. tangent (S_codual)
334+ S, dS = Mooncake. arrayify (S, dS_)
335+ svd_pullback! (dA, A, (U, S, Vᴴ), (nothing , dS, nothing ))
336+ MatrixAlgebraKit. zero! (dS)
337+ return Mooncake. NoRData (), Mooncake. NoRData (), Mooncake. NoRData ()
338+ end
339+ return S_codual, dsvd_vals_adjoint
340+ end
341+
202342end
0 commit comments