Skip to content

Commit d151045

Browse files
committed
A few more updates for GPU compatibility for TensorKit
1 parent eceef30 commit d151045

File tree

8 files changed

+143
-97
lines changed

8 files changed

+143
-97
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5555

5656
[targets]
5757
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"]
58+
59+
[sources]
60+
CUDA = {url="https://github.com/juliagpu/cuda.jl", rev="master"}

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,17 +128,23 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::Strid
128128
end
129129

130130
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A))
131-
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
132-
all(A.diag .== adjoint(A.diag))
131+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T <: Real} = true
132+
function MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T <: Complex}
133+
return all(isreal.(A.diag))
134+
end
133135
MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; kwargs...) =
134136
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)
135137

136138
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) =
137139
all(A .== -adjoint(A))
138-
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
139-
all(A.diag .== -adjoint(A.diag))
140+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T <: Real} = all(iszero(A.diag))
141+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T <: Complex} =
142+
all(iszero.(real.(A.diag)))
140143
MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; kwargs...) =
141144
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)
145+
function MatrixAlgebraKit.isantihermitian_approx(A::Diagonal{T, <:StridedROCVector{T}}; atol, rtol, kwargs...) where {T <: Real}
146+
return sum(abs2, A.diag) max(atol, rtol * norm(A))
147+
end
142148

143149
function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
144150
axes(A) == axes(B) || throw(DimensionMismatch())

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,17 +153,23 @@ end
153153

154154
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) =
155155
all(A .== adjoint(A))
156-
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
157-
all(A.diag .== adjoint(A.diag))
156+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T <: Real} = true
157+
function MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T <: Complex}
158+
return all(isreal.(A.diag))
159+
end
158160
MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; kwargs...) =
159161
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)
160162

161163
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) =
162164
all(A .== -adjoint(A))
163-
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
164-
all(A.diag .== -adjoint(A.diag))
165+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T <: Real} = all(iszero(A.diag))
166+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T <: Complex} =
167+
all(iszero.(real.(A.diag)))
165168
MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; kwargs...) =
166169
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)
170+
function MatrixAlgebraKit.isantihermitian_approx(A::Diagonal{T, <:StridedCuVector{T}}; atol, rtol, kwargs...) where {T <: Real}
171+
return sum(abs2, A.diag) max(atol, rtol * norm(A))
172+
end
167173

168174
function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
169175
axes(A) == axes(B) || throw(DimensionMismatch())

src/MatrixAlgebraKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module MatrixAlgebraKit
33
using LinearAlgebra: LinearAlgebra
44
using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
55
using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
6-
using LinearAlgebra: sylvester, lu!
6+
using LinearAlgebra: sylvester, lu!, diagm
77
using LinearAlgebra: isposdef, issymmetric
88
using LinearAlgebra: Diagonal, diag, diagind, isdiag
99
using LinearAlgebra: UpperTriangular, LowerTriangular

src/implementations/projections.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ copy_input(::typeof(project_isometric), A) = copy_input(left_polar, A)
99

1010
function check_input(::typeof(project_hermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
1111
LinearAlgebra.checksquare(A)
12-
n = Base.require_one_based_indexing(A)
12+
Base.require_one_based_indexing(A)
13+
n = size(A, 1)
1314
B === A || @check_size(B, (n, n))
1415
return nothing
1516
end
1617
function check_input(::typeof(project_antihermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
1718
LinearAlgebra.checksquare(A)
18-
n = Base.require_one_based_indexing(A)
19+
Base.require_one_based_indexing(A)
20+
n = size(A, 1)
1921
B === A || @check_size(B, (n, n))
2022
return nothing
2123
end
@@ -61,6 +63,15 @@ function project_isometric!(A::AbstractMatrix, W, alg::AbstractAlgorithm)
6163
return W
6264
end
6365

66+
function project_hermitian_native!(A::Diagonal, B::Diagonal, ::Val{anti}; kwargs...) where {anti}
67+
if anti
68+
diagview(A) .= imag.(diagview(B)) .* im
69+
else
70+
diagview(A) .= real.(diagview(B))
71+
end
72+
return A
73+
end
74+
6475
function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32)
6576
n = size(A, 1)
6677
for j in 1:blocksize:n

src/implementations/svd.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
356356

357357
return USVᴴ
358358
end
359+
svd_full!(A::Diagonal, USVᴴ, alg::GPU_SVDAlgorithm) = svd_full!(diagm(A.diag), USVᴴ, alg)
359360

360361
function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
361362
check_input(svd_trunc!, A, USVᴴ, alg.alg)
@@ -373,6 +374,7 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
373374

374375
return Utr, Str, Vᴴtr, ϵ
375376
end
377+
svd_trunc!(A::Diagonal, USVᴴ, alg::GPU_SVDAlgorithm) = svd_trunc!(diagm(A.diag), USVᴴ, alg)
376378

377379
function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
378380
check_input(svd_compact!, A, USVᴴ, alg)
@@ -396,6 +398,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
396398

397399
return USVᴴ
398400
end
401+
svd_compact!(A::Diagonal, USVᴴ, alg::GPU_SVDAlgorithm) = svd_compact!(diagm(A.diag), USVᴴ, alg)
399402
_argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x)))
400403
_largest(x, y) = abs(x) < abs(y) ? y : x
401404

