Skip to content

Commit 7edcc83

Browse files
committed
Reverse rules for Enzyme
1 parent 4fbc3bf commit 7edcc83

File tree

12 files changed

+899
-88
lines changed

12 files changed

+899
-88
lines changed

Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,22 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1111
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1212
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
13+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1314

1415
[extensions]
1516
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
1617
MatrixAlgebraKitAMDGPUExt = "AMDGPU"
1718
MatrixAlgebraKitCUDAExt = "CUDA"
19+
MatrixAlgebraKitEnzymeExt = "Enzyme"
1820

1921
[compat]
2022
AMDGPU = "2"
2123
Aqua = "0.6, 0.7, 0.8"
2224
ChainRulesCore = "1"
2325
ChainRulesTestUtils = "1"
2426
CUDA = "5"
27+
Enzyme = "0.13.77"
28+
EnzymeTestUtils = "0.2.3"
2529
JET = "0.9, 0.10"
2630
LinearAlgebra = "1"
2731
SafeTestsets = "0.1"
@@ -34,6 +38,7 @@ julia = "1.10"
3438
[extras]
3539
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3640
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
41+
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
3742
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
3843
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3944
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -42,4 +47,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
4247
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4348

4449
[targets]
45-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU"]
50+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "Enzyme", "EnzymeTestUtils"]

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 451 additions & 0 deletions
Large diffs are not rendered by default.

src/common/view.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# diagind: provided by LinearAlgebra.jl
2-
diagview(D::Diagonal) = D.diag
2+
diagview(D::Diagonal) = D.diag
33
diagview(D::AbstractMatrix) = view(D, diagind(D))
44

55
# triangularind

src/implementations/eigh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real =
1919
end
2020

2121
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm)
22-
check_hermitian(A, alg)
22+
#check_hermitian(A, alg)
2323
D, V = DV
2424
m = size(A, 1)
2525
@assert D isa Diagonal && V isa AbstractMatrix

src/implementations/svd.jl

Lines changed: 14 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ end
8989
function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::AbstractAlgorithm)
9090
return similar(A, real(eltype(A)), (min(size(A)...),))
9191
end
92-
function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm)
92+
function initialize_output(::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm)
9393
return initialize_output(svd_compact!, A, alg.alg)
9494
end
9595

@@ -347,46 +347,25 @@ function _gpu_gesvdj!(
347347
)
348348
throw(MethodError(_gpu_gesvdj!, (A, S, U, Vᴴ)))
349349
end
350-
function _gpu_gesvd_maybe_transpose!(A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix)
351-
m, n = size(A)
352-
m n && return _gpu_gesvd!(A, S, U, Vᴴ)
353-
# both CUSOLVER and ROCSOLVER require m ≥ n for gesvd (QR_Iteration)
354-
# if this condition is not met, do the SVD via adjoint
355-
minmn = min(m, n)
356-
Aᴴ = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A')
357-
Uᴴ = similar(U')
358-
V = similar(Vᴴ')
359-
if size(U) == (m, m)
360-
_gpu_gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ)
361-
else
362-
_gpu_gesvd!(Aᴴ, S, V, Uᴴ)
363-
end
364-
length(U) > 0 && adjoint!(U, Uᴴ)
365-
length(Vᴴ) > 0 && adjoint!(Vᴴ, V)
366-
return U, S, Vᴴ
367-
end
368-
369350
# GPU SVD implementation
370-
function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
351+
function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
371352
check_input(svd_full!, A, USVᴴ, alg)
372353
U, S, Vᴴ = USVᴴ
373354
fill!(S, zero(eltype(S)))
374355
m, n = size(A)
375356
minmn = min(m, n)
376-
if minmn == 0
377-
one!(U)
378-
zero!(S)
379-
one!(Vᴴ)
380-
return USVᴴ
381-
end
382357
if alg isa GPU_QRIteration
383358
isempty(alg.kwargs) ||
384-
@warn "GPU_QRIteration does not accept any keyword arguments"
385-
_gpu_gesvd_maybe_transpose!(A, view(S, 1:minmn, 1), U, Vᴴ)
359+
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments"))
360+
_gpu_gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
386361
elseif alg isa GPU_SVDPolar
387362
_gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
388363
elseif alg isa GPU_Jacobi
389364
_gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...)
365+
# elseif alg isa LAPACK_Bisection
366+
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
367+
# elseif alg isa LAPACK_Jacobi
368+
# throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
390369
else
391370
throw(ArgumentError("Unsupported SVD algorithm"))
392371
end
@@ -403,21 +382,16 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
403382
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
404383
# TODO: make this controllable using a `gaugefix` keyword argument
405384
gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...)
406-
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
407-
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
408-
Strunc = diagview(USVᴴtrunc[2])
409-
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
410-
ϵ = sqrt(norm(A)^2 - norm(Strunc)^2) # is there a more accurate way to do this?
411-
return USVᴴtrunc..., ϵ
385+
return first(truncate(svd_trunc!, USVᴴ, alg.trunc))
412386
end
413387

