Skip to content

Conversation

giordano
Copy link
Member

@giordano giordano commented Sep 4, 2025

This is very preliminary at the moment, but I managed to get access to the devices, so it looks promising:

julia> Reactant.devices()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1756985457.845658 3031425 pjrt_api.cc:118] GetPjrtApi was found for rocm at /cosma/home/do006/dc-gior1/tmp/pjrt-plugin/jax_plugins/xla_rocm60/xla_rocm_plugin.so
I0000 00:00:1756985457.863487 3031425 pjrt_api.cc:96] PJRT_Api is set for device type rocm
I0000 00:00:1756985457.863510 3031425 pjrt_api.cc:167] The PJRT plugin has PJRT API version 0.55. The framework PJRT API version is 0.75.
I0000 00:00:1756985463.767717 3031425 pjrt_c_api_client.cc:133] PjRtCApiClient created.
4-element Vector{Reactant.XLA.PJRT.Device}:
 Reactant.XLA.PJRT.Device(Ptr{Nothing} @0x000000001bc36820, "ROCM:0 AMD Instinct MI300A")
 Reactant.XLA.PJRT.Device(Ptr{Nothing} @0x000000001bc36c70, "ROCM:1 AMD Instinct MI300A")
 Reactant.XLA.PJRT.Device(Ptr{Nothing} @0x000000001bc36fb0, "ROCM:2 AMD Instinct MI300A")
 Reactant.XLA.PJRT.Device(Ptr{Nothing} @0x000000001bc372f0, "ROCM:3 AMD Instinct MI300A")

For the record, for some reason the env var ROCM_PATH wasn't set in my environment and I had to manually do it to help XLA find libdevice, otherwise I was getting errors like

INTERNAL: bitcode module not found at ./opencl.bc

As usual, setting the environment variables

TF_CPP_MIN_VLOG_LEVEL=0
TF_CPP_MAX_VLOG_LEVEL=3

was very useful for debugging.

@giordano giordano marked this pull request as draft September 4, 2025 11:49
Comment on lines +34 to +60
# function download_rocm_pjrt_plugin_if_needed(path=nothing)
# path === nothing && (path = get_rocm_pjrt_plugin_dir())
# @assert path !== nothing "rocm_pjrt_plugin_dir is not set!"

# rocm_pjrt_plugin_path = joinpath(path, "pjrt_plugin_rocm_14.dylib")
# if !isfile(rocm_pjrt_plugin_path)
# zip_file_path = joinpath(path, "pjrt-plugin-rocm.zip")
# tmp_dir = joinpath(path, "tmp")
# Downloads.download(
# if Sys.ARCH === :aarch64
# "https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_rocm-0.1.1-py3-none-macosx_13_0_arm64.whl"
# elseif Sys.ARCH === :x86_64
# "https://files.pythonhosted.org/packages/87/ec/9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f/jax_rocm-0.1.1-py3-none-macosx_10_14_x86_64.whl"
# else
# error("Unsupported architecture: $(Sys.ARCH)")
# end,
# zip_file_path,
# )
# run(`unzip -qq $(zip_file_path) -d $(tmp_dir)`)
# mv(
# joinpath(tmp_dir, "jax_plugins", "rocm_plugin", "pjrt_plugin_rocm_14.dylib"),
# rocm_pjrt_plugin_path,
# )
# rm(tmp_dir; recursive=true)
# rm(zip_file_path; recursive=true)
# end
# end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to download the wheel from https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-jax.html, the problem being that there are different builds for different distributions, we'd need first to identify them. At the moment I'm working with a local copy and I set the ROCM_LIBRARY_PATH env var (is there a standard name?)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's doable to build JAX from source for ROCm.

I actually need to try this for another thing so I'll update you on how easy/painful it is.

end
end

has_rocm() = true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course this needs to be made more realistic than this 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants