Skip to content

Commit c003b22

Browse files
authored
fix: sharding improvements for GB (#919)
* fix: warn * fix: auto-raise if we are using sharding * feat: fix codegen for mixed sharding + nosharidng * fix(IFRT): correct device to host transfer
1 parent d2118b7 commit c003b22

File tree

7 files changed

+134
-28
lines changed

7 files changed

+134
-28
lines changed

src/Compiler.jl

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,34 @@ end
4242
@nospecialize(obj::AbstractArray{T}), field, val
4343
) where {T}
4444
ancestor_obj = ancestor(obj)
45-
(isbitstype(T) || ancestor_obj isa RArray) && return Base.setfield!(obj, field, val)
45+
(isbitstype(T) || ancestor_obj isa RArray) && return setfield_carray!(obj, field, val)
4646
return Base.setindex!(obj, val, field)
4747
end
4848

4949
@inline function traced_setfield!(@nospecialize(obj::Dict), field, val)
5050
return Base.setindex!(obj, field, val)
5151
end
5252

53+
# fallback
54+
@inline function setfield_carray!(obj, field, val)
55+
return Base.setfield!(obj, field, val)
56+
end
57+
58+
@inline function setfield_carray!(obj::ConcretePJRTArray, field, val)
59+
if field !== :data || typeof(val) == typeof(getfield(obj, field))
60+
return Base.setfield!(obj, field, val)
61+
end
62+
63+
# This case is triggered if the user had provided an unsharded input (NoSharding), but
64+
# we had to replicate it before feeding it to XLA
65+
@assert !Reactant.Sharding.is_sharded(obj) "Expected unsharded input. Open an issue on \
66+
Reactant.jl with a MWE."
67+
devices = Reactant.XLA.device.(val)
68+
device = Reactant.XLA.device(only(obj.data))
69+
idx = findfirst(isequal(device), devices)
70+
return Base.setfield!(obj, field, (val[idx],))
71+
end
72+
5373
function create_result(
5474
tocopy::T, path, result_stores, path_to_shard_info, sharding_mesh
5575
) where {T}
@@ -737,10 +757,25 @@ function compile_mlir!(
737757
MLIR.IR.deactivate!(MLIR.IR.body(mod))
738758
MLIR.IR.deactivate!(mod)
739759
end
740-
(; fnwrapped, traced_result, seen_args, ret, linear_args, in_tys, linear_results) =
741-
mlir_fn_res
760+
(;
761+
fnwrapped,
762+
traced_result,
763+
seen_args,
764+
ret,
765+
linear_args,
766+
in_tys,
767+
linear_results,
768+
is_sharded,
769+
) = mlir_fn_res
742770
compiled_f = mlir_fn_res.f
743771

772+
# Custom Kernels without Raising will lead to suboptimal codegen for is_sharded, force
773+
# raising
774+
if is_sharded
775+
is_raising = true
776+
raise isa Bool && (raise = true)
777+
end
778+
744779
concrete_seen = OrderedIdDict()
745780

746781
concrete_result = make_tracer(
@@ -1274,20 +1309,40 @@ function codegen_flatten!(
12741309
device_to_array_slices, _ = XLA.sharding_to_concrete_array_indices(
12751310
condensed_op_sharding, size(carg), mesh.logical_device_ids
12761311
)
1312+
1313+
# Extract the buffer_slice
1314+
buf_slice = Dict{eltype(device_to_array_slices),Symbol}()
1315+
counter = 0
12771316
for j in 1:length(mesh)
1278-
device_id = mesh.logical_device_ids[j]
1279-
buf = Symbol(:buf_, i, :_, device_id)
1317+
sliced_buf = Symbol(:sliced_buf_, i, :_, counter)
12801318
slice = device_to_array_slices[j]
1319+
haskey(buf_slice, slice) && continue
1320+
counter += 1
12811321
push!(
12821322
flatten_code,
1283-
:($buf = XLA.synced_buffer(only($usbuf[$(slice)...].data))),
1323+
:(
1324+
$sliced_buf = only(
1325+
Reactant._fast_slice($usbuf, $(slice...)).data
1326+
)
1327+
),
12841328
)
1329+
buf_slice[slice] = sliced_buf
1330+
end
1331+
1332+
for j in 1:length(mesh)
1333+
device_id = mesh.logical_device_ids[j]
1334+
buf = Symbol(:buf_, i, :_, device_id)
1335+
slice = device_to_array_slices[j]
12851336
sbuf = Symbol(:s, buf)
1286-
device = XLA.get_device(client, device_id)
12871337
push!(flatten_names, sbuf)
12881338
push!(
12891339
flatten_code,
1290-
:($sbuf = XLA.copy_buffer_to_device($buf, $device)),
1340+
:(
1341+
$sbuf = XLA.copy_buffer_to_device(
1342+
XLA.synced_buffer($(buf_slice[slice])),
1343+
$(XLA.get_device(client, device_id)),
1344+
)
1345+
),
12911346
)
12921347
end
12931348
end

src/ConcreteRArray.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,18 @@ function Base.getindex(a::ConcreteIFRTArray, args::Vararg{Int,N}) where {N}
251251
return convert(Array, a)[args...]
252252
end
253253

254+
# This doesn't follow the semantics of getindex with ranges. It is mostly meant to be used
255+
# inside Compiler.jl
256+
@inline function _fast_slice(
257+
a::AbstractConcreteArray{T,N}, args::Vararg{UnitRange,N}
258+
) where {T,N}
259+
# Avoid slicing all-together
260+
args == ntuple(Base.Fix1(UnitRange, 1) Base.Fix1(size, a), N) && return a
261+
# For all other cases do a compile
262+
fn = compile(getindex, (a, args...))
263+
return fn(a, args...)
264+
end
265+
254266
function mysetindex!(a, v, args::Vararg{Any,N}) where {N}
255267
setindex!(a, v, args...)
256268
return nothing

src/TracedUtils.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,7 @@ function make_mlir_fn(
226226
# Insert meshes for the sharded arguments
227227
traced_args_to_shardings = OrderedIdDict()
228228
for (k, v) in seen_args
229-
if (
230-
k isa Reactant.AbstractConcreteNumber || k isa Reactant.AbstractConcreteArray
231-
) && hasfield(typeof(k), :sharding)
229+
if k isa Reactant.AbstractConcreteNumber || k isa Reactant.AbstractConcreteArray
232230
if Reactant.Sharding.is_sharded(k)
233231
Reactant.Ops.mesh(k.sharding.mesh)
234232
traced_args_to_shardings[v] = k.sharding
@@ -373,6 +371,18 @@ function make_mlir_fn(
373371

374372
linear_arg_shardings = Vector{MLIR.IR.Attribute}(undef, length(linear_args))
375373

374+
# If an argument is mutated but is not sharded (aka sharding is NoSharding), we
375+
# need to force a replicated sharding.
376+
for i in mutated_args
377+
arg = linear_args[i]
378+
if !haskey(traced_args_to_shardings, arg)
379+
# Force a replicated sharding
380+
traced_args_to_shardings[arg] = Reactant.Sharding.NamedSharding(
381+
sharding_mesh, ntuple(Returns(nothing), ndims(arg))
382+
)
383+
end
384+
end
385+
376386
# Attach `sdy.sharding` attribute to the argument
377387
for (i, arg) in enumerate(linear_args)
378388
if haskey(traced_args_to_shardings, arg)

src/Types.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,10 @@ function ConcretePJRTArray(
163163
sdata, sharding = sharding(client, device, data)
164164
return ConcretePJRTArray{T,N,1,typeof(sharding)}(sdata, size(data), sharding)
165165
end
166-
@assert device === nothing && idx === nothing "If `sharding` is not `NoSharding`, \
167-
`device` and `idx` cannot be specified!"
166+
if device !== nothing || idx !== nothing
167+
@warn "`device` and `idx` specified for non-`NoSharding` sharding. These arguments \
168+
will be ignored."
169+
end
168170
sharded_data, sharding = sharding(client, nothing, data)
169171
return ConcretePJRTArray{T,N,length(sharded_data),typeof(sharding)}(
170172
sharded_data, size(data), sharding
@@ -282,9 +284,10 @@ function ConcreteIFRTArray(
282284
end
283285
end
284286
else
285-
@assert device === nothing && idx === nothing "If `sharding` is not `NoSharding`, \
286-
`device` and `idx` cannot be \
287-
specified!"
287+
if device !== nothing || idx !== nothing
288+
@warn "`device` and `idx` specified for non-`NoSharding` sharding. These \
289+
arguments will be ignored."
290+
end
288291
end
289292
sharded_data, sharding = sharding(client, device, data)
290293
return ConcreteIFRTArray{T,N}(sharded_data, size(data), sharding)

src/xla/Buffer.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,6 @@ end
2323
return map(synced_buffer, buffers)
2424
end
2525

26-
function Base.show(io::IO, mime::MIME"text/plain", buffer::B) where {B<:AbstractBuffer}
27-
print(io, "$(B) storing ")
28-
show(io, mime, convert(Array, buffer))
29-
return nothing
30-
end
31-
3226
# Async Buffers
3327
abstract type AbstractAsyncBuffer <: AbstractBuffer end
3428

src/xla/IFRT/Array.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,15 @@ end
140140
function XLA.to_host(buffer::Array, data, reactant_sharding)
141141
reactant_sharding = Reactant.Sharding.unwrap_shardinfo(reactant_sharding)
142142

143+
# While some client implementations might support directly copying to host, but we
144+
# avoid the complexity of supporting that for now.
145+
single_device_arrays = disassemble_into_single_device_arrays(buffer, true)
146+
143147
if reactant_sharding isa Reactant.Sharding.NoSharding
144-
GC.@preserve buffer data begin
148+
data_buffer = first(single_device_arrays)
149+
GC.@preserve data_buffer data begin
145150
@ccall MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer(
146-
buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid}
151+
data_buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid}
147152
)::Cvoid
148153
end
149154
return data
@@ -159,10 +164,6 @@ function XLA.to_host(buffer::Array, data, reactant_sharding)
159164
untouched."
160165
end
161166

162-
# While some client implementations might support directly copying to host, but we
163-
# avoid the complexity of supporting that for now.
164-
single_device_arrays = disassemble_into_single_device_arrays(buffer, true)
165-
166167
array_slices, _ = XLA.sharding_to_concrete_array_indices(
167168
convert(XLA.CondensedOpSharding, reactant_sharding.hlo_sharding),
168169
size(data),

test/sharding.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,34 @@ end
262262
[2, 1], #=iota_transpose_perm=#
263263
) == [0, 2, 4, 6, 1, 3, 5, 7]
264264
end
265+
266+
@testset "Sharding with Mutation" begin
267+
if length(addressable_devices) 8
268+
mesh = Sharding.Mesh(reshape(Reactant.addressable_devices(), 2, 2, 2), (:x, :y, :z))
269+
270+
x_ra = Reactant.to_rarray(
271+
randn(Float32, 4, 5); sharding=Sharding.NamedSharding(mesh, ((:x, :y), :z))
272+
)
273+
274+
y_ra = Reactant.to_rarray(randn(Float32, 5, 4); sharding=Sharding.NoSharding())
275+
276+
function fn(x, y)
277+
z = x * y
278+
y[1:2, 1:2] .= 1
279+
return z
280+
end
281+
282+
y_ra_arr = Array(y_ra)
283+
x_ra_arr = Array(x_ra)
284+
z_ra_arr = fn(x_ra_arr, y_ra_arr)
285+
286+
z_ra = @jit fn(x_ra, y_ra)
287+
y_ra_final = Array(y_ra)
288+
289+
@test z_ra_arr Array(z_ra)
290+
@test y_ra_final[1:2, 1:2] y_ra_arr[1:2, 1:2]
291+
@test all(y_ra_final[1:2, 1:2] .== 1)
292+
else
293+
@warn "Not enough addressable devices to run sharding tests"
294+
end
295+
end

0 commit comments

Comments
 (0)