77
88#include < iostream>
99#include < limits>
10+ #include < map>
1011#include < unordered_map>
1112#include < vector>
1213
@@ -327,6 +328,10 @@ __global__ void __launch_bounds__(512, 1)
327328 }
328329}
329330
331+ using IPC_KEY = std::array<uint8_t , sizeof (cudaIpcMemHandle_t)>;
332+ static_assert (sizeof (IPC_KEY) == sizeof (cudaIpcMemHandle_t));
333+ static_assert (alignof (IPC_KEY) == alignof (cudaIpcMemHandle_t));
334+
330335class CustomAllreduce {
331336 public:
332337 int rank_;
@@ -341,7 +346,8 @@ class CustomAllreduce {
341346 // stores the registered device pointers from all ranks
342347 RankData *d_rank_data_base_, *d_rank_data_end_;
343348 std::vector<void *> graph_unreg_buffers_;
344- std::vector<void *> ipc_handles_;
349+ // a map from IPC handles to opened IPC pointers
350+ std::map<IPC_KEY, char *> ipc_handles_;
345351
346352 /* *
347353 * meta is a pointer to device metadata and temporary buffer for allreduce.
@@ -365,10 +371,7 @@ class CustomAllreduce {
365371 for (int i = 0 ; i < world_size_; i++) {
366372 Metadata *rank_meta;
367373 if (i != rank_) {
368- char *handle;
369- CUDACHECK (cudaIpcOpenMemHandle ((void **)&handle, handles[i],
370- cudaIpcMemLazyEnablePeerAccess));
371- ipc_handles_.push_back (handle);
374+ char *handle = open_ipc_handle (&handles[i]);
372375 handle += offsets[i];
373376 rank_meta = (Metadata *)handle;
374377 } else {
@@ -378,6 +381,19 @@ class CustomAllreduce {
378381 }
379382 }
380383
384+ char *open_ipc_handle (const void *ipc_handle) {
385+ auto [it, new_handle] =
386+ ipc_handles_.insert ({*((IPC_KEY *)ipc_handle), nullptr });
387+ if (new_handle) {
388+ char *ipc_ptr;
389+ CUDACHECK (cudaIpcOpenMemHandle ((void **)&ipc_ptr,
390+ *((const cudaIpcMemHandle_t *)ipc_handle),
391+ cudaIpcMemLazyEnablePeerAccess));
392+ it->second = ipc_ptr;
393+ }
394+ return it->second ;
395+ }
396+
381397 std::pair<std::vector<uint8_t >, std::vector<int64_t >>
382398 get_graph_buffer_ipc_meta () {
383399 auto num_buffers = graph_unreg_buffers_.size ();
@@ -413,11 +429,7 @@ class CustomAllreduce {
413429 RankData data;
414430 for (int i = 0 ; i < world_size_; i++) {
415431 if (i != rank_) {
416- char *handle;
417- CUDACHECK (cudaIpcOpenMemHandle (
418- (void **)&handle, *((const cudaIpcMemHandle_t *)handles[i].data ()),
419- cudaIpcMemLazyEnablePeerAccess));
420- ipc_handles_.push_back (handle);
432+ char *handle = open_ipc_handle (handles[i].data ());
421433 handle += offsets[i];
422434 data.ptrs [i] = handle;
423435 } else {
@@ -448,13 +460,8 @@ class CustomAllreduce {
448460 auto &rd = rank_data[i];
449461 for (int j = 0 ; j < world_size_; j++) {
450462 if (j != rank_) {
451- char *handle;
452- CUDACHECK (cudaIpcOpenMemHandle (
453- (void **)&handle,
454- *((cudaIpcMemHandle_t *)&handles[j]
455- [i * sizeof (cudaIpcMemHandle_t)]),
456- cudaIpcMemLazyEnablePeerAccess));
457- ipc_handles_.push_back (handle);
463+ char *handle =
464+ open_ipc_handle (&handles[j][i * sizeof (cudaIpcMemHandle_t)]);
458465 handle += offsets[j][i];
459466 rd.ptrs [j] = handle;
460467 } else {
@@ -541,7 +548,7 @@ class CustomAllreduce {
541548 }
542549
543550 ~CustomAllreduce () {
544- for (auto ptr : ipc_handles_) {
551+ for (auto [_, ptr] : ipc_handles_) {
545552 CUDACHECK (cudaIpcCloseMemHandle (ptr));
546553 }
547554 }
0 commit comments