Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 243 additions & 76 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,11 @@ include("pullbacks/eigh.jl")
include("pullbacks/svd.jl")
include("pullbacks/polar.jl")

include("pushforwards/qr.jl")
include("pushforwards/lq.jl")
include("pushforwards/eig.jl")
include("pushforwards/eigh.jl")
include("pushforwards/polar.jl")
include("pushforwards/svd.jl")

end
2 changes: 1 addition & 1 deletion src/common/view.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# diagind: provided by LinearAlgebra.jl
diagview(D::Diagonal) = D.diag
diagview(D::Diagonal) = D.diag
diagview(D::AbstractMatrix) = view(D, diagind(D))

# triangularind
Expand Down
2 changes: 1 addition & 1 deletion src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real =
end

function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm)
check_hermitian(A, alg)
#check_hermitian(A, alg)
D, V = DV
m = size(A, 1)
@assert D isa Diagonal && V isa AbstractMatrix
Expand Down
3 changes: 2 additions & 1 deletion src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ function eig_pullback!(
Δgauge ≤ gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"

VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
VᴴΔV ./= conj.(transpose(D) .- D)
diagview(VᴴΔV) .= zero(eltype(VᴴΔV))

if !iszerotangent(ΔDmat)
ΔDvec = diagview(ΔDmat)
Expand Down
15 changes: 15 additions & 0 deletions src/pushforwards/eig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...)
D, V = DV
ΔD, ΔV = ΔDV
iVΔAV = inv(V) * ΔA * V
diagview(ΔD) .= diagview(iVΔAV)
if !iszerotangent(ΔV)
F = 1 ./ (transpose(diagview(D)) .- diagview(D))
fill!(diagview(F), zero(eltype(F)))
K̇ = F .* iVΔAV
mul!(ΔV, V, K̇, 1, 0)
end
return ΔDV
end

function eig_trunc_pushforward!(ΔA, A, DV, ΔDV; kwargs...) end
19 changes: 19 additions & 0 deletions src/pushforwards/eigh.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
function eigh_pushforward!(dA, A, DV, dDV; kwargs...)
D, V = DV
dD, dV = dDV
tmpV = V \ dA
∂K = tmpV * V
∂Kdiag = diag(∂K)
diagview(dD) .= real.(∂Kdiag)
if !iszerotangent(dV)
dDD = transpose(diagview(D)) .- diagview(D)
F = one(eltype(dDD)) ./ dDD
diagview(F) .= zero(eltype(F))
∂K .*= F
∂V = mul!(tmpV, V, ∂K)
copyto!(dV, ∂V)
end
return (dD, dV)
end

function eigh_trunc_pushforward!(dA, A, DV, dDV; kwargs...) end
7 changes: 7 additions & 0 deletions src/pushforwards/lq.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=default_pullback_gauge_atol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol)
qr_pushforward!(adjoint(dA), adjoint(A), adjoint.(reverse(LQ)), adjoint.(reverse(dLQ)); tol, rank_atol, gauge_atol)
end

function lq_null_pushforward!(dA, A, Nᴴ, dNᴴ; tol::Real=default_pullback_gauge_atol(Nᴴ), rank_atol::Real=tol, gauge_atol::Real=tol)
iszero(min(size(Nᴴ)...)) && return # nothing to do
end
21 changes: 21 additions & 0 deletions src/pushforwards/polar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...)
W, P = WP
ΔW, ΔP = ΔWP
aWdA = adjoint(W) * ΔA
K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA)))
L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W*adjoint(W))*ΔA*inv(P)
ΔW .= W * K̇ + L̇
ΔP .= aWdA - K̇*P
return (ΔW, ΔP)
end

function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...)
P, Wᴴ = PWᴴ
ΔP, ΔWᴴ = ΔPWᴴ
dAW = ΔA * adjoint(Wᴴ)
K̇ = sylvester(P, P, -(dAW - adjoint(dAW)))
L̇ = inv(P)*ΔA*(Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ)
ΔWᴴ .= K̇ * Wᴴ + L̇
ΔP .= dAW - P * K̇
return (ΔWᴴ, ΔP)
end
61 changes: 61 additions & 0 deletions src/pushforwards/qr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
function qr_pushforward!(dA, A, QR, dQR; tol::Real=default_pullback_gauge_atol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol)
Q, R = QR
m = size(A, 1)
n = size(A, 2)
minmn = min(m, n)
Rd = diagview(R)
p = findlast(>=(rank_atol) ∘ abs, Rd)

m1 = p
m2 = minmn - p
m3 = m - minmn
n1 = p
n2 = n - p

Q1 = view(Q, 1:m, 1:m1) # full rank portion
Q2 = view(Q, 1:m, m1+1:m2+m1)
R11 = view(R, 1:m1, 1:n1)
R12 = view(R, 1:m1, n1+1:n)

dA1 = view(dA, 1:m, 1:n1)
dA2 = view(dA, 1:m, (n1 + 1):n)

