Skip to content

Commit 032ea57

Browse files
avik-palwsmoses
andauthored
feat: plugin API for new device implementations (#1241)
* refactor: move TPUs under accelerators * feat: register using cplugin api from julia * refactor: cleanup * feat: metal plugin register * fix: have different names for client * fix: conditions * fix: path to dylib * fix: ordering * fix: update C API * chore: bump jll * fix: api * fix: use more recent metal plugins * test: metal plugin tests * ci: ensure compatibility * fix: disable registration of metal for now * chore: run fmt * chore: bump jll --------- Co-authored-by: William Moses <[email protected]>
1 parent 51ae0a6 commit 032ea57

File tree

13 files changed

+278
-132
lines changed

13 files changed

+278
-132
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ steps:
4444
REACTANT_TEST_GROUP: "{{matrix.group}}"
4545
JULIA_DEBUG: "Reactant,Reactant_jll"
4646
CUDA_VISIBLE_DEVICES: 0
47+
REACTANT_BACKEND_GROUP: "GPU"
4748
if: build.message !~ /\[skip tests\]/
4849
timeout_in_minutes: 120
4950

@@ -138,31 +139,3 @@ steps:
138139
# rocm: "*"
139140
# if: build.message !~ /\[skip tests\]/
140141
# timeout_in_minutes: 60
141-
142-
# - label: "Metal Julia v{{matrix.version}}"
143-
# matrix:
144-
# setup:
145-
# version:
146-
# - "1.8"
147-
# - "1.9"
148-
# plugins:
149-
# - JuliaCI/julia#v1:
150-
# version: "{{matrix.version}}"
151-
# agents:
152-
# queue: "juliaecosystem"
153-
# os: "macos"
154-
# arch: "aarch64"
155-
# if: build.message !~ /\[skip tests\]/
156-
# timeout_in_minutes: 60
157-
# commands: |
158-
# echo "--- Setup Julia packages"
159-
# julia --color=yes -e '
160-
# import Pkg
161-
# Pkg.develop(; path = pwd())
162-
# Pkg.develop(; path = joinpath(pwd(), "lib", "EnzymeCore"))
163-
# Pkg.develop(; name = "Metal")' || exit 3
164-
165-
# echo "+++ Run tests"
166-
# julia --color=yes test/metal.jl
167-
# env:
168-
# JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ jobs:
143143
id: run_tests
144144
env:
145145
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
146+
ENABLE_PJRT_COMPATIBILITY: 1
146147
REACTANT_TEST_GROUP: ${{ matrix.test_group }}
147148
XLA_FLAGS: "--xla_force_host_platform_device_count=12"
148149
JULIA_DEBUG: "Reactant,Reactant_jll"

deps/ReactantExtra/API.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -534,16 +534,8 @@ extern "C" PjRtClient *MakeTPUClient(const char *tpu_path, const char **error) {
534534
return nullptr;
535535
}
536536

537-
const PJRT_Api *pluginLoad =
538-
LoadPjrtPlugin("tpu", tpu_library_path.c_str(), error);
539-
if (pluginLoad == nullptr)
540-
return nullptr;
541-
auto tpu_status = InitializePjrtPlugin("tpu", error);
542-
if (tpu_status)
543-
return nullptr;
544-
545-
pjrt_client_register_profiler(pluginLoad);
546-
return GetCApiClient("TPU");
537+
return MakeClientUsingPluginAPI("tpu", tpu_library_path.c_str(), "TPU",
538+
error);
547539
}
548540

