-
Notifications
You must be signed in to change notification settings - Fork 29
[ReactantExtra] Update XLA and adapt API.cpp
#995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is currently segfaulting during precompilation because the API probably changed a lot. Disabling precompilation workload by adding [Reactant]
precompile_workload = false to the julia> using Reactant, Reactant.XLA
julia> Reactant.initialize_dialect()
julia> if XLA.REACTANT_XLA_RUNTIME == "PJRT"
client = XLA.PJRT.CPUClient(; checkcount=false)
elseif XLA.REACTANT_XLA_RUNTIME == "IFRT"
client = XLA.IFRT.CPUClient(; checkcount=false)
else
error("Unsupported runtime: $(XLA.REACTANT_XLA_RUNTIME)")
end;
julia> x = ConcreteRNumber(2.0; client);
julia> Reactant.compile(sin, (x,); client, optimize=:all)
ERROR: module @reactant_sin attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<f64> {reactant.donated}) -> tensor<f64> {
%0 = stablehlo.sine %arg0 : tensor<f64>
return %0 : tensor<f64>
}
}
UNIMPLEMENTED: Compile with MLIR Module is not supported.
Stacktrace:
[1] reactant_err(msg::Cstring)
@ Reactant.XLA ~/.julia/dev/Reactant/src/xla/Utils.jl:12
[2] compile(client::Reactant.XLA.PJRT.Client, device::Reactant.XLA.PJRT.Device, mod::Reactant.MLIR.IR.Module; is_sharded::Bool, global_device_ids::Vector{…}, num_outputs::Int64, num_parameters::Int64, num_replicas::Int64, num_partitions::Int64, use_shardy_partitioner::Bool)
@ Reactant.XLA.PJRT ~/.julia/dev/Reactant/src/xla/PJRT/LoadedExecutable.jl:82
[3] compile_xla(f::Function, args::Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{…}}}; client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{optimize::Symbol})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:1913
[4] compile_xla
@ ~/.julia/dev/Reactant/src/Compiler.jl:1870 [inlined]
[5] compile(f::Function, args::Tuple{ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{…}}}; sync::Bool, kwargs::@Kwargs{client::Reactant.XLA.PJRT.Client, optimize::Symbol})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:1936
[6] top-level scope
@ REPL[7]:1
Some type information was truncated. Use `show(err)` to see complete types. |
Good news is that IFRT works: julia> using Reactant, Reactant.XLA
Precompiling Reactant...
1 dependency successfully precompiled in 8 seconds. 76 already precompiled.
julia> client = XLA.IFRT.CPUClient(; checkcount=false)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1742598544.115851 4185742 pjrt_client.cc:525] PjRt-IFRT device count: total=1, addressable=1
I0000 00:00:1742598544.115892 4185742 pjrt_client.cc:529] Addressable PjRt-IFRT device: CpuDevice(id=0)
Reactant.XLA.IFRT.Client(Ptr{Nothing} @0x0000000001e972b0)
julia> x = ConcreteRNumber(2.0; client);
julia> Reactant.compile(sin, (x,); client, optimize=:all)
Reactant.Compiler.Thunk{typeof(sin), Symbol("##sin_reactant#231"), Tuple{ConcreteIFRTNumber{Float64, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, false, Reactant.XLA.IFRT.LoadedExecutable, Reactant.XLA.IFRT.Device}(sin, Reactant.XLA.IFRT.LoadedExecutable(Ptr{Nothing} @0x000000000aa369b0, 1, 1, false, 1, 1), Reactant.XLA.IFRT.Device(Ptr{Nothing} @0x0000000001e92f60)) |
Segfaults in cuda integration tests with PJRT:
For what is worth, backtrace in gdb (not much different than above):
Reported upstream: EnzymeAD/Enzyme-JAX#515 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Last two Julia places where we are before the segfault are Reactant.jl/src/mlir/IR/Pass.jl Line 74 in e07bbe1
Reactant.jl/src/mlir/libMLIR_h.jl Lines 8438 to 8442 in e07bbe1
and then we're entirely in MLIR/LLVM land. |
CUDA integration tests pass for me on openxla/xla#24050. When that PR is merged we can use the newer version of XLA and we should hopefully be good (hoping nothing else breaks down). |
This is updating XLA to openxla/xla@821715b to solve upstream issue and uses Enzyme-JAX version in EnzymeAD/Enzyme-JAX#511 for testing.