diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 88e4af9d..3906dcfc 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); } return (fptr_t) new vllm::CustomAllreduce( - reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } @@ -62,9 +62,9 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, if (inp_size % 16 != 0) return false; if (!_is_weak_contiguous(inp)) return false; if (world_size == 2 || full_nvlink) return inp_size <= max_size; - // 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size - // <= 512k - return world_size <= 4 && inp_size <= 512 * 1024; + // for 4 or more non NVLink-capable GPUs, custom allreduce provides little + // performance improvement over NCCL. + return false; } void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, @@ -126,7 +126,7 @@ void dispose(fptr_t _fa) { delete fa; } -int meta_size() { return sizeof(vllm::Metadata); } +int meta_size() { return sizeof(vllm::Signal); } void register_buffer(fptr_t _fa, torch::Tensor &t, const std::vector &handles, diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 54409e19..750e68d4 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -23,29 +23,17 @@ namespace vllm { +constexpr int kMaxBlocks = 64; +// note: we don't want to use atomics for signals because peer atomics are no +// supported on PCIe links struct Signal { - alignas(64) union { - uint64_t flag; - unsigned char data[8]; - } start; - alignas(64) union { - uint64_t flag; - unsigned char data[8]; - } end; + alignas(128) uint32_t start[kMaxBlocks][8]; + alignas(128) uint32_t end[kMaxBlocks][8]; }; -struct Metadata { - alignas(128) Signal sg; - alignas(128) int counter; -}; -static_assert(offsetof(Metadata, counter) == 128); -static_assert(sizeof(Metadata) == 256); - struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; -struct RankSignals { - volatile Signal *signals[8]; -}; +struct __align__(16) RankSignals { volatile Signal *signals[8]; }; // like std::array, but aligned template @@ -135,70 +123,49 @@ DINLINE O downcast(array_t val) { } } -// compute flag at compile time -__host__ __device__ constexpr uint64_t compute_flag(int ngpus) { - auto m = std::numeric_limits::max(); - return m >> ((8 - ngpus) * 8); -} - +// This function is meant to be used as the first synchronization in the all +// reduce kernel. Thus, it doesn't need to make any visibility guarantees for +// prior memory accesses. Note: volatile writes will not be reordered against +// other volatile writes. template -DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta, +DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, int rank) { - constexpr auto FLAG = compute_flag(ngpus); - if (blockIdx.x == 0) { - if (threadIdx.x < ngpus) - // simultaneously write to the corresponding byte to all other ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->start.data[rank] = 255; - else if (threadIdx.x == 32) - // reset - meta->sg.end.flag = 0; - } - if (threadIdx.x == 0) { - while (meta->sg.start.flag != FLAG) + if (threadIdx.x < ngpus) { + // reset flag for next time + self_sg->end[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while (!self_sg->start[blockIdx.x][threadIdx.x]) ; } __syncthreads(); } +// This function is meant to be used as the second or the final synchronization +// barrier in the all reduce kernel. If it's the final synchronization barrier, +// we don't need to make any visibility guarantees for prior memory accesses. template -DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta, +DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, int rank) { - constexpr auto FLAG = compute_flag(ngpus); __syncthreads(); - __shared__ int num; - if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1); - __syncthreads(); - - // Only the last completing block can perform the end synchronization - // This can ensures when the final busy wait ends, all ranks must have - // finished reading each other's buffer. - if (num == gridDim.x - 1) { - if (threadIdx.x == 32) { - // reset in a different warp - meta->counter = 0; - meta->sg.start.flag = 0; - } else if (threadIdx.x < ngpus) { - // simultaneously write to the corresponding byte to all other ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->end.data[rank] = 255; - } - // if this is the final sync, only one block needs it - // because kernel exit can serve as sync - if constexpr (final_sync) { - if (threadIdx.x == 0) { - while (meta->sg.end.flag != FLAG) - ; - } - } - } - if constexpr (!final_sync) { - if (threadIdx.x == 0) { - while (meta->sg.end.flag != FLAG) - ; - } - __syncthreads(); + // eliminate the case that prior writes are not visible after signals become + // visible. Note that I did not managed to make this happen through a lot of + // testing. Might be the case that hardware provides stronger guarantee than + // the memory model. + if constexpr (!final_sync) __threadfence_system(); + if (threadIdx.x < ngpus) { + // reset flag for next time + self_sg->start[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while (!self_sg->end[blockIdx.x][threadIdx.x]) + ; } + if constexpr (!final_sync) __syncthreads(); } template @@ -214,32 +181,32 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData *_dp, RankSignals sg, - volatile Metadata *meta, T *__restrict__ result, + volatile Signal *self_sg, T *__restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_dp; - start_sync(sg, meta, rank); + start_sync(sg, self_sg, rank); // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { ((P *)result)[idx] = packed_reduce((const P **)&dp.ptrs[0], idx); } - end_sync(sg, meta, rank); + end_sync(sg, self_sg, rank); } template DINLINE P *get_tmp_buf(volatile Signal *sg) { - return (P *)(((Metadata *)sg) + 1); + return (P *)(((Signal *)sg) + 1); } template __global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData *_dp, RankSignals sg, - volatile Metadata *meta, T *__restrict__ result, + volatile Signal *self_sg, T *__restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; @@ -248,6 +215,7 @@ __global__ void __launch_bounds__(512, 1) int part = size / ngpus; int start = rank * part; int end = rank == ngpus - 1 ? size : start + part; + int largest_part = part + size % ngpus; const P *ptrs[ngpus]; P *tmps[ngpus]; #pragma unroll @@ -257,75 +225,28 @@ __global__ void __launch_bounds__(512, 1) tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; - start_sync(sg, meta, rank); + start_sync(sg, self_sg, rank); // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { tmp_out[idx - start] = packed_reduce(ptrs, idx); } - // Maybe TODO: replace this with per-block release-acquire - // can save about 1-2us (not a lot though) - end_sync(sg, meta, rank); + end_sync(sg, self_sg, rank); - // stage 2: allgather - for (int idx = tid; idx < part; idx += stride) { + // stage 2: allgather. Note: it's important to match the tid between + // the two stages, because visibility across devices is only guaranteed + // between threads that have the same tid. If thread i computes the sum of + // start + i in the first stage, then thread i also gathers start + i from all + // ranks. + for (int idx = tid; idx < largest_part; idx += stride) { #pragma unroll for (int i = 0; i < ngpus; i++) { - int dst_idx = ((rank + i) % ngpus) * part + idx; - ((P *)result)[dst_idx] = tmps[i][idx]; + int gather_from_rank = ((rank + i) % ngpus); + if (gather_from_rank == ngpus - 1 || idx < part) { + int dst_idx = gather_from_rank * part + idx; + ((P *)result)[dst_idx] = tmps[i][idx]; + } } } - // process the last larger partition - int remaining = size - part * ngpus; - if (tid < remaining) { - int dst_idx = tid + part * ngpus; - ((P *)result)[dst_idx] = get_tmp_buf