549541
extern "C" int ClientNumDevices(PjRtClient *client) {

src/Distributed.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module Distributed
22

3-
using ..Reactant: Reactant
3+
using ..Reactant: Reactant, Accelerators
44
using Sockets
55

66
const initialized = Ref(false)
@@ -303,8 +303,8 @@ const _TPU_COORDINATOR_PORT = "8476"
303303
function get_coordinator_address(
304304
env::AbstractCloudTPUEnvDetector, timeout_in_seconds::Integer
305305
)
306-
coordinator_address = if Reactant.TPUUtils.has_megascale_address()
307-
Reactant.TPUUtils.get_tpu_env_value("MEGASCALE_COORDINATOR_ADDRESS")
306+
coordinator_address = if Accelerators.TPU.has_megascale_address()
307+
Accelerators.TPU.get_tpu_env_value("MEGASCALE_COORDINATOR_ADDRESS")
308308
else
309309
first(_get_worker_list_in_slice(env))
310310
end
@@ -361,13 +361,13 @@ function get_process_id(env::AbstractCloudTPUEnvDetector)
361361
end
362362

363363
function _get_num_slices(::AbstractCloudTPUEnvDetector)
364-
Reactant.TPUUtils.has_megascale_address() || return 1
365-
return parse(Int, Reactant.TPUUtils.get_tpu_env_value("MEGASCALE_NUM_SLICES"))
364+
Accelerators.TPU.has_megascale_address() || return 1
365+
return parse(Int, Accelerators.TPU.get_tpu_env_value("MEGASCALE_NUM_SLICES"))
366366
end
367367

368368
function _get_slice_id(::AbstractCloudTPUEnvDetector)
369-
Reactant.TPUUtils.has_megascale_address() || return 0
370-
return parse(Int, Reactant.TPUUtils.get_tpu_env_value("MEGASCALE_SLICE_ID"))
369+
Accelerators.TPU.has_megascale_address() || return 0
370+
return parse(Int, Accelerators.TPU.get_tpu_env_value("MEGASCALE_SLICE_ID"))
371371
end
372372

373373
function _get_process_id_in_slice end
@@ -376,7 +376,7 @@ function _get_worker_list_in_slice end
376376
## GceTPUCluster
377377

378378
function is_env_present(::GceTPUCluster)
379-
if !Reactant.TPUUtils.RUNNING_IN_CLOUD_TPU_VM[]
379+
if !Accelerators.TPU.RUNNING_IN_CLOUD_TPU_VM[]
380380
@debug "Did not detect cloud TPU VM"
381381
return false
382382
end
@@ -386,8 +386,8 @@ function is_env_present(::GceTPUCluster)
386386
return false
387387
end
388388

389-
metadata_response, metadata_code = Reactant.TPUUtils.get_metadata("agent-worker-number")
390-
if metadata_code == Reactant.TPUUtils._TPU_METADATA_RESPONSE_CODE_SUCCESS
389+
metadata_response, metadata_code = Accelerators.TPU.get_metadata("agent-worker-number")
390+
if metadata_code == Accelerators.TPU._TPU_METADATA_RESPONSE_CODE_SUCCESS
391391
@debug "Gce Tpu Cluster detected for Reactant Distributed System"
392392
return true
393393
else
@@ -400,23 +400,23 @@ function is_env_present(::GceTPUCluster)
400400
end
401401

402402
function _get_process_id_in_slice(::GceTPUCluster)
403-
return parse(Int, first(Reactant.TPUUtils.get_metadata("agent-worker-number")))
403+
return parse(Int, first(Accelerators.TPU.get_metadata("agent-worker-number")))
404404
end
405405

406406
function _get_worker_list_in_slice(::GceTPUCluster)
407-
workers = split(first(Reactant.TPUUtils.get_metadata("worker-network-endpoints")), ',')
407+
workers = split(first(Accelerators.TPU.get_metadata("worker-network-endpoints")), ',')
408408
return [split(w, ':')[3] for w in workers]
409409
end
410410

411411
## GkeTPUCluster
412412

413413
function is_env_present(::GkeTPUCluster)
414-
if Reactant.TPUUtils.RUNNING_IN_CLOUD_TPU_VM[] && haskey(ENV, "TPU_WORKER_HOSTNAMES")
414+
if Accelerators.TPU.RUNNING_IN_CLOUD_TPU_VM[] && haskey(ENV, "TPU_WORKER_HOSTNAMES")
415415
@debug "Detected GKE TPU cluster for Reactant Distributed System"
416416
return true
417417
end
418418

419-
if !Reactant.TPUUtils.RUNNING_IN_CLOUD_TPU_VM[]
419+
if !Accelerators.TPU.RUNNING_IN_CLOUD_TPU_VM[]
420420
@debug "Did not detect cloud TPU VM"
421421
return false
422422
end

src/Reactant.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ function ancestor(T::Type{<:AbstractArray})
4545
return T
4646
end
4747

48-
include("TPUs.jl")
48+
include("accelerators/Accelerators.jl")
4949

50-
using .TPUUtils: has_tpu
50+
using .Accelerators.TPU: has_tpu
5151

5252
include("mlir/MLIR.jl")
5353
include("xla/XLA.jl")

src/accelerators/Accelerators.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
module Accelerators
2+
3+
include("TPU.jl")
4+
include("Metal.jl")
5+
6+
end

src/accelerators/Metal.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
module Metal
2+
3+
using Reactant: Reactant
4+
using Scratch: @get_scratch!
5+
using Downloads
6+
7+
const metal_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing)
8+
9+
function __init__()
10+
@static if Sys.isapple()
11+
Reactant.precompiling() || setup_metal_pjrt_plugin!()
12+
end
13+
end
14+
15+
function setup_metal_pjrt_plugin!()
16+
path_from_env = get(ENV, "METAL_LIBRARY_PATH", nothing)
17+
if path_from_env !== nothing && ispath(path_from_env)
18+
metal_pjrt_plugin_dir[] = path_from_env
19+
else
20+
metal_pjrt_plugin_dir[] = @get_scratch!("pjrt_metal_plugin")
21+
end
22+
download_metal_pjrt_plugin_if_needed(metal_pjrt_plugin_dir[])
23+
return nothing
24+
end
25+
26+
get_metal_pjrt_plugin_dir() = metal_pjrt_plugin_dir[]
27+
28+
function get_metal_pjrt_plugin_path()
29+
return joinpath(get_metal_pjrt_plugin_dir(), "pjrt_plugin_metal_14.dylib")
30+
end
31+
32+
function download_metal_pjrt_plugin_if_needed(path=nothing)
33+
path === nothing && (path = get_metal_pjrt_plugin_dir())
34+
@assert path !== nothing "metal_pjrt_plugin_dir is not set!"
35+
36+
metal_pjrt_plugin_path = joinpath(path, "pjrt_plugin_metal_14.dylib")
37+
if !isfile(metal_pjrt_plugin_path)
38+
zip_file_path = joinpath(path, "pjrt-plugin-metal.zip")
39+
tmp_dir = joinpath(path, "tmp")
40+
Downloads.download(
41+
if Sys.ARCH === :aarch64
42+
"https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl"
43+
elseif Sys.ARCH === :x86_64
44+
"https://files.pythonhosted.org/packages/87/ec/9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f/jax_metal-0.1.1-py3-none-macosx_10_14_x86_64.whl"
45+
else
46+
error("Unsupported architecture: $(Sys.ARCH)")
47+
end,
48+
zip_file_path,
49+
)
50+
run(`unzip -qq $(zip_file_path) -d $(tmp_dir)`)
51+
mv(
52+
joinpath(tmp_dir, "jax_plugins", "metal_plugin", "pjrt_plugin_metal_14.dylib"),
53+
metal_pjrt_plugin_path,
54+
)
55+
rm(tmp_dir; recursive=true)
56+
rm(zip_file_path; recursive=true)
57+
end
58+
end
59+
60+
end

