Skip to content

Commit 2234edc

Browse files
committed
Mooncake reverse rules
1 parent c99084b commit 2234edc

File tree

10 files changed

+817
-46
lines changed

10 files changed

+817
-46
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1212
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1313
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
1414
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"
15+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1516

1617
[extensions]
1718
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
1819
MatrixAlgebraKitAMDGPUExt = "AMDGPU"
1920
MatrixAlgebraKitCUDAExt = "CUDA"
2021
MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra"
2122
MatrixAlgebraKitGenericSchurExt = "GenericSchur"
23+
MatrixAlgebraKitMooncakeExt = "Mooncake"
2224

2325
[compat]
2426
AMDGPU = "2"
@@ -30,6 +32,7 @@ GenericLinearAlgebra = "0.3.19"
3032
GenericSchur = "0.5.6"
3133
JET = "0.9, 0.10"
3234
LinearAlgebra = "1"
35+
Mooncake = "0.4.167"
3336
SafeTestsets = "0.1"
3437
StableRNGs = "1"
3538
Test = "1"
@@ -41,11 +44,12 @@ julia = "1.10"
4144
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4245
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
4346
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
47+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
4448
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4549
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
4650
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4751
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
4852
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4953

5054
[targets]
51-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"]
55+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
module MatrixAlgebraKitMooncakeExt
2+
3+
using Mooncake
4+
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
5+
using MatrixAlgebraKit
6+
using MatrixAlgebraKit: inv_safe, diagview, copy_input
7+
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
8+
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
9+
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!
10+
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
11+
using MatrixAlgebraKit: svd_pullback!
12+
using LinearAlgebra
13+
14+
15+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
16+
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
17+
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
18+
dAc = Mooncake.zero_tangent(Ac)
19+
function copy_input_pb(::Mooncake.NoRData)
20+
Mooncake.increment!!(Mooncake.tangent(A_dA), dAc)
21+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
22+
end
23+
return CoDual(Ac, dAc), copy_input_pb
24+
end
25+
26+
# two-argument factorizations like LQ, QR, EIG
27+
for (f, pb, adj) in (
28+
(qr_full!, qr_pullback!, :dqr_adjoint),
29+
(qr_compact!, qr_pullback!, :dqr_adjoint),
30+
(lq_full!, lq_pullback!, :dlq_adjoint),
31+
(lq_compact!, lq_pullback!, :dlq_adjoint),
32+
(eig_full!, eig_pullback!, :deig_adjoint),
33+
(eigh_full!, eigh_pullback!, :deigh_adjoint),
34+
(left_polar!, left_polar_pullback!, :dleft_polar_adjoint),
35+
(right_polar!, right_polar_pullback!, :dright_polar_adjoint),
36+
)
37+
38+
@eval begin
39+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
40+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
41+
A, dA = arrayify(A_dA)
42+
args = Mooncake.primal(args_dargs)
43+
dargs = Mooncake.tangent(args_dargs)
44+
arg1, darg1 = arrayify(args[1], dargs[1])
45+
arg2, darg2 = arrayify(args[2], dargs[2])
46+
Ac = copy(A)
47+
arg1c = copy(arg1)
48+
arg2c = copy(arg2)
49+
output = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
50+
function $adj(::Mooncake.NoRData)
51+
$pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...)
52+
arg1 .= arg1c
53+
arg2 .= arg2c
54+
A .= Ac
55+
MatrixAlgebraKit.zero!(darg1)
56+
MatrixAlgebraKit.zero!(darg2)
57+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
58+
end
59+
return Mooncake.CoDual(args, dargs), $adj
60+
end
61+
end
62+
end
63+
64+
for (f, f_full, pb, adj) in (
65+
(qr_null!, qr_full, qr_null_pullback!, :dqr_null_adjoint),
66+
(lq_null!, lq_full, lq_null_pullback!, :dlq_null_adjoint),
67+
)
68+
@eval begin
69+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
70+
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
71+
A, dA = arrayify(A_dA)
72+
Ac = copy(A)
73+
arg, darg = arrayify(arg_darg)
74+
argc = copy(arg)
75+
# WHY is this copy needed?
76+
arg = $f(copy(A), arg, Mooncake.primal(alg_dalg))
77+
function $adj(::Mooncake.NoRData)
78+
$pb(dA, A, arg, darg; kwargs...)
79+
A .= Ac
80+
arg .= argc
81+
MatrixAlgebraKit.zero!(darg)
82+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
83+
end
84+
return arg_darg, $adj
85+
end
86+
end
87+
end
88+
89+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
90+
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
91+
# compute primal
92+
D_ = Mooncake.primal(D_dD)
93+
dD_ = Mooncake.tangent(D_dD)
94+
A_ = Mooncake.primal(A_dA)
95+
dA_ = Mooncake.tangent(A_dA)
96+
A, dA = arrayify(A_, dA_)
97+
D, dD = arrayify(D_, dD_)
98+
Ac = copy(A)
99+
Dc = copy(D)
100+
# update primal
101+
DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
102+
V = DV[2]
103+
eig_vals!(A, D, Mooncake.primal(alg_dalg))
104+
function deig_vals_adjoint(::Mooncake.NoRData)
105+
A .= Ac
106+
eig_pullback!(dA, A, (D, V), (dD, nothing); kwargs...)
107+
D .= Dc
108+
MatrixAlgebraKit.zero!(dD)
109+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
110+
end
111+
return D_dD, deig_vals_adjoint
112+
end
113+
114+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
115+
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
116+
# compute primal
117+
D_ = Mooncake.primal(D_dD)
118+
dD_ = Mooncake.tangent(D_dD)
119+
A_ = Mooncake.primal(A_dA)
120+
dA_ = Mooncake.tangent(A_dA)
121+
A, dA = arrayify(A_, dA_)
122+
Ac = copy(A)
123+
D, dD = arrayify(D_, dD_)
124+
Dc = copy(D)
125+
DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
126+
D .= diagview(DV[1])
127+
V = DV[2]
128+
function deigh_vals_adjoint(::Mooncake.NoRData)
129+
A .= Ac
130+
eigh_pullback!(dA, A, (D, V), (dD, nothing); kwargs...)
131+
D .= Dc
132+
MatrixAlgebraKit.zero!(dD)
133+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
134+
end
135+
return Mooncake.CoDual(DV[1].diag, dD_), deigh_vals_adjoint
136+
end
137+
138+
139+
for f in (svd_full!, svd_compact!)
140+
@eval begin
141+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
142+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual; kwargs...)
143+
A, dA = arrayify(A_dA)
144+
Ac = copy(A)
145+
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
146+
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
147+
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
148+
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
149+
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
150+
Uc = copy(U)
151+
Sc = copy(S)
152+
Vᴴc = copy(Vᴴ)
153+
USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...)
154+
minmn = min(size(A)...)
155+
function dsvd_adjoint(::Mooncake.NoRData)
156+
A .= Ac
157+
if ($f == svd_compact!)
158+
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
159+
else # full
160+
vU = view(U, :, 1:minmn)
161+
vS = Diagonal(diagview(S)[1:minmn])
162+
vVᴴ = view(Vᴴ, 1:minmn, :)
163+
vdU = view(dU, :, 1:minmn)
164+
vdS = Diagonal(diagview(dS)[1:minmn])
165+
vdVᴴ = view(dVᴴ, 1:minmn, :)
166+
svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
167+
end
168+
U .= Uc
169+
S .= Sc
170+
Vᴴ .= Vᴴc
171+
MatrixAlgebraKit.zero!(dU)
172+
MatrixAlgebraKit.zero!(dS)
173+
MatrixAlgebraKit.zero!(dVᴴ)
174+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
175+
end
176+
return Mooncake.CoDual(USVᴴ, dUSVᴴ), dsvd_adjoint
177+
end
178+
end
179+
end
180+
181+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
182+
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...)
183+
# compute primal
184+
S_ = Mooncake.primal(S_dS)
185+
dS_ = Mooncake.tangent(S_dS)
186+
A_ = Mooncake.primal(A_dA)
187+
dA_ = Mooncake.tangent(A_dA)
188+
A, dA = arrayify(A_, dA_)
189+
S, dS = arrayify(S_, dS_)
190+
Ac = copy(A)
191+
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
192+
S .= diagview(nS)
193+
function dsvd_vals_adjoint(::Mooncake.NoRData)
194+
A .= Ac
195+
svd_pullback!(dA, A, (U, S, Vᴴ), (nothing, dS, nothing))
196+
MatrixAlgebraKit.zero!(dS)
197+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
198+
end
199+
return S_dS, dsvd_vals_adjoint
200+
end
201+
202+
end