414-
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
388+
function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
415389
check_input(svd_compact!, A, USVᴴ, alg)
416390
U, S, Vᴴ = USVᴴ
417391
if alg isa GPU_QRIteration
418392
isempty(alg.kwargs) ||
419-
@warn "GPU_QRIteration does not accept any keyword arguments"
420-
_gpu_gesvd_maybe_transpose!(A, S.diag, U, Vᴴ)
393+
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments"))
394+
_gpu_gesvd!(A, S.diag, U, Vᴴ)
421395
elseif alg isa GPU_SVDPolar
422396
_gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...)
423397
elseif alg isa GPU_Jacobi
@@ -437,8 +411,8 @@ function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
437411
U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0))
438412
if alg isa GPU_QRIteration
439413
isempty(alg.kwargs) ||
440-
@warn "GPU_QRIteration does not accept any keyword arguments"
441-
_gpu_gesvd_maybe_transpose!(A, S, U, Vᴴ)
414+
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments"))
415+
_gpu_gesvd!(A, S, U, Vᴴ)
442416
elseif alg isa GPU_SVDPolar
443417
_gpu_Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...)
444418
elseif alg isa GPU_Jacobi

src/pullbacks/eig.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ function eig_pullback!(
4848
Δgauge < gauge_atol ||
4949
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
5050

51-
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
51+
VᴴΔV ./= conj.(transpose(D) .- D)
52+
diagview(VᴴΔV) .= zero(eltype(VᴴΔV))
5253

5354
if !iszerotangent(ΔDmat)
5455
ΔDvec = diagview(ΔDmat)

src/pullbacks/polar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `WP` and
55
cotangent `ΔWP` of `left_polar(A)`.
66
"""
7-
function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP)
7+
function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
88
# Extract the Polar components
99
W, P = WP
1010

@@ -34,7 +34,7 @@ end
3434
Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `PWᴴ`
3535
and cotangent `ΔPWᴴ` of `right_polar(A)`.
3636
"""
37-
function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ)
37+
function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...)
3838
# Extract the Polar components
3939
P, Wᴴ = PWᴴ
4040

