Skip to content

Commit 611dede

Browse files
committed
Format
1 parent 95406cc commit 611dede

File tree

2 files changed

+77
-76
lines changed

2 files changed

+77
-76
lines changed

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ for (f!, f, pb, adj) in (
4141
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
4242
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
4343
A, dA = arrayify(A_dA)
44-
args = Mooncake.primal(args_dargs)
44+
args = Mooncake.primal(args_dargs)
4545
dargs = Mooncake.tangent(args_dargs)
4646
arg1, darg1 = arrayify(args[1], dargs[1])
4747
arg2, darg2 = arrayify(args[2], dargs[2])
48-
Ac = copy(A)
48+
Ac = copy(A)
4949
arg1c = copy(arg1)
5050
arg2c = copy(arg2)
5151
$f!(A, args, Mooncake.primal(alg_dalg); kwargs...)
@@ -68,8 +68,8 @@ for (f!, f, pb, adj) in (
6868
function $adj(::Mooncake.NoRData)
6969
arg1, arg2 = Mooncake.primal(output_codual)
7070
darg1_, darg2_ = Mooncake.tangent(output_codual)
71-
arg1, darg1 = Mooncake.arrayify(arg1, darg1_)
72-
arg2, darg2 = Mooncake.arrayify(arg2, darg2_)
71+
arg1, darg1 = Mooncake.arrayify(arg1, darg1_)
72+
arg2, darg2 = Mooncake.arrayify(arg2, darg2_)
7373
$pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...)
7474
MatrixAlgebraKit.zero!(darg1)
7575
MatrixAlgebraKit.zero!(darg2)
@@ -89,14 +89,14 @@ for (f!, f, pb, adj) in (
8989
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
9090
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
9191
A, dA = arrayify(A_dA)
92-
Ac = copy(A)
92+
Ac = copy(A)
9393
arg, darg = arrayify(arg_darg)
94-
argc = copy(arg)
94+
argc = copy(arg)
9595
$f!(A, arg, Mooncake.primal(alg_dalg))
9696
function $adj(::Mooncake.NoRData)
9797
A .= Ac
9898
$pb(dA, A, arg, darg; kwargs...)
99-
arg .= argc
99+
arg .= argc
100100
MatrixAlgebraKit.zero!(darg)
101101
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
102102
end
@@ -129,9 +129,9 @@ for (f!, f, full_f, pb, adj) in (
129129
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
130130
function Mooncake.rrule!!(::CoDual{<:typeof($f!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
131131
# compute primal
132-
D_ = Mooncake.primal(D_dD)
132+
D_ = Mooncake.primal(D_dD)
133133
dD_ = Mooncake.tangent(D_dD)
134-
A_ = Mooncake.primal(A_dA)
134+
A_ = Mooncake.primal(A_dA)
135135
dA_ = Mooncake.tangent(A_dA)
136136
A, dA = arrayify(A_, dA_)
137137
D, dD = arrayify(D_, dD_)
@@ -160,7 +160,7 @@ for (f!, f, full_f, pb, adj) in (
160160
function $adj(::Mooncake.NoRData)
161161
D = Mooncake.primal(output_codual)
162162
dD_ = Mooncake.tangent(output_codual)
163-
D, dD = Mooncake.arrayify(D, dD_)
163+
D, dD = Mooncake.arrayify(D, dD_)
164164
$pb(dA, A, (D, V), (dD, nothing); kwargs...)
165165
MatrixAlgebraKit.zero!(dD)
166166
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
@@ -179,17 +179,17 @@ for (f, pb, adj) in (
179179
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.TruncatedAlgorithm}
180180
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual; kwargs...)
181181
# compute primal
182-
A_ = Mooncake.primal(A_dA)
183-
dA_ = Mooncake.tangent(A_dA)
184-
A, dA = arrayify(A_, dA_)
185-
alg = Mooncake.primal(alg_dalg)
182+
A_ = Mooncake.primal(A_dA)
183+
dA_ = Mooncake.tangent(A_dA)
184+
A, dA = arrayify(A_, dA_)
185+
alg = Mooncake.primal(alg_dalg)
186186
output = $f(A, alg)
187187
output_codual = Mooncake.CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
188-
function $adj(dy::Tuple{Mooncake.NoRData, Mooncake.NoRData, T}) where {T<:Real}
188+
function $adj(dy::Tuple{Mooncake.NoRData, Mooncake.NoRData, T}) where {T <: Real}
189189
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
190190
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
191-
D, dD = Mooncake.arrayify(Dtrunc, dDtrunc_)
192-
V, dV = Mooncake.arrayify(Vtrunc, dVtrunc_)
191+
D, dD = Mooncake.arrayify(Dtrunc, dDtrunc_)
192+
V, dV = Mooncake.arrayify(Vtrunc, dVtrunc_)
193193
$pb(dA, A, (D, V), (dD, dV))
194194
MatrixAlgebraKit.zero!(dD)
195195
MatrixAlgebraKit.zero!(dV)
@@ -201,19 +201,20 @@ for (f, pb, adj) in (
201201
end
202202
end
203203

204-
for (f!, f) in ((svd_full!, svd_full),
205-
(svd_compact!, svd_compact),
206-
)
204+
for (f!, f) in (
205+
(svd_full!, svd_full),
206+
(svd_compact!, svd_compact),
207+
)
207208
@eval begin
208209
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), AbstractMatrix, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
209210
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual; kwargs...)
210-
A, dA = arrayify(A_dA)
211-
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
212-
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
213-
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
214-
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
211+
A, dA = arrayify(A_dA)
212+
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
213+
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
214+
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
215+
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
215216
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
216-
output = $f(A, Mooncake.primal(alg_dalg); kwargs...)
217+
output = $f(A, Mooncake.primal(alg_dalg); kwargs...)
217218
function dsvd_adjoint(::Mooncake.NoRData)
218219
if ($f == svd_compact!)
219220
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
@@ -237,7 +238,7 @@ for (f!, f) in ((svd_full!, svd_full),
237238
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
238239
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual; kwargs...)
239240
A, dA = arrayify(A_dA)
240-
USVᴴ = $f(A, Mooncake.primal(alg_dalg); kwargs...)
241+
USVᴴ = $f(A, Mooncake.primal(alg_dalg); kwargs...)
241242
USVᴴ_codual = Mooncake.CoDual(USVᴴ, Mooncake.fdata(Mooncake.zero_tangent(USVᴴ)))
242243
function dsvd_adjoint(::Mooncake.NoRData)
243244
U, S, Vᴴ = Mooncake.primal(USVᴴ_codual)
@@ -271,10 +272,10 @@ end
271272
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
272273
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...)
273274
# compute primal
274-
S_ = Mooncake.primal(S_dS)
275-
dS_ = Mooncake.tangent(S_dS)
276-
A_ = Mooncake.primal(A_dA)
277-
dA_ = Mooncake.tangent(A_dA)
275+
S_ = Mooncake.primal(S_dS)
276+
dS_ = Mooncake.tangent(S_dS)
277+
A_ = Mooncake.primal(A_dA)
278+
dA_ = Mooncake.tangent(A_dA)
278279
A, dA = arrayify(A_, dA_)
279280
S, dS = arrayify(S_, dS_)
280281
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
@@ -290,15 +291,15 @@ end
290291
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm}
291292
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals)}, A_dA::CoDual, alg_dalg::CoDual; kwargs...)
292293
# compute primal
293-
A = Mooncake.primal(A_dA)
294-
dA_ = Mooncake.tangent(A_dA)
295-
A, dA = arrayify(A, dA_)
296-
S = svd_vals(A, Mooncake.primal(alg_dalg); kwargs...)
294+
A = Mooncake.primal(A_dA)
295+
dA_ = Mooncake.tangent(A_dA)
296+
A, dA = arrayify(A, dA_)
297+
S = svd_vals(A, Mooncake.primal(alg_dalg); kwargs...)
297298
U, _, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
298299
S_codual = Mooncake.CoDual(S, Mooncake.fdata(Mooncake.zero_tangent(S)))
299300
function dsvd_vals_adjoint(::Mooncake.NoRData)
300-
S = Mooncake.primal(S_codual)
301-
dS_ = Mooncake.tangent(S_codual)
301+
S = Mooncake.primal(S_codual)
302+
dS_ = Mooncake.tangent(S_codual)
302303
S, dS = Mooncake.arrayify(S, dS_)
303304
svd_pullback!(dA, A, (U, S, Vᴴ), (nothing, dS, nothing))
304305
MatrixAlgebraKit.zero!(dS)
@@ -311,17 +312,17 @@ end
311312
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_trunc), Any, MatrixAlgebraKit.TruncatedAlgorithm}
312313
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual; kwargs...)
313314
# compute primal
314-
A_ = Mooncake.primal(A_dA)
315-
dA_ = Mooncake.tangent(A_dA)
316-
A, dA = arrayify(A_, dA_)
317-
alg = Mooncake.primal(alg_dalg)
315+
A_ = Mooncake.primal(A_dA)
316+
dA_ = Mooncake.tangent(A_dA)
317+
A, dA = arrayify(A_, dA_)
318+
alg = Mooncake.primal(alg_dalg)
318319
output = svd_trunc(A, alg)
319320
output_codual = Mooncake.CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
320-
function dsvd_trunc_adjoint(dy::Tuple{Mooncake.NoRData, Mooncake.NoRData, Mooncake.NoRData, T}) where {T<:Real}
321+
function dsvd_trunc_adjoint(dy::Tuple{Mooncake.NoRData, Mooncake.NoRData, Mooncake.NoRData, T}) where {T <: Real}
321322
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
322323
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
323-
U, dU = Mooncake.arrayify(Utrunc, dUtrunc_)
324-
S, dS = Mooncake.arrayify(Strunc, dStrunc_)
324+
U, dU = Mooncake.arrayify(Utrunc, dUtrunc_)
325+
S, dS = Mooncake.arrayify(Strunc, dStrunc_)
325326
Vᴴ, dVᴴ = Mooncake.arrayify(Vᴴtrunc, dVᴴtrunc_)
326327
svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
327328
MatrixAlgebraKit.zero!(dU)

