diff --git a/src/host/abstractarray.jl b/src/host/abstractarray.jl index 440fa836..d0588e80 100644 --- a/src/host/abstractarray.jl +++ b/src/host/abstractarray.jl @@ -55,6 +55,8 @@ Base._show_empty(io::IO, X::AnyGPUArray) = Base._show_empty(io, adapt(ToArray(), X)) Base.show_vector(io::IO, v::AnyGPUArray, args...) = Base.show_vector(io, adapt(ToArray(), v), args...) +Base.show_zero_dim(io::IO, X::AnyGPUArray{T, 0}) where T = + Base.show_zero_dim(io, adapt(ToArray(), X)) ## collect to CPU (discarding wrapper type) diff --git a/src/host/mapreduce.jl b/src/host/mapreduce.jl index cb5b489e..fd1ff205 100644 --- a/src/host/mapreduce.jl +++ b/src/host/mapreduce.jl @@ -90,4 +90,7 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)), end end +Base._mapreduce_dim(f, op, init, A::AnyGPUArray{<:Any,0}, ::Colon) = op(f(@allowscalar A[]), init) +Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AnyGPUArray{<:Any,0}, ::Colon) = f(@allowscalar A[]) + LinearAlgebra.ishermitian(A::AbstractGPUMatrix) = mapreduce(==, &, A, adjoint(A)) diff --git a/test/testsuite/io.jl b/test/testsuite/io.jl index 7d0c58e4..27811f1b 100644 --- a/test/testsuite/io.jl +++ b/test/testsuite/io.jl @@ -30,5 +30,13 @@ @test occursin(Regex("^1×1 Adjoint{Int64,\\s?$AT{Int64,\\s?1}}:\n 1\$"), msg) || occursin(Regex("^1×1 LinearAlgebra.Adjoint{Int64,\\s?$AT{Int64,\\s?1}}:\n 1\$"), msg) || occursin(Regex("^1×1 adjoint\\(::$AT{Int64,\\s?1}\\) with eltype Int64:\n 1\$"), msg) + + C = AT(fill(1)) + msg = showstr(A) + if VERSION >= v"1.7-" + @test msg == "fill(1)" + else + @test msg == "[1]" + end end end diff --git a/test/testsuite/reductions.jl b/test/testsuite/reductions.jl index a07ce04e..510085f8 100644 --- a/test/testsuite/reductions.jl +++ b/test/testsuite/reductions.jl @@ -140,3 +140,22 @@ end @test A != B end end + +@testsuite "reductions/zero-arrays" AT->begin + @testset "$ET" for ET in supported_eltypes() + range = ET <: Real ? (ET(1):ET(10)) : ET + # sum + @test compare(A->sum(A), AT, reshape(rand(range, 1))) + @test compare(A->sum(abs, A), AT, reshape(rand(range, 1))) + if VERSION >= v"1.6" + @test compare(A->sum(A, init=ET(13)), AT, reshape(rand(range, 1))) + end + # other functions, defined together + @test compare(A->prod(A), AT, reshape(rand(range, 1))) + @test compare(A->any(_->true, A), AT, reshape(rand(range, 1))) + @test compare(A->all(_->false, A), AT, reshape(rand(range, 1))) + # zero-dimensional view + @test compare(A->sum(A), AT, view(rand(range, 3),2)) + @test compare(A->prod(sqrt, A), AT, view(rand(range, 3),2)) + end +end