(sg.signals[ngpus - 1])[part + tid]; - } - - // faster than this - // for (int idx = tid; idx < size; idx += stride) { - // int target_rank = idx / part; - // if (target_rank == ngpus) target_rank -= 1; - // ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part]; - // } -} - -template -__global__ void __launch_bounds__(512, 1) - cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg, - volatile Metadata *meta, - T *__restrict__ result, int rank, - int size) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = gridDim.x * blockDim.x; - using P = typename packed_t::P; - using A = typename packed_t::A; - auto tmp_out = get_tmp_buf

(sg.signals[rank]); - constexpr int hg = ngpus / 2; - // Actually not quite half butterfly. - // This is an all-to-all within each group containing half of the ranks - // followed by cross-group add. Equivalent to half butterfly when there - // are 4 GPUs, a common case for PCIe cards like T4 and A10. - const P *ptrs[hg]; - { - int start = rank - rank % hg; -#pragma unroll - for (int i = 0; i < hg; i++) { - ptrs[i] = (const P *)_dp->ptrs[i + start]; - } - } - start_sync(sg, meta, rank); - for (int idx = tid; idx < size; idx += stride) { - tmp_out[idx] = packed_reduce(ptrs, idx); - } - end_sync(sg, meta, rank); - - auto src = get_tmp_buf

