Skip to content

Commit 0465dae

Browse files
committed
patch through leftorth
1 parent 2ced266 commit 0465dae

File tree

1 file changed

+43
-39
lines changed

1 file changed

+43
-39
lines changed

src/tensors/factorizations.jl

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -311,54 +311,58 @@ end
311311
#------------------------------------------------------------------------------------------
312312
const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}}
313313

314+
function _reverse!(t::AbstractTensorMap; dims=:)
315+
for (c, b) in blocks(t)
316+
reverse!(b; dims)
317+
end
318+
return t
319+
end
320+
314321
function leftorth!(t::TensorMap{<:RealOrComplexFloat};
315322
alg::Union{QR,QRpos,QL,QLpos,SVD,SDD,Polar}=QRpos(),
316323
atol::Real=zero(float(real(scalartype(t)))),
317324
rtol::Real=(alg (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
318325
eps(real(float(one(scalartype(t))))) * iszero(atol))
319326
InnerProductStyle(t) === EuclideanInnerProduct() ||
320327
throw_invalid_innerproduct(:leftorth!)
321-
if !iszero(rtol)
322-
atol = max(atol, rtol * norm(t))
323-
end
324-
I = sectortype(t)
325-
dims = SectorDict{I,Int}()
326328

327-
# compute QR factorization for each block
328-
if !isempty(blocks(t))
329-
generator = Base.Iterators.map(blocks(t)) do (c, b)
330-
Qc, Rc = MatrixAlgebra.leftorth!(b, alg, atol)
331-
dims[c] = size(Qc, 2)
332-
return c => (Qc, Rc)
329+
if alg isa QR
330+
return left_orth!(t; kind=:qr, atol, rtol)
331+
elseif alg isa QRpos
332+
return left_orth!(t; kind=:qrpos, atol, rtol)
333+
elseif alg isa SDD
334+
return left_orth!(t; kind=:svd, atol, rtol)
335+
elseif alg isa Polar
336+
return left_orth!(t; kind=:polar, atol, rtol)
337+
elseif alg isa SVD
338+
kind = :svd
339+
if iszero(atol) && iszero(rtol)
340+
alg′ = LAPACK_QRIteration()
341+
return left_orth!(t; kind, alg=BlockAlgorithm(alg′, default_blockscheduler(t)),
342+
atol, rtol)
343+
else
344+
trunc = MatrixAlgebraKit.TruncationKeepAbove(atol, rtol)
345+
svd_alg = LAPACK_QRIteration()
346+
scheduler = default_blockscheduler(t)
347+
alg′ = MatrixAlgebraKit.TruncatedAlgorithm(BlockAlgorithm(svd_alg, scheduler),
348+
trunc)
349+
return left_orth!(t; kind, alg=alg′, atol, rtol)
333350
end
334-
QRdata = SectorDict(generator)
351+
elseif alg isa QL
352+
_reverse!(t; dims=2)
353+
Q, R = left_orth!(t; kind=:qr, atol, rtol)
354+
_reverse!(Q; dims=2)
355+
_reverse!(R)
356+
return Q, R
357+
elseif alg isa QLpos
358+
_reverse!(t; dims=2)
359+
Q, R = left_orth!(t; kind=:qrpos, atol, rtol)
360+
_reverse!(Q; dims=2)
361+
_reverse!(R)
362+
return Q, R
335363
end
336364

337-
# construct new space
338-
S = spacetype(t)
339-
V = S(dims)
340-
if alg isa Polar
341-
@assert V domain(t)
342-
W = domain(t)
343-
elseif length(domain(t)) == 1 && domain(t) V
344-
W = domain(t)
345-
elseif length(codomain(t)) == 1 && codomain(t) V
346-
W = codomain(t)
347-
else
348-
W = ProductSpace(V)
349-
end
350-
351-
# construct output tensors
352-
T = float(scalartype(t))
353-
Q = similar(t, T, codomain(t) W)
354-
R = similar(t, T, W domain(t))
355-
if !isempty(blocks(t))
356-
for (c, (Qc, Rc)) in QRdata
357-
copy!(block(Q, c), Qc)
358-
copy!(block(R, c), Rc)
359-
end
360-
end
361-
return Q, R
365+
throw(ArgumentError("Algorithm $alg not implemented for leftorth!"))
362366
end
363367

364368
function leftnull!(t::TensorMap{<:RealOrComplexFloat};
@@ -685,8 +689,8 @@ function LinearAlgebra.ishermitian(t::TensorMap)
685689
end
686690

687691
function LinearAlgebra.isposdef!(t::TensorMap)
688-
domain(t) == codomain(t) ||
689-
throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same"))
692+
domain(t) codomain(t) ||
693+
throw(SpaceMismatch("`isposdef` requires domain and codomain to be isomorphic"))
690694
InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false
691695
for (c, b) in blocks(t)
692696
isposdef!(b) || return false

0 commit comments

Comments
 (0)