Skip to content

Commit d2bfcb2

Browse files
wsmosesavik-pal
andauthored
Faster copyto! and similar (#1476)
* Faster copyto! and similar * fix * fix * fix alloc * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix comment * fix * fix shape * fix * fix * fix * Fix * chore: run formatter * fix: init * fix: correct client * Update basic.jl * fix: missing var * Update Project.toml * Update ConcreteRArray.jl * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix --------- Co-authored-by: Avik Pal <[email protected]>
1 parent 4c03824 commit d2bfcb2

File tree

6 files changed

+351
-25
lines changed

6 files changed

+351
-25
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ PythonCall = "0.9.25"
9090
Random = "1.10"
9191
Random123 = "1.7"
9292
ReactantCore = "0.1.15"
93-
Reactant_jll = "0.0.220"
93+
Reactant_jll = "0.0.221"
9494
ScopedValues = "1.3.0"
9595
Scratch = "1.2"
9696
Sockets = "1.10"

src/ConcreteRArray.jl

Lines changed: 148 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -371,15 +371,41 @@ function Base.setindex!(a::ConcreteIFRTArray, v, args::Vararg{Int,N}) where {N}
371371
return a
372372
end
373373

374-
# TODO is there any way to allocate an uninitialized buffer in XLA?
375-
function Base.similar(a::ConcretePJRTArray{T}, ::Type{S}=T, dims::Dims=size(a)) where {T,S}
376-
return ConcretePJRTArray(
377-
Array{S}(undef, dims); client=XLA.client(a), device=XLA.device(a), a.sharding
374+
@inline function Base.similar(::Type{<:ConcretePJRTArray}, ::Type{S}, dims::Dims;
375+
client::Union{Nothing,XLA.PJRT.Client}=nothing,
376+
idx::Union{Int,Nothing}=nothing,
377+
device::Union{Nothing,XLA.PJRT.Device}=nothing,
378+
sharding::Sharding.AbstractSharding=Sharding.NoSharding()
379+
) where {S}
380+
client = client === nothing ? XLA.default_backend() : client
381+
382+
if idx isa Int && device === nothing
383+
device = XLA.get_device(client, idx)
384+
end
385+
386+
sdata, sharding = sharding(client, device, S, dims)
387+
388+
return ConcretePJRTArray{S,length(dims),length(sdata),typeof(sharding)}(sdata, dims, sharding)
389+
end
390+
391+
function Base.similar(
392+
a::ConcretePJRTArray{T,N,D,Sh}, ::Type{S}=T, dims::Dims=size(a)
393+
) where {S,T,Sh,N,D}
394+
device_to_array_slices, sharding = Sharding.sharding_to_array_slices(
395+
a.sharding, dims; return_updated_sharding=Val(true), client=XLA.client(a)
378396
)
397+
@assert length(device_to_array_slices) == D
398+
sdata = ntuple(Val(D)) do i
399+
Base.@_inline_meta
400+
Base.similar(a.data[i], S, Dims(length.(device_to_array_slices[i])))
401+
end
402+
return ConcretePJRTArray{S,length(dims),D,Sh}(sdata, dims, a.sharding)
379403
end
404+
380405
Base.similar(a::ConcretePJRTArray, dims::Dims) = similar(a, eltype(a), dims)
381-
function Base.similar(::Type{ConcretePJRTArray{T}}, dims) where {T}
382-
return ConcretePJRTArray(similar(Array{T}, dims))
406+
407+
@inline function Base.similar(AT::Type{<:ConcretePJRTArray{T}}, dims; kwargs...) where {T}
408+
return Base.similar(AT, T, dims; kwargs...)
383409
end
384410

385411
function Base.similar(a::ConcreteIFRTArray{T}, ::Type{S}=T, dims::Dims=size(a)) where {T,S}
@@ -396,16 +422,16 @@ end
396422
Base.BroadcastStyle(::Type{<:ConcretePJRTArray}) = Broadcast.ArrayStyle{ConcretePJRTArray}()
397423
Base.BroadcastStyle(::Type{<:ConcreteIFRTArray}) = Broadcast.ArrayStyle{ConcreteIFRTArray}()
398424

399-
# XXX: correct device + sharding?
400-
function Base.similar(
401-
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcretePJRTArray}}, ::Type{T}
425+
@inline function Base.similar(
426+
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcretePJRTArray}}, ::Type{T}; kwargs...
402427
) where {T}
403-
return ConcretePJRTArray(similar(Array{T}, axes(bc)))
428+
return similar(ConcretePJRTArray, T, axes(bc); kwargs...)
404429
end
405-
function Base.similar(
406-
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteIFRTArray}}, ::Type{T}
430+
431+
@inline function Base.similar(
432+
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteIFRTArray}}, ::Type{T}; kwargs...
407433
) where {T}
408-
return ConcreteIFRTArray(similar(Array{T}, axes(bc)))
434+
return similar(ConcreteIFRTArray, T, axes(bc); kwargs...)
409435
end
410436

