@@ -223,28 +223,100 @@ function replicate_array_to_all_devices(array::Array, sharding, mesh, size_arr)
223
223
)
224
224
end
225
225
226
- shard_info = Reactant. Sharding. ShardInfo (
227
- reactant_sharding,
228
- Reactant. Sharding. sharding_to_array_slices (reactant_sharding, size_arr),
229
- )
230
- sharding_constraint = Reactant. Sharding. NamedSharding (
226
+ XLA. is_replicated (hlo_sharding) && return array
227
+
228
+ output_sharding = Reactant. Sharding. NamedSharding (
231
229
mesh, ntuple (Returns (nothing ), length (size_arr))
232
230
)
233
231
234
- data = Reactant. ConcreteIFRTArray {eltype(array),length(size_arr),typeof(shard_info)} (
235
- AsyncArray (array, nothing ), size_arr, shard_info
232
+ # Manually write the MLIR for resharding resharding
233
+ ctx = MLIR. IR. Context (Reactant. registry[], false )
234
+ Reactant. Compiler. context_gc_vector[ctx] = Vector{
235
+ Union{Reactant. TracedRArray,Reactant. TracedRNumber}
236
+ }(
237
+ undef, 0
236
238
)
239
+ @ccall MLIR. API. mlir_c. RegisterDialects (ctx:: MLIR.API.MlirContext ):: Cvoid
240
+ MLIR. IR. activate! (ctx)
241
+
242
+ sdycache = IdDict{
243
+ Reactant. Sharding. Mesh,
244
+ @NamedTuple {
245
+ sym_name:: MLIR.IR.Attribute ,
246
+ mesh_attr:: MLIR.IR.Attribute ,
247
+ mesh_op:: MLIR.IR.Operation ,
248
+ }
249
+ }()
250
+ Reactant. Compiler. activate_sdycache! (sdycache)
251
+
252
+ output_buffer = try
253
+ data_mlir_type = [MLIR. IR. TensorType (reverse (size_arr), MLIR. IR. Type (eltype (array)))]
254
+ mod = MLIR. IR. Module (MLIR. IR. Location (; context= ctx))
255
+
256
+ (; sym_name, mesh_attr) = Reactant. Ops. mesh (mesh; mod= mod)
257
+ common_args = (ctx, sym_name, mesh_attr, size_arr)
258
+ common_kwargs = (; dialect= :sdy , do_transpose= true )
259
+ input_tensor_sharding_attr, _ = Reactant. Sharding. get_tensor_sharding_attribute (
260
+ reactant_sharding, common_args... ; common_kwargs...
261
+ )
262
+ output_tensor_sharding_attr, _ = Reactant. Sharding. get_tensor_sharding_attribute (
263
+ output_sharding, common_args... ; common_kwargs...
264
+ )
237
265
238
- # TODO : Directly write the MLIR for this part??
239
- fn_compiled = Reactant. compile (
240
- identity,
241
- (data,);
242
- shardy_passes= :to_mhlo_shardings ,
243
- optimize= false ,
244
- output_shardings= Dict (1 => sharding_constraint),
245
- )
266
+ func = MLIR. Dialects. func. func_ (;
267
+ sym_name= " main" ,
268
+ function_type= MLIR. IR. FunctionType (data_mlir_type, data_mlir_type),
269
+ no_inline= true ,
270
+ body= MLIR. IR. Region (),
271
+ )
272
+ fnbody = MLIR. IR. Block (data_mlir_type, [MLIR. IR. Location ()])
273
+ push! (MLIR. IR. region (func, 1 ), fnbody)
274
+ MLIR. IR. activate! (fnbody)
275
+ try
276
+ MLIR. Dialects. func. return_ ([MLIR. IR. argument (fnbody, 1 )])
277
+ finally
278
+ MLIR. IR. deactivate! (fnbody)
279
+ end
280
+ push! (MLIR. IR. body (mod), func)
281
+
282
+ MLIR. API. mlirFuncSetArgAttr (func, 0 , " sdy.sharding" , input_tensor_sharding_attr)
283
+ MLIR. API. mlirFuncSetResultAttr (func, 0 , " sdy.sharding" , output_tensor_sharding_attr)
284
+
285
+ Reactant. Compiler. run_pass_pipeline! (
286
+ mod,
287
+ join (
288
+ [
289
+ " sdy-propagation-pipeline" ,
290
+ " sdy-close-shardings" ,
291
+ " xla-sdy-stablehlo-export-pipeline" ,
292
+ " canonicalize" ,
293
+ " cse" ,
294
+ ],
295
+ " ," ,
296
+ ),
297
+ )
298
+
299
+ exec = XLA. compile (
300
+ XLA. client (array),
301
+ nothing ,
302
+ mod;
303
+ is_sharded= true ,
304
+ global_device_ids= vec (mesh. device_ids),
305
+ num_outputs= 1 , # unused
306
+ num_parameters= 1 , # unused
307
+ num_replicas= - 1 , # unused
308
+ num_partitions= - 1 , # unused
309
+ use_shardy_partitioner= false , # unused
310
+ )
311
+
312
+ only (XLA. execute (exec, (array. buffer,), (UInt8 (0 ),), Val (1 )))
313
+ finally
314
+ Reactant. Compiler. deactivate_sdycache! (sdycache)
315
+ MLIR. IR. deactivate! (ctx)
316
+ end
317
+ delete! (Reactant. Compiler. context_gc_vector, ctx)
246
318
247
- return fn_compiled (data) . data . buffer
319
+ return output_buffer
248
320
end
249
321
250
322
function XLA. unsafe_buffer_pointer (:: Array )
0 commit comments