Skip to content

Commit 9f59836

Browse files
wsmosesavik-palAvik Pal
authored
Fix for jll (#1228)
* Fix for jll * foxup * chore: run formatter * fix: pointer type for allowed_devices * fix * fix: gpu build * Update Project.toml * Update pipeline.yml * Update .buildkite/pipeline.yml * Update .buildkite/pipeline.yml --------- Co-authored-by: Avik Pal <[email protected]> Co-authored-by: Avik Pal <[email protected]>
1 parent e43cd85 commit 9f59836

File tree

9 files changed

+37
-28
lines changed

9 files changed

+37
-28
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ steps:
4242
cuda: "*"
4343
env:
4444
REACTANT_TEST_GROUP: "{{matrix.group}}"
45-
CUDA_VISIBLE_DEVICES: 0
4645
JULIA_DEBUG: "Reactant,Reactant_jll"
46+
CUDA_VISIBLE_DEVICES: 0
4747
if: build.message !~ /\[skip tests\]/
4848
timeout_in_minutes: 120
4949

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ PythonCall = "0.9"
8787
Random = "1.10"
8888
Random123 = "1.7"
8989
ReactantCore = "0.1.9"
90-
Reactant_jll = "0.0.155"
90+
Reactant_jll = "0.0.158"
9191
ScopedValues = "1.3.0"
9292
Scratch = "1.2"
9393
Sockets = "1.10"

deps/ReactantExtra/API.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,9 @@ extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id) {
413413

414414
// xla/python/xla.cc 390
415415
extern "C" PjRtClient *
416-
MakeGPUClient(int node_id, int num_nodes, int *allowed_devices,
417-
int num_allowed_devices, double memory_fraction, bool preallocate,
418-
const char *platform_name, const char **error,
416+
MakeGPUClient(int node_id, int num_nodes, int64_t *allowed_devices,
417+
int64_t num_allowed_devices, double memory_fraction,
418+
bool preallocate, const char *platform_name, const char **error,
419419
void *distributed_runtime_client) {
420420
GpuClientOptions options;
421421

@@ -437,10 +437,15 @@ MakeGPUClient(int node_id, int num_nodes, int *allowed_devices,
437437
options.allocator_config.memory_fraction = memory_fraction;
438438
options.node_id = node_id;
439439
options.num_nodes = num_nodes;
440-
options.allowed_devices =
441-
allowed_devices ? std::set<int>(allowed_devices,
442-
allowed_devices + num_allowed_devices)
443-
: std::optional<std::set<int>>();
440+
if (allowed_devices) {
441+
std::set<int> allowed_devices_set;
442+
for (int i = 0; i < num_allowed_devices; i++) {
443+
allowed_devices_set.insert(static_cast<int>(allowed_devices[i]));
444+
}
445+
options.allowed_devices = allowed_devices_set;
446+
} else {
447+
options.allowed_devices = std::optional<std::set<int>>();
448+
}
444449
options.platform_name =
445450
platform_name ? std::string(platform_name) : std::optional<std::string>();
446451
// options.collectives = num_nodes;
@@ -1406,8 +1411,10 @@ ifrt_compile(ifrt::Client *client, MlirModule cmod, int64_t device_id,
14061411
device_id, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir,
14071412
use_shardy_partitioner, num_replicas, num_partitions,
14081413
use_spmd_partitioning);
1414+
xla::ifrt::DeviceListRef devices = MyValueOrThrow(
1415+
xla::ifrt::GetDeviceListFromXlaCompileOptions(client, compile_options));
14091416
auto options = std::make_unique<xla::ifrt::XlaCompileOptions>(
1410-
xla::ifrt::XlaCompileOptions(compile_options));
1417+
compile_options, std::move(devices));
14111418

14121419
mlir::ModuleOp cmod_op = cast<ModuleOp>(*unwrap(cmod));
14131420
if (use_spmd_partitioning && use_shardy_partitioner) {
@@ -1635,10 +1642,12 @@ ifrt_make_pjrt_cpu_client(uint8_t asynchronous, int node_id, int num_nodes,
16351642
kv_store);
16361643
}
16371644

1638-
extern "C" ifrt::Client *ifrt_make_pjrt_gpu_client(
1639-
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1640-
double memory_fraction, bool preallocate, const char *platform_name,
1641-
const char **error, void *distributed_runtime_client) {
1645+
extern "C" ifrt::Client *
1646+
ifrt_make_pjrt_gpu_client(int node_id, int num_nodes, int64_t *allowed_devices,
1647+
int64_t num_allowed_devices, double memory_fraction,
1648+
bool preallocate, const char *platform_name,
1649+
const char **error,
1650+
void *distributed_runtime_client) {
16421651
PjRtClient *pjrt_client = MakeGPUClient(
16431652
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
16441653
preallocate, platform_name, error, distributed_runtime_client);
@@ -2457,12 +2466,8 @@ extern "C" void ifrt_hlo_module_cost_analysis_properties(
24572466

24582467
#pragma endregion
24592468

2460-
extern "C" void dump_op(Operation *op) {
2461-
llvm::errs() << *op << "\n";
2462-
}
2463-
extern "C" void dump_mval(mlir::Value v) {
2464-
llvm::errs() << v << "\n";
2465-
}
2469+
extern "C" void dump_op(Operation *op) { llvm::errs() << *op << "\n"; }
2470+
extern "C" void dump_mval(mlir::Value v) { llvm::errs() << v << "\n"; }
24662471
extern "C" void dump_operation(Operation *op, const char *filename) {
24672472
std::error_code EC;
24682473
llvm::raw_fd_ostream file(filename, EC, llvm::sys::fs::OF_Text);

deps/ReactantExtra/BUILD

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ cc_library(
900900
"-Wl,-exported_symbol,_addSdyPropagationPipeline",
901901
]}),
902902
deps = [
903-
"@enzyme//:EnzymeMLIR",
903+
"@enzyme//:EnzymeMLIR",
904904
"@llvm-project//mlir:AffineDialect",
905905
"@llvm-project//mlir:AllPassesAndDialects",
906906
"@llvm-project//mlir:ArithDialect",
@@ -1025,8 +1025,11 @@ cc_library(
10251025
"@jax//jaxlib/mosaic:tpu_dialect_capi_objects",
10261026
"@jax//jaxlib/triton:triton_dialect_capi_objects",
10271027
"@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl",
1028+
"@xla//xla/service:gpu_plugin",
1029+
"@xla//xla/pjrt/c:pjrt_c_api_gpu",
10281030
] + select({
10291031
"@xla//xla/tsl:is_cuda_enabled_and_oss":[
1032+
"@xla//xla/stream_executor:cuda_platform",
10301033
"@xla//xla/stream_executor/cuda:all_runtime",
10311034
"@xla//xla/service/gpu/model:hlo_op_profiles",
10321035
"@xla//xla/service/gpu/model:hlo_op_profile_proto_cc_impl",
@@ -1040,6 +1043,7 @@ cc_library(
10401043
"//conditions:default": [
10411044
],
10421045
}) + if_rocm([
1046+
"@xla//xla/stream_executor:rocm_platform",
10431047
"@xla//xla/service/gpu:amdgpu_compiler",
10441048
"@xla//xla/backends/profiler/gpu:device_tracer",
10451049
]) + select({

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "fc12061c02f057da8cd22e7e7bb12e050eca3f60"
12+
ENZYMEXLA_COMMIT = "e1f2496cc251cc30bd2b155ad3133316617beca8"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(

src/Compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@ end
915915

916916
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
917917
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
918-
const enzyme_pass::String = "enzyme{postpasses=\"canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,arith-raise{stablehlo=true},canonicalize,cse,canonicalize\"}"
918+
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"
919919

920920
function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true)
921921
pm = MLIR.IR.PassManager()

src/xla/IFRT/Client.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ function MakeIFRTPJRTGPUClient(;
177177
client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_gpu_client(
178178
node_id::Cint,
179179
num_nodes::Cint,
180-
allowed_devices::Ptr{Cvoid},
181-
num_allowed_devices::Cint,
180+
allowed_devices::Ptr{Int64},
181+
num_allowed_devices::Int64,
182182
XLA.XLA_REACTANT_GPU_MEM_FRACTION[]::Cdouble,
183183
XLA.XLA_REACTANT_GPU_PREALLOCATE[]::Bool,
184184
platform::Cstring,

src/xla/PJRT/Client.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ function MakeGPUClient(;
163163
client = @ccall MLIR.API.mlir_c.MakeGPUClient(
164164
node_id::Cint,
165165
num_nodes::Cint,
166-
allowed_devices::Ptr{Cvoid},
167-
num_allowed_devices::Cint,
166+
allowed_devices::Ptr{Int64},
167+
num_allowed_devices::Int64,
168168
XLA.XLA_REACTANT_GPU_MEM_FRACTION[]::Cdouble,
169169
XLA.XLA_REACTANT_GPU_PREALLOCATE[]::Bool,
170170
platform::Cstring,

test/autodiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ end
148148
(res.val 4ones(2, 2)) &&
149149
(res.derivs[1] 4ones(2, 2)) &&
150150
(res.derivs[2] 2ones(2, 2))
151-
end broken = true
151+
end
152152
end
153153

154154
@testset "onehot" begin

0 commit comments

Comments
 (0)