Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,24 @@ end

## showing

Base.show(io::IO, x::GPUArray) = Base.show(io, Array(x))
Base.show(io::IO, x::LinearAlgebra.Adjoint{<:Any,<:GPUArray}) =
Base.show(io, LinearAlgebra.adjoint(Array(x.parent)))
Base.show(io::IO, x::LinearAlgebra.Transpose{<:Any,<:GPUArray}) =
Base.show(io, LinearAlgebra.transpose(Array(x.parent)))

Base.show_vector(io::IO, x::GPUArray) = Base.show_vector(io, Array(x))

for (atype, op) in
[(:(GPUArray), :(Array)),
(:(LinearAlgebra.Adjoint{<:Any,<:GPUArray}), :(x->LinearAlgebra.adjoint(Array(parent(x))))),
(:(LinearAlgebra.Transpose{<:Any,<:GPUArray}), :(x->LinearAlgebra.transpose(Array(parent(x)))))]
@eval begin
# for display
Base.print_array(io::IO, X::($atype)) =
Base.print_array(io,($op)(X))

# for show
Base._show_nonempty(io::IO, X::($atype), prefix::String) =
Base._show_nonempty(io,($op)(X),prefix)
Base._show_empty(io::IO, X::($atype)) =
Base._show_empty(io,($op)(X))
Base.show_vector(io::IO, v::($atype), args...) =
Base.show_vector(io,($op)(v),args...)
end
end

# memory operations

Expand Down
38 changes: 24 additions & 14 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
# functions in base implemented with a direct loop need to be overloaded to use mapreduce


Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, false, A)
Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, true, A)
Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, 0, A))
Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, A; init = false)
Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, A; init = true)
Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, A; init = 0))

Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, true, A, B))
Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, A, B; init = true))

# hack to get around of fetching the first element of the GPUArray
# as a startvalue, which is a bit complicated with the current reduce implementation
function startvalue(f, T)
error("Please supply a starting value for mapreduce. E.g: mapreduce(func, $f, 1, A)")
error("Please supply a starting value for mapreduce. E.g: mapreduce(func, $f, A; init = 1)")
end
startvalue(::typeof(+), T) = zero(T)
startvalue(::typeof(Base.add_sum), T) = zero(T)
Expand Down Expand Up @@ -50,20 +50,30 @@ gpu_promote_type(::typeof(Base.mul_prod), ::Type{T}) where {T<:Number} = typeof(
gpu_promote_type(::typeof(max), ::Type{T}) where {T<: WidenReduceResult} = T
gpu_promote_type(::typeof(min), ::Type{T}) where {T<: WidenReduceResult} = T

function Base.mapreduce(f::Function, op::Function, A::GPUArray{T, N}) where {T, N}
function Base.mapreduce(f::Function, op::Function, A::GPUArray{T, N}; dims = :, init...) where {T, N}
mapreduce_impl(f, op, init.data, A, dims)
end

function mapreduce_impl(f, op, ::NamedTuple{()}, A::GPUArray{T, N}, ::Colon) where {T, N}
OT = gpu_promote_type(op, T)
v0 = startvalue(op, OT) # TODO do this better
mapreduce(f, op, v0, A)
acc_mapreduce(f, op, v0, A, ())
end
function acc_mapreduce end
function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray, C::Number)
acc_mapreduce(f, op, v0, A, (B, C))

function mapreduce_impl(f, op, nt::NamedTuple{(:init,)}, A::GPUArray{T, N}, ::Colon) where {T, N}
acc_mapreduce(f, op, nt.init, A, ())
end
function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray)
acc_mapreduce(f, op, v0, A, (B,))

function mapreduce_impl(f, op, nt, A::GPUArray{T, N}, dims) where {T, N}
Base._mapreduce_dim(f, op, nt, A, dims)
end
function Base.mapreduce(f, op, v0, A::GPUArray)
acc_mapreduce(f, op, v0, A, ())

function acc_mapreduce end
function Base.mapreduce(f, op, A::GPUArray, B::GPUArray, C::Number; init)
acc_mapreduce(f, op, init, A, (B, C))
end
function Base.mapreduce(f, op, A::GPUArray, B::GPUArray; init)
acc_mapreduce(f, op, init, A, (B,))
end

@generated function mapreducedim_kernel(state, f, op, R, A, range::NTuple{N, Any}) where N
Expand Down
18 changes: 18 additions & 0 deletions src/testsuite/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,29 @@ function test_io(AT)
@testset "showing" begin
io = IOBuffer()
A = AT(Int64[1])
B = AT(Int64[1 2;3 4]) # vectors and non-vector arrays showing
# are handled differently in base/arrayshow.jl

show(io, MIME("text/plain"), A)
seekstart(io)
@test String(take!(io)) == "1-element $AT{Int64,1}:\n 1"

show(io, A)
seekstart(io)
msg = String(take!(io)) # result of e.g. `print` differs on 32bit and 64bit machines
# due to different definition of `Int` type
# print([1]) shows as [1] on 64bit but Int64[1] on 32bit
@test msg == "[1]" || msg == "Int64[1]"

show(io, MIME("text/plain"), B)
seekstart(io)
@test String(take!(io)) == "2×2 $AT{Int64,2}:\n 1 2\n 3 4"

show(io, B)
seekstart(io)
msg = String(take!(io))
@test msg == "[1 2; 3 4]" || msg == "Int64[1 2; 3 4]"

show(io, MIME("text/plain"), A')
seekstart(io)
msg = String(take!(io)) # the printing of Adjoint depends on global state
Expand Down
9 changes: 9 additions & 0 deletions src/testsuite/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ function test_mapreduce(AT)
x = T(y)
@test sum(y, dims = 2) ≈ Array(sum(x, dims = 2))
@test sum(y, dims = 1) ≈ Array(sum(x, dims = 1))

y = rand(range, N, N)
x = T(y)
_zero = zero(ET)
_addone(z) = z + one(ET)
@test mapreduce(_addone, +, y; dims = 2, init = _zero) ≈
Array(mapreduce(_addone, +, x; dims = 2, init = _zero))
@test mapreduce(_addone, +, y; init = _zero) ≈
mapreduce(_addone, +, x; init = _zero)
end
end
@testset "sum maximum minimum prod" begin
Expand Down