src/TPUs.jl renamed to src/accelerators/TPU.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module TPUUtils
1+
module TPU
22

33
using Reactant: Reactant
44
using EnumX: @enumx
@@ -31,20 +31,24 @@ end
3131

3232
get_libtpu_dir() = libtpu_dir[]
3333

34-
get_libtpu_path() = get_libtpu_dir() * "/libtpu.so"
34+
get_libtpu_path() = joinpath(get_libtpu_dir(), "libtpu.so")
3535

3636
function download_libtpu_if_needed(path=nothing)
3737
path === nothing && (path = get_libtpu_dir())
3838
@assert path !== nothing "libtpu_dir is not set!"
39-
if !isfile(path * "/libtpu.so")
39+
40+
libtpu_path = joinpath(path, "libtpu.so")
41+
if !isfile(libtpu_path)
42+
zip_file_path = joinpath(path, "tpu.zip")
43+
tmp_dir = joinpath(path, "tmp")
4044
Downloads.download(
4145
"https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250415+nightly-py3-none-manylinux_2_31_x86_64.whl",
42-
path * "/tpu.zip",
46+
zip_file_path,
4347
)
44-
run(`unzip -qq $(path*"/tpu.zip") -d $(path)/tmp`)
45-
run(`mv $(path)/tmp/libtpu/libtpu.so $(path)/libtpu.so`)
46-
rm(path * "/tmp"; recursive=true)
47-
rm(path * "/tpu.zip"; recursive=true)
48+
run(`unzip -qq $(zip_file_path) -d $(tmp_dir)`)
49+
mv(joinpath(tmp_dir, "libtpu", "libtpu.so"), libtpu_path)
50+
rm(tmp_dir; recursive=true)
51+
rm(zip_file_path; recursive=true)
4852
end
4953
end
5054

