2024-01-28 04:46:35 +08:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include <cuda.h>
|
|
|
|
#include <cuda_bf16.h>
|
|
|
|
#include <cuda_fp16.h>
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
|
2025-04-01 07:49:12 +02:00
|
|
|
#if defined(USE_ROCM)
|
|
|
|
typedef __hip_bfloat16 nv_bfloat16;
|
|
|
|
#endif
|
|
|
|
|
2024-01-28 04:46:35 +08:00
|
|
|
#include <iostream>
|
2024-09-24 01:08:14 -07:00
|
|
|
#include <array>
|
2024-01-28 04:46:35 +08:00
|
|
|
#include <limits>
|
2024-01-30 02:46:29 +08:00
|
|
|
#include <map>
|
2024-01-28 04:46:35 +08:00
|
|
|
#include <unordered_map>
|
|
|
|
#include <vector>
|
|
|
|
|
2025-04-01 07:49:12 +02:00
|
|
|
namespace vllm {
|
2024-01-28 04:46:35 +08:00
|
|
|
#define CUDACHECK(cmd) \
|
|
|
|
do { \
|
|
|
|
cudaError_t e = cmd; \
|
|
|
|
if (e != cudaSuccess) { \
|
|
|
|
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
|
|
|
cudaGetErrorString(e)); \
|
|
|
|
exit(EXIT_FAILURE); \
|
|
|
|
} \
|
|
|
|
} while (0)
|
|
|
|
|
2025-04-01 07:49:12 +02:00
|
|
|
// Maximal number of blocks in allreduce kernel.
|
2024-09-24 01:08:14 -07:00
|
|
|
constexpr int kMaxBlocks = 36;
|
2025-04-01 07:49:12 +02:00
|
|
|
|
|
|
|
// Default number of blocks in allreduce kernel.
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
const int defaultBlockLimit = 36;
|
|
|
|
CUpointer_attribute rangeStartAddrAttr = CU_POINTER_ATTRIBUTE_RANGE_START_ADDR;
|
|
|
|
#else
|
|
|
|
const int defaultBlockLimit = 16;
|
|
|
|
hipPointer_attribute rangeStartAddrAttr =
|
|
|
|
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR;
|
|
|
|
#endif
|
|
|
|
|
2024-09-24 01:08:14 -07:00
|
|
|
// Counter may overflow, but it's fine since unsigned int overflow is
|
|
|
|
// well-defined behavior.
|
|
|
|
using FlagType = uint32_t;
|
2025-04-01 07:49:12 +02:00
|
|
|
|
|
|
|
// Two sets of peer counters are needed for two syncs: starting and ending an
|
|
|
|
// operation. The reason is that it's possible for peer GPU block to arrive at
|
|
|
|
// the second sync point while the current GPU block haven't passed the first
|
|
|
|
// sync point. Thus, peer GPU may write counter+1 while current GPU is busy
|
|
|
|
// waiting for counter. We use alternating counter array to avoid this
|
|
|
|
// possibility.
|
2024-01-28 04:46:35 +08:00
|
|
|
struct Signal {
|
2025-04-01 07:49:12 +02:00
|
|
|
alignas(128) FlagType start[kMaxBlocks][8];
|
|
|
|
alignas(128) FlagType end[kMaxBlocks][8];
|
|
|
|
alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank
|
2024-01-28 04:46:35 +08:00
|
|
|
};
|
|
|
|
|
2025-01-28 00:23:08 +00:00
|
|
|
struct __align__(16) RankData {
|
2025-04-01 07:49:12 +02:00
|
|
|
const void* ptrs[8];
|
2025-01-28 00:23:08 +00:00
|
|
|
};
|
2024-01-28 04:46:35 +08:00
|
|
|
|
2025-01-28 00:23:08 +00:00
|
|
|
struct __align__(16) RankSignals {
|
|
|
|
Signal* signals[8];
|
|
|
|
};
|
2024-01-28 04:46:35 +08:00
|
|
|
|
|
|
|
// like std::array, but aligned
|
|
|
|
template <typename T, int sz>
|
|
|
|
struct __align__(alignof(T) * sz) array_t {
|
|
|
|
T data[sz];
|
|
|
|
using type = T;
|
|
|
|
static constexpr int size = sz;
|
|
|
|
};
|
|
|
|
|
|
|
|
// use packed type to maximize memory efficiency
|
|
|
|
// goal: generate ld.128 and st.128 instructions
|
|
|
|
template <typename T>
|
|
|
|
struct packed_t {
|
|
|
|
// the (P)acked type for load/store
|
|
|
|
using P = array_t<T, 16 / sizeof(T)>;
|
|
|
|
// the (A)ccumulator type for reduction
|
|
|
|
using A = array_t<float, 16 / sizeof(T)>;
|
|
|
|
};
|
|
|
|
|
|
|
|
#define DINLINE __device__ __forceinline__
|
|
|
|
|
|
|
|
// scalar cast functions
|
|
|
|
DINLINE float upcast_s(half val) { return __half2float(val); }
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
DINLINE T downcast_s(float val);
|
|
|
|
template <>
|
|
|
|
DINLINE half downcast_s(float val) {
|
|
|
|
return __float2half(val);
|
|
|
|
}
|
|
|
|
|
|
|
|
// scalar add functions
|
|
|
|
// for some reason when compiling with Pytorch, the + operator for half and
|
|
|
|
// bfloat is disabled so we call the intrinsics directly
|
2024-05-22 03:18:41 -04:00
|
|
|
DINLINE half& assign_add(half& a, half b) {
|
2024-01-28 04:46:35 +08:00
|
|
|
a = __hadd(a, b);
|
|
|
|
return a;
|
|
|
|
}
|
2024-05-22 03:18:41 -04:00
|
|
|
DINLINE float& assign_add(float& a, float b) { return a += b; }
|
2024-01-28 04:46:35 +08:00
|
|
|
|
|
|
|
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
|
|
|
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
|
|
|
template <>
|
|
|
|
DINLINE nv_bfloat16 downcast_s(float val) {
|
|
|
|
return __float2bfloat16(val);
|
|
|
|
}
|
2024-05-22 03:18:41 -04:00
|
|
|
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
|
2024-01-28 04:46:35 +08:00
|
|
|
a = __hadd(a, b);
|
|
|
|
return a;
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
|
|
|
template <typename T, int N>
|
2024-05-22 03:18:41 -04:00
|
|
|
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
|
2024-01-28 04:46:35 +08:00
|
|
|
#pragma unroll
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
assign_add(a.data[i], b.data[i]);
|
|
|
|
}
|
|
|
|
return a;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T, int N>
|
|
|
|
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
|
|
|
if constexpr (std::is_same<T, float>::value) {
|
|
|
|
return val;
|
|
|
|
} else {
|
|
|
|
array_t<float, N> out;
|
|
|
|
#pragma unroll
|
|
|
|
for (int i = 0; i < N; i++) {
|
|
|
|
out.data[i] = upcast_s(val.data[i]);
|
|
|
|
}
|
|
|
|
return out;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename O>
|
|
|
|
DINLINE O downcast(array_t<float, O::size> val) {
|
|
|
|
if constexpr (std::is_same<typename O::type, float>::value) {
|
|
|
|
return val;
|
|
|
|
} else {
|
|
|
|
O out;
|
|
|
|
#pragma unroll
|
|
|
|
for (int i = 0; i < O::size; i++) {
|
|
|
|
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
|
|
|
}
|
|
|
|
return out;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-04-01 07:49:12 +02:00
|
|
|
#if !defined(USE_ROCM)
|
|
|
|
|
2024-09-24 01:08:14 -07:00
|
|
|
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
|
2025-04-01 07:49:12 +02:00
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
2024-09-24 01:08:14 -07:00
|
|
|
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
|
|
|
|
"l"(flag_addr));
|
2025-04-01 07:49:12 +02:00
|
|
|
#else
|
2024-09-25 04:26:33 +00:00
|
|
|
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
|
|
|
|
"l"(flag_addr));
|
2025-04-01 07:49:12 +02:00
|
|
|
#endif
|
2024-09-24 01:08:14 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
|
|
|
|
FlagType flag;
|
2025-04-01 07:49:12 +02:00
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
2024-09-24 01:08:14 -07:00
|
|
|
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
|
|
|
|
: "=r"(flag)
|
|
|
|
: "l"(flag_addr));
|
2025-04-01 07:49:12 +02:00
|
|
|
#else
|
2024-09-25 04:26:33 +00:00
|
|
|
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
|
|
|
|
: "=r"(flag)
|
|
|
|
: "l"(flag_addr));
|
2025-04-01 07:49:12 +02:00
|
|
|
#endif
|
2024-09-24 01:08:14 -07:00
|
|
|
return flag;
|
|
|
|
}
|
|
|
|
|
|
|
|
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
|
|
|
|
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
|
|
|
}
|
|
|
|
|
|
|
|
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
|
|
|
|
FlagType flag;
|
|
|
|
asm volatile("ld.volatile.global.u32 %0, [%1];"
|
|
|
|
: "=r"(flag)
|
|
|
|
: "l"(flag_addr));
|
|
|
|
return flag;
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
|
2025-04-01 07:49:12 +02:00
|
|
|
// This function is meant to be used as the first synchronization in the all
|
|
|
|
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
|
|
|
|
// prior memory accesses. Note: volatile writes will not be reordered against
|
|
|
|
// other volatile writes.
|
|
|
|
template <int ngpus>
|
|
|
|
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
|
|
|
|
int rank) {
|
|
|
|
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
2024-03-21 23:02:58 -07:00
|
|
|
if (threadIdx.x < ngpus) {
|
2025-04-01 07:49:12 +02:00
|
|
|
auto peer_counter_ptr = &sg.signals[threadIdx.x]->start[blockIdx.x][rank];
|
|
|
|
auto self_counter_ptr = &self_sg->start[blockIdx.x][threadIdx.x];
|
|
|
|
// Write the expected counter value to peer and wait for correct value
|
|
|
|
// from peer.
|
|
|
|
st_flag_volatile(peer_counter_ptr, flag);
|
|
|
|
while (ld_flag_volatile(self_counter_ptr) != flag);
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
// use one thread to update flag
|
|
|
|
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
|
|
|
}
|
|
|
|
|
|
|
|
// This function is meant to be used as the second or the final
|
|
|
|
// synchronization barrier in the all reduce kernel. If it's the final
|
|
|
|
// synchronization barrier, we don't need to make any visibility guarantees
|
|
|
|
// for prior memory accesses.
|
|
|
|
template <int ngpus, bool final_sync = false>
|
|
|
|
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
|
|
|
|
__syncthreads();
|
|
|
|
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
|
|
|
if (threadIdx.x < ngpus) {
|
|
|
|
auto peer_counter_ptr = &sg.signals[threadIdx.x]->end[blockIdx.x][rank];
|
|
|
|
auto self_counter_ptr = &self_sg->end[blockIdx.x][threadIdx.x];
|
2024-09-24 01:08:14 -07:00
|
|
|
// Write the expected counter value to peer and wait for correct value from
|
|
|
|
// peer.
|
2025-04-01 07:49:12 +02:00
|
|
|
if constexpr (!final_sync) {
|
|
|
|
st_flag_release(peer_counter_ptr, flag);
|
|
|
|
while (ld_flag_acquire(self_counter_ptr) != flag);
|
2024-09-24 01:08:14 -07:00
|
|
|
} else {
|
2025-04-01 07:49:12 +02:00
|
|
|
st_flag_volatile(peer_counter_ptr, flag);
|
|
|
|
while (ld_flag_volatile(self_counter_ptr) != flag);
|
2024-09-24 01:08:14 -07:00
|
|
|
}
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
2025-04-01 07:49:12 +02:00
|
|
|
if constexpr (!final_sync) __syncthreads();
|
|
|
|
|
|
|
|
// use one thread to update flag
|
|
|
|
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
|
2025-04-01 07:49:12 +02:00
|
|
|
#else
|
|
|
|
|
|
|
|
template <int ngpus>
|
|
|
|
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
|
|
|
|
int rank) {
|
|
|
|
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
|
|
|
if (threadIdx.x < ngpus) {
|
|
|
|
// simultaneously write to the corresponding flag of all ranks.
|
|
|
|
// Latency = 1 p2p write
|
|
|
|
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
|
|
|
|
flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
|
|
|
|
// wait until we got true from all ranks
|
|
|
|
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
|
|
|
|
__ATOMIC_RELAXED,
|
|
|
|
__MEMORY_SCOPE_DEVICE) < flag);
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
// use one thread to update flag
|
|
|
|
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <int ngpus, bool final_sync = false>
|
|
|
|
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
|
|
|
|
__syncthreads();
|
|
|
|
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
|
|
|
if (threadIdx.x < ngpus) {
|
|
|
|
// simultaneously write to the corresponding flag of all ranks.
|
|
|
|
// Latency = 1 p2p write
|
|
|
|
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
|
|
|
|
flag,
|
|
|
|
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
|
|
|
|
__MEMORY_SCOPE_SYSTEM);
|
|
|
|
// wait until we got true from all ranks
|
|
|
|
while (
|
|
|
|
__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
|
|
|
|
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
|
|
|
|
__MEMORY_SCOPE_DEVICE) < flag);
|
|
|
|
}
|
|
|
|
if constexpr (!final_sync) __syncthreads();
|
|
|
|
// use one thread to update flag
|
|
|
|
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
2024-01-28 04:46:35 +08:00
|
|
|
template <typename P, int ngpus, typename A>
|
2024-05-22 03:18:41 -04:00
|
|
|
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
2024-01-28 04:46:35 +08:00
|
|
|
A tmp = upcast(ptrs[0][idx]);
|
|
|
|
#pragma unroll
|
|
|
|
for (int i = 1; i < ngpus; i++) {
|
|
|
|
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
|
|
|
}
|
|
|
|
return downcast<P>(tmp);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T, int ngpus>
|
|
|
|
__global__ void __launch_bounds__(512, 1)
|
2024-09-24 01:08:14 -07:00
|
|
|
cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
|
|
|
|
T* __restrict__ result, int rank, int size) {
|
2024-01-28 04:46:35 +08:00
|
|
|
using P = typename packed_t<T>::P;
|
|
|
|
using A = typename packed_t<T>::A;
|
|
|
|
// note: we don't reorder the address so the accumulation order is the same
|
|
|
|
// for all ranks, ensuring bitwise identical results
|
|
|
|
auto dp = *_dp;
|
2025-04-01 07:49:12 +02:00
|
|
|
barrier_at_start<ngpus>(sg, self_sg, rank);
|
2024-01-28 04:46:35 +08:00
|
|
|
// do the actual reduction
|
|
|
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
|
|
|
idx += gridDim.x * blockDim.x) {
|
2024-05-22 03:18:41 -04:00
|
|
|
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
2025-04-01 07:49:12 +02:00
|
|
|
barrier_at_end<ngpus, true>(sg, self_sg, rank);
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename P>
|
2024-09-24 01:08:14 -07:00
|
|
|
DINLINE P* get_tmp_buf(Signal* sg) {
|
2024-05-22 03:18:41 -04:00
|
|
|
return (P*)(((Signal*)sg) + 1);
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T, int ngpus>
|
|
|
|
__global__ void __launch_bounds__(512, 1)
|
2024-09-24 01:08:14 -07:00
|
|
|
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
|
|
|
|
T* __restrict__ result, int rank, int size) {
|
2024-01-28 04:46:35 +08:00
|
|
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
int stride = gridDim.x * blockDim.x;
|
|
|
|
using P = typename packed_t<T>::P;
|
|
|
|
using A = typename packed_t<T>::A;
|
|
|
|
int part = size / ngpus;
|
|
|
|
int start = rank * part;
|
|
|
|
int end = rank == ngpus - 1 ? size : start + part;
|
2024-03-21 23:02:58 -07:00
|
|
|
int largest_part = part + size % ngpus;
|
2024-05-22 03:18:41 -04:00
|
|
|
const P* ptrs[ngpus];
|
|
|
|
P* tmps[ngpus];
|
2024-01-28 04:46:35 +08:00
|
|
|
#pragma unroll
|
|
|
|
for (int i = 0; i < ngpus; i++) {
|
|
|
|
int target = (rank + i) % ngpus;
|
2024-05-22 03:18:41 -04:00
|
|
|
ptrs[i] = (const P*)_dp->ptrs[target];
|
2024-01-28 04:46:35 +08:00
|
|
|
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
|
|
|
}
|
|
|
|
auto tmp_out = tmps[0];
|
2025-04-01 07:49:12 +02:00
|
|
|
barrier_at_start<ngpus>(sg, self_sg, rank);
|
|
|
|
|
2024-01-28 04:46:35 +08:00
|
|
|
// stage 1: reduce scatter
|
|
|
|
for (int idx = start + tid; idx < end; idx += stride) {
|
|
|
|
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
|
|
|
}
|
2025-04-01 07:49:12 +02:00
|
|
|
barrier_at_end<ngpus>(sg, self_sg, rank);
|
2024-03-21 23:02:58 -07:00
|
|
|
|
|
|
|
// stage 2: allgather. Note: it's important to match the tid between
|
|
|
|
// the two stages, because visibility across devices is only guaranteed
|
|
|
|
// between threads that have the same tid. If thread i computes the sum of
|
2025-04-01 07:49:12 +02:00
|
|
|
// start + i in the first stage, then thread i also gathers start + i from
|
|
|
|
// all ranks.
|
|
|
|
|
2024-03-21 23:02:58 -07:00
|
|
|
for (int idx = tid; idx < largest_part; idx += stride) {
|
2024-01-28 04:46:35 +08:00
|
|
|
#pragma unroll
|
|
|
|
for (int i = 0; i < ngpus; i++) {
|
2024-03-21 23:02:58 -07:00
|
|
|
int gather_from_rank = ((rank + i) % ngpus);
|
|
|
|
if (gather_from_rank == ngpus - 1 || idx < part) {
|
|
|
|
int dst_idx = gather_from_rank * part + idx;
|
2024-05-22 03:18:41 -04:00
|
|
|
((P*)result)[dst_idx] = tmps[i][idx];
|
2024-03-21 23:02:58 -07:00
|
|
|
}
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-01-30 02:46:29 +08:00
|
|
|
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
|
|
|
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
|
|
|
|
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
|
|
|
|
|
2024-01-28 04:46:35 +08:00
|
|
|
class CustomAllreduce {
|
|
|
|
public:
|
|
|
|
int rank_;
|
|
|
|
int world_size_;
|
2025-04-01 07:49:12 +02:00
|
|
|
// Full NVLink or xGMI connection between GPUs.
|
|
|
|
bool fully_connected_;
|
2024-01-28 04:46:35 +08:00
|
|
|
|
|
|
|
RankSignals sg_;
|
2025-04-01 07:49:12 +02:00
|
|
|
// Stores an map from a pointer to its peer pointers from all ranks.
|
2024-05-22 03:18:41 -04:00
|
|
|
std::unordered_map<void*, RankData*> buffers_;
|
|
|
|
Signal* self_sg_;
|
2024-01-28 04:46:35 +08:00
|
|
|
|
2024-11-06 23:50:47 -08:00
|
|
|
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
|
|
|
|
// For cuda graph to work, all kernel arguments must be fixed during graph
|
2025-04-01 07:49:12 +02:00
|
|
|
// capture time. However, the peer pointers are not known during graph
|
|
|
|
// capture time. Therefore, during capture, we increment the rank data
|
|
|
|
// pointer and use that as the argument to the kernel. The kernel arguments
|
|
|
|
// are stored in graph_unreg_buffers_. The actual peer pointers will be
|
|
|
|
// filled in at the memory pointed to by the pointers in
|
|
|
|
// graph_unreg_buffers_ when the IPC handles are exchanged between ranks.
|
2024-11-06 23:50:47 -08:00
|
|
|
//
|
|
|
|
// The overall process looks like this:
|
|
|
|
// 1. Graph capture.
|
|
|
|
// 2. Each rank obtains the IPC handles for each addresses used during cuda
|
|
|
|
// graph capture using get_graph_buffer_ipc_meta.
|
|
|
|
// 3. (In Python) all gather the IPC handles.
|
|
|
|
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
|
|
|
|
// the rank data array at corresponding positions.
|
2024-01-28 04:46:35 +08:00
|
|
|
RankData *d_rank_data_base_, *d_rank_data_end_;
|
2024-05-22 03:18:41 -04:00
|
|
|
std::vector<void*> graph_unreg_buffers_;
|
2024-01-30 02:46:29 +08:00
|
|
|
// a map from IPC handles to opened IPC pointers
|
2024-05-22 03:18:41 -04:00
|
|
|
std::map<IPC_KEY, char*> ipc_handles_;
|
2024-01-28 04:46:35 +08:00
|
|
|
|
|
|
|
/**
|
2024-11-06 23:50:47 -08:00
|
|
|
* Signals are an array of ipc-enabled buffers from all ranks.
|
|
|
|
* For each of the buffer, the layout is as follows:
|
|
|
|
* | -- sizeof(Signal) -- | ------ a few MB ----- |
|
2025-04-01 07:49:12 +02:00
|
|
|
* The first section is for allreduce synchronization, and the second
|
|
|
|
* section is for storing the intermediate results required by some
|
|
|
|
* allreduce algos.
|
2024-01-28 04:46:35 +08:00
|
|
|
*
|
2024-11-06 23:50:47 -08:00
|
|
|
* Note: this class does not own any device memory. Any required buffers
|
|
|
|
* are passed in from the constructor.
|
2024-01-28 04:46:35 +08:00
|
|
|
*/
|
2024-11-06 23:50:47 -08:00
|
|
|
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
|
2025-04-01 07:49:12 +02:00
|
|
|
int rank, int world_size, bool fully_connected = true)
|
2024-01-28 04:46:35 +08:00
|
|
|
: rank_(rank),
|
2024-11-06 23:50:47 -08:00
|
|
|
world_size_(world_size),
|
2025-04-01 07:49:12 +02:00
|
|
|
fully_connected_(fully_connected),
|
2024-11-06 23:50:47 -08:00
|
|
|
self_sg_(signals[rank]),
|
2024-05-22 03:18:41 -04:00
|
|
|
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
2024-01-28 04:46:35 +08:00
|
|
|
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
|
|
|
for (int i = 0; i < world_size_; i++) {
|
2024-11-06 23:50:47 -08:00
|
|
|
sg_.signals[i] = signals[i];
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
char* open_ipc_handle(const void* ipc_handle) {
|
2024-01-30 02:46:29 +08:00
|
|
|
auto [it, new_handle] =
|
2024-05-22 03:18:41 -04:00
|
|
|
ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
2024-01-30 02:46:29 +08:00
|
|
|
if (new_handle) {
|
2024-05-22 03:18:41 -04:00
|
|
|
char* ipc_ptr;
|
|
|
|
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
|
|
|
|
*((const cudaIpcMemHandle_t*)ipc_handle),
|
2024-01-30 02:46:29 +08:00
|
|
|
cudaIpcMemLazyEnablePeerAccess));
|
|
|
|
it->second = ipc_ptr;
|
|
|
|
}
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
|
2024-11-06 23:50:47 -08:00
|
|
|
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
|
2024-01-28 04:46:35 +08:00
|
|
|
auto num_buffers = graph_unreg_buffers_.size();
|
|
|
|
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
2024-11-06 23:50:47 -08:00
|
|
|
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
|
2024-01-28 04:46:35 +08:00
|
|
|
std::vector<int64_t> offsets(num_buffers);
|
|
|
|
for (int i = 0; i < num_buffers; i++) {
|
|
|
|
auto ptr = graph_unreg_buffers_[i];
|
2024-05-22 03:18:41 -04:00
|
|
|
void* base_ptr;
|
2024-01-28 04:46:35 +08:00
|
|
|
// note: must share the base address of each allocation, or we get wrong
|
|
|
|
// address
|
2025-04-01 07:49:12 +02:00
|
|
|
if (cuPointerGetAttribute(&base_ptr, rangeStartAddrAttr,
|
2024-01-28 04:46:35 +08:00
|
|
|
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
|
|
|
throw std::runtime_error("failed to get pointer attr");
|
|
|
|
CUDACHECK(cudaIpcGetMemHandle(
|
2024-05-22 03:18:41 -04:00
|
|
|
(cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
|
|
|
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
return std::make_pair(handles, offsets);
|
|
|
|
}
|
|
|
|
|
|
|
|
void check_rank_data_capacity(size_t num = 1) {
|
|
|
|
if (d_rank_data_base_ + num > d_rank_data_end_)
|
|
|
|
throw std::runtime_error(
|
|
|
|
"Rank data buffer is overflowed by " +
|
|
|
|
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
|
|
|
}
|
|
|
|
|
2024-11-06 23:50:47 -08:00
|
|
|
/**
|
|
|
|
* Register already-shared IPC pointers.
|
|
|
|
*/
|
|
|
|
void register_buffer(void** ptrs) {
|
2024-01-28 04:46:35 +08:00
|
|
|
check_rank_data_capacity();
|
|
|
|
RankData data;
|
|
|
|
for (int i = 0; i < world_size_; i++) {
|
2024-11-06 23:50:47 -08:00
|
|
|
data.ptrs[i] = ptrs[i];
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
auto d_data = d_rank_data_base_++;
|
|
|
|
CUDACHECK(
|
|
|
|
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
2024-11-06 23:50:47 -08:00
|
|
|
buffers_[ptrs[rank_]] = d_data;
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
|
2024-11-06 23:50:47 -08:00
|
|
|
// Note: when registering graph buffers, we intentionally choose to not
|
2024-01-28 04:46:35 +08:00
|
|
|
// deduplicate the addresses. That means if the allocator reuses some
|
2025-04-01 07:49:12 +02:00
|
|
|
// addresses, they will be registered again. This is to account for the
|
|
|
|
// remote possibility of different allocation patterns between ranks. For
|
|
|
|
// example, rank 1 may get the same input address for the second allreduce,
|
|
|
|
// but rank 2 got a different address. IPC handles have internal reference
|
|
|
|
// counting mechanism so overhead should be small.
|
2024-01-28 04:46:35 +08:00
|
|
|
void register_graph_buffers(
|
2024-05-22 03:18:41 -04:00
|
|
|
const std::vector<std::string>& handles,
|
|
|
|
const std::vector<std::vector<int64_t>>& offsets) {
|
2024-01-28 04:46:35 +08:00
|
|
|
auto num_buffers = graph_unreg_buffers_.size();
|
|
|
|
check_rank_data_capacity(num_buffers);
|
|
|
|
std::vector<RankData> rank_data(num_buffers);
|
|
|
|
for (int i = 0; i < num_buffers; i++) {
|
|
|
|
auto self_ptr = graph_unreg_buffers_[i];
|
2024-05-22 03:18:41 -04:00
|
|
|
auto& rd = rank_data[i];
|
2024-01-28 04:46:35 +08:00
|
|
|
for (int j = 0; j < world_size_; j++) {
|
|
|
|
if (j != rank_) {
|
2024-05-22 03:18:41 -04:00
|
|
|
char* handle =
|
2024-01-30 02:46:29 +08:00
|
|
|
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
2024-01-28 04:46:35 +08:00
|
|
|
handle += offsets[j][i];
|
|
|
|
rd.ptrs[j] = handle;
|
|
|
|
} else {
|
|
|
|
rd.ptrs[j] = self_ptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
|
|
|
|
sizeof(RankData) * num_buffers,
|
|
|
|
cudaMemcpyHostToDevice));
|
|
|
|
d_rank_data_base_ += num_buffers;
|
|
|
|
graph_unreg_buffers_.clear();
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
2024-11-06 23:50:47 -08:00
|
|
|
* Performs allreduce, assuming input has already been registered.
|
|
|
|
*
|
2025-04-01 07:49:12 +02:00
|
|
|
* Block and grid default configs are results after careful grid search.
|
|
|
|
* Using 36 blocks give the best or close to the best runtime on the devices
|
|
|
|
* I tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
|
|
|
|
* only take a small amount of SMs. Not quite sure the underlying reason,
|
|
|
|
* but my guess is that too many SMs will cause contention on NVLink bus.
|
2024-01-28 04:46:35 +08:00
|
|
|
*/
|
|
|
|
template <typename T>
|
2024-05-22 03:18:41 -04:00
|
|
|
void allreduce(cudaStream_t stream, T* input, T* output, int size,
|
2025-04-01 07:49:12 +02:00
|
|
|
int threads = 512, int block_limit = defaultBlockLimit) {
|
2024-01-28 04:46:35 +08:00
|
|
|
auto d = packed_t<T>::P::size;
|
|
|
|
if (size % d != 0)
|
|
|
|
throw std::runtime_error(
|
|
|
|
"custom allreduce currently requires input length to be multiple "
|
|
|
|
"of " +
|
|
|
|
std::to_string(d));
|
2024-03-21 23:02:58 -07:00
|
|
|
if (block_limit > kMaxBlocks)
|
|
|
|
throw std::runtime_error("max supported block limit is " +
|
|
|
|
std::to_string(kMaxBlocks) + ". Got " +
|
|
|
|
std::to_string(block_limit));
|
2024-01-28 04:46:35 +08:00
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
RankData* ptrs;
|
2024-01-28 04:46:35 +08:00
|
|
|
cudaStreamCaptureStatus status;
|
|
|
|
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
|
|
|
if (status == cudaStreamCaptureStatusActive) {
|
|
|
|
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
|
|
|
graph_unreg_buffers_.push_back(input);
|
|
|
|
} else {
|
|
|
|
auto it = buffers_.find(input);
|
|
|
|
if (it == buffers_.end())
|
|
|
|
throw std::runtime_error(
|
|
|
|
"buffer address " +
|
|
|
|
std::to_string(reinterpret_cast<uint64_t>(input)) +
|
|
|
|
" is not registered!");
|
|
|
|
ptrs = it->second;
|
|
|
|
}
|
|
|
|
|
|
|
|
size /= d;
|
|
|
|
auto bytes = size * sizeof(typename packed_t<T>::P);
|
|
|
|
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
2024-03-21 23:02:58 -07:00
|
|
|
#define KL(ngpus, name) \
|
|
|
|
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
|
|
|
rank_, size);
|
2024-01-28 04:46:35 +08:00
|
|
|
#define REDUCE_CASE(ngpus) \
|
|
|
|
case ngpus: { \
|
|
|
|
if (world_size_ == 2) { \
|
|
|
|
KL(ngpus, cross_device_reduce_1stage); \
|
2025-04-01 07:49:12 +02:00
|
|
|
} else if (fully_connected_) { \
|
2024-01-28 04:46:35 +08:00
|
|
|
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
|
|
|
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
|
|
|
KL(ngpus, cross_device_reduce_1stage); \
|
|
|
|
} else { \
|
|
|
|
KL(ngpus, cross_device_reduce_2stage); \
|
|
|
|
} \
|
|
|
|
} \
|
|
|
|
break; \
|
|
|
|
}
|
|
|
|
|
|
|
|
switch (world_size_) {
|
|
|
|
REDUCE_CASE(2)
|
|
|
|
REDUCE_CASE(4)
|
|
|
|
REDUCE_CASE(6)
|
|
|
|
REDUCE_CASE(8)
|
|
|
|
default:
|
|
|
|
throw std::runtime_error(
|
2025-04-01 07:49:12 +02:00
|
|
|
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
|
|
|
|
"num "
|
2024-01-28 04:46:35 +08:00
|
|
|
"gpus = " +
|
|
|
|
std::to_string(world_size_));
|
|
|
|
}
|
|
|
|
#undef REDUCE_CASE
|
|
|
|
#undef KL
|
|
|
|
}
|
|
|
|
|
|
|
|
~CustomAllreduce() {
|
2024-01-30 02:46:29 +08:00
|
|
|
for (auto [_, ptr] : ipc_handles_) {
|
2024-01-28 04:46:35 +08:00
|
|
|
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
2025-04-01 07:49:12 +02:00
|
|
|
|
2024-01-28 04:46:35 +08:00
|
|
|
/**
|
2025-04-01 07:49:12 +02:00
|
|
|
* To inspect PTX/SASS, copy paste this header file to compiler explorer and
|
|
|
|
add a template instantiation:
|
2024-03-21 23:02:58 -07:00
|
|
|
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
|
|
|
|
half *, int, int, int);
|
2024-01-28 04:46:35 +08:00
|
|
|
*/
|
2025-04-01 07:49:12 +02:00
|
|
|
} // namespace vllm
|