411437
# TODO replace this copy for `setindex!` maybe? how to copy data to already existing buffer? (i.e. `copyto!`)
@@ -429,9 +455,10 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteP
429455
),
430456
)
431457
end
432-
aux = copyto!(
433-
similar(Array{ElType}, axes(bc)), convert(Broadcast.Broadcasted{Nothing}, bc)
434-
)
458+
459+
aux = similar(ConcretePJRTArray, ElType, length.(axes(bc)))
460+
461+
copyto!(aux, convert(Broadcast.Broadcasted{Nothing}, bc))
435462
return ConcretePJRTArray(aux) # XXX: result should be on correct device?
436463
end
437464

@@ -484,6 +511,111 @@ for aType in (:ConcretePJRTArray, :ConcreteIFRTArray)
484511
end
485512
end
486513

514+
function Base.copyto!(
515+
dest::Vector{T},
516+
doffs::Int64,
517+
src::Reactant.ConcreteIFRTArray{T},
518+
soffs::Int64,
519+
n::Int64,
520+
) where {T}
521+
n == 0 && return dest
522+
n > 0 || Base._throw_argerror("Number of elements to copy must be non-negative.")
523+
@boundscheck checkbounds(dest, doffs:(doffs + n - 1))
524+
@boundscheck checkbounds(src, soffs:(soffs + n - 1))
525+
526+
if n != length(src)
527+
throw(AssertionError("Only full array copyto! supported from ConcreteIFRTArray"))
528+
end
529+
if doffs != 1
530+
throw(AssertionError("Dest offset not yet supported in ConcreteIFRTArray copyto!"))
531+
end
532+
533+
src_async = src.data
534+
src_sync = src_async.buffer
535+
wait(src_async)
536+
537+
GC.@preserve dest begin
538+
@ccall Reactant.MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer(
539+
src_sync.buffer::Ptr{Cvoid},
540+
pointer(dest, doffs)::Ptr{T},
541+
((soffs - 1) * sizeof(T))::Int64,
542+
)::Ptr{Cvoid}
543+
end
544+
545+
return dest
546+
end
547+
548+
function Base.copyto!(
549+
dest::Vector{T},
550+
doffs::Int64,
551+
src::Reactant.ConcretePJRTArray{T},
552+
soffs::Int64,
553+
n::Int64,
554+
) where {T}
555+
n == 0 && return dest
556+
n > 0 || Base._throw_argerror("Number of elements to copy must be non-negative.")
557+
@boundscheck checkbounds(dest, doffs:(doffs + n - 1))
558+
@boundscheck checkbounds(src, soffs:(soffs + n - 1))
559+
560+
client = XLA.client(src)
561+
@assert length(src.data) == 1
562+
src_async = src.data[1]
563+
src_sync = src_async.buffer
564+
wait(src_async)
565+
566+
GC.@preserve dest begin
567+
@ccall Reactant.MLIR.API.mlir_c.CopyFromBuffer(
568+
client.client::Ptr{Cvoid},
569+
src_sync.buffer::Ptr{Cvoid},
570+
pointer(dest, doffs)::Ptr{T},
571+
((soffs - 1) * sizeof(T))::Int64,
572+
(n * sizeof(T))::Int64,
573+
)::Ptr{Cvoid}
574+
end
575+
576+
return dest
577+
end
578+
579+
function Base.copyto!(
580+
dest::Vector{T}, src::Union{Reactant.ConcretePJRTArray{T},Reactant.ConcreteIFRTArray{T}}
581+
) where {T}
582+
return copyto!(dest, 1, src, 1, length(src))
583+
end
584+
585+
function Base.copyto!(
586+
dest::Reactant.ConcretePJRTArray{T},
587+
doffs::Int64,
588+
src::Vector{T},
589+
soffs::Int64,
590+
n::Int64,
591+
) where {T}
592+
n == 0 && return dest
593+
n > 0 || Base._throw_argerror("Number of elements to copy must be non-negative.")
594+
@boundscheck checkbounds(dest, doffs:(doffs + n - 1))
595+
@boundscheck checkbounds(src, soffs:(soffs + n - 1))
596+
597+
client = XLA.client(dest)
598+
dest_async = dest.data[1]
599+
dest_sync = dest_async.buffer
600+
wait(dest_async)
601+
602+
GC.@preserve src begin
603+
@ccall Reactant.MLIR.API.mlir_c.CopyToBuffer(
604+
client.client::Ptr{Cvoid},
605+
dest_sync.buffer::Ptr{Cvoid},
606+
pointer(src, soffs)::Ptr{T},
607+
((doffs - 1) * sizeof(T))::Int64,
608+
(n * sizeof(T))::Int64,
609+
)::Ptr{Cvoid}
610+
end
611+
612+
return dest
613+
end
614+
615+
function Base.copyto!(dest::Reactant.ConcretePJRTArray{T}, src::Vector{T}) where {T}
616+
return copyto!(dest, 1, src, 1, length(src))
617+
end
618+
487619
for aType in (:ConcretePJRTArray, :ConcreteIFRTArray)
488620
@eval begin
489621
function Base.copyto!(

src/Sharding.jl

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,6 @@ Base.in(axis::Union{String,Symbol}, mesh::Mesh) = Symbol(axis) ∈ mesh.axis_nam
152152

153153
abstract type AbstractSharding end
154154

155-
function (T::AbstractSharding)(::XLA.AbstractClient, device, ::Union{AbstractArray,Number})
156-
return error(
157-
"(::$(T))(::XLA.AbstractClient, device, ::Union{AbstractArray,Number}) is \
158-
not implemented"
159-
)
160-
end
161-
162155
# By default we use same sharding for all leaf nodes
163156
Base.getproperty(sharding::AbstractSharding, name) = sharding
164157
function Base.getproperty(sharding::AbstractSharding, name::Symbol)
@@ -198,6 +191,12 @@ function (::NoSharding)(client::XLA.PJRT.Client, device, x::Union{AbstractArray,
198191
return (buffer,), ShardInfo(NoSharding(), nothing)
199192
end
200193

194+
function (::NoSharding)(client::XLA.PJRT.Client, device, S::Type, dims::Dims)
195+
device === nothing && (device = XLA.default_device(client))
196+
buffer = similar(XLA.PJRT.AsyncBuffer, S, dims; client, device)
197+
return (buffer,), ShardInfo(NoSharding(), nothing)
198+
end
199+
201200
function (::NoSharding)(client::XLA.IFRT.Client, device, x::Union{AbstractArray,Number})
202201
device === nothing && (device = XLA.default_device(client))
203202
return (
@@ -208,7 +207,7 @@ end
208207
function sharding_to_array_slices(
209208
sharding::NoSharding, size_x; client=nothing, return_updated_sharding=Val(false)
210209
)
211-
slices = Base.OneTo.(size_x)
210+
slices = (Base.OneTo.(size_x),)
212211
return_updated_sharding isa Val{true} && return (slices, sharding)
213212
return slices
214213
end
@@ -419,6 +418,25 @@ function (sharding::NamedSharding)(
419418
return data, ShardInfo(sharding, device_to_array_slices)
420419
end
421420

421+
function (sharding::NamedSharding)(
422+
client::XLA.PJRT.Client, _, S::Type, dims::Dims
423+
)
424+
if !issorted(sharding.mesh.logical_device_ids)
425+
error("PJRT doesn't support non-iota meshes. Use IFRT instead.")
426+
end
427+
428+
device_to_array_slices, sharding = sharding_to_array_slices(
429+
sharding, dims; client, return_updated_sharding=Val(true)
430+
)
431+
432+
data = ntuple(length(sharding.mesh)) do i
433+
Base.@_inline_meta
434+
Base.similar(XLA.PJRT.AsyncBuffer, S, Dims(length.(device_to_array_slices[i])); client, device=XLA.get_device(client, sharding.mesh.device_ids[i]))
435+
end
436+
437+
return data, ShardInfo(sharding, device_to_array_slices)
438+
end
439+
422440
function (sharding::NamedSharding)(
423441
client::XLA.IFRT.Client, _, x::Union{AbstractArray,Number}
424442
)
@@ -704,6 +722,10 @@ function (sharding::DimsSharding)(
704722
return (NamedSharding(sharding, ndims(x)))(client, device, x)
705723
end
706724

725+
function (sharding::DimsSharding)(client::XLA.PJRT.Client, dev, S::Type, dims::Dims)
726+
return (NamedSharding(sharding, length(dims)))(client, dev, S, dims)
727+
end
728+
707729
function sharding_to_array_slices(sharding::DimsSharding, size_x; kwargs...)
708730
return sharding_to_array_slices(
709731
NamedSharding(sharding, length(size_x)), size_x; kwargs...
@@ -748,6 +770,11 @@ function (sharding::Replicated)(
748770
return (NamedSharding(sharding, ndims(x)))(client, device, x)
749771
end
750772

773+
function (sharding::Replicated)(client::XLA.PJRT.Client, dev, S::Type, dims::Dims)
774+
return (NamedSharding(sharding, length(dims)))(client, dev, S, dims)
775+
end
776+
777+
751778
function sharding_to_array_slices(sharding::Replicated, size_x; kwargs...)
752779
return sharding_to_array_slices(
753780
NamedSharding(sharding, length(size_x)), size_x; kwargs...
@@ -936,6 +963,18 @@ function (sharding::HloSharding)(
936963
return data, ShardInfo(sharding, device_to_array_slices)
937964
end
938965

966+
function (sharding::HloSharding)(
967+
client::XLA.PJRT.Client, ::Nothing, S::Type, dims::Dims
968+
)
969+
device_to_array_slices = sharding_to_array_slices(sharding, dims; client)
970+
971+
data = ntuple(length(sharding.mesh)) do i
972+
Base.similar(XLA.PJRT.AsyncBuffer, S, Dims(length.(device_to_array_slices[i])); client, device=XLA.get_device(client, sharding.mesh.device_ids[i]))
973+
end
974+
975+
return data, ShardInfo(sharding, device_to_array_slices)
976+
end
977+
939978
function (sharding::HloSharding)(
940979
client::XLA.IFRT.Client, ::Nothing, x::Union{AbstractArray,Number}
941980
)

src/xla/PJRT/AsyncBuffer.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,11 @@ function Base.copy(b::AsyncBuffer)
1111
Base.wait(b)
1212
return AsyncBuffer(Base.copy(b.buffer), nothing)
1313
end
14+
15+
function Base.similar(a::AsyncBuffer, args...)
16+
return AsyncBuffer(Base.similar(a.buffer, args...)::Buffer, nothing)
17+
end
18+
19+
@inline function Base.similar(::Type{AsyncBuffer}, args...; kwargs...)
20+
return AsyncBuffer(Base.similar(Buffer, args...; kwargs...)::Buffer, nothing)
21+
end

0 commit comments

Comments
 (0)