Skip to content

Commit 61b2efa

Browse files
committed
Test GPU QR in testsuite
1 parent 56db38b commit 61b2efa

File tree

8 files changed

+32
-258
lines changed

8 files changed

+32
-258
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ GenericLinearAlgebra = "0.3.19"
3030
GenericSchur = "0.5.6"
3131
JET = "0.9, 0.10"
3232
LinearAlgebra = "1"
33+
Random = "1"
3334
SafeTestsets = "0.1"
3435
StableRNGs = "1"
3536
Test = "1"
@@ -43,11 +44,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4344
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
4445
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4546
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
47+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4648
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4749
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
4850
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4951
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
5052
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5153

5254
[targets]
53-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"]
55+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Random"]

src/implementations/qr.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,12 @@ function _gpu_unmqr!(
270270
end
271271

272272
function _gpu_qr!(
273-
A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive = false, blocksize = 1
273+
A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; pivoted = false, positive = false, blocksize = 1
274274
)
275275
blocksize > 1 &&
276276
throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition"))
277+
pivoted &&
278+
throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a pivoted implementation for a QR decomposition"))
277279
m, n = size(A)
278280
minmn = min(m, n)
279281
computeR = length(R) > 0
@@ -309,10 +311,12 @@ function _gpu_qr!(
309311
end
310312

311313
function _gpu_qr_null!(
312-
A::AbstractMatrix, N::AbstractMatrix; positive = false, blocksize = 1
314+
A::AbstractMatrix, N::AbstractMatrix; positive = false, blocksize = 1, pivoted = false
313315
)
314316
blocksize > 1 &&
315317
throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition"))
318+
pivoted &&
319+
throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a pivoted implementation for a QR decomposition"))
316320
m, n = size(A)
317321
minmn = min(m, n)
318322
fill!(N, zero(eltype(N)))

test/amd/qr.jl

Lines changed: 0 additions & 123 deletions
This file was deleted.

test/cuda/qr.jl

Lines changed: 0 additions & 123 deletions
This file was deleted.

test/qr.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Test
33
using TestExtras
44
using StableRNGs
55
using LinearAlgebra: diag, I, Diagonal
6+
using CUDA, AMDGPU
67

78
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
89
GenericFloats = (Float16, BigFloat, Complex{BigFloat})
@@ -14,6 +15,14 @@ m = 54
1415
for T in BLASFloats, n in (37, m, 63)
1516
TestSuite.seed_rng!(123)
1617
TestSuite.test_qr(T, (m, n))
18+
if CUDA.functional()
19+
TestSuite.test_qr(CuMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false)
20+
TestSuite.test_qr(Diagonal{T, CuVector{T}}, m; test_pivoted = false, test_blocksize = false)
21+
end
22+
if AMDGPU.functional()
23+
TestSuite.test_qr(ROCMatrix{T}, (m, n); test_pivoted = false, test_blocksize = false)
24+
TestSuite.test_qr(Diagonal{T, ROCVector{T}}, m; test_pivoted = false, test_blocksize = false)
25+
end
1726
end
1827
for T in (BLASFloats..., GenericFloats...)
1928
AT = Diagonal{T, Vector{T}}

test/runtests.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ end
5757

5858
using CUDA
5959
if CUDA.functional()
60-
@safetestset "CUDA QR" begin
61-
include("cuda/qr.jl")
62-
end
6360
@safetestset "CUDA LQ" begin
6461
include("cuda/lq.jl")
6562
end
@@ -85,9 +82,6 @@ end
8582

8683
using AMDGPU
8784
if AMDGPU.functional()
88-
@safetestset "AMDGPU QR" begin
89-
include("amd/qr.jl")
90-
end
9185
@safetestset "AMDGPU LQ" begin
9286
include("amd/lq.jl")
9387
end

test/testsuite/TestSuite.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using MatrixAlgebraKit
1313
using MatrixAlgebraKit: diagview
1414
using LinearAlgebra: Diagonal, norm, istriu
1515
using Random, StableRNGs
16+
using AMDGPU, CUDA
1617

1718
const tests = Dict()
1819

@@ -33,7 +34,11 @@ seed_rng!(seed) = Random.seed!(rng, seed)
3334

3435
instantiate_matrix(::Type{T}, size) where {T <: Number} = randn(rng, T, size)
3536
instantiate_matrix(::Type{AT}, size) where {AT <: Array} = randn(rng, eltype(AT), size)
37+
instantiate_matrix(::Type{AT}, size) where {AT <: CuArray} = CuArray(randn(rng, eltype(AT), size))
38+
instantiate_matrix(::Type{AT}, size) where {AT <: ROCArray} = ROCArray(randn(rng, eltype(AT), size))
3639
instantiate_matrix(::Type{AT}, size) where {AT <: Diagonal} = Diagonal(randn(rng, eltype(AT), size))
40+
instantiate_matrix(::Type{AT}, size) where {T, AT <: Diagonal{T, <:CuVector}} = Diagonal(CuArray(randn(rng, eltype(AT), size)))
41+
instantiate_matrix(::Type{AT}, size) where {T, AT <: Diagonal{T, <:ROCVector}} = Diagonal(ROCArray(randn(rng, eltype(AT), size)))
3742

3843
precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T)))
3944
precision(::Type{T}) where {T} = precision(eltype(T))

test/testsuite/qr.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
function test_qr(T::Type, sz; kwargs...)
22
summary_str = testargs_summary(T, sz)
33
return @testset "qr $summary_str" begin
4-
test_qr_compact(T, sz; kwargs...)
5-
test_qr_full(T, sz; kwargs...)
6-
test_qr_null(T, sz; kwargs...)
4+
if T <: Union{CuArray, ROCArray}
5+
test_qr_compact(T, sz; pivoted = false, kwargs...)
6+
test_qr_full(T, sz; pivoted = false, kwargs...)
7+
test_qr_null(T, sz; pivoted = false, kwargs...)
8+
else
9+
test_qr_compact(T, sz; kwargs...)
10+
test_qr_full(T, sz; kwargs...)
11+
test_qr_null(T, sz; kwargs...)
12+
end
713
end
814
end
915

0 commit comments

Comments
 (0)