@@ -2,13 +2,56 @@ using Base.Broadcast
22
33import Base. Broadcast: BroadcastStyle, Broadcasted, ArrayStyle
44
5- BroadcastStyle (:: Type{T} ) where T <: GPUArray = ArrayStyle {T} ()
5+ # we define a generic `BroadcastStyle` here that should be sufficient for most cases.
6+ # dependent packages like `CuArrays` can define their own `BroadcastStyle` allowing
7+ # them to further change or optimize broadcasting.
8+ #
9+ # TODO : investigate if we should define out own `GPUArrayStyle{N} <: AbstractArrayStyle{N}`
10+ #
11+ # NOTE: this uses the specific `T` that was used e.g. `JLArray` or `CLArray` for ArrayStyle,
12+ # instead of using `ArrayStyle{GPUArray}`, due to the fact how `similar` works.
13+ BroadcastStyle (:: Type{T} ) where {T<: GPUArray } = ArrayStyle {T} ()
614
7- function Base. similar (bc:: Broadcasted{<:ArrayStyle{GPU}} , :: Type{ElType} ) where {GPU <: GPUArray , ElType}
15+ # These wrapper types otherwise forget that they are GPU compatible
16+ #
17+ # NOTE: Don't directly use ArrayStyle{GPUArray} here since that would mean that `CuArrays`
18+ # customization no longer take effect.
19+ BroadcastStyle (:: Type{<:LinearAlgebra.Transpose{<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
20+ BroadcastStyle (:: Type{<:LinearAlgebra.Adjoint{<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
21+ BroadcastStyle (:: Type{<:SubArray{<:Any,<:Any,T}} ) where {T<: GPUArray } = BroadcastStyle (T)
22+
23+ backend (:: Type{<:LinearAlgebra.Transpose{<:Any,T}} ) where {T<: GPUArray } = backend (T)
24+ backend (:: Type{<:LinearAlgebra.Adjoint{<:Any,T}} ) where {T<: GPUArray } = backend (T)
25+ backend (:: Type{<:SubArray{<:Any,<:Any,T}} ) where {T<: GPUArray } = backend (T)
26+
27+ # This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
28+ # and we could define our methods in terms of Union{GPUArray, WrappedArray{<:Any, <:GPUArray}}
29+ const GPUDestArray = Union{GPUArray,
30+ LinearAlgebra. Transpose{<: Any ,<: GPUArray },
31+ LinearAlgebra. Adjoint{<: Any ,<: GPUArray },
32+ SubArray{<: Any ,<: Any ,<: GPUArray }}
33+
34+ # This method is responsible for selection the output type of broadcast
35+ function Base. similar (bc:: Broadcasted{<:ArrayStyle{GPU}} , :: Type{ElType} ) where
36+ {GPU <: GPUArray , ElType}
837 similar (GPU, ElType, axes (bc))
938end
1039
11- @inline function Base. copyto! (dest:: GPUArray , bc:: Broadcasted{Nothing} )
40+ # We purposefully only specialize `copyto!`, dependent packages need to make sure that they
41+ # can handle:
42+ # - `bc::Broadcast.Broadcasted{Style}`
43+ # - `ex::Broadcast.Extruded`
44+ # - `LinearAlgebra.Transpose{,<:GPUArray}` and `LinearAlgebra.Adjoint{,<:GPUArray}`, etc
45+ # as arguments to a kernel and that they do the right conversion.
46+ #
47+ # This Broadcast can be further customize by:
48+ # - `Broadcast.preprocess(dest::GPUArray, bc::Broadcasted{Nothing})` which allows for a
49+ # complete transformation based on the output type just at the end of the pipeline.
50+ # - `Broadcast.broadcasted(::Style, f)` selection of an implementation of `f` compatible
51+ # with `Style`
52+ #
53+ # For more information see the Base documentation.
54+ @inline function Base. copyto! (dest:: GPUDestArray , bc:: Broadcasted{Nothing} )
1255 axes (dest) == axes (bc) || Broadcast. throwdm (axes (dest), axes (bc))
1356 bc′ = Broadcast. preprocess (dest, bc)
1457 gpu_call (dest, (dest, bc′)) do state, dest, bc′
2063 return dest
2164end
2265
66+ # Base defines this method as a performance optimization, but we don't know how to do
67+ # `fill!` in general for all `GPUDestArray` so we just go straight to the fallback
68+ @inline Base. copyto! (dest:: GPUDestArray , bc:: Broadcasted{<:Broadcast.AbstractArrayStyle{0}} ) =
69+ copyto! (dest, convert (Broadcasted{Nothing}, bc))
70+
71+ # TODO : is this still necessary?
2372function mapidx (f, A:: GPUArray , args:: NTuple{N, Any} ) where N
2473 gpu_call (A, (f, A, args)) do state, f, A, args
2574 ilin = @linearidx (A, state)
0 commit comments