@@ -192,18 +192,20 @@ function ConcretePJRTArray(
192
192
return ConcretePJRTArray {T,N,nsharded,typeof(shardinfo)} (sharded_data, shape, shardinfo)
193
193
end
194
194
195
- function ConcretePJRTArray (
196
- data:: Memory{T} ;
197
- client:: Union{Nothing,XLA.PJRT.Client} = nothing ,
198
- idx:: Union{Int,Nothing} = nothing ,
199
- device:: Union{Nothing,XLA.PJRT.Device} = nothing ,
200
- sharding:: Sharding.AbstractSharding = Sharding. NoSharding (),
201
- ) where {T}
202
- theclient, thedevice = _select_client_and_device (client, idx, device, sharding)
203
- sharded_data, shardinfo = sharding (theclient, thedevice, data)
204
- shape = size (data)
205
- nsharded = length (sharded_data)
206
- return ConcretePJRTArray {T,1,nsharded,typeof(shardinfo)} (sharded_data, shape, shardinfo)
195
+ if isdefined (Base, :Memory )
196
+ function ConcretePJRTArray (
197
+ data:: Memory{T} ;
198
+ client:: Union{Nothing,XLA.PJRT.Client} = nothing ,
199
+ idx:: Union{Int,Nothing} = nothing ,
200
+ device:: Union{Nothing,XLA.PJRT.Device} = nothing ,
201
+ sharding:: Sharding.AbstractSharding = Sharding. NoSharding (),
202
+ ) where {T}
203
+ theclient, thedevice = _select_client_and_device (client, idx, device, sharding)
204
+ sharded_data, shardinfo = sharding (theclient, thedevice, data)
205
+ shape = size (data)
206
+ nsharded = length (sharded_data)
207
+ return ConcretePJRTArray {T,1,nsharded,typeof(shardinfo)} (sharded_data, shape, shardinfo)
208
+ end
207
209
end
208
210
209
211
Base. wait (x:: Union{ConcretePJRTArray,ConcretePJRTNumber} ) = foreach (wait, x. data)
@@ -334,17 +336,19 @@ function ConcreteIFRTArray(
334
336
return ConcreteIFRTArray {T,N,typeof(shardinfo)} (sharded_data, shape, shardinfo, padding)
335
337
end
336
338
337
- function ConcreteIFRTArray (
338
- data:: Memory{T} ;
339
- client:: Union{Nothing,XLA.IFRT.Client} = nothing ,
340
- idx:: Union{Int,Nothing} = nothing ,
341
- device:: Union{Nothing,XLA.IFRT.Device} = nothing ,
342
- sharding:: Sharding.AbstractSharding = Sharding. NoSharding (),
343
- ) where {T}
344
- theclient, thedevice = _select_client_and_device (client, idx, device, sharding)
345
- sharded_data, shardinfo, padding = sharding (theclient, nothing , data)
346
- shape = size (data)
347
- return ConcreteIFRTArray {T,1,typeof(shardinfo)} (sharded_data, shape, shardinfo)
339
+ if isdefined (Base, :Memory )
340
+ function ConcreteIFRTArray (
341
+ data:: Memory{T} ;
342
+ client:: Union{Nothing,XLA.IFRT.Client} = nothing ,
343
+ idx:: Union{Int,Nothing} = nothing ,
344
+ device:: Union{Nothing,XLA.IFRT.Device} = nothing ,
345
+ sharding:: Sharding.AbstractSharding = Sharding. NoSharding (),
346
+ ) where {T}
347
+ theclient, thedevice = _select_client_and_device (client, idx, device, sharding)
348
+ sharded_data, shardinfo, padding = sharding (theclient, nothing , data)
349
+ shape = size (data)
350
+ return ConcreteIFRTArray {T,1,typeof(shardinfo)} (sharded_data, shape, shardinfo)
351
+ end
348
352
end
349
353
350
354
# Assemble data from multiple arrays. Needed in distributed setting where each process wont
0 commit comments