|
42 | 42 | @nospecialize(obj::AbstractArray{T}), field, val
|
43 | 43 | ) where {T}
|
44 | 44 | ancestor_obj = ancestor(obj)
|
45 |
| - (isbitstype(T) || ancestor_obj isa RArray) && return Base.setfield!(obj, field, val) |
| 45 | + (isbitstype(T) || ancestor_obj isa RArray) && return setfield_carray!(obj, field, val) |
46 | 46 | return Base.setindex!(obj, val, field)
|
47 | 47 | end
|
48 | 48 |
|
49 | 49 | @inline function traced_setfield!(@nospecialize(obj::Dict), field, val)
|
50 | 50 | return Base.setindex!(obj, field, val)
|
51 | 51 | end
|
52 | 52 |
|
| 53 | +# fallback |
| 54 | +@inline function setfield_carray!(obj, field, val) |
| 55 | + return Base.setfield!(obj, field, val) |
| 56 | +end |
| 57 | + |
| 58 | +@inline function setfield_carray!(obj::ConcretePJRTArray, field, val) |
| 59 | + if field !== :data || typeof(val) == typeof(getfield(obj, field)) |
| 60 | + return Base.setfield!(obj, field, val) |
| 61 | + end |
| 62 | + |
| 63 | + # This case is triggered if the user had provided an unsharded input (NoSharding), but |
| 64 | + # we had to replicate it before feeding it to XLA |
| 65 | + @assert !Reactant.Sharding.is_sharded(obj) "Expected unsharded input. Open an issue on \ |
| 66 | + Reactant.jl with a MWE." |
| 67 | + devices = Reactant.XLA.device.(val) |
| 68 | + device = Reactant.XLA.device(only(obj.data)) |
| 69 | + idx = findfirst(isequal(device), devices) |
| 70 | + return Base.setfield!(obj, field, (val[idx],)) |
| 71 | +end |
| 72 | + |
53 | 73 | function create_result(
|
54 | 74 | tocopy::T, path, result_stores, path_to_shard_info, sharding_mesh
|
55 | 75 | ) where {T}
|
@@ -737,10 +757,25 @@ function compile_mlir!(
|
737 | 757 | MLIR.IR.deactivate!(MLIR.IR.body(mod))
|
738 | 758 | MLIR.IR.deactivate!(mod)
|
739 | 759 | end
|
740 |
| - (; fnwrapped, traced_result, seen_args, ret, linear_args, in_tys, linear_results) = |
741 |
| - mlir_fn_res |
| 760 | + (; |
| 761 | + fnwrapped, |
| 762 | + traced_result, |
| 763 | + seen_args, |
| 764 | + ret, |
| 765 | + linear_args, |
| 766 | + in_tys, |
| 767 | + linear_results, |
| 768 | + is_sharded, |
| 769 | + ) = mlir_fn_res |
742 | 770 | compiled_f = mlir_fn_res.f
|
743 | 771 |
|
| 772 | + # Custom Kernels without Raising will lead to suboptimal codegen for is_sharded, force |
| 773 | + # raising |
| 774 | + if is_sharded |
| 775 | + is_raising = true |
| 776 | + raise isa Bool && (raise = true) |
| 777 | + end |
| 778 | + |
744 | 779 | concrete_seen = OrderedIdDict()
|
745 | 780 |
|
746 | 781 | concrete_result = make_tracer(
|
@@ -1274,20 +1309,40 @@ function codegen_flatten!(
|
1274 | 1309 | device_to_array_slices, _ = XLA.sharding_to_concrete_array_indices(
|
1275 | 1310 | condensed_op_sharding, size(carg), mesh.logical_device_ids
|
1276 | 1311 | )
|
| 1312 | + |
| 1313 | + # Extract the buffer_slice |
| 1314 | + buf_slice = Dict{eltype(device_to_array_slices),Symbol}() |
| 1315 | + counter = 0 |
1277 | 1316 | for j in 1:length(mesh)
|
1278 |
| - device_id = mesh.logical_device_ids[j] |
1279 |
| - buf = Symbol(:buf_, i, :_, device_id) |
| 1317 | + sliced_buf = Symbol(:sliced_buf_, i, :_, counter) |
1280 | 1318 | slice = device_to_array_slices[j]
|
| 1319 | + haskey(buf_slice, slice) && continue |
| 1320 | + counter += 1 |
1281 | 1321 | push!(
|
1282 | 1322 | flatten_code,
|
1283 |
| - :($buf = XLA.synced_buffer(only($usbuf[$(slice)...].data))), |
| 1323 | + :( |
| 1324 | + $sliced_buf = only( |
| 1325 | + Reactant._fast_slice($usbuf, $(slice...)).data |
| 1326 | + ) |
| 1327 | + ), |
1284 | 1328 | )
|
| 1329 | + buf_slice[slice] = sliced_buf |
| 1330 | + end |
| 1331 | + |
| 1332 | + for j in 1:length(mesh) |
| 1333 | + device_id = mesh.logical_device_ids[j] |
| 1334 | + buf = Symbol(:buf_, i, :_, device_id) |
| 1335 | + slice = device_to_array_slices[j] |
1285 | 1336 | sbuf = Symbol(:s, buf)
|
1286 |
| - device = XLA.get_device(client, device_id) |
1287 | 1337 | push!(flatten_names, sbuf)
|
1288 | 1338 | push!(
|
1289 | 1339 | flatten_code,
|
1290 |
| - :($sbuf = XLA.copy_buffer_to_device($buf, $device)), |
| 1340 | + :( |
| 1341 | + $sbuf = XLA.copy_buffer_to_device( |
| 1342 | + XLA.synced_buffer($(buf_slice[slice])), |
| 1343 | + $(XLA.get_device(client, device_id)), |
| 1344 | + ) |
| 1345 | + ), |
1291 | 1346 | )
|
1292 | 1347 | end
|
1293 | 1348 | end
|
|
0 commit comments