Skip to content

Commit 79e509a

Browse files
authored
fix: move away from sdy internal API (#1012)
* fix: move away from sdy internal API * feat: generate sharding tensor attribute directly * feat: generalize tensor shardings * fix: remove unnecessary change * fix: restore passes * fix: store parent sharding * feat: SdySharding * feat: parse NamedSharding from tensorattr * fix: fixes * fix: more fixes * chore: cleanup * feat: handle non-divisible case * fix: avoid HloSharding rountrip completely * fix: shard_type * fix: temporarily throw an error on axis splits * feat: support sub-axes info * fix: use vector for partition spec * refactor: remove __reconstruct_shardinfo
1 parent fa35069 commit 79e509a

File tree

11 files changed

+575
-192
lines changed

11 files changed

+575
-192
lines changed

.github/workflows/CI-localjll.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ jobs:
9595
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
9696
env:
9797
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
98-
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
98+
XLA_FLAGS: "--xla_force_host_platform_device_count=12"
9999
JULIA_DEBUG: "Reactant,Reactant_jll"
100100
- name: "Setup Runtime Preferences"
101101
run: |
@@ -115,7 +115,7 @@ jobs:
115115
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
116116
env:
117117
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
118-
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
118+
XLA_FLAGS: "--xla_force_host_platform_device_count=12"
119119
JULIA_DEBUG: "Reactant,Reactant_jll"
120120
- uses: julia-actions/julia-processcoverage@v1
121121
- uses: codecov/codecov-action@v5

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ jobs:
139139
env:
140140
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
141141
REACTANT_TEST_GROUP: ${{ matrix.test_group }}
142-
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
142+
XLA_FLAGS: "--xla_force_host_platform_device_count=12"
143143
JULIA_DEBUG: "Reactant,Reactant_jll"
144144
- uses: julia-actions/julia-processcoverage@v1
145145
- uses: codecov/codecov-action@v5

.github/workflows/downgrade.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ jobs:
7474
env:
7575
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
7676
REACTANT_TEST_GROUP: ${{ matrix.test_group }}
77-
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
77+
XLA_FLAGS: "--xla_force_host_platform_device_count=12"
7878
JULIA_DEBUG: "Reactant,Reactant_jll"
7979
- uses: julia-actions/julia-processcoverage@v1
8080
- uses: codecov/codecov-action@v5

src/Compiler.jl

Lines changed: 109 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,6 @@ function create_result(tocopy::T, path, args...) where {T}
110110
return Expr(:new, T, elems...)
111111
end
112112

113-
function __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh, N::Integer)
114-
device_to_array_slices, hlo_sharding = path_to_shard_info[path]
115-
delete!(path_to_shard_info, path)
116-
sharding = Reactant.Sharding.HloSharding(
117-
hlo_sharding, sharding_mesh, ntuple(Returns(true), N), ntuple(Returns(-1), N)
118-
)
119-
return Reactant.Sharding.ShardInfo(sharding, device_to_array_slices)
120-
end
121-
122113
function create_result(
123114
tocopy::ConcretePJRTNumber{T,D,S},
124115
path,
@@ -134,9 +125,7 @@ function create_result(
134125
if haskey(to_unreshard_results, path)
135126
error("TODO: Not yet Implemented. Use IFRT for this.")
136127
end
137-
sharding = __reconstruct_shardinfo(
138-
path, path_to_shard_info, sharding_mesh, ndims(tocopy)
139-
)
128+
sharding = pop!(path_to_shard_info, path)
140129
return :(ConcretePJRTNumber{$T,length($(restore)),$(typeof(sharding))}(
141130
($(restore)...,), $sharding
142131
))
@@ -150,9 +139,7 @@ function create_result(
150139
if haskey(to_unreshard_results, path)
151140
error("TODO: Not yet Implemented. Use IFRT for this.")
152141
end
153-
sharding = __reconstruct_shardinfo(
154-
path, path_to_shard_info, sharding_mesh, ndims(tocopy)
155-
)
142+
sharding = pop!(path_to_shard_info, path)
156143
return :(ConcretePJRTNumber{$T,length($(tocopy.data)),$(typeof(sharding))}(
157144
($(tocopy.data...,)), $sharding
158145
))
@@ -175,9 +162,7 @@ function create_result(
175162
if haskey(to_unreshard_results, path)
176163
error("TODO: Not yet Implemented.")
177164
end
178-
sharding = __reconstruct_shardinfo(
179-
path, path_to_shard_info, sharding_mesh, ndims(tocopy)
180-
)
165+
sharding = pop!(path_to_shard_info, path)
181166
return :(ConcreteIFRTNumber{$T,$(typeof(sharding))}($(restore), $sharding))
182167
else
183168
return :(ConcreteIFRTNumber{$T}($restore))
@@ -189,9 +174,7 @@ function create_result(
189174
if haskey(to_unreshard_results, path)
190175
error("TODO: Not yet Implemented.")
191176
end
192-
sharding = __reconstruct_shardinfo(
193-
path, path_to_shard_info, sharding_mesh, ndims(tocopy)
194-
)
177+
sharding = pop!(path_to_shard_info, path)
195178
return :(ConcreteIFRTNumber{$T,$(typeof(sharding))}($(tocopy.data), $sharding))
196179
end
197180
return :(ConcreteIFRTNumber{$T}($(tocopy.data)))
@@ -212,9 +195,7 @@ function create_result(
212195
if haskey(to_unreshard_results, path)
213196
error("TODO: Not yet Implemented. Use IFRT for this.")
214197
end
215-
sharding = __reconstruct_shardinfo(
216-
path, path_to_shard_info, sharding_mesh, ndims(tocopy)
217-
)
198+
sharding = pop!(path_to_shard_info, path)
218199
return :(ConcretePJRTArray{$T,$N,length($(restore)),$(typeof(sharding))}(
219200
($(restore)...,), $(tocopy.shape), $sharding
220201
))
@@ -228,9 +209,7 @@ function create_result(
228209
if haskey(to_unreshard_results, path)
229210
error("TODO: Not yet Implemented. Use IFRT for this.")
230211
end
231-
sharding = __reconstruct_shardinfo(
232-
path, path_to_shard_info, sharding_mesh, ndims(tocopy)
233-
)
212+
sharding = pop!(path_to_shard_info, path)
234213
return :(ConcretePJRTArray{$T,$N,length($(tocopy.data)),$(typeof(sharding))}(
235214
($(tocopy.data)...,), $(tocopy.shape), $sharding
236215
))
@@ -257,9 +236,7 @@ function create_result(
257236
$(restore), $(to_unreshard_results[path]), $(T), $(N), $(tocopy.shape)
258237
))
259238
end
260-
sharding = __reconstruct_shardinfo(
261-
path, path_to_shard_info, sharding_mesh, ndims(tocopy)
262-
)
239+
sharding = pop!(path_to_shard_info, path)
263240
return :(ConcreteIFRTArray{$T,$N,$(typeof(sharding))}(
264241
$(restore), $(tocopy.shape), $sharding
265242
))
@@ -275,9 +252,7 @@ function create_result(
275252
$(tocopy.data), $(to_unreshard_results[path]), $(T), $(N), $(tocopy.shape)
276253
))
277254
end
278-
sharding = __reconstruct_shardinfo(
279-
path, path_to_shard_info, sharding_mesh, ndims(tocopy)
280-
)
255+
sharding = pop!(path_to_shard_info, path)
281256
return :(ConcreteIFRTArray{$T,$N,$(typeof(sharding))}(
282257
$(tocopy.data), $(tocopy.shape), $sharding
283258
))
@@ -1041,22 +1016,58 @@ function compile_mlir!(
10411016

10421017
# shardy passes
10431018
use_shardy_partitioner = false
1019+
result_shardings = missing
10441020
if is_sharded
10451021
if shardy_passes == :default
10461022
# If `:default` is passed in, we will run a pass to export the sharding
10471023
# inside the corresponding compile function for IFRT/PJRT. This keeps the
10481024
# sharding readable.
10491025
use_shardy_partitioner = true
1050-
elseif shardy_passes == :to_mhlo_shardings
1051-
# Convert all shardy ops to corresponding mhlo attrs/ops that can be consumed by
1052-
# XLA (note we need to set `use_shardy_partitioner` to `false` in the options)
1053-
# TODO: Use https://github.com/openxla/shardy/blob/01d3205086132d1bdf0867e911c05f489918431d/shardy/dialect/sdy/transforms/propagation/propagation_pipeline.cc#L28 to pass in the options
1026+
elseif shardy_passes == :no_stablehlo_export
10541027
run_pass_pipeline!(
10551028
mod,
10561029
join(
1057-
["sdy-propagation-pipeline", "xla-sdy-stablehlo-export-pipeline"], ','
1030+
[
1031+
"sdy-propagation-pipeline",
1032+
"sdy-close-shardings",
1033+
"canonicalize",
1034+
"cse",
1035+
],
1036+
",",
10581037
),
10591038
)
1039+
elseif shardy_passes == :to_mhlo_shardings
1040+
# Convert all shardy ops to corresponding mhlo attrs/ops that can be consumed by
1041+
# XLA (note we need to set `use_shardy_partitioner` to `false` in the options)
1042+
run_pass_pipeline!(
1043+
mod, join(["sdy-propagation-pipeline", "sdy-close-shardings"], ",")
1044+
)
1045+
1046+
# Extract the result shardings from the compiled function
1047+
result_attrs = MLIR.IR.attr(compiled_f, "res_attrs")
1048+
if result_attrs !== nothing
1049+
result_shardings = Vector{
1050+
Union{Reactant.Sharding.NamedSharding,Reactant.Sharding.NoSharding}
1051+
}(
1052+
undef, length(result_attrs)
1053+
)
1054+
for i in 1:length(result_attrs)
1055+
result_attr = result_attrs[i - 1]
1056+
@assert MLIR.IR.isdict(result_attr)
1057+
mlir_attr = MLIR.API.mlirDictionaryAttrGetElementByName(
1058+
result_attr, "sdy.sharding"
1059+
)
1060+
if mlir_attr.ptr == C_NULL
1061+
result_shardings[i] = Reactant.Sharding.NoSharding()
1062+
else
1063+
result_shardings[i] = Reactant.Sharding.named_sharding_from_tensor_sharding_attr(
1064+
mlir_fn_res.sharding_mesh, MLIR.IR.Attribute(mlir_attr)
1065+
)
1066+
end
1067+
end
1068+
end
1069+
1070+
run_pass_pipeline!(mod, join(["xla-sdy-stablehlo-export-pipeline"], ','))
10601071

10611072
# Run our optimization passes here -- we need to be careful to not apply folding
10621073
# here since that violates the semantics of `sdy.constant` which was converted to
@@ -1142,6 +1153,7 @@ function compile_mlir!(
11421153
mlir_fn_res.sharding_mesh,
11431154
mlir_fn_res.mutated_args,
11441155
use_shardy_partitioner,
1156+
result_shardings,
11451157
)
11461158
end
11471159

@@ -1340,7 +1352,19 @@ function compile_call_expr(mod, compiler, options::Dict, args...)
13401352
)
13411353
end
13421354

1343-
function assert_mismatched_sharding(hlo_sharding_from_input, hlo_sharding_from_executable)
1355+
function assert_mismatched_sharding(
1356+
sharding_from_input, hlo_sharding_from_executable::Reactant.XLA.HloSharding
1357+
)
1358+
return assert_mismatched_sharding(
1359+
convert(Reactant.Sharding.HloSharding, sharding_from_input).hlo_sharding,
1360+
hlo_sharding_from_executable,
1361+
)
1362+
end
1363+
1364+
function assert_mismatched_sharding(
1365+
hlo_sharding_from_input::Reactant.XLA.HloSharding,
1366+
hlo_sharding_from_executable::Reactant.XLA.HloSharding,
1367+
)
13441368
@assert hlo_sharding_from_executable == hlo_sharding_from_input "Sharding provided by the user ($(string(hlo_sharding_from_input))) does not match the sharding computed by XLA ($(string(hlo_sharding_from_executable))). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE."
13451369
end
13461370

@@ -1943,7 +1967,8 @@ function compile(f, args; sync=false, kwargs...)
19431967
end
19441968

19451969
result_stores = Dict{Tuple,Symbol}()
1946-
path_to_shard_info = mlir_fn_res.is_sharded ? Dict{Tuple,Tuple}() : nothing
1970+
path_to_shard_info =
1971+
mlir_fn_res.is_sharded ? Dict{Tuple,Reactant.Sharding.ShardInfo}() : nothing
19471972

19481973
# generate Julia `Thunk` code
19491974
flatten_arg_names, flatten_code, resharded_inputs = codegen_flatten!(
@@ -1965,12 +1990,47 @@ function compile(f, args; sync=false, kwargs...)
19651990
)
19661991

19671992
linear_result_shard_info = if mlir_fn_res.is_sharded
1968-
output_shardings = XLA.get_output_shardings(exec)
1969-
XLA.compute_array_indices_and_hlo_sharding.(
1970-
output_shardings,
1971-
size.(mlir_fn_res.linear_results),
1972-
(mlir_fn_res.sharding_mesh.logical_device_ids,),
1993+
output_hlo_shardings = XLA.get_output_shardings(exec)
1994+
output_reactant_shardings = mlir_fn_res.result_shardings
1995+
local linear_result_shard_info = Vector{Reactant.Sharding.ShardInfo}(
1996+
undef, length(linear_results)
19731997
)
1998+
for i in 1:length(linear_results)
1999+
res_size = size(mlir_fn_res.linear_results[i])
2000+
array_slices, hlo_sharding = XLA.compute_array_indices_and_hlo_sharding(
2001+
output_hlo_shardings[i],
2002+
res_size,
2003+
mlir_fn_res.sharding_mesh.logical_device_ids,
2004+
)
2005+
2006+
if output_reactant_shardings !== missing
2007+
reactant_sharding = output_reactant_shardings[i]
2008+
use_hlo_sharding =
2009+
reactant_sharding isa Reactant.Sharding.NoSharding ||
2010+
convert(
2011+
Reactant.Sharding.HloSharding, reactant_sharding
2012+
).hlo_sharding != hlo_sharding
2013+
else
2014+
use_hlo_sharding = true
2015+
end
2016+
2017+
if use_hlo_sharding
2018+
linear_result_shard_info[i] = Reactant.Sharding.ShardInfo(
2019+
Reactant.Sharding.HloSharding(
2020+
hlo_sharding,
2021+
mlir_fn_res.sharding_mesh,
2022+
ntuple(Returns(true), length(res_size)),
2023+
ntuple(Returns(-1), length(res_size)),
2024+
),
2025+
array_slices,
2026+
)
2027+
else
2028+
linear_result_shard_info[i] = Reactant.Sharding.ShardInfo(
2029+
output_reactant_shardings[i], array_slices
2030+
)
2031+
end
2032+
end
2033+
linear_result_shard_info
19742034
else
19752035
ntuple(Returns(nothing), length(linear_results))
19762036
end
@@ -2031,6 +2091,10 @@ end
20312091

20322092
XLA.cost_analysis(thunk::Thunk) = XLA.cost_analysis(thunk.exec)
20332093

2094+
XLA.get_output_shardings(thunk::Thunk) = XLA.get_output_shardings(thunk.exec)
2095+
2096+
XLA.get_parameter_shardings(thunk::Thunk) = XLA.get_parameter_shardings(thunk.exec)
2097+
20342098
struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end
20352099

20362100
function Base.showerror(

src/Ops.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2414,9 +2414,12 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
24142414
cache = Reactant.Compiler.sdycache()
24152415
haskey(cache, sharding.mesh) || mesh(sharding.mesh; location)
24162416
(; sym_name, mesh_attr) = cache[sharding.mesh]
2417-
tensor_sharding_attr = Reactant.Sharding.get_shardy_tensor_sharding_attribute(
2418-
sharding, MLIR.IR.context(), sym_name, mesh_attr; do_transpose=false
2417+
2418+
tensor_sharding_attr, dialect = Reactant.Sharding.get_tensor_sharding_attribute(
2419+
sharding, MLIR.IR.context(), sym_name, mesh_attr, size(input); do_transpose=false
24192420
)
2421+
@assert dialect == :sdy "Expected dialect to be `sdy`, got $(dialect)"
2422+
24202423
resharded_value = MLIR.IR.result(
24212424
MLIR.Dialects.sdy.sharding_constraint(
24222425
input.mlir_data; sharding=tensor_sharding_attr, location

0 commit comments

Comments
 (0)