Skip to content

Commit 87aa468

Browse files
authored
fix: simplify Mesh implementation (#806)
* fix: simplify Mesh implementation * chore: run formatter
1 parent 9af8d87 commit 87aa468

File tree

5 files changed

+65
-54
lines changed

5 files changed

+65
-54
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,9 @@ function vendored_buildIntrinsicLoweringPipeline(
585585
return LLVM.add!(mpm, LLVM.AlwaysInlinerPass())
586586
end
587587

588-
function vendored_buildScalarOptimizerPipeline(fpm, @nospecialize(job), opt_level; instcombine::Bool=false)
588+
function vendored_buildScalarOptimizerPipeline(
589+
fpm, @nospecialize(job), opt_level; instcombine::Bool=false
590+
)
589591
if opt_level >= 2
590592
LLVM.add!(fpm, LLVM.Interop.AllocOptPass())
591593
LLVM.add!(fpm, LLVM.SROAPass())
@@ -597,9 +599,9 @@ function vendored_buildScalarOptimizerPipeline(fpm, @nospecialize(job), opt_leve
597599
LLVM.add!(fpm, LLVM.DCEPass())
598600
LLVM.add!(fpm, LLVM.IRCEPass())
599601
if instcombine
600-
LLVM.add!(fpm, LLVM.InstCombinePass())
602+
LLVM.add!(fpm, LLVM.InstCombinePass())
601603
else
602-
LLVM.add!(fpm, LLVM.InstSimplifyPass())
604+
LLVM.add!(fpm, LLVM.InstSimplifyPass())
603605
end
604606
LLVM.add!(fpm, LLVM.JumpThreadingPass())
605607
end

src/Compiler.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,7 +1178,7 @@ function codegen_flatten!(
11781178

11791179
push!(flatten_code, :($usbuf = $flatcode.data))
11801180
for j in 1:length(mesh)
1181-
sbuf = Symbol(:sbuf_, i, "_", mesh.device_ids[j])
1181+
sbuf = Symbol(:sbuf_, i, "_", mesh.logical_device_ids[j])
11821182
push!(flatten_names, sbuf)
11831183
push!(flatten_code, :($sbuf = XLA.synced_buffer(getindex($usbuf, $j))))
11841184
end
@@ -1188,10 +1188,10 @@ function codegen_flatten!(
11881188
)
11891189
push!(flatten_code, :($usbuf = $flatcode))
11901190
device_to_array_slices = XLA.sharding_to_concrete_array_indices(
1191-
condensed_op_sharding, size(carg), mesh.device_ids
1191+
condensed_op_sharding, size(carg), mesh.logical_device_ids
11921192
)
11931193
for j in 1:length(mesh)
1194-
device_id = mesh.device_ids[j]
1194+
device_id = mesh.logical_device_ids[j]
11951195
buf = Symbol(:buf_, i, :_, device_id)
11961196
slice = device_to_array_slices[j]
11971197
push!(
@@ -1548,7 +1548,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
15481548

15491549
# compile MLIR module to XLA executable
15501550
global_device_ids = if mlir_fn_res.is_sharded
1551-
collect(Int64, mlir_fn_res.sharding_mesh.device_ids)
1551+
vec(mlir_fn_res.sharding_mesh.device_ids)
15521552
else
15531553
Int64[]
15541554
end

src/Ops.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2235,7 +2235,7 @@ We return a NamedTuple with the following fields:
22352235
cache !== nothing && haskey(cache, m) && return cache[m]
22362236
result = mesh(
22372237
[k => Int64(v) for (k, v) in zip(m.axis_names, size(m))],
2238-
collect(Int64, m.logical_device_ids);
2238+
m.logical_device_ids;
22392239
mod,
22402240
sym_name,
22412241
location,
@@ -2246,23 +2246,25 @@ end
22462246

22472247
@noinline function mesh(
22482248
mesh_axes::Vector{<:Pair{<:Union{String,Symbol},Int64}},
2249-
device_ids::Vector{Int64};
2249+
device_ids::AbstractVector{Int64};
22502250
mod::MLIR.IR.Module=MLIR.IR.mmodule(),
22512251
sym_name::String="mesh",
22522252
location=mlir_stacktrace("mesh", @__FILE__, @__LINE__),
22532253
)
22542254
# See https://github.com/openxla/shardy/blob/f9d83e779a58b811b848c4edfaf68e88b636787d/shardy/dialect/sdy/ir/verifiers.cc#L647-L699 for the checks
22552255
ndevices = prod(last, mesh_axes)
2256+
22562257
@assert allunique(first, mesh_axes) "mesh_axes must be unique"
22572258
@assert ndevices == length(device_ids) "length(device_ids) should be same as \
22582259
prod(last, mesh_axes)"
2259-
@assert all(x -> x 0, device_ids) "device_ids must be non-negative"
2260-
@assert Base.sort(device_ids) == collect(Int64, 0:(ndevices - 1)) "sorted device_ids must be the same as iota(product(axes)), got $(Base.sort(device_ids))"
2260+
@assert all(Base.Fix2(, 0), device_ids) "device_ids must be non-negative"
2261+
@assert Base.sort(device_ids) == 0:(ndevices - 1) "sorted device_ids must be the same \
2262+
as iota(product(axes)), got \
2263+
$(Base.sort(device_ids))"
22612264

2262-
if Base.sort(device_ids) == device_ids
2263-
# error: if the ordered device ids are the same as iota(product(axes)), no need to specify them for simplicity
2264-
device_ids = Int64[]
2265-
end
2265+
# error: if the ordered device ids are the same as iota(product(axes)), no need to
2266+
# specify them for simplicity
2267+
issorted(device_ids) && (device_ids = Int64[])
22662268

22672269
ctx = MLIR.IR.context()
22682270
mesh_axis_attrs = [
@@ -2273,7 +2275,7 @@ end
22732275
Int64(length(mesh_axis_attrs)),
22742276
mesh_axis_attrs,
22752277
Int64(length(device_ids)),
2276-
device_ids,
2278+
collect(Int64, device_ids),
22772279
)
22782280

22792281
sym_name = Reactant.TracedUtils.__lookup_unique_name_in_module(mod, sym_name)
@@ -2306,10 +2308,13 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
23062308
`input` and `sharding`.
23072309
"""
23082310
@noinline function sharding_constraint(
2309-
input::Union{TracedRArray,TracedRNumber},
2311+
input::Union{AbstractArray,Number},
23102312
sharding::Reactant.Sharding.AbstractSharding;
23112313
location=mlir_stacktrace("sharding_constraint", @__FILE__, @__LINE__),
23122314
)
2315+
!(input isa TracedRNumber || input isa TracedRArray) &&
2316+
(input = constant(input; location))
2317+
23132318
cache = Reactant.Compiler.sdycache()
23142319
haskey(cache, sharding.mesh) || Ops.mesh(sharding.mesh; location)
23152320
(; sym_name, mesh_attr) = cache[sharding.mesh]

src/Sharding.jl

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,50 +21,54 @@ julia> mesh = Mesh(reshape(devices, 2, 2, 2), (:x, :y, :z));
2121
julia> mesh = Mesh(reshape(devices, 4, 2), (:x, :y));
2222
```
2323
"""
24-
struct Mesh{D,ND}
25-
device_ids::NTuple{ND,Int}
26-
sorted_device_ids::NTuple{ND,Int}
27-
logical_device_ids::NTuple{ND,Int}
28-
shape::Dims{D}
24+
struct Mesh{D}
25+
device_ids::Array{Int64,D}
26+
logical_device_ids::UnitRange{Int}
2927
axis_names::NTuple{D,Symbol}
3028

3129
function Mesh(devices::AbstractArray{<:XLA.AbstractDevice}, axis_names)
3230
return Mesh(XLA.device_ordinal.(devices), axis_names)
3331
end
3432

3533
function Mesh(
36-
devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names
34+
device_ids::AbstractArray{<:Integer,D}, axis_names::NTuple{D,Union{String,Symbol}}
3735
) where {D}
38-
return Mesh(XLA.device_ordinal.(devices), shape, axis_names)
36+
return new{D}(device_ids, 0:(length(device_ids) - 1), Symbol.(axis_names))
3937
end
4038

39+
# XXX (Deprecated): remove in v0.3
4140
function Mesh(
42-
device_ids::AbstractArray{<:Integer,D}, axis_names::NTuple{D,Union{String,Symbol}}
41+
devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names
4342
) where {D}
44-
return Mesh(Tuple(vec(device_ids)), size(device_ids), axis_names)
43+
Base.depwarn(
44+
"Mesh(devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names) is \
45+
deprecated, use Mesh(reshape(collect(XLA.device_ordinal.(devices)), shape), \
46+
axis_names) instead",
47+
:Mesh,
48+
)
49+
global_ids = reshape(collect(XLA.device_ordinal.(devices)), shape)
50+
return Mesh(global_ids, axis_names)
4551
end
4652

53+
# XXX (Deprecated): remove in v0.3
4754
function Mesh(
48-
device_ids::NTuple{D1,Int},
49-
shape::Dims{D},
50-
axis_names::NTuple{D,Union{String,Symbol}},
55+
device_ids::Dims{D1}, shape::Dims{D}, axis_names::NTuple{D,Union{String,Symbol}}
5156
) where {D,D1}
52-
@assert allunique(device_ids)
53-
return new{D,D1}(
54-
device_ids,
55-
Tuple(sort([device_ids...])),
56-
ntuple(Base.Fix2(-, 1), D1),
57-
shape,
58-
Symbol.(axis_names),
57+
Base.depwarn(
58+
"Mesh(device_ids::Dims{D1}, shape::Dims{D}, \
59+
axis_names::NTuple{D,Union{String,Symbol}}) is deprecated, use \
60+
Mesh(reshape(collect(Int64, device_ids), shape), axis_names) instead",
61+
:Mesh,
5962
)
63+
return Mesh(reshape(collect(Int64, device_ids), shape), axis_names)
6064
end
6165
end
6266

63-
Base.length(::Mesh{D,ND}) where {D,ND} = ND
67+
Base.length(m::Mesh) = length(m.device_ids)
6468
Base.ndims(::Mesh{D}) where {D} = D
6569

66-
Base.size(mesh::Mesh) = mesh.shape
67-
Base.size(mesh::Mesh, axis::Int) = mesh.shape[axis]
70+
Base.size(mesh::Mesh) = size(mesh.device_ids)
71+
Base.size(mesh::Mesh, axis::Int) = size(mesh.device_ids, axis)
6872
function Base.size(mesh::Mesh, axis::Union{String,Symbol})
6973
return size(mesh, findfirst(==(Symbol(axis)), mesh.axis_names))
7074
end
@@ -146,18 +150,18 @@ julia> sharding = NamedSharding(mesh, (nothing, nothing)); # fully replicated Ma
146150
147151
See also: [`Sharding.NoSharding`](@ref)
148152
"""
149-
struct NamedSharding{D1,D2,P<:Tuple,D3} <: AbstractSharding
150-
mesh::Mesh{D1,D2}
153+
struct NamedSharding{D1,D2,P<:Tuple} <: AbstractSharding
154+
mesh::Mesh{D1}
151155
partition_spec::P
152-
is_closed::NTuple{D3,Bool}
153-
priority::NTuple{D3,Int}
156+
is_closed::NTuple{D2,Bool}
157+
priority::NTuple{D2,Int}
154158

155159
function NamedSharding(
156-
mesh::Mesh{D1,D2},
160+
mesh::Mesh{D1},
157161
partition_spec::P;
158-
is_closed::NTuple{D3,Bool}=ntuple(Returns(true), length(partition_spec)),
159-
priority::NTuple{D3,Int}=ntuple(i -> -1, length(partition_spec)),
160-
) where {D1,D2,P<:Tuple,D3}
162+
is_closed::NTuple{D2,Bool}=ntuple(Returns(true), length(partition_spec)),
163+
priority::NTuple{D2,Int}=ntuple(i -> -1, length(partition_spec)),
164+
) where {D1,P<:Tuple,D2}
161165
axis_names = Symbol[]
162166
pspec = ()
163167
for p in partition_spec
@@ -177,7 +181,7 @@ struct NamedSharding{D1,D2,P<:Tuple,D3} <: AbstractSharding
177181
end
178182
@assert allunique(axis_names) "Duplicate axis names!"
179183

180-
return new{D1,D2,typeof(pspec),D3}(mesh, pspec, is_closed, priority)
184+
return new{D1,D2,typeof(pspec)}(mesh, pspec, is_closed, priority)
181185
end
182186
end
183187

@@ -226,17 +230,17 @@ end
226230
# This stores the sharding information in the form of XLA.HloSharding, and provides a
227231
# central type for the final storage. It also potentially saves us the pain of not having
228232
# to regenerate the partition spec from the HloSharding.
229-
struct HloSharding{M,D,D2} <: AbstractSharding
233+
struct HloSharding{D1,D2} <: AbstractSharding
230234
hlo_sharding::XLA.HloSharding
231-
mesh::Mesh{M,D}
235+
mesh::Mesh{D1}
232236
is_closed::NTuple{D2,Bool}
233237
priority::NTuple{D2,Int}
234238

235239
function HloSharding(
236-
hlo_sharding::XLA.HloSharding, mesh::Mesh{M,D}, is_closed, priority
237-
) where {M,D}
240+
hlo_sharding::XLA.HloSharding, mesh::Mesh{D1}, is_closed, priority
241+
) where {D1}
238242
@assert length(is_closed) == length(priority)
239-
return new{M,D,length(is_closed)}(hlo_sharding, mesh, is_closed, priority)
243+
return new{D1,length(is_closed)}(hlo_sharding, mesh, is_closed, priority)
240244
end
241245
end
242246

src/TracedUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ function make_mlir_fn(
350350
# TODO: support multiple meshes
351351
if length(unique_meshes) > 1
352352
error("Currently we support using a single mesh")
353-
sorted_devices = [m.sorted_device_ids for m in unique_meshes]
353+
sorted_devices = [sort(vec(m.device_ids)) for m in unique_meshes]
354354
@assert allequal(sorted_devices) "All meshes must have the same device ids"
355355
end
356356
sharding_mesh = first(unique_meshes)

0 commit comments

Comments
 (0)