src/pullbacks/svd.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,13 @@ function svd_pullback!(
2828
degeneracy_atol::Real = tol,
2929
gauge_atol::Real = tol
3030
)
31-
3231
# Extract the SVD components
3332
U, Smat, Vᴴ = USVᴴ
3433
m, n = size(U, 1), size(Vᴴ, 2)
35-
(m, n) == size(ΔA) || throw(DimensionMismatch())
34+
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
3635
minmn = min(m, n)
3736
S = diagview(Smat)
38-
length(S) == minmn || throw(DimensionMismatch())
37+
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
3938
r = searchsortedlast(S, rank_atol; rev = true) # rank
4039
Ur = view(U, :, 1:r)
4140
Vᴴr = view(Vᴴ, 1:r, :)
@@ -46,22 +45,22 @@ function svd_pullback!(
4645
UΔU = fill!(similar(U, (r, r)), 0)
4746
VΔV = fill!(similar(Vᴴ, (r, r)), 0)
4847
if !iszerotangent(ΔU)
49-
m == size(ΔU, 1) || throw(DimensionMismatch())
48+
m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)"))
5049
pU = size(ΔU, 2)
51-
pU > r && throw(DimensionMismatch())
50+
pU > r && throw(DimensionMismatch("second dimension of ΔU ($(size(ΔU, 2))) does not match rank of S ($r)"))
5251
indU = axes(U, 2)[ind]
53-
length(indU) == pU || throw(DimensionMismatch())
52+
length(indU) == pU || throw(DimensionMismatch("length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))"))
5453
UΔUp = view(UΔU, :, indU)
5554
mul!(UΔUp, Ur', ΔU)
5655
# ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU
5756
ΔU = mul!(copy(ΔU), Ur, UΔUp, -1, 1)
5857
end
5958
if !iszerotangent(ΔVᴴ)
60-
n == size(ΔVᴴ, 2) || throw(DimensionMismatch())
59+
n == size(ΔVᴴ, 2) || throw(DimensionMismatch("second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)"))
6160
pV = size(ΔVᴴ, 1)
62-
pV > r && throw(DimensionMismatch())
61+
pV > r && throw(DimensionMismatch("first dimension of ΔVᴴ ($(size(ΔVᴴ, 1))) does not match rank of S ($r)"))
6362
indV = axes(Vᴴ, 1)[ind]
64-
length(indV) == pV || throw(DimensionMismatch())
63+
length(indV) == pV || throw(DimensionMismatch("length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))"))
6564
VΔVp = view(VΔV, :, indV)
6665
mul!(VΔVp, Vᴴr, ΔVᴴ')
6766
# ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ
@@ -84,7 +83,7 @@ function svd_pullback!(
8483
ΔS = diagview(ΔSmat)
8584
pS = length(ΔS)
8685
indS = axes(S, 1)[ind]
87-
length(indS) == pS || throw(DimensionMismatch())
86+
length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))"))
8887
view(diagview(UdΔAV), indS) .+= real.(ΔS)
8988
end
9089
ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA

test/ad_utils.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
function remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ;
2+
degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S))
3+
gaugepart = U' * ΔU + Vᴴ * ΔVᴴ'
4+
gaugepart = (gaugepart - gaugepart') / 2
5+
gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0
6+
mul!(ΔU, U, gaugepart, -1, 1)
7+
return ΔU, ΔVᴴ
8+
end
9+
function remove_eiggauge_dependence!(ΔV, D, V;
10+
degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S))
11+
gaugepart = V' * ΔV
12+
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
13+
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
14+
return ΔV
15+
end
16+
function remove_eighgauge_dependence!(ΔV, D, V;
17+
degeneracy_atol=MatrixAlgebraKit.default_pullback_gaugetol(S))
18+
gaugepart = V' * ΔV
19+
gaugepart = (gaugepart - gaugepart') / 2
20+
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
21+
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
22+
return ΔV
23+
end
24+
25+
precision(::Type{<:Union{Float32,Complex{Float32}}}) = sqrt(eps(Float32))
26+
precision(::Type{<:Union{Float64,Complex{Float64}}}) = sqrt(eps(Float64))

test/chainrules.jl

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,7 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote
66
using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD
77
using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!
88

9-
function remove_svdgauge_dependence!(
10-
ΔU, ΔVᴴ, U, S, Vᴴ;
11-
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
12-
)
13-
gaugepart = U' * ΔU + Vᴴ * ΔVᴴ'
14-
gaugepart = (gaugepart - gaugepart') / 2
15-
gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0
16-
mul!(ΔU, U, gaugepart, -1, 1)
17-
return ΔU, ΔVᴴ
18-
end
19-
function remove_eiggauge_dependence!(
20-
ΔV, D, V;
21-
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D)
22-
)
23-
gaugepart = V' * ΔV
24-
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
25-
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
26-
return ΔV
27-
end
28-
function remove_eighgauge_dependence!(
29-
ΔV, D, V;
30-
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D)
31-
)
32-
gaugepart = V' * ΔV
33-
gaugepart = (gaugepart - gaugepart') / 2
34-
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
35-
mul!(ΔV, V, gaugepart, -1, 1)
36-
return ΔV
37-
end
38-
39-
precision(::Type{<:Union{Float32, Complex{Float32}}}) = sqrt(eps(Float32))
40-
precision(::Type{<:Union{Float64, Complex{Float64}}}) = sqrt(eps(Float64))
9+
include("ad_utils.jl")
4110

4211
for f in
4312
(

0 commit comments

Comments
 (0)