Skip to content

Commit 358ff04

Browse files
authored
fix: codegen for resharded input arrays (#1031)
* fix: codegen for resharded input arrays * fix: handle unresharding path II * fix: set * refactor: add a with_context function * fix: don't rerun tracing for replicating * feat: fast-path using hlo_sharding * fix: missing dispatch * chore: fix comment * feat: cache unresharding * fix: annotate resargs shardings as well * fix: error on unsupported path * test: SingleDeviceSharding test * fix: codegen for cached * fix: field name * fix: mlir codegen
1 parent 3579eb2 commit 358ff04

File tree

8 files changed

+349
-172
lines changed

8 files changed

+349
-172
lines changed

src/Compiler.jl

Lines changed: 154 additions & 97 deletions
Large diffs are not rendered by default.

src/Sharding.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -582,14 +582,7 @@ struct HloSharding{D1,D2,PS} <: AbstractSharding
582582
end
583583

584584
function Base.convert(::Type{HloSharding}, sharding::NamedSharding)
585-
if MLIR.IR._has_context()
586-
ctx = MLIR.IR.context()
587-
else
588-
ctx = MLIR.IR.Context(Reactant.registry[], false)
589-
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
590-
end
591-
592-
MLIR.IR.context!(ctx) do
585+
MLIR.IR.with_context(; allow_use_existing=true) do ctx
593586
mesh_op = Reactant.Ops.mesh(
594587
sharding.mesh; mod=MLIR.IR.Module(MLIR.IR.Location(; context=ctx))
595588
)

src/TracedUtils.jl

Lines changed: 43 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -452,18 +452,33 @@ function make_mlir_fn(
452452
end
453453

454454
# Ensure the sharding of the mutated arguments is propagated to the results
455-
result_not_replicated = falses(length(linear_results))
456455
for i in mutated_args
457456
arg = linear_args[i]
458-
if has_residx(arg) && haskey(traced_args_to_shardings, arg)
459-
residx = findfirst(Base.Fix1(===, arg), linear_results)
460-
@assert residx !== nothing
461-
result_not_replicated[residx] = true
457+
458+
if haskey(traced_args_to_shardings, arg) &&
459+
(has_residx(arg) || has_resargidx(arg))
460+
idx = findfirst(Base.Fix1(===, arg), linear_results)
461+
@assert idx !== nothing
462462
attr, dialect = linear_arg_shardings[i]
463463
if dialect == :sdy
464-
MLIR.API.mlirFuncSetResultAttr(func2, residx - 1, "sdy.sharding", attr)
464+
MLIR.API.mlirFuncSetResultAttr(func2, idx - 1, "sdy.sharding", attr)
465+
elseif dialect == :mhlo
466+
MLIR.API.mlirFuncSetResultAttr(func2, idx - 1, "mhlo.sharding", attr)
467+
else
468+
error("Unsupported dialect for tensor sharding: $(dialect)")
469+
end
470+
end
471+
end
472+
473+
for (i, res) in enumerate(linear_results)
474+
if has_argidx(res) && haskey(traced_args_to_shardings, res)
475+
argidx = findfirst(Base.Fix1(===, res), linear_args)
476+
@assert argidx !== nothing
477+
attr, dialect = linear_arg_shardings[argidx]
478+
if dialect == :sdy
479+
MLIR.API.mlirFuncSetResultAttr(func2, i - 1, "sdy.sharding", attr)
465480
elseif dialect == :mhlo
466-
MLIR.API.mlirFuncSetResultAttr(func2, residx - 1, "mhlo.sharding", attr)
481+
MLIR.API.mlirFuncSetResultAttr(func2, i - 1, "mhlo.sharding", attr)
467482
else
468483
error("Unsupported dialect for tensor sharding: $(dialect)")
469484
end
@@ -562,30 +577,6 @@ function push_val!(ad_inputs, x, path)
562577
return push!(ad_inputs, x)
563578
end
564579

565-
function get_argidx(x)
566-
for path in get_paths(x)
567-
if length(path) == 0
568-
continue
569-
end
570-
if path[1] == :args
571-
return path[2]::Int, path
572-
end
573-
end
574-
throw(AssertionError("No path found for $x"))
575-
end
576-
577-
function has_argidx(x)
578-
for path in get_paths(x)
579-
if length(path) == 0
580-
continue
581-
end
582-
if path[1] == :args
583-
return true
584-
end
585-
end
586-
return false
587-
end
588-
589580
function set!(x, path, tostore; emptypath=false)
590581
for p in path
591582
x = Reactant.Compiler.traced_getfield(x, p)
@@ -596,28 +587,33 @@ function set!(x, path, tostore; emptypath=false)
596587
return emptypath && set_paths!(x, ())
597588
end
598589

599-
function get_residx(x)
600-
for path in get_paths(x)
601-
if length(path) == 0
602-
continue
603-
end
604-
if path[1] == :result
605-
return path
590+
for (fn, key) in ((:arg, :args), (:res, :result), (:resarg, :resargs))
591+
has_fn = Symbol(:has_, fn, :idx)
592+
@eval begin
593+
function $(has_fn)(x)
594+
for path in get_paths(x)
595+
length(path) == 0 && continue
596+
path[1] == $(Meta.quot(key)) && return true
597+
end
598+
return false
606599
end
607600
end
608-
throw(AssertionError("No path found $x"))
609601
end
610602

611-
function has_residx(x)
603+
function get_argidx(x)
604+
for path in get_paths(x)
605+
length(path) == 0 && continue
606+
path[1] == :args && return (path[2]::Int, path)
607+
end
608+
throw(AssertionError("No path found for $x"))
609+
end
610+
611+
function get_residx(x)
612612
for path in get_paths(x)
613-
if length(path) == 0
614-
continue
615-
end
616-
if path[1] == :result
617-
return true
618-
end
613+
length(path) == 0 && continue
614+
path[1] == :result && return path
619615
end
620-
return false
616+
throw(AssertionError("No path found for $x"))
621617
end
622618

623619
function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}

src/mlir/IR/Context.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,33 @@ function context!(f, ctx::Context)
7979
end
8080
end
8181

82+
function with_context(f; allow_use_existing=false)
83+
delete_context = false
84+
if allow_use_existing && _has_context()
85+
ctx = context()
86+
else
87+
delete_context = true
88+
ctx = Context(Reactant.registry[], false)
89+
Reactant.Compiler.context_gc_vector[ctx] = Vector{
90+
Union{Reactant.TracedRArray,Reactant.TracedRNumber}
91+
}(
92+
undef, 0
93+
)
94+
@ccall API.mlir_c.RegisterDialects(ctx::API.MlirContext)::Cvoid
95+
end
96+
97+
activate!(ctx)
98+
result = try
99+
f(ctx)
100+
finally
101+
deactivate!(ctx)
102+
end
103+
104+
delete_context && Base.delete!(Reactant.Compiler.context_gc_vector, ctx)
105+
106+
return result
107+
end
108+
82109
function enable_multithreading!(enable::Bool=true; context::Context=context())
83110
API.mlirContextEnableMultithreading(context, enable)
84111
return context

src/xla/IFRT/Array.jl

Lines changed: 88 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -223,28 +223,100 @@ function replicate_array_to_all_devices(array::Array, sharding, mesh, size_arr)
223223
)
224224
end
225225

