45
45
#include " xla/pjrt/cpu/cpu_client.h"
46
46
#include " xla/pjrt/status_casters.h"
47
47
#include " xla/pjrt/pjrt_executable.h"
48
+ #include " xla/pjrt/pjrt_api.h"
49
+ #include " xla/pjrt/pjrt_c_api_client.h"
48
50
#include " xla/python/ifrt/executable.h"
51
+ #include " xla/service/cpu/simple_orc_jit.h"
49
52
50
53
#include " xla/python/ifrt/hlo/hlo_program.h"
54
+ #include " llvm/MC/TargetRegistry.h"
55
+ #include " llvm/TargetParser/Host.h"
56
+ #include " llvm/ExecutionEngine/ExecutionEngine.h"
57
+ #include " llvm/Support/Process.h"
51
58
52
59
using namespace mlir ;
53
60
using namespace llvm ;
@@ -58,7 +65,17 @@ using namespace xla;
58
65
59
66
extern " C" void InitializeLogs () {
60
67
absl::InitializeLog ();
68
+ LLVMInitializeX86Target ();
69
+ LLVMInitializeX86TargetInfo ();
70
+ LLVMInitializeX86TargetMC ();
71
+ LLVMInitializeX86AsmPrinter ();
72
+ LLVMInitializeX86AsmParser ();
73
+
61
74
LLVMInitializeAArch64Target ();
75
+ LLVMInitializeAArch64TargetInfo ();
76
+ LLVMInitializeAArch64TargetMC ();
77
+ LLVMInitializeAArch64AsmPrinter ();
78
+ LLVMInitializeAArch64AsmParser ();
62
79
}
63
80
64
81
extern " C"
@@ -102,6 +119,39 @@ extern "C" PjRtClient* MakeGPUClient(int node_id, int num_nodes, int* allowed_de
102
119
}
103
120
}
104
121
122
+ const char * const kEnvTpuLibraryPath = " TPU_LIBRARY_PATH" ;
123
+
124
+ extern " C" PjRtClient* MakeTPUClient (const char * tpu_path , const char ** error) {
125
+ // Prefer $TPU_LIBRARY_PATH if set
126
+ std::string tpu_library_path;
127
+ if (tpu_path) {
128
+ tpu_library_path = std::string (tpu_path);
129
+ } else if (auto path = llvm::sys::Process::GetEnv (kEnvTpuLibraryPath )) {
130
+ tpu_library_path = *path;
131
+ } else {
132
+ *error = " Could not find TPU path" ;
133
+ return nullptr ;
134
+ }
135
+
136
+ absl::StatusOr<const PJRT_Api *> pluginLoad = pjrt::LoadPjrtPlugin (" tpu" , tpu_library_path);
137
+ if (!pluginLoad.ok ()) {
138
+ auto str = pluginLoad.status ().message ();
139
+ char * err = (char *)malloc (str.size ()+1 );
140
+ memcpy (err, str.data (), str.size ()+1 );
141
+ *error = err;
142
+ return nullptr ;
143
+ }
144
+ absl::Status tpu_status = pjrt::InitializePjrtPlugin (" tpu" );
145
+ if (!tpu_status.ok ()) {
146
+ auto str = tpu_status.message ();
147
+ char * err = (char *)malloc (str.size ()+1 );
148
+ memcpy (err, str.data (), str.size ()+1 );
149
+ *error = err;
150
+ return nullptr ;
151
+ }
152
+ return xla::GetCApiClient (" TPU" ).value ().release ();
153
+ }
154
+
105
155
extern " C" int ClientNumDevices (PjRtClient* client) {
106
156
return client->device_count ();
107
157
}
0 commit comments