Skip to content

Commit 271eb45

Browse files
committed
Make the copying rules first-class
1 parent 6a3be73 commit 271eb45

File tree

3 files changed

+146
-6
lines changed

3 files changed

+146
-6
lines changed

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu
2323
return CoDual(Ac, dAc), copy_input_pb
2424
end
2525

26-
# two-argument factorizations like LQ, QR, EIG
26+
# two-argument in-place factorizations like LQ, QR, EIG
2727
for (f, pb, adj) in (
2828
(qr_full!, qr_pullback!, :dqr_adjoint),
2929
(qr_compact!, qr_pullback!, :dqr_adjoint),
@@ -61,6 +61,39 @@ for (f, pb, adj) in (
6161
end
6262
end
6363

64+
# two-argument copying factorizations like LQ, QR, EIG
65+
for (f, pb, adj) in (
66+
(qr_full, qr_pullback!, :dqr_adjoint),
67+
(lq_full, lq_pullback!, :dlq_adjoint),
68+
(qr_compact, qr_pullback!, :dqr_adjoint),
69+
(lq_compact, lq_pullback!, :dlq_adjoint),
70+
(eig_full, eig_pullback!, :deig_adjoint),
71+
(eigh_full, eigh_pullback!, :deigh_adjoint),
72+
(left_polar, left_polar_pullback!, :dleft_polar_adjoint),
73+
(right_polar, right_polar_pullback!, :dright_polar_adjoint),
74+
)
75+
76+
@eval begin
77+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
78+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
79+
A, dA = arrayify(A_dA)
80+
output = $f(A, Mooncake.primal(alg_dalg); kwargs...)
81+
output_codual = Mooncake.CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
82+
function $adj(::Mooncake.NoRData)
83+
arg1, arg2 = Mooncake.primal(output_codual)
84+
darg1_, darg2_ = Mooncake.tangent(output_codual)
85+
arg1, darg1 = Mooncake.arrayify(arg1, darg1_)
86+
arg2, darg2 = Mooncake.arrayify(arg2, darg2_)
87+
$pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...)
88+
MatrixAlgebraKit.zero!(darg1)
89+
MatrixAlgebraKit.zero!(darg2)
90+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
91+
end
92+
return output_codual, $adj
93+
end
94+
end
95+
end
96+
6497
for (f, f_full, pb, adj) in (
6598
(qr_null!, qr_full, qr_null_pullback!, :dqr_null_adjoint),
6699
(lq_null!, lq_full, lq_null_pullback!, :dlq_null_adjoint),
@@ -86,6 +119,29 @@ for (f, f_full, pb, adj) in (
86119
end
87120
end
88121

122+
for (f, pb, adj) in (
123+
(qr_null, qr_null_pullback!, :dqr_null_adjoint),
124+
(lq_null, lq_null_pullback!, :dlq_null_adjoint),
125+
)
126+
@eval begin
127+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
128+
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
129+
A, dA = arrayify(A_dA)
130+
output = $f(A, Mooncake.primal(alg_dalg); kwargs...)
131+
output_codual = Mooncake.CoDual(output, Mooncake.zero_tangent(output))
132+
function $adj(::Mooncake.NoRData)
133+
arg = Mooncake.primal(output_codual)
134+
darg_ = Mooncake.tangent(output_codual)
135+
arg, darg = Mooncake.arrayify(arg, darg_)
136+
$pb(dA, A, arg, darg; kwargs...)
137+
MatrixAlgebraKit.zero!(darg)
138+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
139+
end
140+
return output_codual, $adj
141+
end
142+
end
143+
end
144+
89145
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
90146
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
91147
# compute primal
@@ -135,6 +191,35 @@ function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA:
135191
return Mooncake.CoDual(DV[1].diag, dD_), deigh_vals_adjoint
136192
end
137193

194+
for (f, full_f, pb, adj) in (
195+
(eig_vals, eig_full, eig_pullback!, :deig_vals_adjoint),
196+
(eigh_vals, eigh_full, eigh_pullback!, :deigh_vals_adjoint),
197+
)
198+
@eval begin
199+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
200+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual; kwargs...)
201+
# compute primal
202+
A_ = Mooncake.primal(A_dA)
203+
dA_ = Mooncake.tangent(A_dA)
204+
A, dA = arrayify(A_, dA_)
205+
output = $f(A, Mooncake.primal(alg_dalg); kwargs...)
206+
output_codual = Mooncake.CoDual(output, Mooncake.zero_tangent(output))
207+
# update primal
208+
DV = $full_f(A, Mooncake.primal(alg_dalg); kwargs...)
209+
V = DV[2]
210+
function $adj(::Mooncake.NoRData)
211+
D = Mooncake.primal(output_codual)
212+
dD_ = Mooncake.tangent(output_codual)
213+
D, dD = Mooncake.arrayify(D, dD_)
214+
$pb(dA, A, (D, V), (dD, nothing); kwargs...)
215+
MatrixAlgebraKit.zero!(dD)
216+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
217+
end
218+
return output_codual, $adj
219+
end
220+
end
221+
end
222+
138223

139224
for f in (svd_full!, svd_compact!)
140225
@eval begin
@@ -178,6 +263,41 @@ for f in (svd_full!, svd_compact!)
178263
end
179264
end
180265

266+
for f in (svd_full, svd_compact)
267+
@eval begin
268+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
269+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual; kwargs...)
270+
A, dA = arrayify(A_dA)
271+
USVᴴ = $f(A, Mooncake.primal(alg_dalg); kwargs...)
272+
USVᴴ_codual = Mooncake.CoDual(USVᴴ, Mooncake.fdata(Mooncake.zero_tangent(USVᴴ)))
273+
function dsvd_adjoint(::Mooncake.NoRData)
274+
U, S, Vᴴ = Mooncake.primal(USVᴴ_codual)
275+
dU_, dS_, dVᴴ_ = Mooncake.tangent(USVᴴ_codual)
276+
U, dU = Mooncake.arrayify(U, dU_)
277+
S, dS = Mooncake.arrayify(S, dS_)
278+
Vᴴ, dVᴴ = Mooncake.arrayify(Vᴴ, dVᴴ_)
279+
if ($f == svd_compact)
280+
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
281+
else # full
282+
minmn = min(size(A)...)
283+
vU = view(U, :, 1:minmn)
284+
vS = Diagonal(diagview(S)[1:minmn])
285+
vVᴴ = view(Vᴴ, 1:minmn, :)
286+
vdU = view(dU, :, 1:minmn)
287+
vdS = Diagonal(diagview(dS)[1:minmn])
288+
vdVᴴ = view(dVᴴ, 1:minmn, :)
289+
svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
290+
end
291+
MatrixAlgebraKit.zero!(dU)
292+
MatrixAlgebraKit.zero!(dS)
293+
MatrixAlgebraKit.zero!(dVᴴ)
294+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
295+
end
296+
return USVᴴ_codual, dsvd_adjoint
297+
end
298+
end
299+
end
300+
181301
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
182302
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...)
183303
# compute primal
@@ -199,4 +319,24 @@ function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::
199319
return S_dS, dsvd_vals_adjoint
200320
end
201321

322+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}
323+
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals)}, A_dA::CoDual, alg_dalg::CoDual; kwargs...)
324+
# compute primal
325+
A = Mooncake.primal(A_dA)
326+
dA_ = Mooncake.tangent(A_dA)
327+
A, dA = arrayify(A, dA_)
328+
S = svd_vals(A, Mooncake.primal(alg_dalg); kwargs...)
329+
U, _, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
330+
S_codual = Mooncake.CoDual(S, Mooncake.fdata(Mooncake.zero_tangent(S)))
331+
function dsvd_vals_adjoint(::Mooncake.NoRData)
332+
S = Mooncake.primal(S_codual)
333+
dS_ = Mooncake.tangent(S_codual)
334+
S, dS = Mooncake.arrayify(S, dS_)
335+
svd_pullback!(dA, A, (U, S, Vᴴ), (nothing, dS, nothing))
336+
MatrixAlgebraKit.zero!(dS)
337+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
338+
end
339+
return S_codual, dsvd_vals_adjoint
340+
end
341+
202342
end

src/pullbacks/eig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function eig_pullback!(
4646
Δgauge < gauge_atol ||
4747
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
4848

49-
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, tol))
49+
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
5050

5151
if !iszerotangent(ΔDmat)
5252
ΔDvec = diagview(ΔDmat)

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using SafeTestsets
44
# specific ones
55
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
66
if !is_buildkite
7-
@safetestset "Algorithms" begin
7+
#=@safetestset "Algorithms" begin
88
include("algorithms.jl")
99
end
1010
@safetestset "Projections" begin
@@ -37,11 +37,11 @@ if !is_buildkite
3737
end
3838
@safetestset "Image and Null Space" begin
3939
include("orthnull.jl")
40-
end
40+
end=#
4141
@safetestset "Mooncake" begin
4242
include("mooncake.jl")
4343
end
44-
@safetestset "ChainRules" begin
44+
#=@safetestset "ChainRules" begin
4545
include("chainrules.jl")
4646
end
4747
@safetestset "MatrixAlgebraKit.jl" begin
@@ -55,7 +55,7 @@ if !is_buildkite
5555
using JET
5656
JET.test_package(MatrixAlgebraKit; target_defined_modules = true)
5757
end
58-
end
58+
end=#
5959
end
6060

6161
using CUDA

0 commit comments

Comments
 (0)