@@ -371,15 +371,41 @@ function Base.setindex!(a::ConcreteIFRTArray, v, args::Vararg{Int,N}) where {N}
371
371
return a
372
372
end
373
373
374
- # TODO is there any way to allocate an uninitialized buffer in XLA?
375
- function Base. similar (a:: ConcretePJRTArray{T} , :: Type{S} = T, dims:: Dims = size (a)) where {T,S}
376
- return ConcretePJRTArray (
377
- Array {S} (undef, dims); client= XLA. client (a), device= XLA. device (a), a. sharding
374
+ @inline function Base. similar (:: Type{<:ConcretePJRTArray} , :: Type{S} , dims:: Dims ;
375
+ client:: Union{Nothing,XLA.PJRT.Client} = nothing ,
376
+ idx:: Union{Int,Nothing} = nothing ,
377
+ device:: Union{Nothing,XLA.PJRT.Device} = nothing ,
378
+ sharding:: Sharding.AbstractSharding = Sharding. NoSharding ()
379
+ ) where {S}
380
+ client = client === nothing ? XLA. default_backend () : client
381
+
382
+ if idx isa Int && device === nothing
383
+ device = XLA. get_device (client, idx)
384
+ end
385
+
386
+ sdata, sharding = sharding (client, device, S, dims)
387
+
388
+ return ConcretePJRTArray {S,length(dims),length(sdata),typeof(sharding)} (sdata, dims, sharding)
389
+ end
390
+
391
+ function Base. similar (
392
+ a:: ConcretePJRTArray{T,N,D,Sh} , :: Type{S} = T, dims:: Dims = size (a)
393
+ ) where {S,T,Sh,N,D}
394
+ device_to_array_slices, sharding = Sharding. sharding_to_array_slices (
395
+ a. sharding, dims; return_updated_sharding= Val (true ), client= XLA. client (a)
378
396
)
397
+ @assert length (device_to_array_slices) == D
398
+ sdata = ntuple (Val (D)) do i
399
+ Base. @_inline_meta
400
+ Base. similar (a. data[i], S, Dims (length .(device_to_array_slices[i])))
401
+ end
402
+ return ConcretePJRTArray {S,length(dims),D,Sh} (sdata, dims, a. sharding)
379
403
end
404
+
380
405
Base. similar (a:: ConcretePJRTArray , dims:: Dims ) = similar (a, eltype (a), dims)
381
- function Base. similar (:: Type{ConcretePJRTArray{T}} , dims) where {T}
382
- return ConcretePJRTArray (similar (Array{T}, dims))
406
+
407
+ @inline function Base. similar (AT:: Type{<:ConcretePJRTArray{T}} , dims; kwargs... ) where {T}
408
+ return Base. similar (AT, T, dims; kwargs... )
383
409
end
384
410
385
411
function Base. similar (a:: ConcreteIFRTArray{T} , :: Type{S} = T, dims:: Dims = size (a)) where {T,S}
@@ -396,16 +422,16 @@ end
396
422
Base. BroadcastStyle (:: Type{<:ConcretePJRTArray} ) = Broadcast. ArrayStyle {ConcretePJRTArray} ()
397
423
Base. BroadcastStyle (:: Type{<:ConcreteIFRTArray} ) = Broadcast. ArrayStyle {ConcreteIFRTArray} ()
398
424
399
- # XXX : correct device + sharding?
400
- function Base. similar (
401
- bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcretePJRTArray}} , :: Type{T}
425
+ @inline function Base. similar (
426
+ bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcretePJRTArray}} , :: Type{T} ; kwargs...
402
427
) where {T}
403
- return ConcretePJRTArray ( similar (Array{T} , axes (bc)) )
428
+ return similar (ConcretePJRTArray, T , axes (bc); kwargs ... )
404
429
end
405
- function Base. similar (
406
- bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteIFRTArray}} , :: Type{T}
430
+
431
+ @inline function Base. similar (
432
+ bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteIFRTArray}} , :: Type{T} ; kwargs...
407
433
) where {T}
408
- return ConcreteIFRTArray ( similar (Array{T} , axes (bc)) )
434
+ return similar (ConcreteIFRTArray, T , axes (bc); kwargs ... )
409
435
end
410
436
411
437
# TODO replace this copy for `setindex!` maybe? how to copy data to already existing buffer? (i.e. `copyto!`)
@@ -429,9 +455,10 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteP
429
455
),
430
456
)
431
457
end
432
- aux = copyto! (
433
- similar (Array{ElType}, axes (bc)), convert (Broadcast. Broadcasted{Nothing}, bc)
434
- )
458
+
459
+ aux = similar (ConcretePJRTArray, ElType, length .(axes (bc)))
460
+
461
+ copyto! (aux, convert (Broadcast. Broadcasted{Nothing}, bc))
435
462
return ConcretePJRTArray (aux) # XXX : result should be on correct device?
436
463
end
437
464
@@ -484,6 +511,111 @@ for aType in (:ConcretePJRTArray, :ConcreteIFRTArray)
484
511
end
485
512
end
486
513
514
+ function Base. copyto! (
515
+ dest:: Vector{T} ,
516
+ doffs:: Int64 ,
517
+ src:: Reactant.ConcreteIFRTArray{T} ,
518
+ soffs:: Int64 ,
519
+ n:: Int64 ,
520
+ ) where {T}
521
+ n == 0 && return dest
522
+ n > 0 || Base. _throw_argerror (" Number of elements to copy must be non-negative." )
523
+ @boundscheck checkbounds (dest, doffs: (doffs + n - 1 ))
524
+ @boundscheck checkbounds (src, soffs: (soffs + n - 1 ))
525
+
526
+ if n != length (src)
527
+ throw (AssertionError (" Only full array copyto! supported from ConcreteIFRTArray" ))
528
+ end
529
+ if doffs != 1
530
+ throw (AssertionError (" Dest offset not yet supported in ConcreteIFRTArray copyto!" ))
531
+ end
532
+
533
+ src_async = src. data
534
+ src_sync = src_async. buffer
535
+ wait (src_async)
536
+
537
+ GC. @preserve dest begin
538
+ @ccall Reactant. MLIR. API. mlir_c. ifrt_array_copy_to_host_buffer (
539
+ src_sync. buffer:: Ptr{Cvoid} ,
540
+ pointer (dest, doffs):: Ptr{T} ,
541
+ ((soffs - 1 ) * sizeof (T)):: Int64 ,
542
+ ):: Ptr{Cvoid}
543
+ end
544
+
545
+ return dest
546
+ end
547
+
548
+ function Base. copyto! (
549
+ dest:: Vector{T} ,
550
+ doffs:: Int64 ,
551
+ src:: Reactant.ConcretePJRTArray{T} ,
552
+ soffs:: Int64 ,
553
+ n:: Int64 ,
554
+ ) where {T}
555
+ n == 0 && return dest
556
+ n > 0 || Base. _throw_argerror (" Number of elements to copy must be non-negative." )
557
+ @boundscheck checkbounds (dest, doffs: (doffs + n - 1 ))
558
+ @boundscheck checkbounds (src, soffs: (soffs + n - 1 ))
559
+
560
+ client = XLA. client (src)
561
+ @assert length (src. data) == 1
562
+ src_async = src. data[1 ]
563
+ src_sync = src_async. buffer
564
+ wait (src_async)
565
+
566
+ GC. @preserve dest begin
567
+ @ccall Reactant. MLIR. API. mlir_c. CopyFromBuffer (
568
+ client. client:: Ptr{Cvoid} ,
569
+ src_sync. buffer:: Ptr{Cvoid} ,
570
+ pointer (dest, doffs):: Ptr{T} ,
571
+ ((soffs - 1 ) * sizeof (T)):: Int64 ,
572
+ (n * sizeof (T)):: Int64 ,
573
+ ):: Ptr{Cvoid}
574
+ end
575
+
576
+ return dest
577
+ end
578
+
579
+ function Base. copyto! (
580
+ dest:: Vector{T} , src:: Union{Reactant.ConcretePJRTArray{T},Reactant.ConcreteIFRTArray{T}}
581
+ ) where {T}
582
+ return copyto! (dest, 1 , src, 1 , length (src))
583
+ end
584
+
585
+ function Base. copyto! (
586
+ dest:: Reactant.ConcretePJRTArray{T} ,
587
+ doffs:: Int64 ,
588
+ src:: Vector{T} ,
589
+ soffs:: Int64 ,
590
+ n:: Int64 ,
591
+ ) where {T}
592
+ n == 0 && return dest
593
+ n > 0 || Base. _throw_argerror (" Number of elements to copy must be non-negative." )
594
+ @boundscheck checkbounds (dest, doffs: (doffs + n - 1 ))
595
+ @boundscheck checkbounds (src, soffs: (soffs + n - 1 ))
596
+
597
+ client = XLA. client (dest)
598
+ dest_async = dest. data[1 ]
599
+ dest_sync = dest_async. buffer
600
+ wait (dest_async)
601
+
602
+ GC. @preserve src begin
603
+ @ccall Reactant. MLIR. API. mlir_c. CopyToBuffer (
604
+ client. client:: Ptr{Cvoid} ,
605
+ dest_sync. buffer:: Ptr{Cvoid} ,
606
+ pointer (src, soffs):: Ptr{T} ,
607
+ ((doffs - 1 ) * sizeof (T)):: Int64 ,
608
+ (n * sizeof (T)):: Int64 ,
609
+ ):: Ptr{Cvoid}
610
+ end
611
+
612
+ return dest
613
+ end
614
+
615
+ function Base. copyto! (dest:: Reactant.ConcretePJRTArray{T} , src:: Vector{T} ) where {T}
616
+ return copyto! (dest, 1 , src, 1 , length (src))
617
+ end
618
+
487
619
for aType in (:ConcretePJRTArray , :ConcreteIFRTArray )
488
620
@eval begin
489
621
function Base. copyto! (
0 commit comments