(sg.signals[(ngpus - 1) - rank % ngpus]); - // do the cross group reduction - for (int idx = tid; idx < size; idx += stride) { - auto tmp = tmp_out[idx]; - packed_assign_add(tmp, src[idx]); - ((P *)result)[idx] = tmp; - } } using IPC_KEY = std::array; @@ -341,7 +262,7 @@ class CustomAllreduce { // below are device pointers RankSignals sg_; std::unordered_map buffers_; - Metadata *meta_; + Signal *self_sg_; // stores the registered device pointers from all ranks RankData *d_rank_data_base_, *d_rank_data_end_; @@ -352,32 +273,32 @@ class CustomAllreduce { /** * meta is a pointer to device metadata and temporary buffer for allreduce. * - * There's a total of sizeof(Metadata) of prefix before the actual data, + * There's a total of sizeof(Signal) of prefix before the actual data, * so meta + 1 points to actual temporary buffer. * * note: this class does not own any device memory. Any required buffers * are passed in from the constructor */ - CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz, + CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, const cudaIpcMemHandle_t *handles, const std::vector &offsets, int rank, bool full_nvlink = true) : rank_(rank), world_size_(offsets.size()), full_nvlink_(full_nvlink), - meta_(meta), + self_sg_(meta), d_rank_data_base_(reinterpret_cast(rank_data)), d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { for (int i = 0; i < world_size_; i++) { - Metadata *rank_meta; + Signal *rank_sg; if (i != rank_) { char *handle = open_ipc_handle(&handles[i]); handle += offsets[i]; - rank_meta = (Metadata *)handle; + rank_sg = (Signal *)handle; } else { - rank_meta = meta_; + rank_sg = self_sg_; } - sg_.signals[i] = &rank_meta->sg; + sg_.signals[i] = rank_sg; } } @@ -492,6 +413,10 @@ class CustomAllreduce { "custom allreduce currently requires input length to be multiple " "of " + std::to_string(d)); + if (block_limit > kMaxBlocks) + throw std::runtime_error("max supported block limit is " + + std::to_string(kMaxBlocks) + ". Got " + + std::to_string(block_limit)); RankData *ptrs; cudaStreamCaptureStatus status; @@ -512,9 +437,9 @@ class CustomAllreduce { size /= d; auto bytes = size * sizeof(typename packed_t::P); int blocks = std::min(block_limit, (size + threads - 1) / threads); -#define KL(ngpus, name) \ - name \ - <<>>(ptrs, sg_, meta_, output, rank_, size); +#define KL(ngpus, name) \ + name<<>>(ptrs, sg_, self_sg_, output, \ + rank_, size); #define REDUCE_CASE(ngpus) \ case ngpus: { \ if (world_size_ == 2) { \ @@ -526,8 +451,6 @@ class CustomAllreduce { } else { \ KL(ngpus, cross_device_reduce_2stage); \ } \ - } else { \ - KL(ngpus, cross_device_reduce_half_butterfly); \ } \ break; \ } @@ -556,7 +479,7 @@ class CustomAllreduce { /** * To inspect PTX/SASS, copy paste this header file to compiler explorer and add a template instantiation: - * template void CustomAllreduce::allreduce(cudaStream_t, half *, half *, - int, int, int); + * template void vllm::CustomAllreduce::allreduce(cudaStream_t, half *, + half *, int, int, int); */ } // namespace vllm diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index 6b094e2f..c34a5038 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -92,7 +92,7 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, template void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, - int data_size) { + int data_size, bool performance_test) { T *result; cudaStream_t stream; CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); @@ -101,7 +101,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t data_handles[8]; - vllm::Metadata *buffer; + vllm::Signal *buffer; T *self_data_copy; /** * Allocate IPC buffer @@ -115,9 +115,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, * convenience. */ CUDACHECK( - cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata))); - CUDACHECK(cudaMemset(buffer, 0, - 2 * data_size * sizeof(T) + sizeof(vllm::Metadata))); + cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal))); + CUDACHECK( + cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal))); CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T))); CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer)); @@ -133,7 +133,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, offsets, myRank); auto *self_data = reinterpret_cast(reinterpret_cast(buffer) + - sizeof(vllm::Metadata) + data_size * sizeof(T)); + sizeof(vllm::Signal) + data_size * sizeof(T)); // hack buffer registration { std::vector handles; @@ -143,8 +143,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, char *end = (char *)&data_handles[i + 1]; handles.emplace_back(begin, end); } - std::vector offsets( - nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T)); + std::vector offsets(nRanks, + sizeof(vllm::Signal) + data_size * sizeof(T)); fa.register_buffer(handles, offsets, self_data); } @@ -169,81 +169,112 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, } else { ncclDtype = ncclFloat; } - - dummy_kernel<<<1, 1, 0, stream>>>(); - constexpr int warmup_iters = 5; - constexpr int num_iters = 25; - // warmup - for (int i = 0; i < warmup_iters; i++) { - NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm, - stream)); - } - CUDACHECK(cudaEventRecord(start, stream)); - for (int i = 0; i < num_iters; i++) { - NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm, - stream)); - } - CUDACHECK(cudaEventRecord(stop, stream)); - CUDACHECK(cudaStreamSynchronize(stream)); - float allreduce_ms = 0; - cudaEventElapsedTime(&allreduce_ms, start, stop); - - // if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>(); - // set_data<<<16, 1024, 0, stream>>>(self_data, data_size, myRank); - - dummy_kernel<<<1, 1, 0, stream>>>(); - // warm up - for (int i = 0; i < warmup_iters; i++) { - fa.allreduce(stream, self_data, result, data_size, threads, block_limit); - } - CUDACHECK(cudaEventRecord(start, stream)); - for (int i = 0; i < num_iters; i++) { - fa.allreduce(stream, self_data, result, data_size, threads, block_limit); - } - CUDACHECK(cudaEventRecord(stop, stream)); - CUDACHECK(cudaStreamSynchronize(stream)); - - float duration_ms = 0; - cudaEventElapsedTime(&duration_ms, start, stop); - if (myRank == 0) - printf( - "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl " - "time:%.2fus\n", - myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit, - duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters); - - // And wait for all the queued up work to complete - CUDACHECK(cudaStreamSynchronize(stream)); - - NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype, - ncclSum, comm, stream)); - double *nccl_result, *my_result; CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double))); CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double))); - - convert_data<<<108, 1024, 0, stream>>>(self_data, result, nccl_result, - my_result, data_size); - CUDACHECK(cudaStreamSynchronize(stream)); - - for (unsigned long j = 0; j < data_size; j++) { - auto diff = abs(nccl_result[j] - my_result[j]); - if (diff >= 1e-2) { - printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n", - myRank, j, nccl_result[j], my_result[j], ground_truth[j]); - break; + if (performance_test) { + dummy_kernel<<<1, 1, 0, stream>>>(); + constexpr int warmup_iters = 5; + constexpr int num_iters = 100; + // warmup + for (int i = 0; i < warmup_iters; i++) { + NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, + comm, stream)); } - } + CUDACHECK(cudaEventRecord(start, stream)); + for (int i = 0; i < num_iters; i++) { + NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, + comm, stream)); + } + CUDACHECK(cudaEventRecord(stop, stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + float allreduce_ms = 0; + cudaEventElapsedTime(&allreduce_ms, start, stop); - long double nccl_diffs = 0.0; - long double my_diffs = 0.0; - for (int j = 0; j < data_size; j++) { - nccl_diffs += abs(nccl_result[j] - ground_truth[j]); - my_diffs += abs(my_result[j] - ground_truth[j]); + dummy_kernel<<<1, 1, 0, stream>>>(); + // warm up + for (int i = 0; i < warmup_iters; i++) { + fa.allreduce(stream, self_data, result, data_size, threads, + block_limit); + } + CUDACHECK(cudaEventRecord(start, stream)); + for (int i = 0; i < num_iters; i++) { + fa.allreduce(stream, self_data, result, data_size, threads, + block_limit); + } + CUDACHECK(cudaEventRecord(stop, stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + + float duration_ms = 0; + cudaEventElapsedTime(&duration_ms, start, stop); + if (myRank == 0) + printf( + "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl " + "time:%.2fus\n", + myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit, + duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters); + + // And wait for all the queued up work to complete + CUDACHECK(cudaStreamSynchronize(stream)); + + NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype, + ncclSum, comm, stream)); + + convert_data<<<108, 1024, 0, stream>>>(self_data, result, nccl_result, + my_result, data_size); + CUDACHECK(cudaStreamSynchronize(stream)); + + for (unsigned long j = 0; j < data_size; j++) { + auto diff = abs(nccl_result[j] - my_result[j]); + if (diff >= 4e-2) { + printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n", + myRank, j, nccl_result[j], my_result[j], ground_truth[j]); + break; + } + } + long double nccl_diffs = 0.0; + long double my_diffs = 0.0; + for (int j = 0; j < data_size; j++) { + nccl_diffs += abs(nccl_result[j] - ground_truth[j]); + my_diffs += abs(my_result[j] - ground_truth[j]); + } + if (myRank == 0) + std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size + << " me: " << my_diffs / data_size << std::endl; + } else { + for (int i = 0; i < 100; i++) { + fa.allreduce(stream, self_data, result, data_size, threads, + block_limit); + CUDACHECK(cudaStreamSynchronize(stream)); + NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype, + ncclSum, comm, stream)); + convert_data<<<108, 1024, 0, stream>>>( + self_data_copy, result, nccl_result, my_result, data_size); + CUDACHECK(cudaStreamSynchronize(stream)); + + for (unsigned long j = 0; j < data_size; j++) { + auto diff = abs(nccl_result[j] - my_result[j]); + if (diff >= 4e-2) { + printf( + "Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n", + myRank, j, nccl_result[j], my_result[j], ground_truth[j]); + break; + } + } + } + if (myRank == 0) + printf("Test passed: nGPUs:%d, sz (kb): %d, %d, %d\n", nRanks, + data_size * sizeof(T) / 1024, threads, block_limit); + // long double nccl_diffs = 0.0; + // long double my_diffs = 0.0; + // for (int j = 0; j < data_size; j++) { + // nccl_diffs += abs(nccl_result[j] - ground_truth[j]); + // my_diffs += abs(my_result[j] - ground_truth[j]); + // } + // if (myRank == 0) + // std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size + // << " me: " << my_diffs / data_size << std::endl; } - if (myRank == 0) - std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size - << " me: " << my_diffs / data_size << std::endl; CUDACHECK(cudaFree(result)); CUDACHECK(cudaFree(self_data_copy)); @@ -269,14 +300,15 @@ int main(int argc, char **argv) { MPI_COMM_WORLD)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); + bool performance_test = true; cudaProfilerStart(); // for (int threads : {256, 512}) { // for (int block_limit = 16; block_limit < 112; block_limit += 4) { // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); // } // } - for (int sz = 512; sz <= (32 << 20); sz *= 2) { - run(myRank, nRanks, comm, 512, 36, sz + 8 * 50); + for (int sz = 512; sz <= (8 << 20); sz *= 2) { + run(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); } cudaProfilerStop(); diff --git a/vllm/config.py b/vllm/config.py index a86114f3..6dfb5158 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -506,15 +506,6 @@ class ParallelConfig: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") - # FIXME(woosuk): Fix the stability issues and re-enable the custom - # all-reduce kernel. - if not self.disable_custom_all_reduce and self.world_size > 1: - self.disable_custom_all_reduce = True - logger.info( - "Custom all-reduce kernels are temporarily disabled due to " - "stability issues. We will re-enable them once the issues are " - "resolved.") - class SchedulerConfig: """Scheduler configuration. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1f463bda..e9b3d46d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -83,7 +83,7 @@ class LLM: swap_space: int = 4, enforce_eager: bool = False, max_context_len_to_capture: int = 8192, - disable_custom_all_reduce: bool = False, + disable_custom_all_reduce: bool = True, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index 0c749c04..396be894 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -37,16 +37,23 @@ def init_custom_ar() -> None: logger.warn( "Custom allreduce is disabled due to an unsupported world size: " "%d. Supported world sizes: %s. To silence this warning, specify" - "disable_custom_all_reduce=True explicitly.", world_size, + " disable_custom_all_reduce=True explicitly.", world_size, str(_SUPPORTED_WORLD_SIZES)) return if not _can_p2p(rank, world_size): logger.warn( "Custom allreduce is disabled because your platform lacks GPU P2P" - " capability. To silence this warning, specify" - "disable_custom_all_reduce=True explicitly.") + " capability or P2P test failed. To silence this warning, specify" + " disable_custom_all_reduce=True explicitly.") return - _CA_HANDLE = CustomAllreduce(rank, world_size) + full_nvlink = _is_full_nvlink(rank, world_size) + if world_size > 2 and not full_nvlink: + logger.warn( + "Custom allreduce is disabled because it's not supported on more" + " than two PCIe-only GPUs. To silence this warning, specify" + " disable_custom_all_reduce=True explicitly.") + return + _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink) def begin_capture() -> None: @@ -134,18 +141,48 @@ def _is_full_nvlink(rank, world_size): def _can_p2p(rank: int, world_size: int) -> bool: + num_dev = torch.cuda.device_count() + # note: num dev can be larger than world_size if we're only using + # first few GPUs + if num_dev < world_size: + logger.warn( + "Cannot test GPU P2P because not all GPUs are visible to the " + "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" + " is set.") + return False for i in range(world_size): if i == rank: continue if not torch.cuda.can_device_access_peer(rank, i): return False + # on some platforms, P2P support might be buggy and we need + # additional checks. See also: + # https://github.com/vllm-project/vllm/issues/2728 + if not _can_actually_p2p(rank, i): + return False return True +# code partly borrowed from +# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10 +# License: MIT +def _can_actually_p2p(idx_a, idx_b): + dev_i = f"cuda:{idx_a}" + dev_j = f"cuda:{idx_b}" + a = torch.randn(5, device=dev_i) + 123.0 + b = a.to(dev_j) + c = b.to(dev_i) + return torch.all(a == c) + + class CustomAllreduce: # max_size: max supported allreduce size - def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: + def __init__(self, + rank, + world_size, + full_nvlink, + max_size=8192 * 1024) -> None: # buffers memory are owned by this Python class and passed to C++ # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate @@ -167,11 +204,10 @@ class CustomAllreduce: self.max_size = max_size self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) - self.full_nvlink = _is_full_nvlink(rank, world_size) + self.full_nvlink = full_nvlink self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink) - self.fast_cond = self.full_nvlink or world_size <= 2 self.register_buffer(self.buffer) def _get_ipc_meta(self, inp: torch.Tensor):