-
Notifications
You must be signed in to change notification settings - Fork 31
Closed
Labels
Description
# simulating multiple devices on host
ENV["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
ENV["CUDA_VISIBLE_DEVICES"] = ""
using Reactant
mesh = Sharding.Mesh(reshape(Reactant.addressable_devices(), 2, 2, 2), (:x, :y, :z))
x_ra = Reactant.to_rarray(
randn(Float32, 4, 5);
sharding=Sharding.NamedSharding(mesh, ((:x, :y), :z)),
)
y_ra = Reactant.to_rarray(
randn(Float32, 5, 4);
sharding=Sharding.NoSharding(),
)
function fn(x, y)
z = x * y
y[1:2, 1:2] .= 1
return z
end
begin
y_ra_arr = Array(y_ra)
x_ra_arr = Array(x_ra)
z_ra_arr = fn(x_ra_arr, y_ra_arr)
display(z_ra_arr)
display(y_ra_arr)
end
begin
y_ra2 = copy(y_ra)
z_ra = @jit fn(x_ra, y_ra2)
display(z_ra)
display(y_ra2)
end
2025-03-15 14:54:14.444217: I external/xla/xla/service/spmd/shardy/shardy_xla_pass.cc:304] Using Shardy for XLA SPMD propagation.
2025-03-15 14:54:14.447168: I external/xla/xla/hlo/utils/hlo_sharding_util.cc:3025] There is no registered layout_canonicalization_callback.
ERROR: TypeError: in setfield!, expected Tuple{Reactant.XLA.PJRT.AsyncBuffer}, got a value of type NTuple{8, Reactant.XLA.PJRT.AsyncBuffer}
Stacktrace:
[1] traced_setfield!
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:45 [inlined]
[2] macro expansion
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:1811 [inlined]
[3] (::Reactant.Compiler.Thunk{…})(::ConcretePJRTArray{…}, ::ConcretePJRTArray{…})
@ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:1858
[4] top-level scope
@ /mnt/software/lux/Reactant.jl/envs/shardy/sharding.jl:1132
Some type information was truncated. Use `show(err)` to see complete types.