Skip to content

Commit e995366

Browse files
authored
feat: more sharding utilities (#809)
* feat: add support for DimsSharding * fix: provide better error message * fix: minor fixes * feat: add shard_type and ndevices to infer the final shard types * fix: reverse incorrect patch
1 parent 24fed87 commit e995366

File tree

5 files changed

+144
-31
lines changed

5 files changed

+144
-31
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
4-
version = "0.2.32"
4+
version = "0.2.33"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/Sharding.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,16 @@ See also: [`Sharding.NamedSharding`](@ref)
9696
"""
9797
struct NoSharding <: AbstractSharding end
9898

99+
@inline ndevices(::NoSharding) = 1
100+
101+
@inline shard_type(::Type{NoSharding}, _) = ShardInfo{NoSharding,Nothing}
102+
99103
# This allows us to mark entire branches as NoSharding
100104
Base.getproperty(::NoSharding, x) = NoSharding()
101105
Base.getproperty(::NoSharding, x::Symbol) = NoSharding()
102106

103107
function (::NoSharding)(client::XLA.PJRT.Client, device, x::Union{AbstractArray,Number})
108+
device === nothing && (device = XLA.default_device(client))
104109
buffer = XLA.PJRT.AsyncBuffer(client, x, device)
105110
return (buffer,), ShardInfo(NoSharding(), nothing)
106111
end
@@ -185,6 +190,12 @@ struct NamedSharding{D1,D2,P<:Tuple} <: AbstractSharding
185190
end
186191
end
187192

193+
@inline ndevices(sharding::NamedSharding) = length(sharding.mesh.device_ids)
194+
195+
@inline function shard_type(::Type{NamedSharding{D1,D2,P}}, N) where {D1,D2,P}
196+
return shard_type(HloSharding{D1,D2}, N)
197+
end
198+
188199
function (sharding::NamedSharding)(
189200
client::XLA.PJRT.Client, device::Nothing, x::Union{AbstractArray,Number}
190201
)
@@ -226,6 +237,84 @@ function get_shardy_tensor_sharding_attribute(
226237
)
227238
end
228239

240+
# TODO: Something like NamedDims.jl will allow us to support NamedDimsSharding similar to
241+
# `levanter`
242+
243+
"""
244+
DimsSharding(
245+
mesh::Mesh{M},
246+
dims::NTuple{D,Int},
247+
partition_spec;
248+
is_closed::NTuple{D,Bool}=ntuple(Returns(true), D),
249+
priority::NTuple{D,Int}=ntuple(i -> -1, D),
250+
)
251+
252+
Similar to [`NamedSharding`](@ref) but works for a arbitrary dimensional array. Dimensions
253+
not specified in `dims` are replicated. If any dimension in `dims` is greater than the total
254+
number of dimensions in the array, the corresponding `partition_spec`, `is_closed` and
255+
`priority` are ignored. Additionally for any negative dimensions in `dims`, the true
256+
dims are calculated as `ndims(x) - dim + 1`. A dims value of `0` will throw an error.
257+
"""
258+
struct DimsSharding{M,D,P} <: AbstractSharding
259+
mesh::Mesh{M}
260+
dims::NTuple{D,Int}
261+
partition_spec::P
262+
is_closed::NTuple{D,Bool}
263+
priority::NTuple{D,Int}
264+
265+
function DimsSharding(
266+
mesh::Mesh{M},
267+
dims::NTuple{D,Int},
268+
partition_spec;
269+
is_closed::NTuple{D,Bool}=ntuple(Returns(true), length(partition_spec)),
270+
priority::NTuple{D,Int}=ntuple(i -> -1, length(partition_spec)),
271+
) where {M,D}
272+
@assert length(partition_spec) == length(dims)
273+
# Validity checks on the inputs are deferred to NamedSharding
274+
return new{M,D,typeof(partition_spec)}(
275+
mesh, dims, partition_spec, is_closed, priority
276+
)
277+
end
278+
end
279+
280+
@inline ndevices(sharding::DimsSharding) = length(sharding.mesh.device_ids)
281+
282+
@inline function shard_type(::Type{DimsSharding{M,D,P}}, N) where {M,D,P}
283+
return shard_type(HloSharding{M,N}, N)
284+
end
285+
286+
function standardize_sharding(sharding::DimsSharding, x::Union{AbstractArray,Number})
287+
final_dims = map(sharding.dims) do d
288+
@assert !iszero(d) "dims cannot contain 0"
289+
return ifelse(d < 0, ndims(x) + d + 1, d)
290+
end
291+
292+
dim_indices = ntuple(i -> findfirst(==(i), final_dims), ndims(x))
293+
partition_spec = ntuple(ndims(x)) do i
294+
dim_index = dim_indices[i]
295+
dim_index === nothing && return nothing # replicated dimension
296+
return sharding.partition_spec[dim_index]
297+
end
298+
is_closed = ntuple(ndims(x)) do i
299+
dim_index = dim_indices[i]
300+
dim_index === nothing && return true # replicated dimension
301+
return sharding.is_closed[dim_index]
302+
end
303+
priority = ntuple(ndims(x)) do i
304+
dim_index = dim_indices[i]
305+
dim_index === nothing && return -1 # replicated dimension
306+
return sharding.priority[dim_index]
307+
end
308+
309+
return NamedSharding(sharding.mesh, partition_spec; is_closed, priority)
310+
end
311+
312+
function (sharding::DimsSharding)(
313+
client::XLA.PJRT.Client, device::Nothing, x::Union{AbstractArray,Number}
314+
)
315+
return (standardize_sharding(sharding, x))(client, device, x)
316+
end
317+
229318
# HloSharding
230319
# This stores the sharding information in the form of XLA.HloSharding, and provides a
231320
# central type for the final storage. It also potentially saves us the pain of not having
@@ -244,6 +333,12 @@ struct HloSharding{D1,D2} <: AbstractSharding
244333
end
245334
end
246335

336+
@inline ndevices(sharding::HloSharding) = length(sharding.mesh.device_ids)
337+
338+
@inline function shard_type(::Type{HloSharding{D1,D2}}, N) where {D1,D2}
339+
return ShardInfo{HloSharding{D1,D2},Vector{NTuple{N,UnitRange{Int64}}}}
340+
end
341+
247342
function Base.convert(::Type{HloSharding}, sharding::NamedSharding)
248343
if MLIR.IR._has_context()
249344
ctx = MLIR.IR.context()
@@ -321,6 +416,10 @@ struct ShardInfo{S,D} <: AbstractSharding
321416
device_to_array_slices::D
322417
end
323418

419+
@inline ndevices(sharding::ShardInfo) = length(sharding.mesh)
420+
421+
@inline shard_type(::Type{ShardInfo{S,D}}, N) where {S,D} = shard_type(S, N)
422+
324423
function Base.getproperty(sharding::ShardInfo, name::Symbol)
325424
name (:sharding, :device_to_array_slices) && return getfield(sharding, name)
326425
return getproperty(sharding.sharding, name)
@@ -348,6 +447,7 @@ Checks whether the given sharding refers to no sharding.
348447
"""
349448
is_sharded(::NoSharding) = false
350449
is_sharded(::NamedSharding) = true
450+
is_sharded(::DimsSharding) = true
351451
is_sharded(::HloSharding) = true
352452
is_sharded(s::ShardInfo) = is_sharded(s.sharding)
353453

src/Tracing.jl

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,9 @@ Base.@nospecializeinfer function traced_type_inner(
5656
@nospecialize(sharding)
5757
)
5858
if Mode == ArrayToConcrete && T <: track_numbers
59-
if !Sharding.is_sharded(sharding)
60-
return ConcretePJRTNumber{T,1,Sharding.NoShardInfo}
61-
else
62-
error("TODO: implement sharding")
63-
end
59+
return ConcretePJRTNumber{
60+
T,Sharding.ndevices(sharding),Sharding.shard_type(typeof(sharding), 0)
61+
}
6462
elseif (mode == NoStopTracedTrack || mode == TracedTrack || mode == TracedSetPath) &&
6563
T <: track_numbers
6664
return TracedRNumber{T}
@@ -300,11 +298,12 @@ Base.@nospecializeinfer function traced_type_inner(
300298
if mode == ConcreteToTraced
301299
throw("TracedRArray cannot be traced")
302300
elseif mode == TracedToConcrete
303-
if !Sharding.is_sharded(sharding)
304-
return ConcretePJRTArray{T.parameters[1],T.parameters[2],1,Sharding.NoShardInfo}
305-
else
306-
error("TODO: implement sharding")
307-
end
301+
return ConcretePJRTArray{
302+
T.parameters[1],
303+
T.parameters[2],
304+
Sharding.ndevices(sharding),
305+
Sharding.shard_type(typeof(sharding), T.parameters[2]),
306+
}
308307
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
309308
return T
310309
else
@@ -322,14 +321,21 @@ Base.@nospecializeinfer function traced_type_inner(
322321
if mode == ConcreteToTraced
323322
throw("TracedRNumber cannot be traced")
324323
elseif mode == TracedToConcrete
325-
if !Sharding.is_sharded(sharding)
326-
if T isa UnionAll
327-
return UnionAll(T.var, ConcretePJRTNumber{T.var,1,Sharding.NoShardInfo})
328-
end
329-
return ConcretePJRTNumber{T.parameters[1],1,Sharding.NoShardInfo}
330-
else
331-
error("TODO: implement sharding")
324+
if T isa UnionAll
325+
return UnionAll(
326+
T.var,
327+
ConcretePJRTNumber{
328+
T.var,
329+
Sharding.ndevices(sharding),
330+
Sharding.shard_type(typeof(sharding), 0),
331+
},
332+
)
332333
end
334+
return ConcretePJRTNumber{
335+
T.parameters[1],
336+
Sharding.ndevices(sharding),
337+
Sharding.shard_type(typeof(sharding), 0),
338+
}
333339
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
334340
return T
335341
else
@@ -347,11 +353,9 @@ Base.@nospecializeinfer function traced_type_inner(
347353
if mode == ConcreteToTraced
348354
throw("TracedRNG cannot be traced")
349355
elseif mode == TracedToConcrete
350-
if !Sharding.is_sharded(sharding)
351-
return ConcreteRNG{1,Sharding.NoShardInfo}
352-
else
353-
error("TODO: implement sharding")
354-
end
356+
return ConcreteRNG{
357+
traced_type_inner(TracedRArray{UInt64,1}, seen, mode, track_numbers, sharding)
358+
}
355359
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
356360
return T
357361
else
@@ -413,11 +417,9 @@ Base.@nospecializeinfer function traced_type_inner(
413417
else
414418
N = ndims(A)
415419
if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
416-
if !Sharding.is_sharded(sharding)
417-
return ConcretePJRTArray{T,N,1,Sharding.NoShardInfo}
418-
else
419-
error("TODO: implement sharding")
420-
end
420+
return ConcretePJRTArray{
421+
T,N,Sharding.ndevices(sharding),Sharding.shard_type(typeof(sharding), N)
422+
}
421423
else
422424
return Array{
423425
traced_type_inner(T, seen, mode, track_numbers, getproperty(sharding, 1)),N
@@ -914,7 +916,7 @@ function make_tracer(
914916
if !Sharding.is_sharded(sharding)
915917
return prev
916918
else
917-
error("TODO: implement sharding")
919+
return ConcretePJRTNumber(prev; sharding)
918920
end
919921
end
920922
if mode != ConcreteToTraced
@@ -1106,7 +1108,6 @@ function make_tracer(
11061108
return nothing
11071109
end
11081110
RT = Core.Typeof(prev)
1109-
Sharding.is_sharded(sharding) && error("Cannot specify sharding for Numbers")
11101111
if RT <: track_numbers
11111112
if mode == ArrayToConcrete
11121113
return ConcretePJRTNumber(prev; sharding)

src/xla/Sharding.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,16 @@ function sharding_to_concrete_array_indices(
263263
@assert n_shards > 0 "Invalid number of shards: $n_shards"
264264
n_shards == 1 && return [1:dim]
265265
shard_size, remainder = divrem(dim, n_shards)
266-
@assert remainder == 0 "Dimension $dim not evenly divisible by $n_shards shards"
266+
267+
if remainder != 0
268+
throw(
269+
DimensionMismatch(
270+
"Dimension of Size $(dim) cannot be partitioned into $(n_shards) \
271+
shards each of size $(shard_size) (remainder = $(remainder)).",
272+
),
273+
)
274+
end
275+
267276
return [(i * shard_size + 1):((i + 1) * shard_size) for i in 0:(n_shards - 1)]
268277
end
269278

test/sharding.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,19 @@ fn_test3(x) = sum(x; dims=1)
8181
Sharding.NamedSharding(mesh, ("model", "data")),
8282
Sharding.NamedSharding(mesh, ("model", nothing)),
8383
Sharding.NamedSharding(mesh, (nothing, "data")),
84+
Sharding.DimsSharding(mesh, (2,), (:data,)),
8485
),
8586
(
8687
Sharding.NamedSharding(mesh, ("model", "data")),
8788
Sharding.NamedSharding(mesh, (nothing, "data")),
8889
Sharding.NoSharding(),
90+
Sharding.DimsSharding(mesh, (-2,), (:model,)),
8991
),
9092
(
9193
Sharding.NamedSharding(mesh, ("model", "data")),
9294
Sharding.NoSharding(),
9395
Sharding.NoSharding(),
96+
Sharding.NamedSharding(mesh, ("model", "data")),
9497
),
9598
)
9699
samples_ra = Reactant.to_rarray(samples; sharding=samples_sharding)

0 commit comments

Comments
 (0)