Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]

[sources]
CUDA = {url="https://github.com/juliagpu/cuda.jl", rev="master"}
14 changes: 10 additions & 4 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,23 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::Strid
end

MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A))
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
all(A.diag .== adjoint(A.diag))
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T <: Real} = true
function MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T <: Complex}
return all(isreal.(A.diag))
end
MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; kwargs...) =
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)

MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) =
all(A .== -adjoint(A))
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
all(A.diag .== -adjoint(A.diag))
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T <: Real} = all(iszero(A.diag))
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T <: Complex} =
all(iszero.(real.(A.diag)))
MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; kwargs...) =
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)
function MatrixAlgebraKit.isantihermitian_approx(A::Diagonal{T, <:StridedROCVector{T}}; atol, rtol, kwargs...) where {T <: Real}
return sum(abs2, A.diag) ≤ max(atol, rtol * norm(A))
end

function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
axes(A) == axes(B) || throw(DimensionMismatch())
Expand Down
14 changes: 10 additions & 4 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,23 @@ end

MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) =
all(A .== adjoint(A))
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
all(A.diag .== adjoint(A.diag))
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T <: Real} = true
function MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T <: Complex}
return all(isreal.(A.diag))
end
MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; kwargs...) =
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)

MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) =
all(A .== -adjoint(A))
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
all(A.diag .== -adjoint(A.diag))
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T <: Real} = all(iszero(A.diag))
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T <: Complex} =
all(iszero.(real.(A.diag)))
MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; kwargs...) =
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)
function MatrixAlgebraKit.isantihermitian_approx(A::Diagonal{T, <:StridedCuVector{T}}; atol, rtol, kwargs...) where {T <: Real}
return sum(abs2, A.diag) ≤ max(atol, rtol * norm(A))
end

function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
axes(A) == axes(B) || throw(DimensionMismatch())
Expand Down
2 changes: 1 addition & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MatrixAlgebraKit
using LinearAlgebra: LinearAlgebra
using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
using LinearAlgebra: sylvester, lu!
using LinearAlgebra: sylvester, lu!, diagm
using LinearAlgebra: isposdef, issymmetric
using LinearAlgebra: Diagonal, diag, diagind, isdiag
using LinearAlgebra: UpperTriangular, LowerTriangular
Expand Down
15 changes: 13 additions & 2 deletions src/implementations/projections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ copy_input(::typeof(project_isometric), A) = copy_input(left_polar, A)

function check_input(::typeof(project_hermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
LinearAlgebra.checksquare(A)
n = Base.require_one_based_indexing(A)
Base.require_one_based_indexing(A)
n = size(A, 1)
B === A || @check_size(B, (n, n))
return nothing
end
function check_input(::typeof(project_antihermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
LinearAlgebra.checksquare(A)
n = Base.require_one_based_indexing(A)
Base.require_one_based_indexing(A)
n = size(A, 1)
B === A || @check_size(B, (n, n))
return nothing
end
Expand Down Expand Up @@ -61,6 +63,15 @@ function project_isometric!(A::AbstractMatrix, W, alg::AbstractAlgorithm)
return W
end

function project_hermitian_native!(A::Diagonal, B::Diagonal, ::Val{anti}; kwargs...) where {anti}
if anti
diagview(A) .= imag.(diagview(B)) .* im
else
diagview(A) .= real.(diagview(B))
end
return A
end

function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32)
n = size(A, 1)
for j in 1:blocksize:n
Expand Down
10 changes: 10 additions & 0 deletions src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)

return USVᴴ
end
svd_full!(A::Diagonal, USVᴴ, alg::GPU_SVDAlgorithm) = svd_full!(diagm(A.diag), USVᴴ, alg)

function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
check_input(svd_trunc!, A, USVᴴ, alg.alg)
Expand All @@ -373,6 +374,7 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran

return Utr, Str, Vᴴtr, ϵ
end
svd_trunc!(A::Diagonal, USVᴴ, alg::GPU_SVDAlgorithm) = svd_trunc!(diagm(A.diag), USVᴴ, alg)

function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
check_input(svd_compact!, A, USVᴴ, alg)
Expand All @@ -396,6 +398,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)

return USVᴴ
end
svd_compact!(A::Diagonal, USVᴴ, alg::GPU_SVDAlgorithm) = svd_compact!(diagm(A.diag), USVᴴ, alg)
_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x)))
_largest(x, y) = abs(x) < abs(y) ? y : x