226-
shard_info = Reactant.Sharding.ShardInfo(
227-
reactant_sharding,
228-
Reactant.Sharding.sharding_to_array_slices(reactant_sharding, size_arr),
229-
)
230-
sharding_constraint = Reactant.Sharding.NamedSharding(
226+
XLA.is_replicated(hlo_sharding) && return array
227+
228+
output_sharding = Reactant.Sharding.NamedSharding(
231229
mesh, ntuple(Returns(nothing), length(size_arr))
232230
)
233231

234-
data = Reactant.ConcreteIFRTArray{eltype(array),length(size_arr),typeof(shard_info)}(
235-
AsyncArray(array, nothing), size_arr, shard_info
232+
# Manually write the MLIR for resharding resharding
233+
ctx = MLIR.IR.Context(Reactant.registry[], false)
234+
Reactant.Compiler.context_gc_vector[ctx] = Vector{
235+
Union{Reactant.TracedRArray,Reactant.TracedRNumber}
236+
}(
237+
undef, 0
236238
)
239+
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
240+
MLIR.IR.activate!(ctx)
241+
242+
sdycache = IdDict{
243+
Reactant.Sharding.Mesh,
244+
@NamedTuple{
245+
sym_name::MLIR.IR.Attribute,
246+
mesh_attr::MLIR.IR.Attribute,
247+
mesh_op::MLIR.IR.Operation,
248+
}
249+
}()
250+
Reactant.Compiler.activate_sdycache!(sdycache)
251+
252+
output_buffer = try
253+
data_mlir_type = [MLIR.IR.TensorType(reverse(size_arr), MLIR.IR.Type(eltype(array)))]
254+
mod = MLIR.IR.Module(MLIR.IR.Location(; context=ctx))
255+
256+
(; sym_name, mesh_attr) = Reactant.Ops.mesh(mesh; mod=mod)
257+
common_args = (ctx, sym_name, mesh_attr, size_arr)
258+
common_kwargs = (; dialect=:sdy, do_transpose=true)
259+
input_tensor_sharding_attr, _ = Reactant.Sharding.get_tensor_sharding_attribute(
260+
reactant_sharding, common_args...; common_kwargs...
261+
)
262+
output_tensor_sharding_attr, _ = Reactant.Sharding.get_tensor_sharding_attribute(
263+
output_sharding, common_args...; common_kwargs...
264+
)
237265

238-
# TODO: Directly write the MLIR for this part??
239-
fn_compiled = Reactant.compile(
240-
identity,
241-
(data,);
242-
shardy_passes=:to_mhlo_shardings,
243-
optimize=false,
244-
output_shardings=Dict(1 => sharding_constraint),
245-
)
266+
func = MLIR.Dialects.func.func_(;
267+
sym_name="main",
268+
function_type=MLIR.IR.FunctionType(data_mlir_type, data_mlir_type),
269+
no_inline=true,
270+
body=MLIR.IR.Region(),
271+
)
272+
fnbody = MLIR.IR.Block(data_mlir_type, [MLIR.IR.Location()])
273+
push!(MLIR.IR.region(func, 1), fnbody)
274+
MLIR.IR.activate!(fnbody)
275+
try
276+
MLIR.Dialects.func.return_([MLIR.IR.argument(fnbody, 1)])
277+
finally
278+
MLIR.IR.deactivate!(fnbody)
279+
end
280+
push!(MLIR.IR.body(mod), func)
281+
282+
MLIR.API.mlirFuncSetArgAttr(func, 0, "sdy.sharding", input_tensor_sharding_attr)
283+
MLIR.API.mlirFuncSetResultAttr(func, 0, "sdy.sharding", output_tensor_sharding_attr)
284+
285+
Reactant.Compiler.run_pass_pipeline!(
286+
mod,
287+
join(
288+
[
289+
"sdy-propagation-pipeline",
290+
"sdy-close-shardings",
291+
"xla-sdy-stablehlo-export-pipeline",
292+
"canonicalize",
293+
"cse",
294+
],
295+
",",
296+
),
297+
)
298+
299+
exec = XLA.compile(
300+
XLA.client(array),
301+
nothing,
302+
mod;
303+
is_sharded=true,
304+
global_device_ids=vec(mesh.device_ids),
305+
num_outputs=1, # unused
306+
num_parameters=1, # unused
307+
num_replicas=-1, # unused
308+
num_partitions=-1, # unused
309+
use_shardy_partitioner=false, # unused
310+
)
311+
312+
only(XLA.execute(exec, (array.buffer,), (UInt8(0),), Val(1)))
313+
finally
314+
Reactant.Compiler.deactivate_sdycache!(sdycache)
315+
MLIR.IR.deactivate!(ctx)
316+
end
317+
delete!(Reactant.Compiler.context_gc_vector, ctx)
246318

247-
return fn_compiled(data).data.buffer
319+
return output_buffer
248320
end
249321

250322
function XLA.unsafe_buffer_pointer(::Array)

src/xla/IFRT/AsyncArray.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,10 @@ function replicate_array_to_all_devices(array::AsyncArray, args...)
1818
wait(array)
1919
return replicate_array_to_all_devices(array.buffer, args...)
2020
end
21+
22+
function XLA.to_host(array::AsyncArray, data, reactant_sharding)
23+
wait(array)
24+
return XLA.to_host(array.buffer, data, reactant_sharding)
25+
end
26+
27+
XLA.sharding(x::AsyncArray) = XLA.sharding(x.buffer)

test/bcast.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@ end
1919
end
2020

2121
function test()
22-
ctx = MLIR.IR.Context(Reactant.registry[], false)
23-
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
24-
25-
MLIR.IR.context!(ctx) do
22+
MLIR.IR.with_context() do ctx
2623
mod = MLIR.IR.Module(MLIR.IR.Location())
2724
modbody = MLIR.IR.body(mod)
2825

test/sharding.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,3 +349,31 @@ end
349349
@warn "Not enough addressable devices to run sharding tests"
350350
end
351351
end
352+
353+
@testset "Bad Codegen for Resharded Inputs: #1027" begin
354+
if length(addressable_devices) 12 && Reactant.XLA.runtime() isa Val{:IFRT}
355+
x_ra = Reactant.to_rarray(
356+
randn(Float32, 32, 32);
357+
sharding=Sharding.NamedSharding(
358+
Sharding.Mesh(reshape(0:11, 3, 4), (:x, :y)), (:x, :y)
359+
),
360+
)
361+
362+
z_ra = Reactant.to_rarray(ones(Float32, 32, 32))
363+
364+
function test1!(x, z)
365+
y = x .+ x'
366+
x .+= y
367+
z .= x
368+
return z
369+
end
370+
371+
@jit test1!(x_ra, z_ra)
372+
373+
@test contains(
374+
string(Reactant.XLA.sharding(z_ra.data.buffer)), "SingleDeviceSharding"
375+
)
376+
else
377+
@warn "Not enough addressable devices to run sharding tests"
378+
end
379+
end

0 commit comments

Comments
 (0)