Skip to content

Sharding + NoSharding with Argument Mutation #925

@avik-pal

Description

@avik-pal
# simulating multiple devices on host
ENV["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
ENV["CUDA_VISIBLE_DEVICES"] = ""

using Reactant

mesh = Sharding.Mesh(reshape(Reactant.addressable_devices(), 2, 2, 2), (:x, :y, :z))

x_ra = Reactant.to_rarray(
    randn(Float32, 4, 5);
    sharding=Sharding.NamedSharding(mesh, ((:x, :y), :z)),
)

y_ra = Reactant.to_rarray(
    randn(Float32, 5, 4);
    sharding=Sharding.NoSharding(),
)

function fn(x, y)
    z = x * y
    y[1:2, 1:2] .= 1
    return z
end

begin
    y_ra_arr = Array(y_ra)
    x_ra_arr = Array(x_ra)
    z_ra_arr = fn(x_ra_arr, y_ra_arr)
    display(z_ra_arr)
    display(y_ra_arr)
end

begin
    y_ra2 = copy(y_ra)
    z_ra = @jit fn(x_ra, y_ra2)
    display(z_ra)
    display(y_ra2)
end
2025-03-15 14:54:14.444217: I external/xla/xla/service/spmd/shardy/shardy_xla_pass.cc:304] Using Shardy for XLA SPMD propagation.
2025-03-15 14:54:14.447168: I external/xla/xla/hlo/utils/hlo_sharding_util.cc:3025] There is no registered layout_canonicalization_callback.
ERROR: TypeError: in setfield!, expected Tuple{Reactant.XLA.PJRT.AsyncBuffer}, got a value of type NTuple{8, Reactant.XLA.PJRT.AsyncBuffer}
Stacktrace:
 [1] traced_setfield!
   @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:45 [inlined]
 [2] macro expansion
   @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:1811 [inlined]
 [3] (::Reactant.Compiler.Thunk{…})(::ConcretePJRTArray{…}, ::ConcretePJRTArray{…})
   @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:1858
 [4] top-level scope
   @ /mnt/software/lux/Reactant.jl/envs/shardy/sharding.jl:1132
Some type information was truncated. Use `show(err)` to see complete types.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions