|
59 | 59 | #include "xla/tsl/profiler/rpc/client/capture_profile.h"
|
60 | 60 | #include "xla/tsl/profiler/rpc/profiler_server.h"
|
61 | 61 | #include "xla/python/profiler_utils.h"
|
| 62 | +#include "tsl/platform/init_main.h" |
62 | 63 |
|
63 | 64 | #include "xla/python/ifrt/hlo/hlo_program.h"
|
64 | 65 | #include "llvm/ExecutionEngine/ExecutionEngine.h"
|
@@ -205,7 +206,11 @@ T *unwrap_absl_statusor(absl::StatusOr<T> status, char **error_msg) {
|
205 | 206 | // int xla::_LayoutProto_default_instance_;
|
206 | 207 |
|
207 | 208 | extern "C" void InitializeLogs() {
|
208 |
| - absl::InitializeLog(); |
| 209 | + const char* binary = "julia"; |
| 210 | + int argc = 1; |
| 211 | + char* argv[] = {(char*)binary}; |
| 212 | + char** argv2 = &argv[0]; |
| 213 | + tsl::port::InitMain(binary, &argc, &argv2); |
209 | 214 | LLVMInitializeX86Target();
|
210 | 215 | LLVMInitializeX86TargetInfo();
|
211 | 216 | LLVMInitializeX86TargetMC();
|
@@ -668,7 +673,9 @@ extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client,
|
668 | 673 | options.executable_build_options.set_device_assignment(device_assignment);
|
669 | 674 |
|
670 | 675 | // https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460
|
671 |
| - xla::ExportShardyForHloRoundTrip(cmodop); |
| 676 | + auto status = xla::ExportShardyForHloRoundTrip(cmodop); |
| 677 | + if (!status.ok()) |
| 678 | + ReactantThrowError(status.ToString().c_str()); |
672 | 679 | } else {
|
673 | 680 | assert(device_id >= 0);
|
674 | 681 |
|
@@ -867,8 +874,6 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len,
|
867 | 874 | uint8_t *is_arg_donatable,
|
868 | 875 | int num_results, PjRtBuffer **op_results,
|
869 | 876 | uint8_t *futures, FutureType **future_results) {
|
870 |
| - auto client = exec->client(); |
871 |
| - |
872 | 877 | // Ensure argument_handles is structured as num_mesh_ids x num_args
|
873 | 878 | std::vector<std::vector<PjRtBuffer *>> argument_handles(num_mesh_ids);
|
874 | 879 | int num_args = op_args_len / num_mesh_ids;
|
|
0 commit comments