diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 6e71bb9a..54409e19 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -327,6 +328,10 @@ __global__ void __launch_bounds__(512, 1) } } +using IPC_KEY = std::array; +static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t)); +static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t)); + class CustomAllreduce { public: int rank_; @@ -341,7 +346,8 @@ class CustomAllreduce { // stores the registered device pointers from all ranks RankData *d_rank_data_base_, *d_rank_data_end_; std::vector graph_unreg_buffers_; - std::vector ipc_handles_; + // a map from IPC handles to opened IPC pointers + std::map ipc_handles_; /** * meta is a pointer to device metadata and temporary buffer for allreduce. @@ -365,10 +371,7 @@ class CustomAllreduce { for (int i = 0; i < world_size_; i++) { Metadata *rank_meta; if (i != rank_) { - char *handle; - CUDACHECK(cudaIpcOpenMemHandle((void **)&handle, handles[i], - cudaIpcMemLazyEnablePeerAccess)); - ipc_handles_.push_back(handle); + char *handle = open_ipc_handle(&handles[i]); handle += offsets[i]; rank_meta = (Metadata *)handle; } else { @@ -378,6 +381,19 @@ class CustomAllreduce { } } + char *open_ipc_handle(const void *ipc_handle) { + auto [it, new_handle] = + ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); + if (new_handle) { + char *ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, + *((const cudaIpcMemHandle_t *)ipc_handle), + cudaIpcMemLazyEnablePeerAccess)); + it->second = ipc_ptr; + } + return it->second; + } + std::pair, std::vector> get_graph_buffer_ipc_meta() { auto num_buffers = graph_unreg_buffers_.size(); @@ -413,11 +429,7 @@ class CustomAllreduce { RankData data; for (int i = 0; i < world_size_; i++) { if (i != rank_) { - char *handle; - CUDACHECK(cudaIpcOpenMemHandle( - (void **)&handle, *((const cudaIpcMemHandle_t *)handles[i].data()), - cudaIpcMemLazyEnablePeerAccess)); - ipc_handles_.push_back(handle); + char *handle = open_ipc_handle(handles[i].data()); handle += offsets[i]; data.ptrs[i] = handle; } else { @@ -448,13 +460,8 @@ class CustomAllreduce { auto &rd = rank_data[i]; for (int j = 0; j < world_size_; j++) { if (j != rank_) { - char *handle; - CUDACHECK(cudaIpcOpenMemHandle( - (void **)&handle, - *((cudaIpcMemHandle_t *)&handles[j] - [i * sizeof(cudaIpcMemHandle_t)]), - cudaIpcMemLazyEnablePeerAccess)); - ipc_handles_.push_back(handle); + char *handle = + open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); handle += offsets[j][i]; rd.ptrs[j] = handle; } else { @@ -541,7 +548,7 @@ class CustomAllreduce { } ~CustomAllreduce() { - for (auto ptr : ipc_handles_) { + for (auto [_, ptr] : ipc_handles_) { CUDACHECK(cudaIpcCloseMemHandle(ptr)); } }