test/mooncake.jl

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -416,53 +416,53 @@ end
416416
@testset "svd_trunc" begin
417417
@testset for r in 1:4:minmn
418418
U, S, Vᴴ = svd_compact(A)
419-
ΔU = randn(rng, T, m, minmn)
420-
ΔS = randn(rng, real(T), minmn, minmn)
421-
ΔS2 = Diagonal(randn(rng, real(T), minmn))
422-
ΔVᴴ = randn(rng, T, minmn, n)
423-
ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
419+
ΔU = randn(rng, T, m, minmn)
420+
ΔS = randn(rng, real(T), minmn, minmn)
421+
ΔS2 = Diagonal(randn(rng, real(T), minmn))
422+
ΔVᴴ = randn(rng, T, minmn, n)
423+
ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
424424
truncalg = TruncatedAlgorithm(alg, truncrank(r))
425-
ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
426-
Strunc = Diagonal(diagview(S)[ind])
427-
Utrunc = U[:, ind]
428-
Vᴴtrunc = Vᴴ[ind, :]
429-
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
430-
ΔUtrunc = ΔU[:, ind]
425+
ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
426+
Strunc = Diagonal(diagview(S)[ind])
427+
Utrunc = U[:, ind]
428+
Vᴴtrunc = Vᴴ[ind, :]
429+
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
430+
ΔUtrunc = ΔU[:, ind]
431431
ΔVᴴtrunc = ΔVᴴ[ind, :]
432-
dStrunc = make_mooncake_tangent(ΔStrunc)
433-
dUtrunc = make_mooncake_tangent(ΔUtrunc)
432+
dStrunc = make_mooncake_tangent(ΔStrunc)
433+
dUtrunc = make_mooncake_tangent(ΔUtrunc)
434434
dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc)
435-
ϵ = zero(real(T))
435+
ϵ = zero(real(T))
436436
dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ)
437-
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
438-
dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
437+
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
438+
dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
439439
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
440440
Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
441441
#Mooncake.TestUtils.test_rule(rng, svd_trunc!, copy(A), (U, S, Vᴴ), truncalg; mode=Mooncake.ReverseMode, output_tangent=dUSVᴴerr, atol=atol, rtol=rtol, is_primitive=false)
442442
test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
443443
end
444444
@testset "trunctol" begin
445445
U, S, Vᴴ = svd_compact(A)
446-
ΔU = randn(rng, T, m, minmn)
447-
ΔS = randn(rng, real(T), minmn, minmn)
448-
ΔS2 = Diagonal(randn(rng, real(T), minmn))
449-
ΔVᴴ = randn(rng, T, minmn, n)
450-
ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
446+
ΔU = randn(rng, T, m, minmn)
447+
ΔS = randn(rng, real(T), minmn, minmn)
448+
ΔS2 = Diagonal(randn(rng, real(T), minmn))
449+
ΔVᴴ = randn(rng, T, minmn, n)
450+
ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
451451
truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2))
452-
ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
453-
Strunc = Diagonal(diagview(S)[ind])
454-
Utrunc = U[:, ind]
455-
Vᴴtrunc = Vᴴ[ind, :]
456-
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
457-
ΔUtrunc = ΔU[:, ind]
452+
ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
453+
Strunc = Diagonal(diagview(S)[ind])
454+
Utrunc = U[:, ind]
455+
Vᴴtrunc = Vᴴ[ind, :]
456+
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
457+
ΔUtrunc = ΔU[:, ind]
458458
ΔVᴴtrunc = ΔVᴴ[ind, :]
459-
dStrunc = make_mooncake_tangent(ΔStrunc)
460-
dUtrunc = make_mooncake_tangent(ΔUtrunc)
459+
dStrunc = make_mooncake_tangent(ΔStrunc)
460+
dUtrunc = make_mooncake_tangent(ΔUtrunc)
461461
dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc)
462-
ϵ = zero(real(T))
462+
ϵ = zero(real(T))
463463
dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ)
464-
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
465-
dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
464+
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
465+
dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
466466
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
467467
Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
468468
#Mooncake.TestUtils.test_rule(rng, svd_trunc!, copy(A), (U, S, Vᴴ), truncalg; mode=Mooncake.ReverseMode, output_tangent=dUSVᴴerr, atol=atol, rtol=rtol, is_primitive=false)

0 commit comments

Comments
 (0)