From 6da83c44ccd2f8060b7d2eb3408fc8a7b6bbd2e1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 22 Jun 2021 11:38:48 -0400 Subject: [PATCH 1/8] sum zero-arrays --- src/host/mapreduce.jl | 3 +++ test/testsuite/reductions.jl | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/host/mapreduce.jl b/src/host/mapreduce.jl index cb5b489e..1be18263 100644 --- a/src/host/mapreduce.jl +++ b/src/host/mapreduce.jl @@ -87,6 +87,9 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)), @eval begin Base.$(fname!)(f::Function, r::AnyGPUArray, A::AnyGPUArray{T}) where T = GPUArrays.mapreducedim!(f, $(op), r, A; init=neutral_element($(op), T)) + + Base.$fname(r::AnyGPUArray{<:Any,0}) = @allowscalar r[] + Base.$fname(f::Function, r::AnyGPUArray{<:Any,0}) = f(@allowscalar r[]) end end diff --git a/test/testsuite/reductions.jl b/test/testsuite/reductions.jl index a07ce04e..8874c64b 100644 --- a/test/testsuite/reductions.jl +++ b/test/testsuite/reductions.jl @@ -140,3 +140,20 @@ end @test A != B end end + +@testsuite "reductions/zero-arrays" AT->begin + @testset "$ET" for ET in supported_eltypes() + range = ET <: Real ? (ET(1):ET(10)) : ET + # sum + @test compare(A->sum(A), AT, reshape(rand(range, 1))) + @test compare(A->sum(abs, A), AT, reshape(rand(range, 1))) + # other functions, defined together + @test compare(A->prod(A), AT, reshape(rand(range, 1))) + @test compare(A->max(A), AT, reshape(rand(range, 1))) + @test compare(A->any(_->true, A), AT, reshape(rand(range, 1))) + @test compare(A->all(_->false, A), AT, reshape(rand(range, 1))) + # zero-dimensional view + @test compare(A->sum(A), AT, view(rand(range, 3),2)) + @test compare(A->prod(sqrt, A), AT, view(rand(range, 3),2)) + end +end From 51c25522445ca18178bb47186b28be5afe6095f4 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 22 Jun 2021 12:05:03 -0400 Subject: [PATCH 2/8] max -> maximum --- test/testsuite/reductions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/testsuite/reductions.jl b/test/testsuite/reductions.jl index 8874c64b..85feb356 100644 --- a/test/testsuite/reductions.jl +++ b/test/testsuite/reductions.jl @@ -149,7 +149,7 @@ end @test compare(A->sum(abs, A), AT, reshape(rand(range, 1))) # other functions, defined together @test compare(A->prod(A), AT, reshape(rand(range, 1))) - @test compare(A->max(A), AT, reshape(rand(range, 1))) + @test compare(A->maximum(A), AT, reshape(rand(range, 1))) @test compare(A->any(_->true, A), AT, reshape(rand(range, 1))) @test compare(A->all(_->false, A), AT, reshape(rand(range, 1))) # zero-dimensional view From ede96db7c00730efab22f7cf2a4ba7da57642f02 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 22 Jun 2021 14:10:37 -0400 Subject: [PATCH 3/8] better way, allows init --- src/host/mapreduce.jl | 6 +++--- test/testsuite/reductions.jl | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/host/mapreduce.jl b/src/host/mapreduce.jl index 1be18263..fd1ff205 100644 --- a/src/host/mapreduce.jl +++ b/src/host/mapreduce.jl @@ -87,10 +87,10 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)), @eval begin Base.$(fname!)(f::Function, r::AnyGPUArray, A::AnyGPUArray{T}) where T = GPUArrays.mapreducedim!(f, $(op), r, A; init=neutral_element($(op), T)) - - Base.$fname(r::AnyGPUArray{<:Any,0}) = @allowscalar r[] - Base.$fname(f::Function, r::AnyGPUArray{<:Any,0}) = f(@allowscalar r[]) end end +Base._mapreduce_dim(f, op, init, A::AnyGPUArray{<:Any,0}, ::Colon) = op(f(@allowscalar A[]), init) +Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AnyGPUArray{<:Any,0}, ::Colon) = f(@allowscalar A[]) + LinearAlgebra.ishermitian(A::AbstractGPUMatrix) = mapreduce(==, &, A, adjoint(A)) diff --git a/test/testsuite/reductions.jl b/test/testsuite/reductions.jl index 85feb356..7b8e3195 100644 --- a/test/testsuite/reductions.jl +++ b/test/testsuite/reductions.jl @@ -147,6 +147,7 @@ end # sum @test compare(A->sum(A), AT, reshape(rand(range, 1))) @test compare(A->sum(abs, A), AT, reshape(rand(range, 1))) + @test compare(A->sum(A, init=ET(13)), AT, reshape(rand(range, 1))) # other functions, defined together @test compare(A->prod(A), AT, reshape(rand(range, 1))) @test compare(A->maximum(A), AT, reshape(rand(range, 1))) From 6bedf05e347aa1045e4efc6560a9eb4fba43b294 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 22 Jun 2021 14:39:32 -0400 Subject: [PATCH 4/8] skip tests --- test/testsuite/reductions.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/testsuite/reductions.jl b/test/testsuite/reductions.jl index 7b8e3195..510085f8 100644 --- a/test/testsuite/reductions.jl +++ b/test/testsuite/reductions.jl @@ -147,10 +147,11 @@ end # sum @test compare(A->sum(A), AT, reshape(rand(range, 1))) @test compare(A->sum(abs, A), AT, reshape(rand(range, 1))) - @test compare(A->sum(A, init=ET(13)), AT, reshape(rand(range, 1))) + if VERSION >= v"1.6" + @test compare(A->sum(A, init=ET(13)), AT, reshape(rand(range, 1))) + end # other functions, defined together @test compare(A->prod(A), AT, reshape(rand(range, 1))) - @test compare(A->maximum(A), AT, reshape(rand(range, 1))) @test compare(A->any(_->true, A), AT, reshape(rand(range, 1))) @test compare(A->all(_->false, A), AT, reshape(rand(range, 1))) # zero-dimensional view From 1a2af66cc33d3c625bc8bc237dc4f28478b26d69 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 22 Jun 2021 16:28:38 -0400 Subject: [PATCH 5/8] zero-dim printing? --- src/host/abstractarray.jl | 2 ++ test/testsuite/io.jl | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/host/abstractarray.jl b/src/host/abstractarray.jl index 440fa836..46a76c8c 100644 --- a/src/host/abstractarray.jl +++ b/src/host/abstractarray.jl @@ -55,6 +55,8 @@ Base._show_empty(io::IO, X::AnyGPUArray) = Base._show_empty(io, adapt(ToArray(), X)) Base.show_vector(io::IO, v::AnyGPUArray, args...) = Base.show_vector(io, adapt(ToArray(), v), args...) +Base.show_zero_dim(io::IO, X::AnyGPUArray{T, 0}) where T + Base.show_zero_dim(io, adapt(ToArray(), X)) ## collect to CPU (discarding wrapper type) diff --git a/test/testsuite/io.jl b/test/testsuite/io.jl index 7d0c58e4..83c29f3f 100644 --- a/test/testsuite/io.jl +++ b/test/testsuite/io.jl @@ -30,5 +30,9 @@ @test occursin(Regex("^1×1 Adjoint{Int64,\\s?$AT{Int64,\\s?1}}:\n 1\$"), msg) || occursin(Regex("^1×1 LinearAlgebra.Adjoint{Int64,\\s?$AT{Int64,\\s?1}}:\n 1\$"), msg) || occursin(Regex("^1×1 adjoint\\(::$AT{Int64,\\s?1}\\) with eltype Int64:\n 1\$"), msg) + + C = AT(fill(1)) + msg = showstr(A) + @test msg == "fill(1)" end end From 4dab8d231b1953eb03ec37dd42f690de66399436 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 22 Jun 2021 17:19:52 -0400 Subject: [PATCH 6/8] don't forget the = sign --- src/host/abstractarray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/host/abstractarray.jl b/src/host/abstractarray.jl index 46a76c8c..d0588e80 100644 --- a/src/host/abstractarray.jl +++ b/src/host/abstractarray.jl @@ -55,7 +55,7 @@ Base._show_empty(io::IO, X::AnyGPUArray) = Base._show_empty(io, adapt(ToArray(), X)) Base.show_vector(io::IO, v::AnyGPUArray, args...) = Base.show_vector(io, adapt(ToArray(), v), args...) -Base.show_zero_dim(io::IO, X::AnyGPUArray{T, 0}) where T +Base.show_zero_dim(io::IO, X::AnyGPUArray{T, 0}) where T = Base.show_zero_dim(io, adapt(ToArray(), X)) ## collect to CPU (discarding wrapper type) From 8ded81e9f3767dffa324a4a5ce67baad87332156 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 22 Jun 2021 18:13:40 -0400 Subject: [PATCH 7/8] old versions --- test/testsuite/io.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/testsuite/io.jl b/test/testsuite/io.jl index 83c29f3f..b2bd801c 100644 --- a/test/testsuite/io.jl +++ b/test/testsuite/io.jl @@ -33,6 +33,10 @@ C = AT(fill(1)) msg = showstr(A) - @test msg == "fill(1)" + if VERSION >= v"1.6" + @test msg == "fill(1)" + else + @test msg == "[1]" + end end end From 3c94d3774f28a07040a65891d08ae28296453e37 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 22 Jun 2021 18:23:48 -0400 Subject: [PATCH 8/8] better bounds? --- test/testsuite/io.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/testsuite/io.jl b/test/testsuite/io.jl index b2bd801c..27811f1b 100644 --- a/test/testsuite/io.jl +++ b/test/testsuite/io.jl @@ -33,7 +33,7 @@ C = AT(fill(1)) msg = showstr(A) - if VERSION >= v"1.6" + if VERSION >= v"1.7-" @test msg == "fill(1)" else @test msg == "[1]"