@@ -418,3 +421,10 @@ function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
418421

419422
return S
420423
end
424+
function svd_vals!(A::Diagonal, S, alg::GPU_SVDAlgorithm)
425+
check_input(svd_vals!, A, S, alg)
426+
Ad = diagview(A)
427+
S .= abs.(Ad)
428+
sort!(S; rev = true)
429+
return S
430+
end

test/amd/projections.jl

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,41 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
1212
m = 54
1313
noisefactor = eps(real(T))^(3 / 4)
1414
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
15-
A = ROCArray(randn(rng, T, m, m))
16-
Ah = (A + A') / 2
17-
Aa = (A - A') / 2
18-
Ac = copy(A)
15+
for A in (ROCArray(randn(rng, T, m, m)), Diagonal(ROCArray(randn(rng, T, m))))
16+
Ah = (A + A') / 2
17+
Aa = (A - A') / 2
18+
Ac = copy(A)
1919

20-
Bh = project_hermitian(A, alg)
21-
@test ishermitian(Bh)
22-
@test Bh Ah
23-
@test A == Ac
24-
Bh_approx = Bh + noisefactor * Aa
25-
@test !ishermitian(Bh_approx)
26-
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
20+
Bh = project_hermitian(A, alg)
21+
@test ishermitian(Bh)
22+
@test Bh Ah
23+
@test A == Ac
24+
# this is still hermitian for real Diagonals!
25+
Bh_approx = Bh + noisefactor * Aa
26+
if !isa(A, Diagonal) && !(T <: Real)
27+
@test !ishermitian(Bh_approx)
28+
end
29+
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
2730

28-
Ba = project_antihermitian(A, alg)
29-
@test isantihermitian(Ba)
30-
@test Ba Aa
31-
@test A == Ac
32-
Ba_approx = Ba + noisefactor * Ah
33-
@test !isantihermitian(Ba_approx)
34-
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
31+
Ba = project_antihermitian(A, alg)
32+
@test isantihermitian(Ba)
33+
@test Ba Aa
34+
@test A == Ac
35+
Ba_approx = Ba + noisefactor * Ah
36+
@test !isantihermitian(Ba_approx)
37+
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
3538

36-
Bh = project_hermitian!(Ac, alg)
37-
@test Bh === Ac
38-
@test ishermitian(Bh)
39-
@test Bh Ah
39+
Bh = project_hermitian!(Ac, alg)
40+
@test Bh === Ac
41+
@test ishermitian(Bh)
42+
@test Bh Ah
4043

41-
copy!(Ac, A)
42-
Ba = project_antihermitian!(Ac, alg)
43-
@test Ba === Ac
44-
@test isantihermitian(Ba)
45-
@test Ba Aa
44+
copy!(Ac, A)
45+
Ba = project_antihermitian!(Ac, alg)
46+
@test Ba === Ac
47+
@test isantihermitian(Ba)
48+
@test Ba Aa
49+
end
4650
end
4751
end
4852

@@ -54,24 +58,25 @@ end
5458
svdalgs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
5559
algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO
5660
@testset "algorithm $alg" for alg in algs
57-
A = ROCArray(randn(rng, T, m, n))
58-
W = project_isometric(A, alg)
59-
@test isisometric(W)
60-
W2 = project_isometric(W, alg)
61-
@test W2 W # stability of the projection
62-
@test W * (W' * A) A
61+
for A in (ROCArray(randn(rng, T, m, n)), Diagonal(ROCArray(randn(rng, T, m))))
62+
W = project_isometric(A, alg)
63+
@test isisometric(W)
64+
W2 = project_isometric(W, alg)
65+
@test W2 W # stability of the projection
66+
@test W * (W' * A) A
6367

64-
Ac = similar(A)
65-
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
66-
@test W2 === W
67-
@test isisometric(W)
68+
Ac = similar(A)
69+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
70+
@test W2 === W
71+
@test isisometric(W)
6872

69-
# test that W is closer to A then any other isometry
70-
for k in 1:10
71-
δA = ROCArray(randn(rng, T, m, n))
72-
W = project_isometric(A, alg)
73-
W2 = project_isometric(A + δA / 100, alg)
74-
@test norm(A - W2) > norm(A - W)
73+
# test that W is closer to A then any other isometry
74+
for k in 1:10
75+
δA = ROCArray(randn(rng, T, size(A)...))
76+
W = project_isometric(A, alg)
77+
W2 = project_isometric(A + δA / 100, alg)
78+
@test norm(A - W2) > norm(A - W)
79+
end
7580
end
7681
end
7782
end

test/cuda/projections.jl

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,41 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
1212
m = 54
1313
noisefactor = eps(real(T))^(3 / 4)
1414
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
15-
A = CuArray(randn(rng, T, m, m))
16-
Ah = (A + A') / 2
17-
Aa = (A - A') / 2
18-
Ac = copy(A)
15+
for A in (CuArray(randn(rng, T, m, m)), Diagonal(CuArray(randn(rng, T, m))))
16+
Ah = (A + A') / 2
17+
Aa = (A - A') / 2
18+
Ac = copy(A)
1919

20-
Bh = project_hermitian(A, alg)
21-
@test ishermitian(Bh)
22-
@test Bh Ah
23-
@test A == Ac
24-
Bh_approx = Bh + noisefactor * Aa
25-
@test !ishermitian(Bh_approx)
26-
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
20+
Bh = project_hermitian(A, alg)
21+
@test ishermitian(Bh)
22+
@test Bh Ah
23+
@test A == Ac
24+
# this is still hermitian for real Diagonals!
25+
Bh_approx = Bh + noisefactor * Aa
26+
if !isa(A, Diagonal) && !(T <: Real)
27+
@test !ishermitian(Bh_approx)
28+
end
29+
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
2730

28-
Ba = project_antihermitian(A, alg)
29-
@test isantihermitian(Ba)
30-
@test Ba Aa
31-
@test A == Ac
32-
Ba_approx = Ba + noisefactor * Ah
33-
@test !isantihermitian(Ba_approx)
34-
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
31+
Ba = project_antihermitian(A, alg)
32+
@test isantihermitian(Ba)
33+
@test Ba Aa
34+
@test A == Ac
35+
Ba_approx = Ba + noisefactor * Ah
36+
@test !isantihermitian(Ba_approx)
37+
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
3538

36-
Bh = project_hermitian!(Ac, alg)
37-
@test Bh === Ac
38-
@test ishermitian(Bh)
39-
@test Bh Ah
39+
Bh = project_hermitian!(Ac, alg)
40+
@test Bh === Ac
41+
@test ishermitian(Bh)
42+
@test Bh Ah
4043

41-
copy!(Ac, A)
42-
Ba = project_antihermitian!(Ac, alg)
43-
@test Ba === Ac
44-
@test isantihermitian(Ba)
45-
@test Ba Aa
44+
copy!(Ac, A)
45+
Ba = project_antihermitian!(Ac, alg)
46+
@test Ba === Ac
47+
@test isantihermitian(Ba)
48+
@test Ba Aa
49+
end
4650
end
4751
end
4852

@@ -54,24 +58,25 @@ end
5458
svdalgs = (CUSOLVER_SVDPolar(), CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
5559
algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO
5660
@testset "algorithm $alg" for alg in algs
57-
A = CuArray(randn(rng, T, m, n))
58-
W = project_isometric(A, alg)
59-
@test isisometric(W)
60-
W2 = project_isometric(W, alg)
61-
@test W2 W # stability of the projection
62-
@test W * (W' * A) A
61+
for A in (CuArray(randn(rng, T, m, n)), Diagonal(CuArray(randn(rng, T, m))))
62+
W = project_isometric(A, alg)
63+
@test isisometric(W)
64+
W2 = project_isometric(W, alg)
65+
@test W2 W # stability of the projection
66+
@test W * (W' * A) A
6367

64-
Ac = similar(A)
65-
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
66-
@test W2 === W
67-
@test isisometric(W)
68+
Ac = similar(A)
69+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
70+
@test W2 === W
71+
@test isisometric(W)
6872

69-
# test that W is closer to A then any other isometry
70-
for k in 1:10
71-
δA = CuArray(randn(rng, T, m, n))
72-
W = project_isometric(A, alg)
73-
W2 = project_isometric(A + δA / 100, alg)
74-
@test norm(A - W2) > norm(A - W)
73+
# test that W is closer to A then any other isometry
74+
for k in 1:10
75+
δA = CuArray(randn(rng, T, size(A)...))
76+
W = project_isometric(A, alg)
77+
W2 = project_isometric(A + δA / 100, alg)
78+
@test norm(A - W2) > norm(A - W)
79+
end
7580
end
7681
end
7782
end

0 commit comments

Comments
 (0)