Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>

Expand Down Expand Up @@ -327,6 +328,10 @@ __global__ void __launch_bounds__(512, 1)
}
}

using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));

class CustomAllreduce {
public:
int rank_;
Expand All @@ -341,7 +346,8 @@ class CustomAllreduce {
// stores the registered device pointers from all ranks
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void *> graph_unreg_buffers_;
std::vector<void *> ipc_handles_;
// a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char *> ipc_handles_;

/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
Expand All @@ -365,10 +371,7 @@ class CustomAllreduce {
for (int i = 0; i < world_size_; i++) {
Metadata *rank_meta;
if (i != rank_) {
char *handle;
CUDACHECK(cudaIpcOpenMemHandle((void **)&handle, handles[i],
cudaIpcMemLazyEnablePeerAccess));
ipc_handles_.push_back(handle);
char *handle = open_ipc_handle(&handles[i]);
handle += offsets[i];
rank_meta = (Metadata *)handle;
} else {
Expand All @@ -378,6 +381,19 @@ class CustomAllreduce {
}
}

char *open_ipc_handle(const void *ipc_handle) {
auto [it, new_handle] =
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
if (new_handle) {
char *ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
*((const cudaIpcMemHandle_t *)ipc_handle),
cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}

std::pair<std::vector<uint8_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
Expand Down Expand Up @@ -413,11 +429,7 @@ class CustomAllreduce {
RankData data;
for (int i = 0; i < world_size_; i++) {
if (i != rank_) {
char *handle;
CUDACHECK(cudaIpcOpenMemHandle(
(void **)&handle, *((const cudaIpcMemHandle_t *)handles[i].data()),
cudaIpcMemLazyEnablePeerAccess));
ipc_handles_.push_back(handle);
char *handle = open_ipc_handle(handles[i].data());
handle += offsets[i];
data.ptrs[i] = handle;
} else {
Expand Down Expand Up @@ -448,13 +460,8 @@ class CustomAllreduce {
auto &rd = rank_data[i];
for (int j = 0; j < world_size_; j++) {
if (j != rank_) {
char *handle;
CUDACHECK(cudaIpcOpenMemHandle(
(void **)&handle,
*((cudaIpcMemHandle_t *)&handles[j]
[i * sizeof(cudaIpcMemHandle_t)]),
cudaIpcMemLazyEnablePeerAccess));
ipc_handles_.push_back(handle);
char *handle =
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
Expand Down Expand Up @@ -541,7 +548,7 @@ class CustomAllreduce {
}

~CustomAllreduce() {
for (auto ptr : ipc_handles_) {
for (auto [_, ptr] : ipc_handles_) {
CUDACHECK(cudaIpcCloseMemHandle(ptr));
}
}
Expand Down