2024-01-28 04:46:35 +08:00
|
|
|
#include <ATen/cuda/Exceptions.h>
|
|
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
#include <c10/cuda/CUDAStream.h>
|
2024-06-09 16:23:30 -04:00
|
|
|
#include <torch/all.h>
|
2024-01-28 04:46:35 +08:00
|
|
|
|
|
|
|
#include "custom_all_reduce.cuh"
|
|
|
|
|
2024-11-06 23:50:47 -08:00
|
|
|
// Fake pointer type, must match fptr_t type in ops.h.
|
|
|
|
// We use this type alias to indicate when pointers are passed in as int64_t.
|
2024-06-09 16:23:30 -04:00
|
|
|
using fptr_t = int64_t;
|
2024-05-22 03:18:41 -04:00
|
|
|
static_assert(sizeof(void*) == sizeof(fptr_t));
|
2024-01-28 04:46:35 +08:00
|
|
|
|
2024-11-06 23:50:47 -08:00
|
|
|
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
|
|
|
torch::Tensor& rank_data, int64_t rank,
|
2025-04-01 07:49:12 +02:00
|
|
|
bool fully_connected) {
|
2024-11-06 23:50:47 -08:00
|
|
|
int world_size = fake_ipc_ptrs.size();
|
2024-01-28 04:46:35 +08:00
|
|
|
if (world_size > 8)
|
|
|
|
throw std::invalid_argument("world size > 8 is not supported");
|
|
|
|
if (world_size % 2 != 0)
|
|
|
|
throw std::invalid_argument("Odd num gpus is not supported for now");
|
|
|
|
if (rank < 0 || rank >= world_size)
|
|
|
|
throw std::invalid_argument("invalid rank passed in");
|
|
|
|
|
2024-11-06 23:50:47 -08:00
|
|
|
vllm::Signal* ipc_ptrs[8];
|
2024-01-28 04:46:35 +08:00
|
|
|
for (int i = 0; i < world_size; i++) {
|
2024-11-06 23:50:47 -08:00
|
|
|
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
2024-11-06 23:50:47 -08:00
|
|
|
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
|
|
|
|
rank_data.numel(), rank, world_size,
|
2025-04-01 07:49:12 +02:00
|
|
|
fully_connected);
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
|
|
|
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
|
|
|
* because it allows transpose of contiguous slice (i.e. slicing the first
|
|
|
|
* dimension). Currently, we require this because stride information is not
|
|
|
|
* passed into the kernels and we treat input tensors as flat.
|
|
|
|
*
|
|
|
|
* Examples
|
|
|
|
* A = torch.zeros(3, 3, 3)
|
|
|
|
* 1. A: OK
|
|
|
|
* 2. A[1:]: OK
|
|
|
|
* 3. A.permute(2, 0, 1): OK
|
|
|
|
* 4. A[1:].permute(2, 0, 1): OK
|
|
|
|
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
|
|
|
* 6. A[:, 1:, 1:]: Not OK
|
|
|
|
*/
|
2024-05-22 03:18:41 -04:00
|
|
|
bool _is_weak_contiguous(torch::Tensor& t) {
|
2024-01-28 04:46:35 +08:00
|
|
|
return t.is_contiguous() ||
|
|
|
|
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
|
|
|
t.numel() * t.element_size());
|
|
|
|
}
|
|
|
|
|
2024-11-06 23:50:47 -08:00
|
|
|
/**
|
|
|
|
* Performs an out-of-place allreduce and stores result in out.
|
|
|
|
*
|
|
|
|
* If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
|
|
|
|
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
|
|
|
* copied into _reg_buffer.
|
|
|
|
*/
|
|
|
|
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
|
|
|
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
2024-05-22 03:18:41 -04:00
|
|
|
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
2024-11-06 23:50:47 -08:00
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
|
|
|
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
|
|
|
|
|
|
|
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
|
|
|
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
2024-01-28 04:46:35 +08:00
|
|
|
TORCH_CHECK(_is_weak_contiguous(out));
|
2024-11-06 23:50:47 -08:00
|
|
|
TORCH_CHECK(_is_weak_contiguous(inp));
|
|
|
|
auto input_size = inp.numel() * inp.element_size();
|
|
|
|
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
|
|
|
|
if (reg_buffer) {
|
|
|
|
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
|
|
|
|
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
|
|
|
|
cudaMemcpyDeviceToDevice, stream));
|
|
|
|
} else {
|
|
|
|
reg_buffer = inp.data_ptr();
|
|
|
|
}
|
2024-01-28 04:46:35 +08:00
|
|
|
switch (out.scalar_type()) {
|
|
|
|
case at::ScalarType::Float: {
|
2024-11-06 23:50:47 -08:00
|
|
|
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
|
2024-05-22 03:18:41 -04:00
|
|
|
reinterpret_cast<float*>(out.data_ptr()),
|
2024-01-28 04:46:35 +08:00
|
|
|
out.numel());
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
case at::ScalarType::Half: {
|
2024-11-06 23:50:47 -08:00
|
|
|
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
|
2024-05-22 03:18:41 -04:00
|
|
|
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
2024-01-28 04:46:35 +08:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
|
|
|
case at::ScalarType::BFloat16: {
|
|
|
|
fa->allreduce<nv_bfloat16>(
|
2024-11-06 23:50:47 -08:00
|
|
|
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
|
2024-05-22 03:18:41 -04:00
|
|
|
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
|
2024-01-28 04:46:35 +08:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
default:
|
|
|
|
throw std::runtime_error(
|
|
|
|
"custom allreduce only supports float32, float16 and bfloat16");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void dispose(fptr_t _fa) {
|
2024-11-06 23:50:47 -08:00
|
|
|
delete reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
|
2024-06-09 16:23:30 -04:00
|
|
|
int64_t meta_size() { return sizeof(vllm::Signal); }
|
2024-01-28 04:46:35 +08:00
|
|
|
|
2024-11-06 23:50:47 -08:00
|
|
|
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
2024-05-22 03:18:41 -04:00
|
|
|
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
2024-11-06 23:50:47 -08:00
|
|
|
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
|
|
|
|
void* ipc_ptrs[8];
|
|
|
|
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
|
|
|
|
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
|
|
|
|
}
|
|
|
|
fa->register_buffer(ipc_ptrs);
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
|
2024-11-06 23:50:47 -08:00
|
|
|
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
|
|
|
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
|
|
|
|
get_graph_buffer_ipc_meta(fptr_t _fa) {
|
2024-05-22 03:18:41 -04:00
|
|
|
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
2024-11-06 23:50:47 -08:00
|
|
|
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
|
|
|
|
std::vector<int64_t> bytes(handle.begin(), handle.end());
|
|
|
|
return std::make_tuple(bytes, offsets);
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
|
|
|
|
2024-11-06 23:50:47 -08:00
|
|
|
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
|
|
|
void register_graph_buffers(fptr_t _fa,
|
|
|
|
const std::vector<std::vector<int64_t>>& handles,
|
2024-05-22 03:18:41 -04:00
|
|
|
const std::vector<std::vector<int64_t>>& offsets) {
|
|
|
|
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
|
2024-11-06 23:50:47 -08:00
|
|
|
std::vector<std::string> bytes;
|
|
|
|
bytes.reserve(handles.size());
|
|
|
|
for (int i = 0; i < handles.size(); i++) {
|
|
|
|
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
|
|
|
}
|
|
|
|
bytes.reserve(handles.size());
|
|
|
|
fa->register_graph_buffers(bytes, offsets);
|
2024-01-28 04:46:35 +08:00
|
|
|
}
|
2025-04-01 07:49:12 +02:00
|
|
|
|
|
|
|
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
|
|
|
int64_t size) {
|
|
|
|
auto device_index = c10::cuda::current_device();
|
|
|
|
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
|
|
|
|
void* buffer;
|
|
|
|
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
|
|
|
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
|
|
|
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
|
|
|
|
|
|
|
// Allocate buffer
|
|
|
|
#if defined(USE_ROCM)
|
|
|
|
// data buffers need to be "uncached" for signal on MI200
|
|
|
|
AT_CUDA_CHECK(
|
|
|
|
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
|
|
|
|
#else
|
|
|
|
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
|
|
|
|
#endif
|
|
|
|
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
|
|
|
|
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
|
|
|
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
|
|
|
|
|
|
|
// Create IPC memhandle for the allocated buffer.
|
|
|
|
// Will use it in open_mem_handle.
|
|
|
|
auto options =
|
|
|
|
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
|
|
|
auto handle =
|
|
|
|
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
|
|
|
|
AT_CUDA_CHECK(
|
|
|
|
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));
|
|
|
|
|
|
|
|
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
|
|
|
|
}
|
|
|
|
|
|
|
|
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
|
|
|
|
void* ipc_ptr;
|
|
|
|
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
|
|
|
|
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
|
|
|
|
cudaIpcMemLazyEnablePeerAccess));
|
|
|
|
return reinterpret_cast<fptr_t>(ipc_ptr);
|
|
|
|
}
|
|
|
|
|
|
|
|
void free_shared_buffer(fptr_t buffer) {
|
|
|
|
AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
|
|
|
}
|