Expand All @@ -418,3 +421,10 @@ function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)

return S
end
function svd_vals!(A::Diagonal, S, alg::GPU_SVDAlgorithm)
check_input(svd_vals!, A, S, alg)
Ad = diagview(A)
S .= abs.(Ad)
sort!(S; rev = true)
return S
end
91 changes: 48 additions & 43 deletions test/amd/projections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,41 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
m = 54
noisefactor = eps(real(T))^(3 / 4)
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
A = ROCArray(randn(rng, T, m, m))
Ah = (A + A') / 2
Aa = (A - A') / 2
Ac = copy(A)
for A in (ROCArray(randn(rng, T, m, m)), Diagonal(ROCArray(randn(rng, T, m))))
Ah = (A + A') / 2
Aa = (A - A') / 2
Ac = copy(A)

Bh = project_hermitian(A, alg)
@test ishermitian(Bh)
@test Bh ≈ Ah
@test A == Ac
Bh_approx = Bh + noisefactor * Aa
@test !ishermitian(Bh_approx)
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
Bh = project_hermitian(A, alg)
@test ishermitian(Bh)
@test Bh ≈ Ah
@test A == Ac
# this is still hermitian for real Diagonals!
Bh_approx = Bh + noisefactor * Aa
if !isa(A, Diagonal) && !(T <: Real)
@test !ishermitian(Bh_approx)
end
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)

Ba = project_antihermitian(A, alg)
@test isantihermitian(Ba)
@test Ba ≈ Aa
@test A == Ac
Ba_approx = Ba + noisefactor * Ah
@test !isantihermitian(Ba_approx)
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
Ba = project_antihermitian(A, alg)
@test isantihermitian(Ba)
@test Ba ≈ Aa
@test A == Ac
Ba_approx = Ba + noisefactor * Ah
@test !isantihermitian(Ba_approx)
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)

Bh = project_hermitian!(Ac, alg)
@test Bh === Ac
@test ishermitian(Bh)
@test Bh ≈ Ah
Bh = project_hermitian!(Ac, alg)
@test Bh === Ac
@test ishermitian(Bh)
@test Bh ≈ Ah

copy!(Ac, A)
Ba = project_antihermitian!(Ac, alg)
@test Ba === Ac
@test isantihermitian(Ba)
@test Ba ≈ Aa
copy!(Ac, A)
Ba = project_antihermitian!(Ac, alg)
@test Ba === Ac
@test isantihermitian(Ba)
@test Ba ≈ Aa
end
end
end

Expand All @@ -54,24 +58,25 @@ end
svdalgs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO
@testset "algorithm $alg" for alg in algs
A = ROCArray(randn(rng, T, m, n))
W = project_isometric(A, alg)
@test isisometric(W)
W2 = project_isometric(W, alg)
@test W2 ≈ W # stability of the projection
@test W * (W' * A) ≈ A
for A in (ROCArray(randn(rng, T, m, n)), Diagonal(ROCArray(randn(rng, T, m))))
W = project_isometric(A, alg)
@test isisometric(W)
W2 = project_isometric(W, alg)
@test W2 ≈ W # stability of the projection
@test W * (W' * A) ≈ A

Ac = similar(A)
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
@test W2 === W
@test isisometric(W)
Ac = similar(A)
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
@test W2 === W
@test isisometric(W)

# test that W is closer to A then any other isometry
for k in 1:10
δA = ROCArray(randn(rng, T, m, n))
W = project_isometric(A, alg)
W2 = project_isometric(A + δA / 100, alg)
@test norm(A - W2) > norm(A - W)
# test that W is closer to A then any other isometry
for k in 1:10
δA = ROCArray(randn(rng, T, size(A)...))
W = project_isometric(A, alg)
W2 = project_isometric(A + δA / 100, alg)
@test norm(A - W2) > norm(A - W)
end
end
end
end
Expand Down
91 changes: 48 additions & 43 deletions test/cuda/projections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,41 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
m = 54
noisefactor = eps(real(T))^(3 / 4)
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
A = CuArray(randn(rng, T, m, m))
Ah = (A + A') / 2
Aa = (A - A') / 2
Ac = copy(A)
for A in (CuArray(randn(rng, T, m, m)), Diagonal(CuArray(randn(rng, T, m))))
Ah = (A + A') / 2
Aa = (A - A') / 2
Ac = copy(A)

Bh = project_hermitian(A, alg)
@test ishermitian(Bh)
@test Bh ≈ Ah
@test A == Ac
Bh_approx = Bh + noisefactor * Aa
@test !ishermitian(Bh_approx)
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
Bh = project_hermitian(A, alg)
@test ishermitian(Bh)
@test Bh ≈ Ah
@test A == Ac
# this is still hermitian for real Diagonals!
Bh_approx = Bh + noisefactor * Aa
if !isa(A, Diagonal) && !(T <: Real)
@test !ishermitian(Bh_approx)
end
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)

Ba = project_antihermitian(A, alg)
@test isantihermitian(Ba)
@test Ba ≈ Aa
@test A == Ac
Ba_approx = Ba + noisefactor * Ah
@test !isantihermitian(Ba_approx)
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
Ba = project_antihermitian(A, alg)
@test isantihermitian(Ba)
@test Ba ≈ Aa
@test A == Ac
Ba_approx = Ba + noisefactor * Ah
@test !isantihermitian(Ba_approx)
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)

Bh = project_hermitian!(Ac, alg)
@test Bh === Ac
@test ishermitian(Bh)
@test Bh ≈ Ah
Bh = project_hermitian!(Ac, alg)
@test Bh === Ac
@test ishermitian(Bh)
@test Bh ≈ Ah

copy!(Ac, A)
Ba = project_antihermitian!(Ac, alg)
@test Ba === Ac
@test isantihermitian(Ba)
@test Ba ≈ Aa
copy!(Ac, A)
Ba = project_antihermitian!(Ac, alg)
@test Ba === Ac
@test isantihermitian(Ba)
@test Ba ≈ Aa
end
end
end

Expand All @@ -54,24 +58,25 @@ end
svdalgs = (CUSOLVER_SVDPolar(), CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO
@testset "algorithm $alg" for alg in algs
A = CuArray(randn(rng, T, m, n))
W = project_isometric(A, alg)
@test isisometric(W)
W2 = project_isometric(W, alg)
@test W2 ≈ W # stability of the projection
@test W * (W' * A) ≈ A
for A in (CuArray(randn(rng, T, m, n)), Diagonal(CuArray(randn(rng, T, m))))
W = project_isometric(A, alg)
@test isisometric(W)
W2 = project_isometric(W, alg)
@test W2 ≈ W # stability of the projection
@test W * (W' * A) ≈ A

Ac = similar(A)
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
@test W2 === W
@test isisometric(W)
Ac = similar(A)
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
@test W2 === W
@test isisometric(W)

# test that W is closer to A then any other isometry
for k in 1:10
δA = CuArray(randn(rng, T, m, n))
W = project_isometric(A, alg)
W2 = project_isometric(A + δA / 100, alg)
@test norm(A - W2) > norm(A - W)
# test that W is closer to A then any other isometry
for k in 1:10
δA = CuArray(randn(rng, T, size(A)...))
W = project_isometric(A, alg)
W2 = project_isometric(A + δA / 100, alg)
@test norm(A - W2) > norm(A - W)
end
end
end
end
Expand Down