@@ -96,11 +96,16 @@ See also: [`Sharding.NamedSharding`](@ref)
96
96
"""
97
97
struct NoSharding <: AbstractSharding end
98
98
99
+ @inline ndevices (:: NoSharding ) = 1
100
+
101
+ @inline shard_type (:: Type{NoSharding} , _) = ShardInfo{NoSharding,Nothing}
102
+
99
103
# This allows us to mark entire branches as NoSharding
100
104
Base. getproperty (:: NoSharding , x) = NoSharding ()
101
105
Base. getproperty (:: NoSharding , x:: Symbol ) = NoSharding ()
102
106
103
107
function (:: NoSharding )(client:: XLA.PJRT.Client , device, x:: Union{AbstractArray,Number} )
108
+ device === nothing && (device = XLA. default_device (client))
104
109
buffer = XLA. PJRT. AsyncBuffer (client, x, device)
105
110
return (buffer,), ShardInfo (NoSharding (), nothing )
106
111
end
@@ -185,6 +190,12 @@ struct NamedSharding{D1,D2,P<:Tuple} <: AbstractSharding
185
190
end
186
191
end
187
192
193
+ @inline ndevices (sharding:: NamedSharding ) = length (sharding. mesh. device_ids)
194
+
195
+ @inline function shard_type (:: Type{NamedSharding{D1,D2,P}} , N) where {D1,D2,P}
196
+ return shard_type (HloSharding{D1,D2}, N)
197
+ end
198
+
188
199
function (sharding:: NamedSharding )(
189
200
client:: XLA.PJRT.Client , device:: Nothing , x:: Union{AbstractArray,Number}
190
201
)
@@ -226,6 +237,84 @@ function get_shardy_tensor_sharding_attribute(
226
237
)
227
238
end
228
239
240
+ # TODO : Something like NamedDims.jl will allow us to support NamedDimsSharding similar to
241
+ # `levanter`
242
+
243
+ """
244
+ DimsSharding(
245
+ mesh::Mesh{M},
246
+ dims::NTuple{D,Int},
247
+ partition_spec;
248
+ is_closed::NTuple{D,Bool}=ntuple(Returns(true), D),
249
+ priority::NTuple{D,Int}=ntuple(i -> -1, D),
250
+ )
251
+
252
+ Similar to [`NamedSharding`](@ref) but works for a arbitrary dimensional array. Dimensions
253
+ not specified in `dims` are replicated. If any dimension in `dims` is greater than the total
254
+ number of dimensions in the array, the corresponding `partition_spec`, `is_closed` and
255
+ `priority` are ignored. Additionally for any negative dimensions in `dims`, the true
256
+ dims are calculated as `ndims(x) - dim + 1`. A dims value of `0` will throw an error.
257
+ """
258
+ struct DimsSharding{M,D,P} <: AbstractSharding
259
+ mesh:: Mesh{M}
260
+ dims:: NTuple{D,Int}
261
+ partition_spec:: P
262
+ is_closed:: NTuple{D,Bool}
263
+ priority:: NTuple{D,Int}
264
+
265
+ function DimsSharding (
266
+ mesh:: Mesh{M} ,
267
+ dims:: NTuple{D,Int} ,
268
+ partition_spec;
269
+ is_closed:: NTuple{D,Bool} = ntuple (Returns (true ), length (partition_spec)),
270
+ priority:: NTuple{D,Int} = ntuple (i -> - 1 , length (partition_spec)),
271
+ ) where {M,D}
272
+ @assert length (partition_spec) == length (dims)
273
+ # Validity checks on the inputs are deferred to NamedSharding
274
+ return new {M,D,typeof(partition_spec)} (
275
+ mesh, dims, partition_spec, is_closed, priority
276
+ )
277
+ end
278
+ end
279
+
280
+ @inline ndevices (sharding:: DimsSharding ) = length (sharding. mesh. device_ids)
281
+
282
+ @inline function shard_type (:: Type{DimsSharding{M,D,P}} , N) where {M,D,P}
283
+ return shard_type (HloSharding{M,N}, N)
284
+ end
285
+
286
+ function standardize_sharding (sharding:: DimsSharding , x:: Union{AbstractArray,Number} )
287
+ final_dims = map (sharding. dims) do d
288
+ @assert ! iszero (d) " dims cannot contain 0"
289
+ return ifelse (d < 0 , ndims (x) + d + 1 , d)
290
+ end
291
+
292
+ dim_indices = ntuple (i -> findfirst (== (i), final_dims), ndims (x))
293
+ partition_spec = ntuple (ndims (x)) do i
294
+ dim_index = dim_indices[i]
295
+ dim_index === nothing && return nothing # replicated dimension
296
+ return sharding. partition_spec[dim_index]
297
+ end
298
+ is_closed = ntuple (ndims (x)) do i
299
+ dim_index = dim_indices[i]
300
+ dim_index === nothing && return true # replicated dimension
301
+ return sharding. is_closed[dim_index]
302
+ end
303
+ priority = ntuple (ndims (x)) do i
304
+ dim_index = dim_indices[i]
305
+ dim_index === nothing && return - 1 # replicated dimension
306
+ return sharding. priority[dim_index]
307
+ end
308
+
309
+ return NamedSharding (sharding. mesh, partition_spec; is_closed, priority)
310
+ end
311
+
312
+ function (sharding:: DimsSharding )(
313
+ client:: XLA.PJRT.Client , device:: Nothing , x:: Union{AbstractArray,Number}
314
+ )
315
+ return (standardize_sharding (sharding, x))(client, device, x)
316
+ end
317
+
229
318
# HloSharding
230
319
# This stores the sharding information in the form of XLA.HloSharding, and provides a
231
320
# central type for the final storage. It also potentially saves us the pain of not having
@@ -244,6 +333,12 @@ struct HloSharding{D1,D2} <: AbstractSharding
244
333
end
245
334
end
246
335
336
+ @inline ndevices (sharding:: HloSharding ) = length (sharding. mesh. device_ids)
337
+
338
+ @inline function shard_type (:: Type{HloSharding{D1,D2}} , N) where {D1,D2}
339
+ return ShardInfo{HloSharding{D1,D2},Vector{NTuple{N,UnitRange{Int64}}}}
340
+ end
341
+
247
342
function Base. convert (:: Type{HloSharding} , sharding:: NamedSharding )
248
343
if MLIR. IR. _has_context ()
249
344
ctx = MLIR. IR. context ()
@@ -321,6 +416,10 @@ struct ShardInfo{S,D} <: AbstractSharding
321
416
device_to_array_slices:: D
322
417
end
323
418
419
+ @inline ndevices (sharding:: ShardInfo ) = length (sharding. mesh)
420
+
421
+ @inline shard_type (:: Type{ShardInfo{S,D}} , N) where {S,D} = shard_type (S, N)
422
+
324
423
function Base. getproperty (sharding:: ShardInfo , name:: Symbol )
325
424
name ∈ (:sharding , :device_to_array_slices ) && return getfield (sharding, name)
326
425
return getproperty (sharding. sharding, name)
@@ -348,6 +447,7 @@ Checks whether the given sharding refers to no sharding.
348
447
"""
349
448
is_sharded (:: NoSharding ) = false
350
449
is_sharded (:: NamedSharding ) = true
450
+ is_sharded (:: DimsSharding ) = true
351
451
is_sharded (:: HloSharding ) = true
352
452
is_sharded (s:: ShardInfo ) = is_sharded (s. sharding)
353
453
0 commit comments