Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
MatrixAlgebraKitAMDGPUExt = "AMDGPU"
MatrixAlgebraKitCUDAExt = "CUDA"
MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra"
MatrixAlgebraKitGenericSchurExt = "GenericSchur"
MatrixAlgebraKitMooncakeExt = "Mooncake"

[compat]
AMDGPU = "2"
Expand All @@ -30,6 +32,7 @@ GenericLinearAlgebra = "0.3.19"
GenericSchur = "0.5.6"
JET = "0.9, 0.10"
LinearAlgebra = "1"
Mooncake = "0.4.174"
SafeTestsets = "0.1"
StableRNGs = "1"
Test = "1"
Expand All @@ -43,11 +46,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]
333 changes: 333 additions & 0 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/pullbacks/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `WP` and
cotangent `ΔWP` of `left_polar(A)`.
"""
function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP)
function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
# Extract the Polar components
W, P = WP

Expand Down Expand Up @@ -34,7 +34,7 @@ end
Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `PWᴴ`
and cotangent `ΔPWᴴ` of `right_polar(A)`.
"""
function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ)
function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...)
# Extract the Polar components
P, Wᴴ = PWᴴ

Expand Down
19 changes: 9 additions & 10 deletions src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@ function svd_pullback!(
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
)

# Extract the SVD components
U, Smat, Vᴴ = USVᴴ
m, n = size(U, 1), size(Vᴴ, 2)
(m, n) == size(ΔA) || throw(DimensionMismatch())
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
minmn = min(m, n)
S = diagview(Smat)
length(S) == minmn || throw(DimensionMismatch())
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
r = searchsortedlast(S, rank_atol; rev = true) # rank
Ur = view(U, :, 1:r)
Vᴴr = view(Vᴴ, 1:r, :)
Expand All @@ -44,22 +43,22 @@ function svd_pullback!(
UΔU = fill!(similar(U, (r, r)), 0)
VΔV = fill!(similar(Vᴴ, (r, r)), 0)
if !iszerotangent(ΔU)
m == size(ΔU, 1) || throw(DimensionMismatch())
m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)"))
pU = size(ΔU, 2)
pU > r && throw(DimensionMismatch())
pU > r && throw(DimensionMismatch("second dimension of ΔU ($(size(ΔU, 2))) does not match rank of S ($r)"))
indU = axes(U, 2)[ind]
length(indU) == pU || throw(DimensionMismatch())
length(indU) == pU || throw(DimensionMismatch("length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))"))
UΔUp = view(UΔU, :, indU)
mul!(UΔUp, Ur', ΔU)
# ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU
ΔU = mul!(copy(ΔU), Ur, UΔUp, -1, 1)
end
if !iszerotangent(ΔVᴴ)
n == size(ΔVᴴ, 2) || throw(DimensionMismatch())
n == size(ΔVᴴ, 2) || throw(DimensionMismatch("second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)"))
pV = size(ΔVᴴ, 1)
pV > r && throw(DimensionMismatch())
pV > r && throw(DimensionMismatch("first dimension of ΔVᴴ ($(size(ΔVᴴ, 1))) does not match rank of S ($r)"))
indV = axes(Vᴴ, 1)[ind]
length(indV) == pV || throw(DimensionMismatch())
length(indV) == pV || throw(DimensionMismatch("length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))"))
VΔVp = view(VΔV, :, indV)
mul!(VΔVp, Vᴴr, ΔVᴴ')
# ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ
Expand All @@ -82,7 +81,7 @@ function svd_pullback!(
ΔS = diagview(ΔSmat)
pS = length(ΔS)
indS = axes(S, 1)[ind]
length(indS) == pS || throw(DimensionMismatch())
length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))"))
view(diagview(UdΔAV), indS) .+= real.(ΔS)
end
ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA
Expand Down
31 changes: 31 additions & 0 deletions test/ad_utils.jl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that I have to copy this implementation in TensorKit, and failed horrendously:
https://github.com/QuantumKitHub/TensorKit.jl/blob/bbeb8e510727fd2ce77c579ae9833f5056c85cc8/test/autodiff/ad.jl#L105-L120
I think it would be nice to have this as part of the main package, and if not, to turn this into a "test module" that can be included from within downstream packages (probably mostly by adding some comments to this file to indicate that its location and contents are considered "public" and cannot be moved/removed without breaking changes)

Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
function remove_svdgauge_dependence!(
ΔU, ΔVᴴ, U, S, Vᴴ;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
)
gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true)
gaugepart = project_antihermitian!(gaugepart)
gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0
mul!(ΔU, U, gaugepart, -1, 1)
return ΔU, ΔVᴴ
end
function remove_eiggauge_dependence!(
ΔV, D, V;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
)
gaugepart = V' * ΔV
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how important this is, because it is only used for tests, but V / (V' * V) could probably be replaced with

Q, R = qr_compact(V)
pinvV = rdiv!(Q, UpperTriangular(R)')

for optimal numerical accuracy/stability.

return ΔV
end
function remove_eighgauge_dependence!(
ΔV, D, V;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
)
gaugepart = V' * ΔV
gaugepart = project_antihermitian!(gaugepart)
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
mul!(ΔV, V, gaugepart, -1, 1)
return ΔV
end

precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T)))
33 changes: 1 addition & 32 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,7 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote
using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD
using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!

function remove_svdgauge_dependence!(
ΔU, ΔVᴴ, U, S, Vᴴ;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
)
gaugepart = U' * ΔU + Vᴴ * ΔVᴴ'
gaugepart = (gaugepart - gaugepart') / 2
gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0
mul!(ΔU, U, gaugepart, -1, 1)
return ΔU, ΔVᴴ
end
function remove_eiggauge_dependence!(
ΔV, D, V;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D)
)
gaugepart = V' * ΔV
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
return ΔV
end
function remove_eighgauge_dependence!(
ΔV, D, V;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D)
)
gaugepart = V' * ΔV
gaugepart = (gaugepart - gaugepart') / 2
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
mul!(ΔV, V, gaugepart, -1, 1)
return ΔV
end

precision(::Type{<:Union{Float32, Complex{Float32}}}) = sqrt(eps(Float32))
precision(::Type{<:Union{Float64, Complex{Float64}}}) = sqrt(eps(Float64))
include("ad_utils.jl")

for f in
(
Expand Down
Loading