@@ -413,9 +413,9 @@ extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id) {
413
413
414
414
// xla/python/xla.cc 390
415
415
extern " C" PjRtClient *
416
- MakeGPUClient (int node_id, int num_nodes, int *allowed_devices,
417
- int num_allowed_devices, double memory_fraction, bool preallocate ,
418
- const char *platform_name, const char **error,
416
+ MakeGPUClient (int node_id, int num_nodes, int64_t *allowed_devices,
417
+ int64_t num_allowed_devices, double memory_fraction,
418
+ bool preallocate, const char *platform_name, const char **error,
419
419
void *distributed_runtime_client) {
420
420
GpuClientOptions options;
421
421
@@ -437,10 +437,15 @@ MakeGPUClient(int node_id, int num_nodes, int *allowed_devices,
437
437
options.allocator_config .memory_fraction = memory_fraction;
438
438
options.node_id = node_id;
439
439
options.num_nodes = num_nodes;
440
- options.allowed_devices =
441
- allowed_devices ? std::set<int >(allowed_devices,
442
- allowed_devices + num_allowed_devices)
443
- : std::optional<std::set<int >>();
440
+ if (allowed_devices) {
441
+ std::set<int > allowed_devices_set;
442
+ for (int i = 0 ; i < num_allowed_devices; i++) {
443
+ allowed_devices_set.insert (static_cast <int >(allowed_devices[i]));
444
+ }
445
+ options.allowed_devices = allowed_devices_set;
446
+ } else {
447
+ options.allowed_devices = std::optional<std::set<int >>();
448
+ }
444
449
options.platform_name =
445
450
platform_name ? std::string (platform_name) : std::optional<std::string>();
446
451
// options.collectives = num_nodes;
@@ -1406,8 +1411,10 @@ ifrt_compile(ifrt::Client *client, MlirModule cmod, int64_t device_id,
1406
1411
device_id, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir,
1407
1412
use_shardy_partitioner, num_replicas, num_partitions,
1408
1413
use_spmd_partitioning);
1414
+ xla::ifrt::DeviceListRef devices = MyValueOrThrow (
1415
+ xla::ifrt::GetDeviceListFromXlaCompileOptions (client, compile_options));
1409
1416
auto options = std::make_unique<xla::ifrt::XlaCompileOptions>(
1410
- xla::ifrt::XlaCompileOptions (compile_options ));
1417
+ compile_options, std::move (devices ));
1411
1418
1412
1419
mlir::ModuleOp cmod_op = cast<ModuleOp>(*unwrap (cmod));
1413
1420
if (use_spmd_partitioning && use_shardy_partitioner) {
@@ -1635,10 +1642,12 @@ ifrt_make_pjrt_cpu_client(uint8_t asynchronous, int node_id, int num_nodes,
1635
1642
kv_store);
1636
1643
}
1637
1644
1638
- extern " C" ifrt::Client *ifrt_make_pjrt_gpu_client (
1639
- int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1640
- double memory_fraction, bool preallocate, const char *platform_name,
1641
- const char **error, void *distributed_runtime_client) {
1645
+ extern " C" ifrt::Client *
1646
+ ifrt_make_pjrt_gpu_client (int node_id, int num_nodes, int64_t *allowed_devices,
1647
+ int64_t num_allowed_devices, double memory_fraction,
1648
+ bool preallocate, const char *platform_name,
1649
+ const char **error,
1650
+ void *distributed_runtime_client) {
1642
1651
PjRtClient *pjrt_client = MakeGPUClient (
1643
1652
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1644
1653
preallocate, platform_name, error, distributed_runtime_client);
@@ -2457,12 +2466,8 @@ extern "C" void ifrt_hlo_module_cost_analysis_properties(
2457
2466
2458
2467
#pragma endregion
2459
2468
2460
- extern " C" void dump_op (Operation *op) {
2461
- llvm::errs () << *op << " \n " ;
2462
- }
2463
- extern " C" void dump_mval (mlir::Value v) {
2464
- llvm::errs () << v << " \n " ;
2465
- }
2469
+ extern " C" void dump_op (Operation *op) { llvm::errs () << *op << " \n " ; }
2470
+ extern " C" void dump_mval (mlir::Value v) { llvm::errs () << v << " \n " ; }
2466
2471
extern " C" void dump_operation (Operation *op, const char *filename) {
2467
2472
std::error_code EC;
2468
2473
llvm::raw_fd_ostream file (filename, EC, llvm::sys::fs::OF_Text);
0 commit comments