src/common/view.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# diagind: provided by LinearAlgebra.jl
22
diagview(D::Diagonal) = D.diag
33
diagview(D::AbstractMatrix) = view(D, diagind(D))
4+
diagview(D::AbstractVector) = D
45

56
# triangularind
67
function lowertriangularind(A::AbstractMatrix)

src/pullbacks/eig.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function eig_pullback!(
4646
Δgauge < gauge_atol ||
4747
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
4848

49-
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
49+
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, tol))
5050

5151
if !iszerotangent(ΔDmat)
5252
ΔDvec = diagview(ΔDmat)

src/pullbacks/polar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `WP` and
55
cotangent `ΔWP` of `left_polar(A)`.
66
"""
7-
function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP)
7+
function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
88
# Extract the Polar components
99
W, P = WP
1010

@@ -34,7 +34,7 @@ end
3434
Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `PWᴴ`
3535
and cotangent `ΔPWᴴ` of `right_polar(A)`.
3636
"""
37-
function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ)
37+
function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...)
3838
# Extract the Polar components
3939
P, Wᴴ = PWᴴ
4040

src/pullbacks/svd.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,13 @@ function svd_pullback!(
2626
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
2727
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
2828
)
29-
3029
# Extract the SVD components
3130
U, Smat, Vᴴ = USVᴴ
3231
m, n = size(U, 1), size(Vᴴ, 2)
33-
(m, n) == size(ΔA) || throw(DimensionMismatch())
32+
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
3433
minmn = min(m, n)
3534
S = diagview(Smat)
36-
length(S) == minmn || throw(DimensionMismatch())
35+
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
3736
r = searchsortedlast(S, rank_atol; rev = true) # rank
3837
Ur = view(U, :, 1:r)
3938
Vᴴr = view(Vᴴ, 1:r, :)
@@ -44,22 +43,22 @@ function svd_pullback!(
4443
UΔU = fill!(similar(U, (r, r)), 0)
4544
VΔV = fill!(similar(Vᴴ, (r, r)), 0)
4645
if !iszerotangent(ΔU)
47-
m == size(ΔU, 1) || throw(DimensionMismatch())
46+
m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)"))
4847
pU = size(ΔU, 2)
49-
pU > r && throw(DimensionMismatch())
48+
pU > r && throw(DimensionMismatch("second dimension of ΔU ($(size(ΔU, 2))) does not match rank of S ($r)"))
5049
indU = axes(U, 2)[ind]
51-
length(indU) == pU || throw(DimensionMismatch())
50+
length(indU) == pU || throw(DimensionMismatch("length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))"))
5251
UΔUp = view(UΔU, :, indU)
5352
mul!(UΔUp, Ur', ΔU)
5453
# ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU
5554
ΔU = mul!(copy(ΔU), Ur, UΔUp, -1, 1)
5655
end
5756
if !iszerotangent(ΔVᴴ)
58-
n == size(ΔVᴴ, 2) || throw(DimensionMismatch())
57+
n == size(ΔVᴴ, 2) || throw(DimensionMismatch("second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)"))
5958
pV = size(ΔVᴴ, 1)
60-
pV > r && throw(DimensionMismatch())
59+
pV > r && throw(DimensionMismatch("first dimension of ΔVᴴ ($(size(ΔVᴴ, 1))) does not match rank of S ($r)"))
6160
indV = axes(Vᴴ, 1)[ind]
62-
length(indV) == pV || throw(DimensionMismatch())
61+
length(indV) == pV || throw(DimensionMismatch("length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))"))
6362
VΔVp = view(VΔV, :, indV)
6463
mul!(VΔVp, Vᴴr, ΔVᴴ')
6564
# ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ
@@ -82,7 +81,7 @@ function svd_pullback!(
8281
ΔS = diagview(ΔSmat)
8382
pS = length(ΔS)
8483
indS = axes(S, 1)[ind]
85-
length(indS) == pS || throw(DimensionMismatch())
84+
length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))"))
8685
view(diagview(UdΔAV), indS) .+= real.(ΔS)
8786
end
8887
ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA

test/ad_utils.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
function remove_svdgauge_dependence!(
2+
ΔU, ΔVᴴ, U, S, Vᴴ;
3+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
4+
)
5+
gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true)
6+
gaugepart = project_antihermitian!(gaugepart)
7+
gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0
8+
mul!(ΔU, U, gaugepart, -1, 1)
9+
return ΔU, ΔVᴴ
10+
end
11+
function remove_eiggauge_dependence!(
12+
ΔV, D, V;
13+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
14+
)
15+
gaugepart = V' * ΔV
16+
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
17+
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
18+
return ΔV
19+
end
20+
function remove_eighgauge_dependence!(
21+
ΔV, D, V;
22+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
23+
)
24+
gaugepart = V' * ΔV
25+
gaugepart = project_antihermitian!(gaugepart)
26+
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
27+
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
28+
return ΔV
29+
end
30+
31+
precision(::Type{<:Union{Float32, Complex{Float32}}}) = sqrt(eps(Float32))
32+
precision(::Type{<:Union{Float64, Complex{Float64}}}) = sqrt(eps(Float64))

0 commit comments

Comments
 (0)