diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu new file mode 100644 index 00000000..88e4af9d --- /dev/null +++ b/csrc/custom_all_reduce.cu @@ -0,0 +1,148 @@ +#include +#include +#include +#include + +#include "custom_all_reduce.cuh" + +// fake pointer type +using fptr_t = uint64_t; +static_assert(sizeof(void *) == sizeof(fptr_t)); + +fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, + const std::vector &handles, + const std::vector &offsets, int rank, + bool full_nvlink) { + int world_size = offsets.size(); + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (world_size != handles.size()) + throw std::invalid_argument( + "handles length should equal to offsets length"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + + cudaIpcMemHandle_t ipc_handles[8]; + for (int i = 0; i < world_size; i++) { + std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); + } + return (fptr_t) new vllm::CustomAllreduce( + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); +} + +/** + * Make sure tensor t's data lies completely within ((char)t.data_ptr()) + + * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous() + * because it allows transpose of contiguous slice (i.e. slicing the first + * dimension). Currently, we require this because stride information is not + * passed into the kernels and we treat input tensors as flat. + * + * Examples + * A = torch.zeros(3, 3, 3) + * 1. A: OK + * 2. A[1:]: OK + * 3. A.permute(2, 0, 1): OK + * 4. A[1:].permute(2, 0, 1): OK + * 5. A[None].expand(2, -1, -1, -1): Not OK + * 6. A[:, 1:, 1:]: Not OK + */ +bool _is_weak_contiguous(torch::Tensor &t) { + return t.is_contiguous() || + (t.storage().nbytes() - t.storage_offset() * t.element_size() == + t.numel() * t.element_size()); +} + +bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, + bool full_nvlink) { + auto inp_size = inp.numel() * inp.element_size(); + // custom allreduce requires input byte size to be multiples of 16 + if (inp_size % 16 != 0) return false; + if (!_is_weak_contiguous(inp)) return false; + if (world_size == 2 || full_nvlink) return inp_size <= max_size; + // 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size + // <= 512k + return world_size <= 4 && inp_size <= 512 * 1024; +} + +void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, + cudaStream_t stream) { + auto fa = reinterpret_cast(_fa); + TORCH_CHECK(_is_weak_contiguous(out)); + switch (out.scalar_type()) { + case at::ScalarType::Float: { + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel()); + break; + } + case at::ScalarType::Half: { + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel()); + break; + } +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + case at::ScalarType::BFloat16: { + fa->allreduce( + stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); + break; + } +#endif + default: + throw std::runtime_error( + "custom allreduce only supports float32, float16 and bfloat16"); + } +} + +void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + _all_reduce(_fa, inp, out, stream); +} + +void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, + torch::Tensor &out) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + + auto input_size = inp.numel() * inp.element_size(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), + "registered buffer is too small to contain the input"); + AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), + input_size, cudaMemcpyDeviceToDevice, stream)); + _all_reduce(_fa, reg_buffer, out, stream); +} + +void dispose(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + delete fa; +} + +int meta_size() { return sizeof(vllm::Metadata); } + +void register_buffer(fptr_t _fa, torch::Tensor &t, + const std::vector &handles, + const std::vector &offsets) { + auto fa = reinterpret_cast(_fa); + fa->register_buffer(handles, offsets, t.data_ptr()); +} + +std::pair, std::vector> get_graph_buffer_ipc_meta( + fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + return fa->get_graph_buffer_ipc_meta(); +} + +void register_graph_buffers(fptr_t _fa, const std::vector &handles, + const std::vector> &offsets) { + auto fa = reinterpret_cast(_fa); + fa->register_graph_buffers(handles, offsets); +} diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh new file mode 100644 index 00000000..6e71bb9a --- /dev/null +++ b/csrc/custom_all_reduce.cuh @@ -0,0 +1,555 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +namespace vllm { + +struct Signal { + alignas(64) union { + uint64_t flag; + unsigned char data[8]; + } start; + alignas(64) union { + uint64_t flag; + unsigned char data[8]; + } end; +}; + +struct Metadata { + alignas(128) Signal sg; + alignas(128) int counter; +}; +static_assert(offsetof(Metadata, counter) == 128); +static_assert(sizeof(Metadata) == 256); + +struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; + +struct RankSignals { + volatile Signal *signals[8]; +}; + +// like std::array, but aligned +template +struct __align__(alignof(T) * sz) array_t { + T data[sz]; + using type = T; + static constexpr int size = sz; +}; + +// use packed type to maximize memory efficiency +// goal: generate ld.128 and st.128 instructions +template +struct packed_t { + // the (P)acked type for load/store + using P = array_t; + // the (A)ccumulator type for reduction + using A = array_t; +}; + +#define DINLINE __device__ __forceinline__ + +// scalar cast functions +DINLINE float upcast_s(half val) { return __half2float(val); } + +template +DINLINE T downcast_s(float val); +template <> +DINLINE half downcast_s(float val) { + return __float2half(val); +} + +// scalar add functions +// for some reason when compiling with Pytorch, the + operator for half and +// bfloat is disabled so we call the intrinsics directly +DINLINE half &assign_add(half &a, half b) { + a = __hadd(a, b); + return a; +} +DINLINE float &assign_add(float &a, float b) { return a += b; } + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } +template <> +DINLINE nv_bfloat16 downcast_s(float val) { + return __float2bfloat16(val); +} +DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { + a = __hadd(a, b); + return a; +} +#endif + +template +DINLINE array_t &packed_assign_add(array_t &a, array_t b) { +#pragma unroll + for (int i = 0; i < N; i++) { + assign_add(a.data[i], b.data[i]); + } + return a; +} + +template +DINLINE array_t upcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + array_t out; +#pragma unroll + for (int i = 0; i < N; i++) { + out.data[i] = upcast_s(val.data[i]); + } + return out; + } +} + +template +DINLINE O downcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + O out; +#pragma unroll + for (int i = 0; i < O::size; i++) { + out.data[i] = downcast_s(val.data[i]); + } + return out; + } +} + +// compute flag at compile time +__host__ __device__ constexpr uint64_t compute_flag(int ngpus) { + auto m = std::numeric_limits::max(); + return m >> ((8 - ngpus) * 8); +} + +template +DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta, + int rank) { + constexpr auto FLAG = compute_flag(ngpus); + if (blockIdx.x == 0) { + if (threadIdx.x < ngpus) + // simultaneously write to the corresponding byte to all other ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->start.data[rank] = 255; + else if (threadIdx.x == 32) + // reset + meta->sg.end.flag = 0; + } + if (threadIdx.x == 0) { + while (meta->sg.start.flag != FLAG) + ; + } + __syncthreads(); +} + +template +DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta, + int rank) { + constexpr auto FLAG = compute_flag(ngpus); + __syncthreads(); + __shared__ int num; + if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1); + __syncthreads(); + + // Only the last completing block can perform the end synchronization + // This can ensures when the final busy wait ends, all ranks must have + // finished reading each other's buffer. + if (num == gridDim.x - 1) { + if (threadIdx.x == 32) { + // reset in a different warp + meta->counter = 0; + meta->sg.start.flag = 0; + } else if (threadIdx.x < ngpus) { + // simultaneously write to the corresponding byte to all other ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->end.data[rank] = 255; + } + // if this is the final sync, only one block needs it + // because kernel exit can serve as sync + if constexpr (final_sync) { + if (threadIdx.x == 0) { + while (meta->sg.end.flag != FLAG) + ; + } + } + } + if constexpr (!final_sync) { + if (threadIdx.x == 0) { + while (meta->sg.end.flag != FLAG) + ; + } + __syncthreads(); + } +} + +template +DINLINE P packed_reduce(const P *ptrs[], int idx) { + A tmp = upcast(ptrs[0][idx]); +#pragma unroll + for (int i = 1; i < ngpus; i++) { + packed_assign_add(tmp, upcast(ptrs[i][idx])); + } + return downcast

(tmp); +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_1stage(RankData *_dp, RankSignals sg, + volatile Metadata *meta, T *__restrict__ result, + int rank, int size) { + using P = typename packed_t::P; + using A = typename packed_t::A; + // note: we don't reorder the address so the accumulation order is the same + // for all ranks, ensuring bitwise identical results + auto dp = *_dp; + start_sync(sg, meta, rank); + // do the actual reduction + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + ((P *)result)[idx] = + packed_reduce((const P **)&dp.ptrs[0], idx); + } + end_sync(sg, meta, rank); +} + +template +DINLINE P *get_tmp_buf(volatile Signal *sg) { + return (P *)(((Metadata *)sg) + 1); +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_2stage(RankData *_dp, RankSignals sg, + volatile Metadata *meta, T *__restrict__ result, + int rank, int size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename packed_t::P; + using A = typename packed_t::A; + int part = size / ngpus; + int start = rank * part; + int end = rank == ngpus - 1 ? size : start + part; + const P *ptrs[ngpus]; + P *tmps[ngpus]; +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int target = (rank + i) % ngpus; + ptrs[i] = (const P *)_dp->ptrs[target]; + tmps[i] = get_tmp_buf

(sg.signals[target]); + } + auto tmp_out = tmps[0]; + start_sync(sg, meta, rank); + // stage 1: reduce scatter + for (int idx = start + tid; idx < end; idx += stride) { + tmp_out[idx - start] = packed_reduce(ptrs, idx); + } + // Maybe TODO: replace this with per-block release-acquire + // can save about 1-2us (not a lot though) + end_sync(sg, meta, rank); + + // stage 2: allgather + for (int idx = tid; idx < part; idx += stride) { +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int dst_idx = ((rank + i) % ngpus) * part + idx; + ((P *)result)[dst_idx] = tmps[i][idx]; + } + } + // process the last larger partition + int remaining = size - part * ngpus; + if (tid < remaining) { + int dst_idx = tid + part * ngpus; + ((P *)result)[dst_idx] = get_tmp_buf

(sg.signals[ngpus - 1])[part + tid]; + } + + // faster than this + // for (int idx = tid; idx < size; idx += stride) { + // int target_rank = idx / part; + // if (target_rank == ngpus) target_rank -= 1; + // ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part]; + // } +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg, + volatile Metadata *meta, + T *__restrict__ result, int rank, + int size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename packed_t::P; + using A = typename packed_t::A; + auto tmp_out = get_tmp_buf

(sg.signals[rank]); + constexpr int hg = ngpus / 2; + // Actually not quite half butterfly. + // This is an all-to-all within each group containing half of the ranks + // followed by cross-group add. Equivalent to half butterfly when there + // are 4 GPUs, a common case for PCIe cards like T4 and A10. + const P *ptrs[hg]; + { + int start = rank - rank % hg; +#pragma unroll + for (int i = 0; i < hg; i++) { + ptrs[i] = (const P *)_dp->ptrs[i + start]; + } + } + start_sync(sg, meta, rank); + for (int idx = tid; idx < size; idx += stride) { + tmp_out[idx] = packed_reduce(ptrs, idx); + } + end_sync(sg, meta, rank); + + auto src = get_tmp_buf

(sg.signals[(ngpus - 1) - rank % ngpus]); + // do the cross group reduction + for (int idx = tid; idx < size; idx += stride) { + auto tmp = tmp_out[idx]; + packed_assign_add(tmp, src[idx]); + ((P *)result)[idx] = tmp; + } +} + +class CustomAllreduce { + public: + int rank_; + int world_size_; + bool full_nvlink_; + + // below are device pointers + RankSignals sg_; + std::unordered_map buffers_; + Metadata *meta_; + + // stores the registered device pointers from all ranks + RankData *d_rank_data_base_, *d_rank_data_end_; + std::vector graph_unreg_buffers_; + std::vector ipc_handles_; + + /** + * meta is a pointer to device metadata and temporary buffer for allreduce. + * + * There's a total of sizeof(Metadata) of prefix before the actual data, + * so meta + 1 points to actual temporary buffer. + * + * note: this class does not own any device memory. Any required buffers + * are passed in from the constructor + */ + CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz, + const cudaIpcMemHandle_t *handles, + const std::vector &offsets, int rank, + bool full_nvlink = true) + : rank_(rank), + world_size_(offsets.size()), + full_nvlink_(full_nvlink), + meta_(meta), + d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { + for (int i = 0; i < world_size_; i++) { + Metadata *rank_meta; + if (i != rank_) { + char *handle; + CUDACHECK(cudaIpcOpenMemHandle((void **)&handle, handles[i], + cudaIpcMemLazyEnablePeerAccess)); + ipc_handles_.push_back(handle); + handle += offsets[i]; + rank_meta = (Metadata *)handle; + } else { + rank_meta = meta_; + } + sg_.signals[i] = &rank_meta->sg; + } + } + + std::pair, std::vector> + get_graph_buffer_ipc_meta() { + auto num_buffers = graph_unreg_buffers_.size(); + auto handle_sz = sizeof(cudaIpcMemHandle_t); + std::vector handles(handle_sz * num_buffers, 0); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = graph_unreg_buffers_[i]; + void *base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (cuPointerGetAttribute(&base_ptr, + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + (CUdeviceptr)ptr) != CUDA_SUCCESS) + throw std::runtime_error("failed to get pointer attr"); + CUDACHECK(cudaIpcGetMemHandle( + (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char *)ptr) - ((char *)base_ptr); + } + return std::make_pair(handles, offsets); + } + + void check_rank_data_capacity(size_t num = 1) { + if (d_rank_data_base_ + num > d_rank_data_end_) + throw std::runtime_error( + "Rank data buffer is overflowed by " + + std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); + } + + void register_buffer(const std::vector &handles, + const std::vector &offsets, void *self) { + check_rank_data_capacity(); + RankData data; + for (int i = 0; i < world_size_; i++) { + if (i != rank_) { + char *handle; + CUDACHECK(cudaIpcOpenMemHandle( + (void **)&handle, *((const cudaIpcMemHandle_t *)handles[i].data()), + cudaIpcMemLazyEnablePeerAccess)); + ipc_handles_.push_back(handle); + handle += offsets[i]; + data.ptrs[i] = handle; + } else { + data.ptrs[i] = self; + } + } + auto d_data = d_rank_data_base_++; + CUDACHECK( + cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); + buffers_[self] = d_data; + } + + // note: when registering graph buffers, we intentionally choose to not + // deduplicate the addresses. That means if the allocator reuses some + // addresses, they will be registered again. This is to account for the remote + // possibility of different allocation patterns between ranks. For example, + // rank 1 may get the same input address for the second allreduce, but rank 2 + // got a different address. IPC handles have internal reference counting + // mechanism so overhead should be small. + void register_graph_buffers( + const std::vector &handles, + const std::vector> &offsets) { + auto num_buffers = graph_unreg_buffers_.size(); + check_rank_data_capacity(num_buffers); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto self_ptr = graph_unreg_buffers_[i]; + auto &rd = rank_data[i]; + for (int j = 0; j < world_size_; j++) { + if (j != rank_) { + char *handle; + CUDACHECK(cudaIpcOpenMemHandle( + (void **)&handle, + *((cudaIpcMemHandle_t *)&handles[j] + [i * sizeof(cudaIpcMemHandle_t)]), + cudaIpcMemLazyEnablePeerAccess)); + ipc_handles_.push_back(handle); + handle += offsets[j][i]; + rd.ptrs[j] = handle; + } else { + rd.ptrs[j] = self_ptr; + } + } + } + CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(), + sizeof(RankData) * num_buffers, + cudaMemcpyHostToDevice)); + d_rank_data_base_ += num_buffers; + graph_unreg_buffers_.clear(); + } + + /** + * This is the result after careful grid search. Using 36 blocks give the best + * or close to the best runtime on the devices I tried: A100, A10, A30, T4, + * V100. You'll notice that NCCL kernels also only take a small amount of SMs. + * Not quite sure the underlying reason, but my guess is that too many SMs + * will cause contention on NVLink bus. + */ + template + void allreduce(cudaStream_t stream, T *input, T *output, int size, + int threads = 512, int block_limit = 36) { + auto d = packed_t::P::size; + if (size % d != 0) + throw std::runtime_error( + "custom allreduce currently requires input length to be multiple " + "of " + + std::to_string(d)); + + RankData *ptrs; + cudaStreamCaptureStatus status; + CUDACHECK(cudaStreamIsCapturing(stream, &status)); + if (status == cudaStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) + throw std::runtime_error( + "buffer address " + + std::to_string(reinterpret_cast(input)) + + " is not registered!"); + ptrs = it->second; + } + + size /= d; + auto bytes = size * sizeof(typename packed_t::P); + int blocks = std::min(block_limit, (size + threads - 1) / threads); +#define KL(ngpus, name) \ + name \ + <<>>(ptrs, sg_, meta_, output, rank_, size); +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (world_size_ == 2) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (full_nvlink_) { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || \ + (world_size_ <= 8 && bytes < 256 * 1024)) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else { \ + KL(ngpus, cross_device_reduce_2stage); \ + } \ + } else { \ + KL(ngpus, cross_device_reduce_half_butterfly); \ + } \ + break; \ + } + + switch (world_size_) { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } +#undef REDUCE_CASE +#undef KL + } + + ~CustomAllreduce() { + for (auto ptr : ipc_handles_) { + CUDACHECK(cudaIpcCloseMemHandle(ptr)); + } + } +}; +/** + * To inspect PTX/SASS, copy paste this header file to compiler explorer and add + a template instantiation: + * template void CustomAllreduce::allreduce(cudaStream_t, half *, half *, + int, int, int); +*/ +} // namespace vllm diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu new file mode 100644 index 00000000..6b094e2f --- /dev/null +++ b/csrc/custom_all_reduce_test.cu @@ -0,0 +1,284 @@ +/** + * This is a standalone test for custom allreduce. + * To compile, make sure you have MPI and NCCL installed in your system. + * export MPI_HOME=XXX + * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o + * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi + * + * Warning: this C++ test is not designed to be very readable and was used + * during the rapid prototyping process. + * + * To run: + * mpirun -np 8 ./custom_all_reduce_test + */ +#include +#include +#include +#include + +#include +#include + +#include "cuda_profiler_api.h" +#include "custom_all_reduce.cuh" +#include "mpi.h" +#include "nccl.h" + +#define MPICHECK(cmd) \ + do { \ + int e = cmd; \ + if (e != MPI_SUCCESS) { \ + printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define NCCLCHECK(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \ + ncclGetErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +__global__ void dummy_kernel() { + for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms +} + +template +__global__ void set_data(T *data, int size, int myRank) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + data[idx] = myRank * 0.11f; + } +} + +template +__global__ void convert_data(const T *data1, const T *data2, double *fdata1, + double *fdata2, int size) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + fdata1[idx] = data1[idx]; + fdata2[idx] = data2[idx]; + } +} + +__global__ void init_rand(curandState_t *state, int size, int nRanks) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + for (int i = 0; i < nRanks; i++) { + curand_init(i + 1, idx, 0, &state[idx * nRanks + i]); + } + } +} + +template +__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, + int myRank, int nRanks, int size) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + double sum = 0.0; + for (int i = 0; i < nRanks; i++) { + double val = curand_uniform_double(&state[idx * nRanks + i]) * 4; + T hval = val; // downcast first + sum += static_cast(hval); + if (i == myRank) data[idx] = hval; + } + ground_truth[idx] = sum; + } +} + +template +void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, + int data_size) { + T *result; + cudaStream_t stream; + CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); + CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T))); + + cudaIpcMemHandle_t self_data_handle; + cudaIpcMemHandle_t data_handles[8]; + vllm::Metadata *buffer; + T *self_data_copy; + /** + * Allocate IPC buffer + * + * The first section is a temporary buffer for storing intermediate allreduce + * results, if a particular algorithm requires it. The second section is for + * the input to the allreduce. The actual API takes the input pointer as an + * argument (that is, they can and usually should be allocated separately). + * But since the input pointers and the temporary buffer all require IPC + * registration, they are allocated and registered together in the test for + * convenience. + */ + CUDACHECK( + cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata))); + CUDACHECK(cudaMemset(buffer, 0, + 2 * data_size * sizeof(T) + sizeof(vllm::Metadata))); + CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T))); + CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer)); + + MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t), + MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), + MPI_BYTE, MPI_COMM_WORLD)); + + void *rank_data; + size_t rank_data_sz = 16 * 1024 * 1024; + CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); + std::vector offsets(nRanks, 0); + vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, + offsets, myRank); + auto *self_data = + reinterpret_cast(reinterpret_cast(buffer) + + sizeof(vllm::Metadata) + data_size * sizeof(T)); + // hack buffer registration + { + std::vector handles; + handles.reserve(nRanks); + for (int i = 0; i < nRanks; i++) { + char *begin = (char *)&data_handles[i]; + char *end = (char *)&data_handles[i + 1]; + handles.emplace_back(begin, end); + } + std::vector offsets( + nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T)); + fa.register_buffer(handles, offsets, self_data); + } + + double *ground_truth; + CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); + curandState_t *states; + CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); + init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); + gen_data<<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, + nRanks, data_size); + CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + cudaEvent_t start, stop; + CUDACHECK(cudaEventCreate(&start)); + CUDACHECK(cudaEventCreate(&stop)); + + ncclDataType_t ncclDtype; + if (std::is_same::value) { + ncclDtype = ncclFloat16; + } else if (std::is_same::value) { + ncclDtype = ncclBfloat16; + } else { + ncclDtype = ncclFloat; + } + + dummy_kernel<<<1, 1, 0, stream>>>(); + constexpr int warmup_iters = 5; + constexpr int num_iters = 25; + // warmup + for (int i = 0; i < warmup_iters; i++) { + NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm, + stream)); + } + CUDACHECK(cudaEventRecord(start, stream)); + for (int i = 0; i < num_iters; i++) { + NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm, + stream)); + } + CUDACHECK(cudaEventRecord(stop, stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + float allreduce_ms = 0; + cudaEventElapsedTime(&allreduce_ms, start, stop); + + // if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>(); + // set_data<<<16, 1024, 0, stream>>>(self_data, data_size, myRank); + + dummy_kernel<<<1, 1, 0, stream>>>(); + // warm up + for (int i = 0; i < warmup_iters; i++) { + fa.allreduce(stream, self_data, result, data_size, threads, block_limit); + } + CUDACHECK(cudaEventRecord(start, stream)); + for (int i = 0; i < num_iters; i++) { + fa.allreduce(stream, self_data, result, data_size, threads, block_limit); + } + CUDACHECK(cudaEventRecord(stop, stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + + float duration_ms = 0; + cudaEventElapsedTime(&duration_ms, start, stop); + if (myRank == 0) + printf( + "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl " + "time:%.2fus\n", + myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit, + duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters); + + // And wait for all the queued up work to complete + CUDACHECK(cudaStreamSynchronize(stream)); + + NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype, + ncclSum, comm, stream)); + + double *nccl_result, *my_result; + CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double))); + CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double))); + + convert_data<<<108, 1024, 0, stream>>>(self_data, result, nccl_result, + my_result, data_size); + CUDACHECK(cudaStreamSynchronize(stream)); + + for (unsigned long j = 0; j < data_size; j++) { + auto diff = abs(nccl_result[j] - my_result[j]); + if (diff >= 1e-2) { + printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n", + myRank, j, nccl_result[j], my_result[j], ground_truth[j]); + break; + } + } + + long double nccl_diffs = 0.0; + long double my_diffs = 0.0; + for (int j = 0; j < data_size; j++) { + nccl_diffs += abs(nccl_result[j] - ground_truth[j]); + my_diffs += abs(my_result[j] - ground_truth[j]); + } + if (myRank == 0) + std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size + << " me: " << my_diffs / data_size << std::endl; + + CUDACHECK(cudaFree(result)); + CUDACHECK(cudaFree(self_data_copy)); + CUDACHECK(cudaFree(rank_data)); + CUDACHECK(cudaFree(buffer)); + CUDACHECK(cudaFree(states)); + CUDACHECK(cudaFreeHost(ground_truth)); + CUDACHECK(cudaFreeHost(nccl_result)); + CUDACHECK(cudaFreeHost(my_result)); + CUDACHECK(cudaStreamDestroy(stream)); +} + +int main(int argc, char **argv) { + int nRanks, myRank; + MPICHECK(MPI_Init(&argc, &argv)); + MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); + MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks)); + CUDACHECK(cudaSetDevice(myRank)); + ncclUniqueId id; + ncclComm_t comm; + if (myRank == 0) ncclGetUniqueId(&id); + MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, + MPI_COMM_WORLD)); + NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); + + cudaProfilerStart(); + // for (int threads : {256, 512}) { + // for (int block_limit = 16; block_limit < 112; block_limit += 4) { + // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); + // } + // } + for (int sz = 512; sz <= (32 << 20); sz *= 2) { + run(myRank, nRanks, comm, 512, 36, sz + 8 * 50); + } + + cudaProfilerStop(); + return EXIT_SUCCESS; +} diff --git a/csrc/ops.h b/csrc/ops.h index d4961964..6e996fd0 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -97,3 +97,25 @@ torch::Tensor gptq_gemm( void gptq_shuffle( torch::Tensor q_weight, torch::Tensor q_perm); + + +#ifndef USE_ROCM +using fptr_t = uint64_t; +fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, + const std::vector &handles, + const std::vector &offsets, int rank, + bool full_nvlink); +bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, + bool full_nvlink); +void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); +void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, + torch::Tensor &out); +void dispose(fptr_t _fa); +int meta_size(); +void register_buffer(fptr_t _fa, torch::Tensor &t, + const std::vector &handles, + const std::vector &offsets); +std::pair, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector &handles, + const std::vector> &offsets); +#endif diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 88af7eac..f94efadf 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -88,4 +88,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &get_max_shared_memory_per_block_device_attribute, "Gets the maximum shared memory per block device attribute."); +#ifndef USE_ROCM + // Custom all-reduce kernels + pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce"); + custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar"); + custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar"); + custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg"); + custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg"); + custom_ar.def("dispose", &dispose, "dispose"); + custom_ar.def("meta_size", &meta_size, "meta_size"); + custom_ar.def("register_buffer", ®ister_buffer, "register_buffer"); + custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, + "get_graph_buffer_ipc_meta"); + custom_ar.def("register_graph_buffers", ®ister_graph_buffers, + "register_graph_buffers"); +#endif + } diff --git a/requirements.txt b/requirements.txt index 299bad38..19871bdc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ fastapi uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. aioprometheus[starlette] +pynvml == 11.5.0 diff --git a/setup.py b/setup.py index 88fa4952..2f624269 100644 --- a/setup.py +++ b/setup.py @@ -51,8 +51,8 @@ if _is_hip(): "Cannot find ROCM_HOME. ROCm must be available to build the package." ) NVCC_FLAGS += ["-DUSE_ROCM"] - NVCC_FLAGS += [f"-U__HIP_NO_HALF_CONVERSIONS__"] - NVCC_FLAGS += [f"-U__HIP_NO_HALF_OPERATORS__"] + NVCC_FLAGS += ["-U__HIP_NO_HALF_CONVERSIONS__"] + NVCC_FLAGS += ["-U__HIP_NO_HALF_OPERATORS__"] if _is_cuda() and CUDA_HOME is None: raise RuntimeError( @@ -307,6 +307,7 @@ vllm_extension_sources = [ if _is_cuda(): vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") + vllm_extension_sources.append("csrc/custom_all_reduce.cu") if not _is_neuron(): vllm_extension = CUDAExtension( @@ -316,6 +317,7 @@ if not _is_neuron(): "cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS, }, + libraries=["cuda"] if _is_cuda() else [], ) ext_modules.append(vllm_extension) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index b12e563f..9474cb21 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -6,25 +6,13 @@ import pytest import torch import ray -from vllm.config import ParallelConfig -from vllm.utils import get_open_port from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather, broadcast_tensor_dict, ) -from vllm.worker.worker import _init_distributed_environment - - -def init_test_distributed_environment(pipeline_parallel_size: int, - tensor_parallel_size: int, rank: int, - distributed_init_port: str): - parallel_config = ParallelConfig(pipeline_parallel_size, - tensor_parallel_size, - worker_use_ray=True) - distributed_init_method = f"tcp://localhost:{distributed_init_port}" - _init_distributed_environment(parallel_config, rank, - distributed_init_method) +from vllm.test_utils import (init_test_distributed_environment, + multi_process_tensor_parallel) @ray.remote(num_gpus=1, max_calls=1) @@ -101,16 +89,4 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, broadcast_tensor_dict_test_worker ]) def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - # Using ray helps debugging the error when it failed - # as compared to multiprocessing. - ray.init() - - distributed_init_port = get_open_port() - refs = [] - for rank in range(tensor_parallel_size): - refs.append( - test_target.remote(tensor_parallel_size, rank, - distributed_init_port)) - ray.get(refs) - - ray.shutdown() + multi_process_tensor_parallel(tensor_parallel_size, test_target) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py new file mode 100644 index 00000000..ed496559 --- /dev/null +++ b/tests/distributed/test_custom_all_reduce.py @@ -0,0 +1,85 @@ +import random + +import os +import pytest +import ray +import torch +import torch.distributed as dist + +from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.test_utils import (init_test_distributed_environment, + multi_process_tensor_parallel) + +random.seed(42) +test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] +for i, v in enumerate(test_sizes): + test_sizes[i] -= v % 8 + + +@ray.remote(num_gpus=1, max_calls=1) +def graph_allreduce(world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(1, world_size, rank, + distributed_init_port) + + custom_ar.init_custom_ar() + for sz in test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + with custom_ar.capture(): + # use integers so result matches NCCL exactly + inp1 = torch.randint(1, + 16, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + inp2 = torch.randint(1, + 16, (sz, ), + dtype=dtype, + device=torch.cuda.current_device()) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2) + graph.replay() + assert torch.allclose(out1, inp1) + assert torch.allclose(out2, inp2) + + +@ray.remote(num_gpus=1, max_calls=1) +def eager_allreduce(world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(1, world_size, rank, + distributed_init_port) + + sz = 1024 + custom_ar.init_custom_ar() + fa = custom_ar.get_handle() + inp = torch.ones(sz, dtype=torch.float32, device=device) + out = fa.all_reduce_unreg(inp) + assert torch.allclose(out, inp * world_size) + + inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) + out = fa.all_reduce_unreg(inp) + assert torch.allclose(out, inp * world_size) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("tensor_parallel_size", [2]) +@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce]) +def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): + multi_process_tensor_parallel(tensor_parallel_size, test_target) + + +if __name__ == "__main__": + multi_process_tensor_parallel(2, graph_allreduce) diff --git a/vllm/config.py b/vllm/config.py index 4f1ce87c..da97eaa7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -328,6 +328,8 @@ class ParallelConfig: worker_use_ray: Whether to use Ray for model workers. Will be set to True if either pipeline_parallel_size or tensor_parallel_size is greater than 1. + disable_custom_all_reduce: Disable the custom all-reduce kernel and + fall back to NCCL. """ def __init__( @@ -336,11 +338,13 @@ class ParallelConfig: tensor_parallel_size: int, worker_use_ray: bool, max_parallel_loading_workers: Optional[int] = None, + disable_custom_all_reduce: bool = False, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.worker_use_ray = worker_use_ray self.max_parallel_loading_workers = max_parallel_loading_workers + self.disable_custom_all_reduce = disable_custom_all_reduce self.world_size = pipeline_parallel_size * tensor_parallel_size if self.world_size > 1: @@ -351,6 +355,16 @@ class ParallelConfig: if self.pipeline_parallel_size > 1: raise NotImplementedError( "Pipeline parallelism is not supported yet.") + if is_hip(): + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "supported on AMD GPUs.") + elif self.pipeline_parallel_size > 1: + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "supported with pipeline parallelism.") class SchedulerConfig: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 090fa95b..968362c4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -35,6 +35,7 @@ class EngineArgs: quantization: Optional[str] = None enforce_eager: bool = False max_context_len_to_capture: int = 8192 + disable_custom_all_reduce: bool = False enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 @@ -208,6 +209,10 @@ class EngineArgs: help='maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') + parser.add_argument('--disable-custom-all-reduce', + action='store_true', + default=EngineArgs.disable_custom_all_reduce, + help='See ParallelConfig') # LoRA related configs parser.add_argument('--enable-lora', action='store_true', @@ -269,7 +274,8 @@ class EngineArgs: parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, - self.max_parallel_loading_workers) + self.max_parallel_loading_workers, + self.disable_custom_all_reduce) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0dedc232..87752eea 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -82,6 +82,7 @@ class LLMEngine: f"download_dir={model_config.download_dir!r}, " f"load_format={model_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " + f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, " f"quantization={model_config.quantization}, " f"enforce_eager={model_config.enforce_eager}, " f"seed={model_config.seed})") diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index aab0c961..614e6fa5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -64,6 +64,7 @@ class LLM: max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. + disable_custom_all_reduce: See ParallelConfig """ def __init__( @@ -82,6 +83,7 @@ class LLM: swap_space: int = 4, enforce_eager: bool = False, max_context_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: @@ -101,6 +103,7 @@ class LLM: swap_space=swap_space, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args(engine_args) diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index fff6920b..65671994 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -10,17 +10,27 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, ) +from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group. - NOTE: This operation is applied in-place on the input tensor. + NOTE: This operation will be applied in-place on the input tensor if + disable_custom_all_reduce is set to True. Otherwise, this operation may or + may not be applied in place depending on whether custom all reduce is + invoked for a particular tensor, which further depends on the tensor size + and GPU topology. + + TLDR: always assume this function modifies its input, but use the return + value as the output. """ # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: return input_ - # All-reduce. + out = custom_all_reduce(input_) + if out is not None: + return out torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) return input_ diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py new file mode 100644 index 00000000..5b88649c --- /dev/null +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -0,0 +1,223 @@ +from contextlib import contextmanager +from typing import Optional + +import torch +import torch.distributed as dist + +from vllm.logger import init_logger +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank) + +try: + from vllm._C import custom_ar + import pynvml +except ImportError: + # For AMD GPUs + custom_ar = None + pynvml = None + +logger = init_logger(__name__) + +_CA_HANDLE = None +_IS_CAPTURING = False +_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + + +def init_custom_ar() -> None: + global _CA_HANDLE + if _CA_HANDLE is not None: + return + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + if world_size not in _SUPPORTED_WORLD_SIZES: + logger.warn( + "Custom allreduce is disabled due to an unsupported world size: " + "%d. Supported world sizes: %s. To slience this warning, specify" + "disable_custom_all_reduce=True explicitly.", world_size, + str(_SUPPORTED_WORLD_SIZES)) + return + if not _can_p2p(rank, world_size): + logger.warn( + "Custom allreduce is disabled because your platform lacks GPU P2P" + " capability. To slience this warning, specify" + "disable_custom_all_reduce=True explicitly.") + return + _CA_HANDLE = CustomAllreduce(rank, world_size) + + +def begin_capture() -> None: + global _IS_CAPTURING + _IS_CAPTURING = True + + +def end_capture() -> None: + global _IS_CAPTURING + _IS_CAPTURING = False + + +def is_capturing() -> bool: + return _IS_CAPTURING and _CA_HANDLE is not None + + +def get_handle() -> Optional["CustomAllreduce"]: + return _CA_HANDLE + + +@contextmanager +def capture(): + try: + begin_capture() + yield + finally: + end_capture() + handle = get_handle() + if handle is not None: + handle.register_graph_buffers() + + +def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: + ca_handle = get_handle() + # when custom allreduce is disabled, this will be None + if ca_handle is None: + return + if is_capturing(): + if torch.cuda.is_current_stream_capturing(): + if ca_handle.should_custom_ar(input): + return ca_handle.all_reduce_reg(input) + else: + if ca_handle.should_custom_ar(input): + # if warm up, mimic the allocation pattern + # since custom allreduce is out-of-place + return torch.empty_like(input) + else: + # note: outside of cuda graph context, + # custom allreduce incurs a cost of cudaMemcpy, which should + # be small(<=1% of overall latency) compared to the performance + # gains of using custom kernels + if ca_handle.should_custom_ar(input): + return ca_handle.all_reduce_unreg(input) + + +@contextmanager +def _nvml(): + try: + pynvml.nvmlInit() + yield + finally: + pynvml.nvmlShutdown() + + +# query if the set of gpus are fully connected by nvlink (1 hop) +@_nvml() +def _is_full_nvlink(rank, world_size): + handle = pynvml.nvmlDeviceGetHandleByIndex(rank) + for i in range(world_size): + if i != rank: + try: + link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i) + if not link_state: + return False + except pynvml.NVMLError as error: + logger.info( + f"NVLink detection failed with message \"{str(error)}\". " + "This is normal if your machine has no NVLink equipped") + return False + return True + + +def _can_p2p(rank: int, world_size: int) -> bool: + for i in range(world_size): + if i == rank: + continue + if not torch.cuda.can_device_access_peer(rank, i): + return False + return True + + +class CustomAllreduce: + + # max_size: max supported allreduce size + def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: + # buffers memory are owned by this Python class and passed to C++ + # meta data composes of two parts: meta data for synchronization + # (256 bytes) and a temporary buffer for storing intermediate + # allreduce results. + self.meta = torch.zeros(custom_ar.meta_size() + max_size, + dtype=torch.uint8, + device="cuda") + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda") + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty(8 * 1024 * 1024, + dtype=torch.uint8, + device="cuda") + self.max_size = max_size + self.world_size = world_size + handles, offsets = self._get_ipc_meta(self.meta) + self.full_nvlink = _is_full_nvlink(rank, world_size) + self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, + handles, offsets, rank, + self.full_nvlink) + self.fast_cond = self.full_nvlink or world_size <= 2 + self.register_buffer(self.buffer) + + def _get_ipc_meta(self, inp: torch.Tensor): + data = inp.untyped_storage()._share_cuda_() + shard_data = ( + data[1], # ipc handle to base ptr + data[3], # offset of base ptr + ) + return self._gather_ipc_meta(shard_data) + + def _gather_ipc_meta(self, shard_data): + all_data = [None] * self.world_size + dist.all_gather_object(all_data, shard_data) + + handles = [] + offsets = [] + for i in range(len(all_data)): + handles.append(all_data[i][0]) + offsets.append(all_data[i][1]) + return handles, offsets + + def register_buffer(self, inp: torch.Tensor): + handles, offsets = self._get_ipc_meta(inp) + custom_ar.register_buffer(self._ptr, inp, handles, offsets) + + def register_graph_buffers(self): + handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr) + handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) + logger.info("Registering %d cuda graph addresses", len(offset)) + custom_ar.register_graph_buffers(self._ptr, handles, offsets) + + def should_custom_ar(self, inp: torch.Tensor): + return custom_ar.should_custom_ar(inp, self.max_size, self.world_size, + self.full_nvlink) + + # all reduce, assuming inp tensor is IPC registered with register_buffer, + # or, in the context of cuda graphs, register_graph_buffers + def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + custom_ar.all_reduce_reg(self._ptr, inp, out) + return out + + # all reduce, assuming inp tensor is NOT IPC registered + def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) + return out + + def close(self): + if self._ptr: + custom_ar.dispose(self._ptr) + self._ptr = 0 + + def __del__(self): + self.close() diff --git a/vllm/test_utils.py b/vllm/test_utils.py new file mode 100644 index 00000000..4f74c050 --- /dev/null +++ b/vllm/test_utils.py @@ -0,0 +1,38 @@ +import ray + +from vllm.config import ParallelConfig +from vllm.utils import get_open_port +from vllm.worker.worker import init_distributed_environment + + +def init_test_distributed_environment( + pipeline_parallel_size: int, + tensor_parallel_size: int, + rank: int, + distributed_init_port: str, +) -> None: + parallel_config = ParallelConfig(pipeline_parallel_size, + tensor_parallel_size, + worker_use_ray=True) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment(parallel_config, rank, + distributed_init_method) + + +def multi_process_tensor_parallel( + tensor_parallel_size: int, + test_target, +) -> None: + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + ray.init() + + distributed_init_port = get_open_port() + refs = [] + for rank in range(tensor_parallel_size): + refs.append( + test_target.remote(tensor_parallel_size, rank, + distributed_init_port)) + ray.get(refs) + + ray.shutdown() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 98511561..60f5b71d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,6 +10,7 @@ from vllm.logger import init_logger from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) +from vllm.model_executor.parallel_utils import custom_all_reduce from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager @@ -651,37 +652,38 @@ class ModelRunner: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. - for batch_size in reversed(batch_size_capture_list): - # Create dummy input_metadata. - input_metadata = InputMetadata( - is_prompt=False, - slot_mapping=slot_mapping[:batch_size], - prompt_lens=None, - max_seq_len=None, - start_loc=None, - max_context_len=self.max_context_len_to_capture, - context_lens=context_lens[:batch_size], - block_tables=block_tables[:batch_size], - use_cuda_graph=True, - ) - - if self.lora_config: - lora_mapping = LoRAMapping( - [0] * batch_size, - [0] * batch_size, + with custom_all_reduce.capture(): + for batch_size in reversed(batch_size_capture_list): + # Create dummy input_metadata. + input_metadata = InputMetadata( + is_prompt=False, + slot_mapping=slot_mapping[:batch_size], + prompt_lens=None, + max_seq_len=None, + start_loc=None, + max_context_len=self.max_context_len_to_capture, + context_lens=context_lens[:batch_size], + block_tables=block_tables[:batch_size], + use_cuda_graph=True, ) - self.set_active_loras(set(), lora_mapping) - graph_runner = CUDAGraphRunner(self.model) - graph_runner.capture( - input_tokens[:batch_size], - input_positions[:batch_size], - kv_caches, - input_metadata, - memory_pool=self.graph_memory_pool, - ) - self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[batch_size] = graph_runner + if self.lora_config: + lora_mapping = LoRAMapping( + [0] * batch_size, + [0] * batch_size, + ) + self.set_active_loras(set(), lora_mapping) + + graph_runner = CUDAGraphRunner(self.model) + graph_runner.capture( + input_tokens[:batch_size], + input_positions[:batch_size], + kv_caches, + input_metadata, + memory_pool=self.graph_memory_pool, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner end_time = time.perf_counter() elapsed_time = end_time - start_time diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 84528358..f1dad64b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,6 +11,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.model_executor import set_random_seed from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) +from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar from vllm.model_executor.parallel_utils.parallel_state import ( ensure_model_parallel_initialized) from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -78,9 +79,10 @@ class Worker: _check_if_gpu_supports_dtype(self.model_config.dtype) # Initialize the distributed environment. - _init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method) - + init_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method) + if not self.parallel_config.disable_custom_all_reduce: + init_custom_ar() # Initialize the model. set_random_seed(self.model_config.seed) @@ -219,7 +221,7 @@ class Worker: return self.model_runner.list_loras() -def _init_distributed_environment( +def init_distributed_environment( parallel_config: ParallelConfig, rank: int, distributed_init_method: Optional[str] = None,