dQ, dR = dQR
dQ1 = view(dQ, 1:m, 1:m1)
dQ2 = view(dQ, 1:m, m1+1:m2+m1)
dQ3 = minmn+1 < size(dQ, 2) ? view(dQ, :, minmn+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0))
dR11 = view(dR, 1:m1, 1:n1)
dR12 = view(dR, 1:m1, n1+1:n)
dR22 = view(dR, m1+1:m1+m2, n1+1:n)

# fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need
invR11 = inv(R11)
tmp = Q1' * dA1 * invR11
Rtmp = tmp + tmp'
diagview(Rtmp) ./= 2
ltRtmp = view(Rtmp, lowertriangularind(Rtmp))
ltRtmp .= zero(eltype(Rtmp))
dR11 .= Rtmp * R11
dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11
dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12)
if size(Q2, 2) > 0
dQ2 .= -Q1 * (Q1' * Q2)
dQ2 .+= Q2 * (Q2' * dQ2)
end
if m3 > 0 && size(Q, 2) > minmn
# only present for qr_full or rank-deficient qr_compact
Q′ = view(Q, :, 1:minmn)
Q3 = view(Q, :, minmn+1:m)
#dQ3 .= Q′ * (Q′' * Q3)
dQ3 .= Q3
end
if !isempty(dR22)
_, r22 = qr_compact(dA2 - dQ1*R12 - Q1*dR12; positive=true)
dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2))
end
return (dQ, dR)
end

function qr_null_pushforward!(dA, A, N, dN; tol::Real=default_pullback_gauge_atol(N), rank_atol::Real=tol, gauge_atol::Real=tol)
iszero(min(size(N)...)) && return # nothing to do
end
82 changes: 82 additions & 0 deletions src/pushforwards/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; rank_atol=default_pullback_rank_atol(A), kwargs...)
U, Smat, Vᴴ = USVᴴ
m, n = size(U, 1), size(Vᴴ, 2)
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
minmn = min(m, n)
S = diagview(Smat)
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
r = searchsortedlast(S, rank_atol; rev = true) # rank

vΔU = view(ΔU, :, 1:r)
vΔS = view(ΔS, 1:r, 1:r)
vΔVᴴ = view(ΔVᴴ, 1:r, :)

vU = view(U, :, 1:r)
vS = view(S, 1:r)
vSmat = view(Smat, 1:r, 1:r)
vVᴴ = view(Vᴴ, 1:r, :)

# compact region
vV = adjoint(vVᴴ)
UΔAV = vU' * ΔA * vV
copyto!(diagview(vΔS), diag(real.(UΔAV)))
F = one(eltype(S)) ./ (transpose(vS) .- vS)
G = one(eltype(S)) ./ (transpose(vS) .+ vS)
diagview(F) .= zero(eltype(F))
hUΔAV = F .* (UΔAV + UΔAV') ./ 2
aUΔAV = G .* (UΔAV - UΔAV') ./ 2
K̇ = hUΔAV + aUΔAV
Ṁ = hUΔAV - aUΔAV

# check gauge condition
@assert isantihermitian(K̇)
@assert isantihermitian(Ṁ)
K̇diag = diagview(K̇)
for i in 1:length(K̇diag)
@assert K̇diag[i] ≈ (im/2) * imag(diagview(UΔAV)[i])/S[i]
end

∂U = vU * K̇
∂V = vV * Ṁ
# full component
if size(U, 2) > minmn && size(Vᴴ, 1) > minmn
Uperp = view(U, :, minmn+1:m)
Vᴴperp = view(Vᴴ, minmn+1:n, :)

aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp)

UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2)))
fill!(UÃÃV, 0)
view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV
view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV'
rhs = vcat( adjoint(Uperp, ΔA, V), Vᴴperp * ΔA' * U)
superKM = -sylvester(UÃÃV, Smat, rhs)
K̇perp = view(superKM, 1:size(aUAV, 2))
Ṁperp = view(superKM, size(aUAV, 2)+1:size(aUAV, 1)+size(aUAV, 2))
∂U .+= Uperp * K̇perp
∂V .+= Vperp * Ṁperp
else
ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU*vU')
ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV*vVᴴ)
upper = ImUU * ΔA * vV
lower = ImVV * ΔA' * vU
rhs = vcat(upper, lower)

à = ImUU * A * ImVV
ÃÃ = similar(A, (m + n, m + n))
fill!(ÃÃ, 0)
view(ÃÃ, (1:m), m .+ (1:n)) .= Ã
view(ÃÃ, m .+ (1:n), 1:m ) .= Ã'

superLN = -sylvester(ÃÃ, vSmat, rhs)
∂U += view(superLN, 1:size(upper, 1), :)
∂V += view(superLN, size(upper, 1)+1:size(upper,1)+size(lower,1), :)
end
copyto!(vΔU, ∂U)
adjoint!(vΔVᴴ, ∂V)
return (ΔU, ΔS, ΔVᴴ)
end

function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol=default_pullback_rank_atol(A), kwargs...)

end
Loading
Loading