src/xla/IFRT/Client.jl

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,15 @@ end
112112

113113
# Different Backends
114114
const cpu_client_count = Ref(0)
115-
const gpu_client_count = Ref(0)
115+
const cuda_client_count = Ref(0)
116116
const tpu_client_count = Ref(0)
117+
const metal_client_count = Ref(0)
117118

118119
for (backend, counter) in (
119120
(:CPUClient, :cpu_client_count),
120-
(:GPUClient, :gpu_client_count),
121+
(:CUDAClient, :cuda_client_count),
121122
(:TPUClient, :tpu_client_count),
123+
(:MetalClient, :metal_client_count),
122124
)
123125
main_fn = Symbol(:MakeIFRTPJRT, backend)
124126
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
@@ -159,7 +161,7 @@ function MakeIFRTPJRTCPUClient(;
159161
return client, refstr
160162
end
161163

162-
function MakeIFRTPJRTGPUClient(;
164+
function MakeIFRTPJRTCUDAClient(;
163165
node_id::Integer=0,
164166
num_nodes::Integer=1,
165167
platform::String="gpu",
@@ -196,19 +198,51 @@ function MakeIFRTPJRTTPUClient(;
196198
num_nodes::Integer=1,
197199
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
198200
)
199-
refstr = Ref{Cstring}()
201+
return MakeIFRTPJRTClientViaPluginAPI(
202+
tpu_path, "tpu", "TPU"; node_id, num_nodes, distributed_runtime_client
203+
)
204+
end
205+
206+
function MakeIFRTPJRTMetalClient(;
207+
metal_pjrt_plugin_path::String,
208+
node_id::Integer=0,
209+
num_nodes::Integer=1,
210+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
211+
)
212+
return MakeIFRTPJRTClientViaPluginAPI(
213+
metal_pjrt_plugin_path,
214+
"metal",
215+
"METAL";
216+
node_id,
217+
num_nodes,
218+
distributed_runtime_client,
219+
)
220+
end
221+
222+
function MakeIFRTPJRTClientViaPluginAPI(
223+
library_path::String,
224+
device_type::String,
225+
client_name::String=uppercase(device_type);
226+
node_id::Integer=0,
227+
num_nodes::Integer=1,
228+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
229+
)
230+
pjrt_client = XLA.PJRT.MakeClientUsingPluginAPI(library_path, device_type, client_name)
231+
200232
distributed_runtime_client =
201233
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
202234

203-
GC.@preserve refstr distributed_runtime_client begin
204-
client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_tpu_client(
205-
tpu_path::Cstring,
206-
refstr::Ptr{Cstring},
235+
errstr = Ref{Cstring}()
236+
GC.@preserve pjrt_client errstr distributed_runtime_client device_type begin
237+
client = @ccall MLIR.API.mlir_c.ifrt_pjrt_make_client_with_default_kv_store(
238+
pjrt_client::Ptr{Cvoid},
207239
node_id::Cint,
208240
num_nodes::Cint,
209241
distributed_runtime_client::Ptr{Cvoid},
242+
errstr::Ptr{Cstring},
243+
device_type::Cstring,
210244
)::Ptr{Cvoid}
211245
end
212246

213-
return client, refstr
247+
return client, errstr
214248
end

0 commit comments

Comments
 (0)