[BugFix] Some fixes for custom allreduce kernels (#2760)
This commit is contained in:
parent
e90fc21f2e
commit
f721096d48
@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||
}
|
||||
|
||||
@ -62,9 +62,9 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||
if (inp_size % 16 != 0) return false;
|
||||
if (!_is_weak_contiguous(inp)) return false;
|
||||
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
|
||||
// 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
|
||||
// <= 512k
|
||||
return world_size <= 4 && inp_size <= 512 * 1024;
|
||||
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
|
||||
// performance improvement over NCCL.
|
||||
return false;
|
||||
}
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
||||
@ -126,7 +126,7 @@ void dispose(fptr_t _fa) {
|
||||
delete fa;
|
||||
}
|
||||
|
||||
int meta_size() { return sizeof(vllm::Metadata); }
|
||||
int meta_size() { return sizeof(vllm::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||
const std::vector<std::string> &handles,
|
||||
|
@ -23,29 +23,17 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
constexpr int kMaxBlocks = 64;
|
||||
// note: we don't want to use atomics for signals because peer atomics are no
|
||||
// supported on PCIe links
|
||||
struct Signal {
|
||||
alignas(64) union {
|
||||
uint64_t flag;
|
||||
unsigned char data[8];
|
||||
} start;
|
||||
alignas(64) union {
|
||||
uint64_t flag;
|
||||
unsigned char data[8];
|
||||
} end;
|
||||
alignas(128) uint32_t start[kMaxBlocks][8];
|
||||
alignas(128) uint32_t end[kMaxBlocks][8];
|
||||
};
|
||||
|
||||
struct Metadata {
|
||||
alignas(128) Signal sg;
|
||||
alignas(128) int counter;
|
||||
};
|
||||
static_assert(offsetof(Metadata, counter) == 128);
|
||||
static_assert(sizeof(Metadata) == 256);
|
||||
|
||||
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
|
||||
|
||||
struct RankSignals {
|
||||
volatile Signal *signals[8];
|
||||
};
|
||||
struct __align__(16) RankSignals { volatile Signal *signals[8]; };
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
@ -135,70 +123,49 @@ DINLINE O downcast(array_t<float, O::size> val) {
|
||||
}
|
||||
}
|
||||
|
||||
// compute flag at compile time
|
||||
__host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
|
||||
auto m = std::numeric_limits<uint64_t>::max();
|
||||
return m >> ((8 - ngpus) * 8);
|
||||
}
|
||||
|
||||
// This function is meant to be used as the first synchronization in the all
|
||||
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
|
||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||
// other volatile writes.
|
||||
template <int ngpus>
|
||||
DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
|
||||
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
int rank) {
|
||||
constexpr auto FLAG = compute_flag(ngpus);
|
||||
if (blockIdx.x == 0) {
|
||||
if (threadIdx.x < ngpus)
|
||||
// simultaneously write to the corresponding byte to all other ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start.data[rank] = 255;
|
||||
else if (threadIdx.x == 32)
|
||||
// reset
|
||||
meta->sg.end.flag = 0;
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
while (meta->sg.start.flag != FLAG)
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->end[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->start[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// This function is meant to be used as the second or the final synchronization
|
||||
// barrier in the all reduce kernel. If it's the final synchronization barrier,
|
||||
// we don't need to make any visibility guarantees for prior memory accesses.
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
|
||||
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
int rank) {
|
||||
constexpr auto FLAG = compute_flag(ngpus);
|
||||
__syncthreads();
|
||||
__shared__ int num;
|
||||
if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
|
||||
__syncthreads();
|
||||
|
||||
// Only the last completing block can perform the end synchronization
|
||||
// This can ensures when the final busy wait ends, all ranks must have
|
||||
// finished reading each other's buffer.
|
||||
if (num == gridDim.x - 1) {
|
||||
if (threadIdx.x == 32) {
|
||||
// reset in a different warp
|
||||
meta->counter = 0;
|
||||
meta->sg.start.flag = 0;
|
||||
} else if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding byte to all other ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end.data[rank] = 255;
|
||||
}
|
||||
// if this is the final sync, only one block needs it
|
||||
// because kernel exit can serve as sync
|
||||
if constexpr (final_sync) {
|
||||
if (threadIdx.x == 0) {
|
||||
while (meta->sg.end.flag != FLAG)
|
||||
;
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (!final_sync) {
|
||||
if (threadIdx.x == 0) {
|
||||
while (meta->sg.end.flag != FLAG)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
if constexpr (!final_sync) __threadfence_system();
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->start[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->end[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
@ -214,32 +181,32 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
|
||||
volatile Metadata *meta, T *__restrict__ result,
|
||||
volatile Signal *self_sg, T *__restrict__ result,
|
||||
int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
start_sync<ngpus>(sg, meta, rank);
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
((P *)result)[idx] =
|
||||
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
||||
}
|
||||
end_sync<ngpus, true>(sg, meta, rank);
|
||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
DINLINE P *get_tmp_buf(volatile Signal *sg) {
|
||||
return (P *)(((Metadata *)sg) + 1);
|
||||
return (P *)(((Signal *)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
|
||||
volatile Metadata *meta, T *__restrict__ result,
|
||||
volatile Signal *self_sg, T *__restrict__ result,
|
||||
int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
@ -248,6 +215,7 @@ __global__ void __launch_bounds__(512, 1)
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P *ptrs[ngpus];
|
||||
P *tmps[ngpus];
|
||||
#pragma unroll
|
||||
@ -257,75 +225,28 @@ __global__ void __launch_bounds__(512, 1)
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
start_sync<ngpus>(sg, meta, rank);
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
// Maybe TODO: replace this with per-block release-acquire
|
||||
// can save about 1-2us (not a lot though)
|
||||
end_sync<ngpus>(sg, meta, rank);
|
||||
end_sync<ngpus>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather
|
||||
for (int idx = tid; idx < part; idx += stride) {
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int dst_idx = ((rank + i) % ngpus) * part + idx;
|
||||
((P *)result)[dst_idx] = tmps[i][idx];
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P *)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
// process the last larger partition
|
||||
int remaining = size - part * ngpus;
|
||||
if (tid < remaining) {
|
||||
int dst_idx = tid + part * ngpus;
|
||||
((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
|
||||
}
|
||||
|
||||
// faster than this
|
||||
// for (int idx = tid; idx < size; idx += stride) {
|
||||
// int target_rank = idx / part;
|
||||
// if (target_rank == ngpus) target_rank -= 1;
|
||||
// ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
|
||||
// }
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
|
||||
volatile Metadata *meta,
|
||||
T *__restrict__ result, int rank,
|
||||
int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
|
||||
constexpr int hg = ngpus / 2;
|
||||
// Actually not quite half butterfly.
|
||||
// This is an all-to-all within each group containing half of the ranks
|
||||
// followed by cross-group add. Equivalent to half butterfly when there
|
||||
// are 4 GPUs, a common case for PCIe cards like T4 and A10.
|
||||
const P *ptrs[hg];
|
||||
{
|
||||
int start = rank - rank % hg;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < hg; i++) {
|
||||
ptrs[i] = (const P *)_dp->ptrs[i + start];
|
||||
}
|
||||
}
|
||||
start_sync<ngpus>(sg, meta, rank);
|
||||
for (int idx = tid; idx < size; idx += stride) {
|
||||
tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
|
||||
}
|
||||
end_sync<ngpus>(sg, meta, rank);
|
||||
|
||||
auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
|
||||
// do the cross group reduction
|
||||
for (int idx = tid; idx < size; idx += stride) {
|
||||
auto tmp = tmp_out[idx];
|
||||
packed_assign_add(tmp, src[idx]);
|
||||
((P *)result)[idx] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||
@ -341,7 +262,7 @@ class CustomAllreduce {
|
||||
// below are device pointers
|
||||
RankSignals sg_;
|
||||
std::unordered_map<void *, RankData *> buffers_;
|
||||
Metadata *meta_;
|
||||
Signal *self_sg_;
|
||||
|
||||
// stores the registered device pointers from all ranks
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
@ -352,32 +273,32 @@ class CustomAllreduce {
|
||||
/**
|
||||
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||
*
|
||||
* There's a total of sizeof(Metadata) of prefix before the actual data,
|
||||
* There's a total of sizeof(Signal) of prefix before the actual data,
|
||||
* so meta + 1 points to actual temporary buffer.
|
||||
*
|
||||
* note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor
|
||||
*/
|
||||
CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
|
||||
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz,
|
||||
const cudaIpcMemHandle_t *handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(offsets.size()),
|
||||
full_nvlink_(full_nvlink),
|
||||
meta_(meta),
|
||||
self_sg_(meta),
|
||||
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
Metadata *rank_meta;
|
||||
Signal *rank_sg;
|
||||
if (i != rank_) {
|
||||
char *handle = open_ipc_handle(&handles[i]);
|
||||
handle += offsets[i];
|
||||
rank_meta = (Metadata *)handle;
|
||||
rank_sg = (Signal *)handle;
|
||||
} else {
|
||||
rank_meta = meta_;
|
||||
rank_sg = self_sg_;
|
||||
}
|
||||
sg_.signals[i] = &rank_meta->sg;
|
||||
sg_.signals[i] = rank_sg;
|
||||
}
|
||||
}
|
||||
|
||||
@ -492,6 +413,10 @@ class CustomAllreduce {
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error("max supported block limit is " +
|
||||
std::to_string(kMaxBlocks) + ". Got " +
|
||||
std::to_string(block_limit));
|
||||
|
||||
RankData *ptrs;
|
||||
cudaStreamCaptureStatus status;
|
||||
@ -512,9 +437,9 @@ class CustomAllreduce {
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus> \
|
||||
<<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
||||
rank_, size);
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
@ -526,8 +451,6 @@ class CustomAllreduce {
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_half_butterfly); \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
@ -556,7 +479,7 @@ class CustomAllreduce {
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
|
||||
int, int, int);
|
||||
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
|
||||
half *, int, int, int);
|
||||
*/
|
||||
} // namespace vllm
|
||||
|
@ -92,7 +92,7 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
||||
|
||||
template <typename T>
|
||||
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
int data_size) {
|
||||
int data_size, bool performance_test) {
|
||||
T *result;
|
||||
cudaStream_t stream;
|
||||
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
@ -101,7 +101,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
|
||||
cudaIpcMemHandle_t self_data_handle;
|
||||
cudaIpcMemHandle_t data_handles[8];
|
||||
vllm::Metadata *buffer;
|
||||
vllm::Signal *buffer;
|
||||
T *self_data_copy;
|
||||
/**
|
||||
* Allocate IPC buffer
|
||||
@ -115,9 +115,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
* convenience.
|
||||
*/
|
||||
CUDACHECK(
|
||||
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
||||
CUDACHECK(cudaMemset(buffer, 0,
|
||||
2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
||||
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
||||
CUDACHECK(
|
||||
cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
||||
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
|
||||
CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
|
||||
|
||||
@ -133,7 +133,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
offsets, myRank);
|
||||
auto *self_data =
|
||||
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
|
||||
sizeof(vllm::Metadata) + data_size * sizeof(T));
|
||||
sizeof(vllm::Signal) + data_size * sizeof(T));
|
||||
// hack buffer registration
|
||||
{
|
||||
std::vector<std::string> handles;
|
||||
@ -143,8 +143,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
char *end = (char *)&data_handles[i + 1];
|
||||
handles.emplace_back(begin, end);
|
||||
}
|
||||
std::vector<int64_t> offsets(
|
||||
nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T));
|
||||
std::vector<int64_t> offsets(nRanks,
|
||||
sizeof(vllm::Signal) + data_size * sizeof(T));
|
||||
fa.register_buffer(handles, offsets, self_data);
|
||||
}
|
||||
|
||||
@ -169,81 +169,112 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
} else {
|
||||
ncclDtype = ncclFloat;
|
||||
}
|
||||
|
||||
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
constexpr int warmup_iters = 5;
|
||||
constexpr int num_iters = 25;
|
||||
// warmup
|
||||
for (int i = 0; i < warmup_iters; i++) {
|
||||
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
||||
stream));
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(start, stream));
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
||||
stream));
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(stop, stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
float allreduce_ms = 0;
|
||||
cudaEventElapsedTime(&allreduce_ms, start, stop);
|
||||
|
||||
// if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
// set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank);
|
||||
|
||||
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
// warm up
|
||||
for (int i = 0; i < warmup_iters; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(start, stream));
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(stop, stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
float duration_ms = 0;
|
||||
cudaEventElapsedTime(&duration_ms, start, stop);
|
||||
if (myRank == 0)
|
||||
printf(
|
||||
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
|
||||
"time:%.2fus\n",
|
||||
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
|
||||
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
|
||||
|
||||
// And wait for all the queued up work to complete
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
|
||||
ncclSum, comm, stream));
|
||||
|
||||
double *nccl_result, *my_result;
|
||||
CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
|
||||
CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
|
||||
|
||||
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
|
||||
my_result, data_size);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
for (unsigned long j = 0; j < data_size; j++) {
|
||||
auto diff = abs(nccl_result[j] - my_result[j]);
|
||||
if (diff >= 1e-2) {
|
||||
printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
||||
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
||||
break;
|
||||
if (performance_test) {
|
||||
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
constexpr int warmup_iters = 5;
|
||||
constexpr int num_iters = 100;
|
||||
// warmup
|
||||
for (int i = 0; i < warmup_iters; i++) {
|
||||
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
|
||||
comm, stream));
|
||||
}
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(start, stream));
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
|
||||
comm, stream));
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(stop, stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
float allreduce_ms = 0;
|
||||
cudaEventElapsedTime(&allreduce_ms, start, stop);
|
||||
|
||||
long double nccl_diffs = 0.0;
|
||||
long double my_diffs = 0.0;
|
||||
for (int j = 0; j < data_size; j++) {
|
||||
nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
||||
my_diffs += abs(my_result[j] - ground_truth[j]);
|
||||
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
// warm up
|
||||
for (int i = 0; i < warmup_iters; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads,
|
||||
block_limit);
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(start, stream));
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads,
|
||||
block_limit);
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(stop, stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
float duration_ms = 0;
|
||||
cudaEventElapsedTime(&duration_ms, start, stop);
|
||||
if (myRank == 0)
|
||||
printf(
|
||||
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
|
||||
"time:%.2fus\n",
|
||||
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
|
||||
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
|
||||
|
||||
// And wait for all the queued up work to complete
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
|
||||
ncclSum, comm, stream));
|
||||
|
||||
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
|
||||
my_result, data_size);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
for (unsigned long j = 0; j < data_size; j++) {
|
||||
auto diff = abs(nccl_result[j] - my_result[j]);
|
||||
if (diff >= 4e-2) {
|
||||
printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
||||
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
long double nccl_diffs = 0.0;
|
||||
long double my_diffs = 0.0;
|
||||
for (int j = 0; j < data_size; j++) {
|
||||
nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
||||
my_diffs += abs(my_result[j] - ground_truth[j]);
|
||||
}
|
||||
if (myRank == 0)
|
||||
std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
||||
<< " me: " << my_diffs / data_size << std::endl;
|
||||
} else {
|
||||
for (int i = 0; i < 100; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads,
|
||||
block_limit);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype,
|
||||
ncclSum, comm, stream));
|
||||
convert_data<T><<<108, 1024, 0, stream>>>(
|
||||
self_data_copy, result, nccl_result, my_result, data_size);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
for (unsigned long j = 0; j < data_size; j++) {
|
||||
auto diff = abs(nccl_result[j] - my_result[j]);
|
||||
if (diff >= 4e-2) {
|
||||
printf(
|
||||
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
||||
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (myRank == 0)
|
||||
printf("Test passed: nGPUs:%d, sz (kb): %d, %d, %d\n", nRanks,
|
||||
data_size * sizeof(T) / 1024, threads, block_limit);
|
||||
// long double nccl_diffs = 0.0;
|
||||
// long double my_diffs = 0.0;
|
||||
// for (int j = 0; j < data_size; j++) {
|
||||
// nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
||||
// my_diffs += abs(my_result[j] - ground_truth[j]);
|
||||
// }
|
||||
// if (myRank == 0)
|
||||
// std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
||||
// << " me: " << my_diffs / data_size << std::endl;
|
||||
}
|
||||
if (myRank == 0)
|
||||
std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
||||
<< " me: " << my_diffs / data_size << std::endl;
|
||||
|
||||
CUDACHECK(cudaFree(result));
|
||||
CUDACHECK(cudaFree(self_data_copy));
|
||||
@ -269,14 +300,15 @@ int main(int argc, char **argv) {
|
||||
MPI_COMM_WORLD));
|
||||
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
|
||||
|
||||
bool performance_test = true;
|
||||
cudaProfilerStart();
|
||||
// for (int threads : {256, 512}) {
|
||||
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
|
||||
// }
|
||||
// }
|
||||
for (int sz = 512; sz <= (32 << 20); sz *= 2) {
|
||||
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 50);
|
||||
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
|
||||
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
|
||||
}
|
||||
|
||||
cudaProfilerStop();
|
||||
|
@ -506,15 +506,6 @@ class ParallelConfig:
|
||||
raise ValueError("Unable to use nsight profiling unless workers "
|
||||
"run with Ray.")
|
||||
|
||||
# FIXME(woosuk): Fix the stability issues and re-enable the custom
|
||||
# all-reduce kernel.
|
||||
if not self.disable_custom_all_reduce and self.world_size > 1:
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.info(
|
||||
"Custom all-reduce kernels are temporarily disabled due to "
|
||||
"stability issues. We will re-enable them once the issues are "
|
||||
"resolved.")
|
||||
|
||||
|
||||
class SchedulerConfig:
|
||||
"""Scheduler configuration.
|
||||
|
@ -83,7 +83,7 @@ class LLM:
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_custom_all_reduce: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
|
@ -37,16 +37,23 @@ def init_custom_ar() -> None:
|
||||
logger.warn(
|
||||
"Custom allreduce is disabled due to an unsupported world size: "
|
||||
"%d. Supported world sizes: %s. To silence this warning, specify"
|
||||
"disable_custom_all_reduce=True explicitly.", world_size,
|
||||
" disable_custom_all_reduce=True explicitly.", world_size,
|
||||
str(_SUPPORTED_WORLD_SIZES))
|
||||
return
|
||||
if not _can_p2p(rank, world_size):
|
||||
logger.warn(
|
||||
"Custom allreduce is disabled because your platform lacks GPU P2P"
|
||||
" capability. To silence this warning, specify"
|
||||
"disable_custom_all_reduce=True explicitly.")
|
||||
" capability or P2P test failed. To silence this warning, specify"
|
||||
" disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
_CA_HANDLE = CustomAllreduce(rank, world_size)
|
||||
full_nvlink = _is_full_nvlink(rank, world_size)
|
||||
if world_size > 2 and not full_nvlink:
|
||||
logger.warn(
|
||||
"Custom allreduce is disabled because it's not supported on more"
|
||||
" than two PCIe-only GPUs. To silence this warning, specify"
|
||||
" disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
_CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
|
||||
|
||||
|
||||
def begin_capture() -> None:
|
||||
@ -134,18 +141,48 @@ def _is_full_nvlink(rank, world_size):
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
num_dev = torch.cuda.device_count()
|
||||
# note: num dev can be larger than world_size if we're only using
|
||||
# first few GPUs
|
||||
if num_dev < world_size:
|
||||
logger.warn(
|
||||
"Cannot test GPU P2P because not all GPUs are visible to the "
|
||||
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
|
||||
" is set.")
|
||||
return False
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
if not torch.cuda.can_device_access_peer(rank, i):
|
||||
return False
|
||||
# on some platforms, P2P support might be buggy and we need
|
||||
# additional checks. See also:
|
||||
# https://github.com/vllm-project/vllm/issues/2728
|
||||
if not _can_actually_p2p(rank, i):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# code partly borrowed from
|
||||
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
|
||||
# License: MIT
|
||||
def _can_actually_p2p(idx_a, idx_b):
|
||||
dev_i = f"cuda:{idx_a}"
|
||||
dev_j = f"cuda:{idx_b}"
|
||||
a = torch.randn(5, device=dev_i) + 123.0
|
||||
b = a.to(dev_j)
|
||||
c = b.to(dev_i)
|
||||
return torch.all(a == c)
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(self, rank, world_size, max_size=8192 * 1024) -> None:
|
||||
def __init__(self,
|
||||
rank,
|
||||
world_size,
|
||||
full_nvlink,
|
||||
max_size=8192 * 1024) -> None:
|
||||
# buffers memory are owned by this Python class and passed to C++
|
||||
# meta data composes of two parts: meta data for synchronization
|
||||
# (256 bytes) and a temporary buffer for storing intermediate
|
||||
@ -167,11 +204,10 @@ class CustomAllreduce:
|
||||
self.max_size = max_size
|
||||
self.world_size = world_size
|
||||
handles, offsets = self._get_ipc_meta(self.meta)
|
||||
self.full_nvlink = _is_full_nvlink(rank, world_size)
|
||||
self.full_nvlink = full_nvlink
|
||||
self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
|
||||
handles, offsets, rank,
|
||||
self.full_nvlink)
|
||||
self.fast_cond = self.full_nvlink or world_size <= 2
|
||||
self.register_buffer(self.buffer)
|
||||
|
||||
def _get_ipc_meta(self, inp: torch.Tensor):
|
||||
|
Loading…
x
Reference in New Issue
Block a user