@@ -110,15 +110,6 @@ function create_result(tocopy::T, path, args...) where {T}
110
110
return Expr (:new , T, elems... )
111
111
end
112
112
113
- function __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh, N:: Integer )
114
- device_to_array_slices, hlo_sharding = path_to_shard_info[path]
115
- delete! (path_to_shard_info, path)
116
- sharding = Reactant. Sharding. HloSharding (
117
- hlo_sharding, sharding_mesh, ntuple (Returns (true ), N), ntuple (Returns (- 1 ), N)
118
- )
119
- return Reactant. Sharding. ShardInfo (sharding, device_to_array_slices)
120
- end
121
-
122
113
function create_result (
123
114
tocopy:: ConcretePJRTNumber{T,D,S} ,
124
115
path,
@@ -134,9 +125,7 @@ function create_result(
134
125
if haskey (to_unreshard_results, path)
135
126
error (" TODO: Not yet Implemented. Use IFRT for this." )
136
127
end
137
- sharding = __reconstruct_shardinfo (
138
- path, path_to_shard_info, sharding_mesh, ndims (tocopy)
139
- )
128
+ sharding = pop! (path_to_shard_info, path)
140
129
return :(ConcretePJRTNumber {$T,length($(restore)),$(typeof(sharding))} (
141
130
($ (restore). .. ,), $ sharding
142
131
))
@@ -150,9 +139,7 @@ function create_result(
150
139
if haskey (to_unreshard_results, path)
151
140
error (" TODO: Not yet Implemented. Use IFRT for this." )
152
141
end
153
- sharding = __reconstruct_shardinfo (
154
- path, path_to_shard_info, sharding_mesh, ndims (tocopy)
155
- )
142
+ sharding = pop! (path_to_shard_info, path)
156
143
return :(ConcretePJRTNumber {$T,length($(tocopy.data)),$(typeof(sharding))} (
157
144
($ (tocopy. data... ,)), $ sharding
158
145
))
@@ -175,9 +162,7 @@ function create_result(
175
162
if haskey (to_unreshard_results, path)
176
163
error (" TODO: Not yet Implemented." )
177
164
end
178
- sharding = __reconstruct_shardinfo (
179
- path, path_to_shard_info, sharding_mesh, ndims (tocopy)
180
- )
165
+ sharding = pop! (path_to_shard_info, path)
181
166
return :(ConcreteIFRTNumber {$T,$(typeof(sharding))} ($ (restore), $ sharding))
182
167
else
183
168
return :(ConcreteIFRTNumber {$T} ($ restore))
@@ -189,9 +174,7 @@ function create_result(
189
174
if haskey (to_unreshard_results, path)
190
175
error (" TODO: Not yet Implemented." )
191
176
end
192
- sharding = __reconstruct_shardinfo (
193
- path, path_to_shard_info, sharding_mesh, ndims (tocopy)
194
- )
177
+ sharding = pop! (path_to_shard_info, path)
195
178
return :(ConcreteIFRTNumber {$T,$(typeof(sharding))} ($ (tocopy. data), $ sharding))
196
179
end
197
180
return :(ConcreteIFRTNumber {$T} ($ (tocopy. data)))
@@ -212,9 +195,7 @@ function create_result(
212
195
if haskey (to_unreshard_results, path)
213
196
error (" TODO: Not yet Implemented. Use IFRT for this." )
214
197
end
215
- sharding = __reconstruct_shardinfo (
216
- path, path_to_shard_info, sharding_mesh, ndims (tocopy)
217
- )
198
+ sharding = pop! (path_to_shard_info, path)
218
199
return :(ConcretePJRTArray {$T,$N,length($(restore)),$(typeof(sharding))} (
219
200
($ (restore). .. ,), $ (tocopy. shape), $ sharding
220
201
))
@@ -228,9 +209,7 @@ function create_result(
228
209
if haskey (to_unreshard_results, path)
229
210
error (" TODO: Not yet Implemented. Use IFRT for this." )
230
211
end
231
- sharding = __reconstruct_shardinfo (
232
- path, path_to_shard_info, sharding_mesh, ndims (tocopy)
233
- )
212
+ sharding = pop! (path_to_shard_info, path)
234
213
return :(ConcretePJRTArray {$T,$N,length($(tocopy.data)),$(typeof(sharding))} (
235
214
($ (tocopy. data). .. ,), $ (tocopy. shape), $ sharding
236
215
))
@@ -257,9 +236,7 @@ function create_result(
257
236
$ (restore), $ (to_unreshard_results[path]), $ (T), $ (N), $ (tocopy. shape)
258
237
))
259
238
end
260
- sharding = __reconstruct_shardinfo (
261
- path, path_to_shard_info, sharding_mesh, ndims (tocopy)
262
- )
239
+ sharding = pop! (path_to_shard_info, path)
263
240
return :(ConcreteIFRTArray {$T,$N,$(typeof(sharding))} (
264
241
$ (restore), $ (tocopy. shape), $ sharding
265
242
))
@@ -275,9 +252,7 @@ function create_result(
275
252
$ (tocopy. data), $ (to_unreshard_results[path]), $ (T), $ (N), $ (tocopy. shape)
276
253
))
277
254
end
278
- sharding = __reconstruct_shardinfo (
279
- path, path_to_shard_info, sharding_mesh, ndims (tocopy)
280
- )
255
+ sharding = pop! (path_to_shard_info, path)
281
256
return :(ConcreteIFRTArray {$T,$N,$(typeof(sharding))} (
282
257
$ (tocopy. data), $ (tocopy. shape), $ sharding
283
258
))
@@ -1041,22 +1016,58 @@ function compile_mlir!(
1041
1016
1042
1017
# shardy passes
1043
1018
use_shardy_partitioner = false
1019
+ result_shardings = missing
1044
1020
if is_sharded
1045
1021
if shardy_passes == :default
1046
1022
# If `:default` is passed in, we will run a pass to export the sharding
1047
1023
# inside the corresponding compile function for IFRT/PJRT. This keeps the
1048
1024
# sharding readable.
1049
1025
use_shardy_partitioner = true
1050
- elseif shardy_passes == :to_mhlo_shardings
1051
- # Convert all shardy ops to corresponding mhlo attrs/ops that can be consumed by
1052
- # XLA (note we need to set `use_shardy_partitioner` to `false` in the options)
1053
- # TODO : Use https://github.com/openxla/shardy/blob/01d3205086132d1bdf0867e911c05f489918431d/shardy/dialect/sdy/transforms/propagation/propagation_pipeline.cc#L28 to pass in the options
1026
+ elseif shardy_passes == :no_stablehlo_export
1054
1027
run_pass_pipeline! (
1055
1028
mod,
1056
1029
join (
1057
- [" sdy-propagation-pipeline" , " xla-sdy-stablehlo-export-pipeline" ], ' ,'
1030
+ [
1031
+ " sdy-propagation-pipeline" ,
1032
+ " sdy-close-shardings" ,
1033
+ " canonicalize" ,
1034
+ " cse" ,
1035
+ ],
1036
+ " ," ,
1058
1037
),
1059
1038
)
1039
+ elseif shardy_passes == :to_mhlo_shardings
1040
+ # Convert all shardy ops to corresponding mhlo attrs/ops that can be consumed by
1041
+ # XLA (note we need to set `use_shardy_partitioner` to `false` in the options)
1042
+ run_pass_pipeline! (
1043
+ mod, join ([" sdy-propagation-pipeline" , " sdy-close-shardings" ], " ," )
1044
+ )
1045
+
1046
+ # Extract the result shardings from the compiled function
1047
+ result_attrs = MLIR. IR. attr (compiled_f, " res_attrs" )
1048
+ if result_attrs != = nothing
1049
+ result_shardings = Vector{
1050
+ Union{Reactant. Sharding. NamedSharding,Reactant. Sharding. NoSharding}
1051
+ }(
1052
+ undef, length (result_attrs)
1053
+ )
1054
+ for i in 1 : length (result_attrs)
1055
+ result_attr = result_attrs[i - 1 ]
1056
+ @assert MLIR. IR. isdict (result_attr)
1057
+ mlir_attr = MLIR. API. mlirDictionaryAttrGetElementByName (
1058
+ result_attr, " sdy.sharding"
1059
+ )
1060
+ if mlir_attr. ptr == C_NULL
1061
+ result_shardings[i] = Reactant. Sharding. NoSharding ()
1062
+ else
1063
+ result_shardings[i] = Reactant. Sharding. named_sharding_from_tensor_sharding_attr (
1064
+ mlir_fn_res. sharding_mesh, MLIR. IR. Attribute (mlir_attr)
1065
+ )
1066
+ end
1067
+ end
1068
+ end
1069
+
1070
+ run_pass_pipeline! (mod, join ([" xla-sdy-stablehlo-export-pipeline" ], ' ,' ))
1060
1071
1061
1072
# Run our optimization passes here -- we need to be careful to not apply folding
1062
1073
# here since that violates the semantics of `sdy.constant` which was converted to
@@ -1142,6 +1153,7 @@ function compile_mlir!(
1142
1153
mlir_fn_res. sharding_mesh,
1143
1154
mlir_fn_res. mutated_args,
1144
1155
use_shardy_partitioner,
1156
+ result_shardings,
1145
1157
)
1146
1158
end
1147
1159
@@ -1340,7 +1352,19 @@ function compile_call_expr(mod, compiler, options::Dict, args...)
1340
1352
)
1341
1353
end
1342
1354
1343
- function assert_mismatched_sharding (hlo_sharding_from_input, hlo_sharding_from_executable)
1355
+ function assert_mismatched_sharding (
1356
+ sharding_from_input, hlo_sharding_from_executable:: Reactant.XLA.HloSharding
1357
+ )
1358
+ return assert_mismatched_sharding (
1359
+ convert (Reactant. Sharding. HloSharding, sharding_from_input). hlo_sharding,
1360
+ hlo_sharding_from_executable,
1361
+ )
1362
+ end
1363
+
1364
+ function assert_mismatched_sharding (
1365
+ hlo_sharding_from_input:: Reactant.XLA.HloSharding ,
1366
+ hlo_sharding_from_executable:: Reactant.XLA.HloSharding ,
1367
+ )
1344
1368
@assert hlo_sharding_from_executable == hlo_sharding_from_input " Sharding provided by the user ($(string (hlo_sharding_from_input)) ) does not match the sharding computed by XLA ($(string (hlo_sharding_from_executable)) ). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE."
1345
1369
end
1346
1370
@@ -1943,7 +1967,8 @@ function compile(f, args; sync=false, kwargs...)
1943
1967
end
1944
1968
1945
1969
result_stores = Dict {Tuple,Symbol} ()
1946
- path_to_shard_info = mlir_fn_res. is_sharded ? Dict {Tuple,Tuple} () : nothing
1970
+ path_to_shard_info =
1971
+ mlir_fn_res. is_sharded ? Dict {Tuple,Reactant.Sharding.ShardInfo} () : nothing
1947
1972
1948
1973
# generate Julia `Thunk` code
1949
1974
flatten_arg_names, flatten_code, resharded_inputs = codegen_flatten! (
@@ -1965,12 +1990,47 @@ function compile(f, args; sync=false, kwargs...)
1965
1990
)
1966
1991
1967
1992
linear_result_shard_info = if mlir_fn_res. is_sharded
1968
- output_shardings = XLA. get_output_shardings (exec)
1969
- XLA. compute_array_indices_and_hlo_sharding .(
1970
- output_shardings,
1971
- size .(mlir_fn_res. linear_results),
1972
- (mlir_fn_res. sharding_mesh. logical_device_ids,),
1993
+ output_hlo_shardings = XLA. get_output_shardings (exec)
1994
+ output_reactant_shardings = mlir_fn_res. result_shardings
1995
+ local linear_result_shard_info = Vector {Reactant.Sharding.ShardInfo} (
1996
+ undef, length (linear_results)
1973
1997
)
1998
+ for i in 1 : length (linear_results)
1999
+ res_size = size (mlir_fn_res. linear_results[i])
2000
+ array_slices, hlo_sharding = XLA. compute_array_indices_and_hlo_sharding (
2001
+ output_hlo_shardings[i],
2002
+ res_size,
2003
+ mlir_fn_res. sharding_mesh. logical_device_ids,
2004
+ )
2005
+
2006
+ if output_reactant_shardings != = missing
2007
+ reactant_sharding = output_reactant_shardings[i]
2008
+ use_hlo_sharding =
2009
+ reactant_sharding isa Reactant. Sharding. NoSharding ||
2010
+ convert (
2011
+ Reactant. Sharding. HloSharding, reactant_sharding
2012
+ ). hlo_sharding != hlo_sharding
2013
+ else
2014
+ use_hlo_sharding = true
2015
+ end
2016
+
2017
+ if use_hlo_sharding
2018
+ linear_result_shard_info[i] = Reactant. Sharding. ShardInfo (
2019
+ Reactant. Sharding. HloSharding (
2020
+ hlo_sharding,
2021
+ mlir_fn_res. sharding_mesh,
2022
+ ntuple (Returns (true ), length (res_size)),
2023
+ ntuple (Returns (- 1 ), length (res_size)),
2024
+ ),
2025
+ array_slices,
2026
+ )
2027
+ else
2028
+ linear_result_shard_info[i] = Reactant. Sharding. ShardInfo (
2029
+ output_reactant_shardings[i], array_slices
2030
+ )
2031
+ end
2032
+ end
2033
+ linear_result_shard_info
1974
2034
else
1975
2035
ntuple (Returns (nothing ), length (linear_results))
1976
2036
end
@@ -2031,6 +2091,10 @@ end
2031
2091
2032
2092
XLA. cost_analysis (thunk:: Thunk ) = XLA. cost_analysis (thunk. exec)
2033
2093
2094
+ XLA. get_output_shardings (thunk:: Thunk ) = XLA. get_output_shardings (thunk. exec)
2095
+
2096
+ XLA. get_parameter_shardings (thunk:: Thunk ) = XLA. get_parameter_shardings (thunk. exec)
2097
+
2034
2098
struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end
2035
2099
2036
2100
function Base. showerror (
0 commit comments