Skip to content

Commit 3a2b61c

Browse files
authored
Fix aarch build (#64)
* More prints * fix * for yggy * more fixes * cleanup
1 parent 4177c8d commit 3a2b61c

File tree

3 files changed

+54
-1
lines changed

3 files changed

+54
-1
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,16 @@
4545
#include "xla/pjrt/cpu/cpu_client.h"
4646
#include "xla/pjrt/status_casters.h"
4747
#include "xla/pjrt/pjrt_executable.h"
48+
#include "xla/pjrt/pjrt_api.h"
49+
#include "xla/pjrt/pjrt_c_api_client.h"
4850
#include "xla/python/ifrt/executable.h"
51+
#include "xla/service/cpu/simple_orc_jit.h"
4952

5053
#include "xla/python/ifrt/hlo/hlo_program.h"
54+
#include "llvm/MC/TargetRegistry.h"
55+
#include "llvm/TargetParser/Host.h"
56+
#include "llvm/ExecutionEngine/ExecutionEngine.h"
57+
#include "llvm/Support/Process.h"
5158

5259
using namespace mlir;
5360
using namespace llvm;
@@ -58,7 +65,17 @@ using namespace xla;
5865

5966
extern "C" void InitializeLogs() {
6067
absl::InitializeLog();
68+
LLVMInitializeX86Target();
69+
LLVMInitializeX86TargetInfo();
70+
LLVMInitializeX86TargetMC();
71+
LLVMInitializeX86AsmPrinter();
72+
LLVMInitializeX86AsmParser();
73+
6174
LLVMInitializeAArch64Target();
75+
LLVMInitializeAArch64TargetInfo();
76+
LLVMInitializeAArch64TargetMC();
77+
LLVMInitializeAArch64AsmPrinter();
78+
LLVMInitializeAArch64AsmParser();
6279
}
6380

6481
extern "C"
@@ -102,6 +119,39 @@ extern "C" PjRtClient* MakeGPUClient(int node_id, int num_nodes, int* allowed_de
102119
}
103120
}
104121

122+
const char* const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH";
123+
124+
extern "C" PjRtClient* MakeTPUClient(const char* tpu_path , const char** error) {
125+
// Prefer $TPU_LIBRARY_PATH if set
126+
std::string tpu_library_path;
127+
if (tpu_path) {
128+
tpu_library_path = std::string(tpu_path);
129+
} else if (auto path = llvm::sys::Process::GetEnv(kEnvTpuLibraryPath)) {
130+
tpu_library_path = *path;
131+
} else {
132+
*error = "Could not find TPU path";
133+
return nullptr;
134+
}
135+
136+
absl::StatusOr<const PJRT_Api *> pluginLoad = pjrt::LoadPjrtPlugin("tpu", tpu_library_path);
137+
if (!pluginLoad.ok()) {
138+
auto str = pluginLoad.status().message();
139+
char* err = (char*)malloc(str.size()+1);
140+
memcpy(err, str.data(), str.size()+1);
141+
*error = err;
142+
return nullptr;
143+
}
144+
absl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu");
145+
if (!tpu_status.ok()) {
146+
auto str = tpu_status.message();
147+
char* err = (char*)malloc(str.size()+1);
148+
memcpy(err, str.data(), str.size()+1);
149+
*error = err;
150+
return nullptr;
151+
}
152+
return xla::GetCApiClient("TPU").value().release();
153+
}
154+
105155
extern "C" int ClientNumDevices(PjRtClient* client) {
106156
return client->device_count();
107157
}

deps/ReactantExtra/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,13 @@ cc_library(
319319
"@llvm-project//llvm:Support",
320320
"@llvm-project//llvm:AArch64AsmParser",
321321
"@llvm-project//llvm:AArch64CodeGen",
322+
"@llvm-project//llvm:X86AsmParser",
323+
"@llvm-project//llvm:X86CodeGen",
322324
"@enzyme_ad//src/enzyme_ad/jax:TransformOps",
323325
"@enzyme_ad//src/enzyme_ad/jax:XLADerivatives",
324326
"@stablehlo//:chlo_ops",
327+
"@xla//xla/pjrt:pjrt_api",
328+
"@xla//xla/pjrt:pjrt_c_api_client",
325329
"@xla//xla/pjrt/cpu:cpu_client",
326330
"@xla//xla/pjrt/gpu:se_gpu_pjrt_client",
327331

deps/ReactantExtra/external

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)