Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/accelerators/Accelerators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ module Accelerators

include("TPU.jl")
include("Metal.jl")
include("ROCm.jl")

end
62 changes: 62 additions & 0 deletions src/accelerators/ROCm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
module ROCm

using Reactant: Reactant
using Scratch: @get_scratch!
using Downloads

const rocm_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing)

function __init__()
@static if Sys.islinux()
Reactant.precompiling() || setup_rocm_pjrt_plugin!()
end
end

has_rocm() = true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course this needs to be made more realistic than this 😄


function setup_rocm_pjrt_plugin!()
path_from_env = get(ENV, "ROCM_LIBRARY_PATH", nothing)
if path_from_env !== nothing && ispath(path_from_env)
rocm_pjrt_plugin_dir[] = path_from_env
else
rocm_pjrt_plugin_dir[] = @get_scratch!("pjrt_rocm_plugin")
end
# download_rocm_pjrt_plugin_if_needed(rocm_pjrt_plugin_dir[])
return nothing
end

get_rocm_pjrt_plugin_dir() = rocm_pjrt_plugin_dir[]

function get_rocm_pjrt_plugin_path()
return joinpath(get_rocm_pjrt_plugin_dir(), "xla_rocm_plugin.so")
end

# function download_rocm_pjrt_plugin_if_needed(path=nothing)
# path === nothing && (path = get_rocm_pjrt_plugin_dir())
# @assert path !== nothing "rocm_pjrt_plugin_dir is not set!"

# rocm_pjrt_plugin_path = joinpath(path, "pjrt_plugin_rocm_14.dylib")
# if !isfile(rocm_pjrt_plugin_path)
# zip_file_path = joinpath(path, "pjrt-plugin-rocm.zip")
# tmp_dir = joinpath(path, "tmp")
# Downloads.download(
# if Sys.ARCH === :aarch64
# "https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_rocm-0.1.1-py3-none-macosx_13_0_arm64.whl"
# elseif Sys.ARCH === :x86_64
# "https://files.pythonhosted.org/packages/87/ec/9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f/jax_rocm-0.1.1-py3-none-macosx_10_14_x86_64.whl"
# else
# error("Unsupported architecture: $(Sys.ARCH)")
# end,
# zip_file_path,
# )
# run(`unzip -qq $(zip_file_path) -d $(tmp_dir)`)
# mv(
# joinpath(tmp_dir, "jax_plugins", "rocm_plugin", "pjrt_plugin_rocm_14.dylib"),
# rocm_pjrt_plugin_path,
# )
# rm(tmp_dir; recursive=true)
# rm(zip_file_path; recursive=true)
# end
# end
Comment on lines +34 to +60
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to download the wheel from https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-jax.html, the problem being that there are different builds for different distributions, we'd need first to identify them. At the moment I'm working with a local copy and I set the ROCM_LIBRARY_PATH env var (is there a standard name?)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's doable to build JAX from source for ROCm.

I actually need to try this for another thing so I'll update you on how easy/painful it is.


end # module ROCm
18 changes: 18 additions & 0 deletions src/xla/IFRT/Client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,14 @@ const cpu_client_count = Ref(0)
const cuda_client_count = Ref(0)
const tpu_client_count = Ref(0)
const metal_client_count = Ref(0)
const rocm_client_count = Ref(0)

for (backend, counter) in (
(:CPUClient, :cpu_client_count),
(:CUDAClient, :cuda_client_count),
(:TPUClient, :tpu_client_count),
(:MetalClient, :metal_client_count),
(:ROCmClient, :rocm_client_count),
)
main_fn = Symbol(:MakeIFRTPJRT, backend)
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
Expand Down Expand Up @@ -219,6 +221,22 @@ function MakeIFRTPJRTMetalClient(;
)
end

function MakeIFRTPJRTROCmClient(;
rocm_pjrt_plugin_path::String,
node_id::Integer=0,
num_nodes::Integer=1,
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
)
return MakeIFRTPJRTClientViaPluginAPI(
rocm_pjrt_plugin_path,
"rocm",
"ROCM";
node_id,
num_nodes,
distributed_runtime_client,
)
end

function MakeIFRTPJRTClientViaPluginAPI(
library_path::String,
device_type::String,
Expand Down
16 changes: 16 additions & 0 deletions src/xla/PJRT/Client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,14 @@ const cpu_client_count = Ref(0)
const cuda_client_count = Ref(0)
const tpu_client_count = Ref(0)
const metal_client_count = Ref(0)
const rocm_client_count = Ref(0)

for (backend, counter) in (
(:CPUClient, :cpu_client_count),
(:CUDAClient, :cuda_client_count),
(:TPUClient, :tpu_client_count),
(:MetalClient, :metal_client_count),
(:ROCmClient, :rocm_client_count),
)
main_fn = Symbol(:Make, backend)
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
Expand Down Expand Up @@ -207,6 +209,20 @@ function MakeMetalClient(;
return MakeClientUsingPluginAPI(metal_pjrt_plugin_path, "metal", "METAL")
end

function MakeROCmClient(;
rocm_pjrt_plugin_path::String,
node_id::Integer=0,
num_nodes::Integer=1,
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
)
@assert node_id == 0 "`PJRT.MakeROCmClient` does not support node_id"
@assert num_nodes == 1 "`PJRT.MakeROCmClient` does not support num_nodes > 1"
@assert distributed_runtime_client === nothing "`PJRT.MakeROCmClient` does not support \
distributed_runtime_client"

return MakeClientUsingPluginAPI(rocm_pjrt_plugin_path, "rocm", "ROCM")
end

function MakeClientUsingPluginAPI(
library_path::String, device_type::String, client_name::String=uppercase(device_type)
)
Expand Down
16 changes: 16 additions & 0 deletions src/xla/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,22 @@ for runtime in (:PJRT, :IFRT)
catch e
println(stdout, e)
end
elseif Accelerators.ROCm.has_rocm()
try
if was_initialized && haskey(state.clients, "rocm")
XLA.free_client(state.clients["rocm"])
XLA.$(runtime).rocm_client_count[] -= 1
end
gpu = $(runtime).ROCmClient(
;
rocm_pjrt_plugin_path=Accelerators.ROCm.get_rocm_pjrt_plugin_path(),
common_kwargs...
)
state.clients["rocm"] = gpu
state.default_client = gpu
catch e
println(stdout, e)
end
else
try
if was_initialized && haskey(state.clients, "cuda")
Expand Down
Loading