Implement custom all reduce kernels (#2192)
This commit is contained in:
parent
220a47627b
commit
380170038e
148
csrc/custom_all_reduce.cu
Normal file
148
csrc/custom_all_reduce.cu
Normal file
@ -0,0 +1,148 @@
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "custom_all_reduce.cuh"
|
||||
|
||||
// fake pointer type
|
||||
using fptr_t = uint64_t;
|
||||
static_assert(sizeof(void *) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = offsets.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (world_size != handles.size())
|
||||
throw std::invalid_argument(
|
||||
"handles length should equal to offsets length");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
cudaIpcMemHandle_t ipc_handles[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||
* dimension). Currently, we require this because stride information is not
|
||||
* passed into the kernels and we treat input tensors as flat.
|
||||
*
|
||||
* Examples
|
||||
* A = torch.zeros(3, 3, 3)
|
||||
* 1. A: OK
|
||||
* 2. A[1:]: OK
|
||||
* 3. A.permute(2, 0, 1): OK
|
||||
* 4. A[1:].permute(2, 0, 1): OK
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor &t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||
t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||
bool full_nvlink) {
|
||||
auto inp_size = inp.numel() * inp.element_size();
|
||||
// custom allreduce requires input byte size to be multiples of 16
|
||||
if (inp_size % 16 != 0) return false;
|
||||
if (!_is_weak_contiguous(inp)) return false;
|
||||
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
|
||||
// 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
|
||||
// <= 512k
|
||||
return world_size <= 4 && inp_size <= 512 * 1024;
|
||||
}
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
||||
cudaStream_t stream) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
|
||||
reinterpret_cast<float *>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
|
||||
reinterpret_cast<half *>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
|
||||
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
_all_reduce(_fa, inp, out, stream);
|
||||
}
|
||||
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||
torch::Tensor &out) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
||||
"registered buffer is too small to contain the input");
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
||||
input_size, cudaMemcpyDeviceToDevice, stream));
|
||||
_all_reduce(_fa, reg_buffer, out, stream);
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
int meta_size() { return sizeof(vllm::Metadata); }
|
||||
|
||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
return fa->get_graph_buffer_ipc_meta();
|
||||
}
|
||||
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
fa->register_graph_buffers(handles, offsets);
|
||||
}
|
555
csrc/custom_all_reduce.cuh
Normal file
555
csrc/custom_all_reduce.cuh
Normal file
@ -0,0 +1,555 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace vllm {
|
||||
|
||||
struct Signal {
|
||||
alignas(64) union {
|
||||
uint64_t flag;
|
||||
unsigned char data[8];
|
||||
} start;
|
||||
alignas(64) union {
|
||||
uint64_t flag;
|
||||
unsigned char data[8];
|
||||
} end;
|
||||
};
|
||||
|
||||
struct Metadata {
|
||||
alignas(128) Signal sg;
|
||||
alignas(128) int counter;
|
||||
};
|
||||
static_assert(offsetof(Metadata, counter) == 128);
|
||||
static_assert(sizeof(Metadata) == 256);
|
||||
|
||||
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
|
||||
|
||||
struct RankSignals {
|
||||
volatile Signal *signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) { return __half2float(val); }
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Pytorch, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half &assign_add(half &a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float &assign_add(float &a, float b) { return a += b; }
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
||||
template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
// compute flag at compile time
|
||||
__host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
|
||||
auto m = std::numeric_limits<uint64_t>::max();
|
||||
return m >> ((8 - ngpus) * 8);
|
||||
}
|
||||
|
||||
template <int ngpus>
|
||||
DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
|
||||
int rank) {
|
||||
constexpr auto FLAG = compute_flag(ngpus);
|
||||
if (blockIdx.x == 0) {
|
||||
if (threadIdx.x < ngpus)
|
||||
// simultaneously write to the corresponding byte to all other ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start.data[rank] = 255;
|
||||
else if (threadIdx.x == 32)
|
||||
// reset
|
||||
meta->sg.end.flag = 0;
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
while (meta->sg.start.flag != FLAG)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
|
||||
int rank) {
|
||||
constexpr auto FLAG = compute_flag(ngpus);
|
||||
__syncthreads();
|
||||
__shared__ int num;
|
||||
if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
|
||||
__syncthreads();
|
||||
|
||||
// Only the last completing block can perform the end synchronization
|
||||
// This can ensures when the final busy wait ends, all ranks must have
|
||||
// finished reading each other's buffer.
|
||||
if (num == gridDim.x - 1) {
|
||||
if (threadIdx.x == 32) {
|
||||
// reset in a different warp
|
||||
meta->counter = 0;
|
||||
meta->sg.start.flag = 0;
|
||||
} else if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding byte to all other ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end.data[rank] = 255;
|
||||
}
|
||||
// if this is the final sync, only one block needs it
|
||||
// because kernel exit can serve as sync
|
||||
if constexpr (final_sync) {
|
||||
if (threadIdx.x == 0) {
|
||||
while (meta->sg.end.flag != FLAG)
|
||||
;
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (!final_sync) {
|
||||
if (threadIdx.x == 0) {
|
||||
while (meta->sg.end.flag != FLAG)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
|
||||
volatile Metadata *meta, T *__restrict__ result,
|
||||
int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
start_sync<ngpus>(sg, meta, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
((P *)result)[idx] =
|
||||
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
||||
}
|
||||
end_sync<ngpus, true>(sg, meta, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
DINLINE P *get_tmp_buf(volatile Signal *sg) {
|
||||
return (P *)(((Metadata *)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
|
||||
volatile Metadata *meta, T *__restrict__ result,
|
||||
int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
const P *ptrs[ngpus];
|
||||
P *tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P *)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
start_sync<ngpus>(sg, meta, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
// Maybe TODO: replace this with per-block release-acquire
|
||||
// can save about 1-2us (not a lot though)
|
||||
end_sync<ngpus>(sg, meta, rank);
|
||||
|
||||
// stage 2: allgather
|
||||
for (int idx = tid; idx < part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int dst_idx = ((rank + i) % ngpus) * part + idx;
|
||||
((P *)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
// process the last larger partition
|
||||
int remaining = size - part * ngpus;
|
||||
if (tid < remaining) {
|
||||
int dst_idx = tid + part * ngpus;
|
||||
((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
|
||||
}
|
||||
|
||||
// faster than this
|
||||
// for (int idx = tid; idx < size; idx += stride) {
|
||||
// int target_rank = idx / part;
|
||||
// if (target_rank == ngpus) target_rank -= 1;
|
||||
// ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
|
||||
// }
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
|
||||
volatile Metadata *meta,
|
||||
T *__restrict__ result, int rank,
|
||||
int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
|
||||
constexpr int hg = ngpus / 2;
|
||||
// Actually not quite half butterfly.
|
||||
// This is an all-to-all within each group containing half of the ranks
|
||||
// followed by cross-group add. Equivalent to half butterfly when there
|
||||
// are 4 GPUs, a common case for PCIe cards like T4 and A10.
|
||||
const P *ptrs[hg];
|
||||
{
|
||||
int start = rank - rank % hg;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < hg; i++) {
|
||||
ptrs[i] = (const P *)_dp->ptrs[i + start];
|
||||
}
|
||||
}
|
||||
start_sync<ngpus>(sg, meta, rank);
|
||||
for (int idx = tid; idx < size; idx += stride) {
|
||||
tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
|
||||
}
|
||||
end_sync<ngpus>(sg, meta, rank);
|
||||
|
||||
auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
|
||||
// do the cross group reduction
|
||||
for (int idx = tid; idx < size; idx += stride) {
|
||||
auto tmp = tmp_out[idx];
|
||||
packed_assign_add(tmp, src[idx]);
|
||||
((P *)result)[idx] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
// below are device pointers
|
||||
RankSignals sg_;
|
||||
std::unordered_map<void *, RankData *> buffers_;
|
||||
Metadata *meta_;
|
||||
|
||||
// stores the registered device pointers from all ranks
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void *> graph_unreg_buffers_;
|
||||
std::vector<void *> ipc_handles_;
|
||||
|
||||
/**
|
||||
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||
*
|
||||
* There's a total of sizeof(Metadata) of prefix before the actual data,
|
||||
* so meta + 1 points to actual temporary buffer.
|
||||
*
|
||||
* note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor
|
||||
*/
|
||||
CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
|
||||
const cudaIpcMemHandle_t *handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(offsets.size()),
|
||||
full_nvlink_(full_nvlink),
|
||||
meta_(meta),
|
||||
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
Metadata *rank_meta;
|
||||
if (i != rank_) {
|
||||
char *handle;
|
||||
CUDACHECK(cudaIpcOpenMemHandle((void **)&handle, handles[i],
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
ipc_handles_.push_back(handle);
|
||||
handle += offsets[i];
|
||||
rank_meta = (Metadata *)handle;
|
||||
} else {
|
||||
rank_meta = meta_;
|
||||
}
|
||||
sg_.signals[i] = &rank_meta->sg;
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void *base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr,
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(cudaIpcGetMemHandle(
|
||||
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char *)ptr) - ((char *)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " +
|
||||
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
void register_buffer(const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, void *self) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
if (i != rank_) {
|
||||
char *handle;
|
||||
CUDACHECK(cudaIpcOpenMemHandle(
|
||||
(void **)&handle, *((const cudaIpcMemHandle_t *)handles[i].data()),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
ipc_handles_.push_back(handle);
|
||||
handle += offsets[i];
|
||||
data.ptrs[i] = handle;
|
||||
} else {
|
||||
data.ptrs[i] = self;
|
||||
}
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CUDACHECK(
|
||||
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||
buffers_[self] = d_data;
|
||||
}
|
||||
|
||||
// note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void register_graph_buffers(
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto &rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char *handle;
|
||||
CUDACHECK(cudaIpcOpenMemHandle(
|
||||
(void **)&handle,
|
||||
*((cudaIpcMemHandle_t *)&handles[j]
|
||||
[i * sizeof(cudaIpcMemHandle_t)]),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
ipc_handles_.push_back(handle);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
|
||||
sizeof(RankData) * num_buffers,
|
||||
cudaMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the result after careful grid search. Using 36 blocks give the best
|
||||
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
||||
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
||||
* Not quite sure the underlying reason, but my guess is that too many SMs
|
||||
* will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T *input, T *output, int size,
|
||||
int threads = 512, int block_limit = 36) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
|
||||
RankData *ptrs;
|
||||
cudaStreamCaptureStatus status;
|
||||
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
||||
if (status == cudaStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " +
|
||||
std::to_string(reinterpret_cast<uint64_t>(input)) +
|
||||
" is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus> \
|
||||
<<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_half_butterfly); \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto ptr : ipc_handles_) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
|
||||
int, int, int);
|
||||
*/
|
||||
} // namespace vllm
|
284
csrc/custom_all_reduce_test.cu
Normal file
284
csrc/custom_all_reduce_test.cu
Normal file
@ -0,0 +1,284 @@
|
||||
/**
|
||||
* This is a standalone test for custom allreduce.
|
||||
* To compile, make sure you have MPI and NCCL installed in your system.
|
||||
* export MPI_HOME=XXX
|
||||
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
||||
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
|
||||
*
|
||||
* Warning: this C++ test is not designed to be very readable and was used
|
||||
* during the rapid prototyping process.
|
||||
*
|
||||
* To run:
|
||||
* mpirun -np 8 ./custom_all_reduce_test
|
||||
*/
|
||||
#include <cuda.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "cuda_profiler_api.h"
|
||||
#include "custom_all_reduce.cuh"
|
||||
#include "mpi.h"
|
||||
#include "nccl.h"
|
||||
|
||||
#define MPICHECK(cmd) \
|
||||
do { \
|
||||
int e = cmd; \
|
||||
if (e != MPI_SUCCESS) { \
|
||||
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define NCCLCHECK(cmd) \
|
||||
do { \
|
||||
ncclResult_t r = cmd; \
|
||||
if (r != ncclSuccess) { \
|
||||
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
ncclGetErrorString(r)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
__global__ void dummy_kernel() {
|
||||
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void set_data(T *data, int size, int myRank) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
data[idx] = myRank * 0.11f;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
|
||||
double *fdata2, int size) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
fdata1[idx] = data1[idx];
|
||||
fdata2[idx] = data2[idx];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
||||
int myRank, int nRanks, int size) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
double sum = 0.0;
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
|
||||
T hval = val; // downcast first
|
||||
sum += static_cast<double>(hval);
|
||||
if (i == myRank) data[idx] = hval;
|
||||
}
|
||||
ground_truth[idx] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
int data_size) {
|
||||
T *result;
|
||||
cudaStream_t stream;
|
||||
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
|
||||
CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
|
||||
|
||||
cudaIpcMemHandle_t self_data_handle;
|
||||
cudaIpcMemHandle_t data_handles[8];
|
||||
vllm::Metadata *buffer;
|
||||
T *self_data_copy;
|
||||
/**
|
||||
* Allocate IPC buffer
|
||||
*
|
||||
* The first section is a temporary buffer for storing intermediate allreduce
|
||||
* results, if a particular algorithm requires it. The second section is for
|
||||
* the input to the allreduce. The actual API takes the input pointer as an
|
||||
* argument (that is, they can and usually should be allocated separately).
|
||||
* But since the input pointers and the temporary buffer all require IPC
|
||||
* registration, they are allocated and registered together in the test for
|
||||
* convenience.
|
||||
*/
|
||||
CUDACHECK(
|
||||
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
||||
CUDACHECK(cudaMemset(buffer, 0,
|
||||
2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
||||
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
|
||||
CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
|
||||
|
||||
MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
|
||||
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
|
||||
MPI_BYTE, MPI_COMM_WORLD));
|
||||
|
||||
void *rank_data;
|
||||
size_t rank_data_sz = 16 * 1024 * 1024;
|
||||
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
|
||||
std::vector<int64_t> offsets(nRanks, 0);
|
||||
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
|
||||
offsets, myRank);
|
||||
auto *self_data =
|
||||
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
|
||||
sizeof(vllm::Metadata) + data_size * sizeof(T));
|
||||
// hack buffer registration
|
||||
{
|
||||
std::vector<std::string> handles;
|
||||
handles.reserve(nRanks);
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
char *begin = (char *)&data_handles[i];
|
||||
char *end = (char *)&data_handles[i + 1];
|
||||
handles.emplace_back(begin, end);
|
||||
}
|
||||
std::vector<int64_t> offsets(
|
||||
nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T));
|
||||
fa.register_buffer(handles, offsets, self_data);
|
||||
}
|
||||
|
||||
double *ground_truth;
|
||||
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
|
||||
curandState_t *states;
|
||||
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
|
||||
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
|
||||
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
|
||||
nRanks, data_size);
|
||||
CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
cudaEvent_t start, stop;
|
||||
CUDACHECK(cudaEventCreate(&start));
|
||||
CUDACHECK(cudaEventCreate(&stop));
|
||||
|
||||
ncclDataType_t ncclDtype;
|
||||
if (std::is_same<T, half>::value) {
|
||||
ncclDtype = ncclFloat16;
|
||||
} else if (std::is_same<T, nv_bfloat16>::value) {
|
||||
ncclDtype = ncclBfloat16;
|
||||
} else {
|
||||
ncclDtype = ncclFloat;
|
||||
}
|
||||
|
||||
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
constexpr int warmup_iters = 5;
|
||||
constexpr int num_iters = 25;
|
||||
// warmup
|
||||
for (int i = 0; i < warmup_iters; i++) {
|
||||
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
||||
stream));
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(start, stream));
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
||||
stream));
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(stop, stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
float allreduce_ms = 0;
|
||||
cudaEventElapsedTime(&allreduce_ms, start, stop);
|
||||
|
||||
// if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
// set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank);
|
||||
|
||||
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
// warm up
|
||||
for (int i = 0; i < warmup_iters; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(start, stream));
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(stop, stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
float duration_ms = 0;
|
||||
cudaEventElapsedTime(&duration_ms, start, stop);
|
||||
if (myRank == 0)
|
||||
printf(
|
||||
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
|
||||
"time:%.2fus\n",
|
||||
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
|
||||
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
|
||||
|
||||
// And wait for all the queued up work to complete
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
|
||||
ncclSum, comm, stream));
|
||||
|
||||
double *nccl_result, *my_result;
|
||||
CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
|
||||
CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
|
||||
|
||||
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
|
||||
my_result, data_size);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
for (unsigned long j = 0; j < data_size; j++) {
|
||||
auto diff = abs(nccl_result[j] - my_result[j]);
|
||||
if (diff >= 1e-2) {
|
||||
printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
||||
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
long double nccl_diffs = 0.0;
|
||||
long double my_diffs = 0.0;
|
||||
for (int j = 0; j < data_size; j++) {
|
||||
nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
||||
my_diffs += abs(my_result[j] - ground_truth[j]);
|
||||
}
|
||||
if (myRank == 0)
|
||||
std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
||||
<< " me: " << my_diffs / data_size << std::endl;
|
||||
|
||||
CUDACHECK(cudaFree(result));
|
||||
CUDACHECK(cudaFree(self_data_copy));
|
||||
CUDACHECK(cudaFree(rank_data));
|
||||
CUDACHECK(cudaFree(buffer));
|
||||
CUDACHECK(cudaFree(states));
|
||||
CUDACHECK(cudaFreeHost(ground_truth));
|
||||
CUDACHECK(cudaFreeHost(nccl_result));
|
||||
CUDACHECK(cudaFreeHost(my_result));
|
||||
CUDACHECK(cudaStreamDestroy(stream));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
int nRanks, myRank;
|
||||
MPICHECK(MPI_Init(&argc, &argv));
|
||||
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
|
||||
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
|
||||
CUDACHECK(cudaSetDevice(myRank));
|
||||
ncclUniqueId id;
|
||||
ncclComm_t comm;
|
||||
if (myRank == 0) ncclGetUniqueId(&id);
|
||||
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
|
||||
MPI_COMM_WORLD));
|
||||
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
|
||||
|
||||
cudaProfilerStart();
|
||||
// for (int threads : {256, 512}) {
|
||||
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
|
||||
// }
|
||||
// }
|
||||
for (int sz = 512; sz <= (32 << 20); sz *= 2) {
|
||||
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 50);
|
||||
}
|
||||
|
||||
cudaProfilerStop();
|
||||
return EXIT_SUCCESS;
|
||||
}
|
22
csrc/ops.h
22
csrc/ops.h
@ -97,3 +97,25 @@ torch::Tensor gptq_gemm(
|
||||
void gptq_shuffle(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm);
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = uint64_t;
|
||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink);
|
||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||
bool full_nvlink);
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||
torch::Tensor &out);
|
||||
void dispose(fptr_t _fa);
|
||||
int meta_size();
|
||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets);
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets);
|
||||
#endif
|
||||
|
@ -88,4 +88,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
&get_max_shared_memory_per_block_device_attribute,
|
||||
"Gets the maximum shared memory per block device attribute.");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Custom all-reduce kernels
|
||||
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
|
||||
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
|
||||
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
|
||||
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
|
||||
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
|
||||
custom_ar.def("dispose", &dispose, "dispose");
|
||||
custom_ar.def("meta_size", &meta_size, "meta_size");
|
||||
custom_ar.def("register_buffer", ®ister_buffer, "register_buffer");
|
||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
|
||||
"get_graph_buffer_ipc_meta");
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
||||
"register_graph_buffers");
|
||||
#endif
|
||||
|
||||
}
|
||||
|
@ -10,3 +10,4 @@ fastapi
|
||||
uvicorn[standard]
|
||||
pydantic >= 2.0 # Required for OpenAI server.
|
||||
aioprometheus[starlette]
|
||||
pynvml == 11.5.0
|
||||
|
6
setup.py
6
setup.py
@ -51,8 +51,8 @@ if _is_hip():
|
||||
"Cannot find ROCM_HOME. ROCm must be available to build the package."
|
||||
)
|
||||
NVCC_FLAGS += ["-DUSE_ROCM"]
|
||||
NVCC_FLAGS += [f"-U__HIP_NO_HALF_CONVERSIONS__"]
|
||||
NVCC_FLAGS += [f"-U__HIP_NO_HALF_OPERATORS__"]
|
||||
NVCC_FLAGS += ["-U__HIP_NO_HALF_CONVERSIONS__"]
|
||||
NVCC_FLAGS += ["-U__HIP_NO_HALF_OPERATORS__"]
|
||||
|
||||
if _is_cuda() and CUDA_HOME is None:
|
||||
raise RuntimeError(
|
||||
@ -307,6 +307,7 @@ vllm_extension_sources = [
|
||||
|
||||
if _is_cuda():
|
||||
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
|
||||
vllm_extension_sources.append("csrc/custom_all_reduce.cu")
|
||||
|
||||
if not _is_neuron():
|
||||
vllm_extension = CUDAExtension(
|
||||
@ -316,6 +317,7 @@ if not _is_neuron():
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
libraries=["cuda"] if _is_cuda() else [],
|
||||
)
|
||||
ext_modules.append(vllm_extension)
|
||||
|
||||
|
@ -6,25 +6,13 @@ import pytest
|
||||
import torch
|
||||
import ray
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_all_gather,
|
||||
broadcast_tensor_dict,
|
||||
)
|
||||
from vllm.worker.worker import _init_distributed_environment
|
||||
|
||||
|
||||
def init_test_distributed_environment(pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
parallel_config = ParallelConfig(pipeline_parallel_size,
|
||||
tensor_parallel_size,
|
||||
worker_use_ray=True)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
_init_distributed_environment(parallel_config, rank,
|
||||
distributed_init_method)
|
||||
from vllm.test_utils import (init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
@ -101,16 +89,4 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||
broadcast_tensor_dict_test_worker
|
||||
])
|
||||
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
||||
# Using ray helps debugging the error when it failed
|
||||
# as compared to multiprocessing.
|
||||
ray.init()
|
||||
|
||||
distributed_init_port = get_open_port()
|
||||
refs = []
|
||||
for rank in range(tensor_parallel_size):
|
||||
refs.append(
|
||||
test_target.remote(tensor_parallel_size, rank,
|
||||
distributed_init_port))
|
||||
ray.get(refs)
|
||||
|
||||
ray.shutdown()
|
||||
multi_process_tensor_parallel(tensor_parallel_size, test_target)
|
||||
|
85
tests/distributed/test_custom_all_reduce.py
Normal file
85
tests/distributed/test_custom_all_reduce.py
Normal file
@ -0,0 +1,85 @@
|
||||
import random
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.test_utils import (init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
|
||||
random.seed(42)
|
||||
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
|
||||
for i, v in enumerate(test_sizes):
|
||||
test_sizes[i] -= v % 8
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def graph_allreduce(world_size, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, world_size, rank,
|
||||
distributed_init_port)
|
||||
|
||||
custom_ar.init_custom_ar()
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
with custom_ar.capture():
|
||||
# use integers so result matches NCCL exactly
|
||||
inp1 = torch.randint(1,
|
||||
16, (sz, ),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device())
|
||||
inp2 = torch.randint(1,
|
||||
16, (sz, ),
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device())
|
||||
torch.cuda.synchronize()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||
# the input buffer is immediately modified to test
|
||||
# synchronization
|
||||
dist.all_reduce(inp1)
|
||||
out2 = tensor_model_parallel_all_reduce(inp2)
|
||||
dist.all_reduce(inp2)
|
||||
graph.replay()
|
||||
assert torch.allclose(out1, inp1)
|
||||
assert torch.allclose(out2, inp2)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def eager_allreduce(world_size, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, world_size, rank,
|
||||
distributed_init_port)
|
||||
|
||||
sz = 1024
|
||||
custom_ar.init_custom_ar()
|
||||
fa = custom_ar.get_handle()
|
||||
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||
out = fa.all_reduce_unreg(inp)
|
||||
assert torch.allclose(out, inp * world_size)
|
||||
|
||||
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
|
||||
out = fa.all_reduce_unreg(inp)
|
||||
assert torch.allclose(out, inp * world_size)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
||||
@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
|
||||
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
||||
multi_process_tensor_parallel(tensor_parallel_size, test_target)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
multi_process_tensor_parallel(2, graph_allreduce)
|
@ -328,6 +328,8 @@ class ParallelConfig:
|
||||
worker_use_ray: Whether to use Ray for model workers. Will be set to
|
||||
True if either pipeline_parallel_size or tensor_parallel_size is
|
||||
greater than 1.
|
||||
disable_custom_all_reduce: Disable the custom all-reduce kernel and
|
||||
fall back to NCCL.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -336,11 +338,13 @@ class ParallelConfig:
|
||||
tensor_parallel_size: int,
|
||||
worker_use_ray: bool,
|
||||
max_parallel_loading_workers: Optional[int] = None,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.max_parallel_loading_workers = max_parallel_loading_workers
|
||||
self.disable_custom_all_reduce = disable_custom_all_reduce
|
||||
|
||||
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
||||
if self.world_size > 1:
|
||||
@ -351,6 +355,16 @@ class ParallelConfig:
|
||||
if self.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is not supported yet.")
|
||||
if is_hip():
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.info(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported on AMD GPUs.")
|
||||
elif self.pipeline_parallel_size > 1:
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.info(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported with pipeline parallelism.")
|
||||
|
||||
|
||||
class SchedulerConfig:
|
||||
|
@ -35,6 +35,7 @@ class EngineArgs:
|
||||
quantization: Optional[str] = None
|
||||
enforce_eager: bool = False
|
||||
max_context_len_to_capture: int = 8192
|
||||
disable_custom_all_reduce: bool = False
|
||||
enable_lora: bool = False
|
||||
max_loras: int = 1
|
||||
max_lora_rank: int = 16
|
||||
@ -208,6 +209,10 @@ class EngineArgs:
|
||||
help='maximum context length covered by CUDA '
|
||||
'graphs. When a sequence has context length '
|
||||
'larger than this, we fall back to eager mode.')
|
||||
parser.add_argument('--disable-custom-all-reduce',
|
||||
action='store_true',
|
||||
default=EngineArgs.disable_custom_all_reduce,
|
||||
help='See ParallelConfig')
|
||||
# LoRA related configs
|
||||
parser.add_argument('--enable-lora',
|
||||
action='store_true',
|
||||
@ -269,7 +274,8 @@ class EngineArgs:
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.worker_use_ray,
|
||||
self.max_parallel_loading_workers)
|
||||
self.max_parallel_loading_workers,
|
||||
self.disable_custom_all_reduce)
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs,
|
||||
model_config.max_model_len,
|
||||
|
@ -82,6 +82,7 @@ class LLMEngine:
|
||||
f"download_dir={model_config.download_dir!r}, "
|
||||
f"load_format={model_config.load_format}, "
|
||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||
f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, "
|
||||
f"quantization={model_config.quantization}, "
|
||||
f"enforce_eager={model_config.enforce_eager}, "
|
||||
f"seed={model_config.seed})")
|
||||
|
@ -64,6 +64,7 @@ class LLM:
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode.
|
||||
disable_custom_all_reduce: See ParallelConfig
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -82,6 +83,7 @@ class LLM:
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
@ -101,6 +103,7 @@ class LLM:
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
@ -10,17 +10,27 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_group,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group.
|
||||
|
||||
NOTE: This operation is applied in-place on the input tensor.
|
||||
NOTE: This operation will be applied in-place on the input tensor if
|
||||
disable_custom_all_reduce is set to True. Otherwise, this operation may or
|
||||
may not be applied in place depending on whether custom all reduce is
|
||||
invoked for a particular tensor, which further depends on the tensor size
|
||||
and GPU topology.
|
||||
|
||||
TLDR: always assume this function modifies its input, but use the return
|
||||
value as the output.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
return input_
|
||||
# All-reduce.
|
||||
out = custom_all_reduce(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
torch.distributed.all_reduce(input_,
|
||||
group=get_tensor_model_parallel_group())
|
||||
return input_
|
||||
|
223
vllm/model_executor/parallel_utils/custom_all_reduce.py
Normal file
223
vllm/model_executor/parallel_utils/custom_all_reduce.py
Normal file
@ -0,0 +1,223 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank)
|
||||
|
||||
try:
|
||||
from vllm._C import custom_ar
|
||||
import pynvml
|
||||
except ImportError:
|
||||
# For AMD GPUs
|
||||
custom_ar = None
|
||||
pynvml = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CA_HANDLE = None
|
||||
_IS_CAPTURING = False
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
|
||||
def init_custom_ar() -> None:
|
||||
global _CA_HANDLE
|
||||
if _CA_HANDLE is not None:
|
||||
return
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
if world_size not in _SUPPORTED_WORLD_SIZES:
|
||||
logger.warn(
|
||||
"Custom allreduce is disabled due to an unsupported world size: "
|
||||
"%d. Supported world sizes: %s. To slience this warning, specify"
|
||||
"disable_custom_all_reduce=True explicitly.", world_size,
|
||||
str(_SUPPORTED_WORLD_SIZES))
|
||||
return
|
||||
if not _can_p2p(rank, world_size):
|
||||
logger.warn(
|
||||
"Custom allreduce is disabled because your platform lacks GPU P2P"
|
||||
" capability. To slience this warning, specify"
|
||||
"disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
_CA_HANDLE = CustomAllreduce(rank, world_size)
|
||||
|
||||
|
||||
def begin_capture() -> None:
|
||||
global _IS_CAPTURING
|
||||
_IS_CAPTURING = True
|
||||
|
||||
|
||||
def end_capture() -> None:
|
||||
global _IS_CAPTURING
|
||||
_IS_CAPTURING = False
|
||||
|
||||
|
||||
def is_capturing() -> bool:
|
||||
return _IS_CAPTURING and _CA_HANDLE is not None
|
||||
|
||||
|
||||
def get_handle() -> Optional["CustomAllreduce"]:
|
||||
return _CA_HANDLE
|
||||
|
||||
|
||||
@contextmanager
|
||||
def capture():
|
||||
try:
|
||||
begin_capture()
|
||||
yield
|
||||
finally:
|
||||
end_capture()
|
||||
handle = get_handle()
|
||||
if handle is not None:
|
||||
handle.register_graph_buffers()
|
||||
|
||||
|
||||
def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
ca_handle = get_handle()
|
||||
# when custom allreduce is disabled, this will be None
|
||||
if ca_handle is None:
|
||||
return
|
||||
if is_capturing():
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if ca_handle.should_custom_ar(input):
|
||||
return ca_handle.all_reduce_reg(input)
|
||||
else:
|
||||
if ca_handle.should_custom_ar(input):
|
||||
# if warm up, mimic the allocation pattern
|
||||
# since custom allreduce is out-of-place
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# note: outside of cuda graph context,
|
||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||
# be small(<=1% of overall latency) compared to the performance
|
||||
# gains of using custom kernels
|
||||
if ca_handle.should_custom_ar(input):
|
||||
return ca_handle.all_reduce_unreg(input)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _nvml():
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
yield
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
|
||||
# query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
@_nvml()
|
||||
def _is_full_nvlink(rank, world_size):
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
|
||||
for i in range(world_size):
|
||||
if i != rank:
|
||||
try:
|
||||
link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i)
|
||||
if not link_state:
|
||||
return False
|
||||
except pynvml.NVMLError as error:
|
||||
logger.info(
|
||||
f"NVLink detection failed with message \"{str(error)}\". "
|
||||
"This is normal if your machine has no NVLink equipped")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
if not torch.cuda.can_device_access_peer(rank, i):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(self, rank, world_size, max_size=8192 * 1024) -> None:
|
||||
# buffers memory are owned by this Python class and passed to C++
|
||||
# meta data composes of two parts: meta data for synchronization
|
||||
# (256 bytes) and a temporary buffer for storing intermediate
|
||||
# allreduce results.
|
||||
self.meta = torch.zeros(custom_ar.meta_size() + max_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda")
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda")
|
||||
# This is a buffer for storing the tuples of pointers pointing to
|
||||
# IPC buffers from all ranks. Each registered tuple has size of
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(8 * 1024 * 1024,
|
||||
dtype=torch.uint8,
|
||||
device="cuda")
|
||||
self.max_size = max_size
|
||||
self.world_size = world_size
|
||||
handles, offsets = self._get_ipc_meta(self.meta)
|
||||
self.full_nvlink = _is_full_nvlink(rank, world_size)
|
||||
self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
|
||||
handles, offsets, rank,
|
||||
self.full_nvlink)
|
||||
self.fast_cond = self.full_nvlink or world_size <= 2
|
||||
self.register_buffer(self.buffer)
|
||||
|
||||
def _get_ipc_meta(self, inp: torch.Tensor):
|
||||
data = inp.untyped_storage()._share_cuda_()
|
||||
shard_data = (
|
||||
data[1], # ipc handle to base ptr
|
||||
data[3], # offset of base ptr
|
||||
)
|
||||
return self._gather_ipc_meta(shard_data)
|
||||
|
||||
def _gather_ipc_meta(self, shard_data):
|
||||
all_data = [None] * self.world_size
|
||||
dist.all_gather_object(all_data, shard_data)
|
||||
|
||||
handles = []
|
||||
offsets = []
|
||||
for i in range(len(all_data)):
|
||||
handles.append(all_data[i][0])
|
||||
offsets.append(all_data[i][1])
|
||||
return handles, offsets
|
||||
|
||||
def register_buffer(self, inp: torch.Tensor):
|
||||
handles, offsets = self._get_ipc_meta(inp)
|
||||
custom_ar.register_buffer(self._ptr, inp, handles, offsets)
|
||||
|
||||
def register_graph_buffers(self):
|
||||
handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
|
||||
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
||||
logger.info("Registering %d cuda graph addresses", len(offset))
|
||||
custom_ar.register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def should_custom_ar(self, inp: torch.Tensor):
|
||||
return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
|
||||
self.full_nvlink)
|
||||
|
||||
# all reduce, assuming inp tensor is IPC registered with register_buffer,
|
||||
# or, in the context of cuda graphs, register_graph_buffers
|
||||
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
custom_ar.all_reduce_reg(self._ptr, inp, out)
|
||||
return out
|
||||
|
||||
# all reduce, assuming inp tensor is NOT IPC registered
|
||||
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
|
||||
return out
|
||||
|
||||
def close(self):
|
||||
if self._ptr:
|
||||
custom_ar.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
38
vllm/test_utils.py
Normal file
38
vllm/test_utils.py
Normal file
@ -0,0 +1,38 @@
|
||||
import ray
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.worker.worker import init_distributed_environment
|
||||
|
||||
|
||||
def init_test_distributed_environment(
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
rank: int,
|
||||
distributed_init_port: str,
|
||||
) -> None:
|
||||
parallel_config = ParallelConfig(pipeline_parallel_size,
|
||||
tensor_parallel_size,
|
||||
worker_use_ray=True)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
init_distributed_environment(parallel_config, rank,
|
||||
distributed_init_method)
|
||||
|
||||
|
||||
def multi_process_tensor_parallel(
|
||||
tensor_parallel_size: int,
|
||||
test_target,
|
||||
) -> None:
|
||||
# Using ray helps debugging the error when it failed
|
||||
# as compared to multiprocessing.
|
||||
ray.init()
|
||||
|
||||
distributed_init_port = get_open_port()
|
||||
refs = []
|
||||
for rank in range(tensor_parallel_size):
|
||||
refs.append(
|
||||
test_target.remote(tensor_parallel_size, rank,
|
||||
distributed_init_port))
|
||||
ray.get(refs)
|
||||
|
||||
ray.shutdown()
|
@ -10,6 +10,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_tensor_dict)
|
||||
from vllm.model_executor.parallel_utils import custom_all_reduce
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
@ -651,37 +652,38 @@ class ModelRunner:
|
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
# Create dummy input_metadata.
|
||||
input_metadata = InputMetadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
prompt_lens=None,
|
||||
max_seq_len=None,
|
||||
start_loc=None,
|
||||
max_context_len=self.max_context_len_to_capture,
|
||||
context_lens=context_lens[:batch_size],
|
||||
block_tables=block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
[0] * batch_size,
|
||||
[0] * batch_size,
|
||||
with custom_all_reduce.capture():
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
# Create dummy input_metadata.
|
||||
input_metadata = InputMetadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
prompt_lens=None,
|
||||
max_seq_len=None,
|
||||
start_loc=None,
|
||||
max_context_len=self.max_context_len_to_capture,
|
||||
context_lens=context_lens[:batch_size],
|
||||
block_tables=block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
self.set_active_loras(set(), lora_mapping)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
graph_runner.capture(
|
||||
input_tokens[:batch_size],
|
||||
input_positions[:batch_size],
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[batch_size] = graph_runner
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
[0] * batch_size,
|
||||
[0] * batch_size,
|
||||
)
|
||||
self.set_active_loras(set(), lora_mapping)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
graph_runner.capture(
|
||||
input_tokens[:batch_size],
|
||||
input_positions[:batch_size],
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[batch_size] = graph_runner
|
||||
|
||||
end_time = time.perf_counter()
|
||||
elapsed_time = end_time - start_time
|
||||
|
@ -11,6 +11,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_tensor_dict)
|
||||
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
ensure_model_parallel_initialized)
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
@ -78,9 +79,10 @@ class Worker:
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
|
||||
# Initialize the distributed environment.
|
||||
_init_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method)
|
||||
|
||||
init_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method)
|
||||
if not self.parallel_config.disable_custom_all_reduce:
|
||||
init_custom_ar()
|
||||
# Initialize the model.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
@ -219,7 +221,7 @@ class Worker:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
|
||||
def _init_distributed_environment(
|
||||
def init_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user