[Core] Support fully transparent sleep mode (#11743)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-22 14:39:32 +08:00 committed by GitHub
parent 4004f144f3
commit 68ad4e3a8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 877 additions and 40 deletions

View File

@ -76,7 +76,9 @@ steps:
- tests/basic_correctness/test_basic_correctness
- tests/basic_correctness/test_cpu_offload
- tests/basic_correctness/test_preemption
- tests/basic_correctness/test_cumem.py
commands:
- pytest -v -s basic_correctness/test_cumem.py
- pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py

View File

@ -181,6 +181,31 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
# Define other extension targets
#
#
# cumem_allocator extension
#
set(VLLM_CUMEM_EXT_SRC
"csrc/cumem_allocator.cpp")
set_gencode_flags_for_srcs(
SRCS "${VLLM_CUMEM_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Enabling cumem allocator extension.")
# link against cuda driver library
list(APPEND CUMEM_LIBS cuda)
define_gpu_extension_target(
cumem_allocator
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_CUMEM_EXT_SRC}
LIBRARIES ${CUMEM_LIBS}
USE_SABI 3.8
WITH_SOABI)
endif()
#
# _C extension
#

310
csrc/cumem_allocator.cpp Normal file
View File

@ -0,0 +1,310 @@
// A CUDAPluggableAllocator based on cumem* APIs.
// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle*
// need to be unsigned long long
#include <iostream>
extern "C" {
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <cuda.h>
#define CUDA_CHECK(condition) \
do { \
CUresult error = condition; \
if (error != 0) { \
char* error_string; \
cuGetErrorString(error, (const char**)&error_string); \
std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl; \
} \
} while (0)
// Global references to Python callables
// NOTE: this is borrowed reference, so we don't need to DECREF them.
// This brings the limitation that the allocator needs to be singleton.
static PyObject* g_python_malloc_callback = nullptr;
static PyObject* g_python_free_callback = nullptr;
// ---------------------------------------------------------------------------
// Helper functions:
void ensure_context(unsigned long long device) {
CUcontext pctx;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
if (!pctx) {
// Ensure device context.
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}
}
void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
CUmemGenericAllocationHandle* p_memHandle) {
ensure_context(device);
// Define memory allocation properties
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
// Allocate memory using cuMemCreate
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0));
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = device;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1));
// std::cout << "create_and_map: device=" << device << ", size=" << size << ",
// d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
}
void unmap_and_release(unsigned long long device, ssize_t size,
CUdeviceptr d_mem,
CUmemGenericAllocationHandle* p_memHandle) {
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
ensure_context(device);
CUDA_CHECK(cuMemUnmap(d_mem, size));
CUDA_CHECK(cuMemRelease(*p_memHandle));
}
PyObject* create_tuple_from_c_integers(unsigned long long a,
unsigned long long b,
unsigned long long c,
unsigned long long d) {
// Create a new tuple of size 4
PyObject* tuple = PyTuple_New(4);
if (!tuple) {
return NULL; // Return NULL on failure
}
// Convert integers to Python objects and set them in the tuple
PyTuple_SetItem(
tuple, 0,
PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d));
// Note: PyTuple_SetItem "steals" a reference to each object,
// so we do not need to Py_DECREF the PyLong objects explicitly.
return tuple; // Return the created tuple
}
// ---------------------------------------------------------------------------
// Our exported C functions that call Python:
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h
void* my_malloc(ssize_t size, int device, CUstream stream) {
ensure_context(device);
// first allocation, align the size, and reserve an address, and also allocate
// a CUmemGenericAllocationHandle
// Define memory allocation properties
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device;
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
// Check if the allocation is supported
size_t granularity;
CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop,
CU_MEM_ALLOC_GRANULARITY_MINIMUM));
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
CUdeviceptr d_mem;
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0));
// allocate the CUmemGenericAllocationHandle
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)malloc(
sizeof(CUmemGenericAllocationHandle));
if (!g_python_malloc_callback) {
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
return nullptr;
}
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* arg_tuple = create_tuple_from_c_integers(
(unsigned long long)device, (unsigned long long)alignedSize,
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
// Call g_python_malloc_callback
PyObject* py_result =
PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL);
Py_DECREF(arg_tuple);
if (!py_result) {
PyErr_Print();
PyGILState_Release(gstate);
return nullptr;
}
PyGILState_Release(gstate);
// do the final mapping
create_and_map(device, alignedSize, d_mem, p_memHandle);
return (void*)d_mem;
}
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h
void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
// get memory handle from the pointer
if (!g_python_free_callback) {
std::cerr << "ERROR: g_python_free_callback not set.\n";
return;
}
// Acquire GIL (not in stable ABI officially, but often works)
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* py_ptr =
PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr));
PyObject* py_result =
PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL);
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size,
&recv_d_mem, &recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return;
}
PyGILState_Release(gstate);
// recv_size == size
// recv_device == device
// Free memory
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(device, size, d_mem, p_memHandle);
// free address and the handle
CUDA_CHECK(cuMemAddressFree(d_mem, size));
free(p_memHandle);
}
// ---------------------------------------------------------------------------
// Python extension boilerplate:
// Python-exposed function: init_module(python_malloc, python_free)
static PyObject* py_init_module(PyObject* self, PyObject* args) {
PyObject* malloc_callback = nullptr;
PyObject* free_callback = nullptr;
if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) {
return nullptr;
}
if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) {
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
return nullptr;
}
// Save the Python callables
// This module does not handle GC of these objects, so they must be kept alive
// outside of this module.
g_python_malloc_callback = malloc_callback;
g_python_free_callback = free_callback;
Py_RETURN_NONE;
}
static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
Py_RETURN_NONE;
}
static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
return nullptr;
}
unsigned long long recv_device, recv_size;
unsigned long long recv_d_mem, recv_p_memHandle;
// Unpack the tuple into four C integers
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
&recv_p_memHandle)) {
// PyArg_ParseTuple sets an error if it fails
return nullptr;
}
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
CUmemGenericAllocationHandle* p_memHandle =
(CUmemGenericAllocationHandle*)recv_p_memHandle;
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
Py_RETURN_NONE;
}
static PyMethodDef module_methods[] = {
{"init_module", (PyCFunction)py_init_module, METH_VARARGS,
"Initialize module with python_malloc and python_free callables."},
{"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS,
"Create and map memory on the device."},
{"python_unmap_and_release", (PyCFunction)python_unmap_and_release,
METH_VARARGS, "Unmap and release memory on the device."},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef cumem_allocator_module = {
PyModuleDef_HEAD_INIT, "cumem_allocator",
"cumem-based allocator for CUDAPluggableAllocator", -1, module_methods};
PyMODINIT_FUNC PyInit_cumem_allocator(void) {
// Initialize the module
PyObject* module = PyModule_Create(&cumem_allocator_module);
if (!module) {
return NULL;
}
return module;
}
} // extern "C"

View File

@ -301,6 +301,7 @@ class repackage_wheel(build_ext):
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so",
"vllm/vllm_flash_attn/flash_attn_interface.py",
"vllm/vllm_flash_attn/__init__.py",
"vllm/cumem_allocator.abi3.so",
# "vllm/_version.py", # not available in nightly wheels yet
]
file_members = filter(lambda x: x.filename in files_to_copy,
@ -594,6 +595,7 @@ if _is_hip():
if _is_cuda():
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))

View File

@ -0,0 +1,112 @@
import torch
from vllm import LLM, SamplingParams
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.utils import GiB_bytes
from ..utils import fork_new_process_for_each_test
@fork_new_process_for_each_test
def test_basic_cumem():
# some tensors from default memory pool
shape = (1024, 1024)
x = torch.empty(shape, device='cuda')
x.zero_()
# some tensors from custom memory pool
allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool():
# custom memory pool
y = torch.empty(shape, device='cuda')
y.zero_()
y += 1
z = torch.empty(shape, device='cuda')
z.zero_()
z += 2
# they can be used together
output = x + y + z
assert torch.allclose(output, torch.ones_like(output) * 3)
free_bytes = torch.cuda.mem_get_info()[0]
allocator.sleep()
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
assert free_bytes_after_sleep > free_bytes
allocator.wake_up()
# they can be used together
output = x + y + z
assert torch.allclose(output, torch.ones_like(output) * 3)
@fork_new_process_for_each_test
def test_cumem_with_cudagraph():
allocator = CuMemAllocator.get_instance()
with allocator.use_memory_pool():
weight = torch.eye(1024, device='cuda')
with allocator.use_memory_pool(tag="discard"):
cache = torch.empty(1024, 1024, device='cuda')
def model(x):
out = x @ weight
cache[:out.size(0)].copy_(out)
return out + 1
x = torch.empty(128, 1024, device='cuda')
# warmup
model(x)
# capture cudagraph
model_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(model_graph):
y = model(x)
free_bytes = torch.cuda.mem_get_info()[0]
allocator.sleep()
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
assert free_bytes_after_sleep > free_bytes
allocator.wake_up()
# after waking up, the content in the weight tensor
# should be restored, but the content in the cache tensor
# should be discarded
# this operation is also compatible with cudagraph
x.random_()
model_graph.replay()
# cache content is as expected
assert torch.allclose(x, cache[:x.size(0)])
# output content is as expected
assert torch.allclose(y, x + 1)
@fork_new_process_for_each_test
def test_end_to_end():
free, total = torch.cuda.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
llm = LLM("meta-llama/Llama-3.2-1B", enable_sleep_mode=True)
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
output = llm.generate(prompt, sampling_params)
# the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
# which is difficult to measure in the test. therefore, we only
# test sleep level 1 here.
llm.sleep(level=1)
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
# now the memory usage is mostly cudagraph memory pool,
# and it should be less than the model weights (1B model, 2GiB weights)
assert used_bytes < 2 * GiB_bytes
llm.wake_up()
output2 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text

View File

@ -195,40 +195,43 @@ class ModelConfig:
factors.append(self.rope_theta)
return hashlib.sha256(str(factors).encode()).hexdigest()
def __init__(self,
model: str,
task: Union[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None) -> None:
def __init__(
self,
model: str,
task: Union[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None,
enable_sleep_mode: bool = False,
) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
@ -277,6 +280,12 @@ class ModelConfig:
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init
self.enable_sleep_mode = enable_sleep_mode
from vllm.platforms import current_platform
if self.enable_sleep_mode and not current_platform.is_cuda():
raise ValueError("Sleep mode is only supported on CUDA devices.")
hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, config_format)
@ -348,7 +357,6 @@ class ModelConfig:
self.is_hybrid = self._init_is_hybrid()
self.has_inner_state = self._init_has_inner_state()
from vllm.platforms import current_platform
if current_platform.is_neuron():
self.override_neuron_config = override_neuron_config
else:

View File

View File

@ -0,0 +1,254 @@
# cumem-based pytorch pluggable allocator to implement sleep mode.
# other approaches tried but failed:
# - cuda-python package binding
# - custom libcuda driver ctypes wrapper
# both of them failed because of cuda context mismatch.
# not sure why, they are created from a different context.
# the only successful approach is to call cuda driver API in C.
import dataclasses
from contextlib import contextmanager
from typing import Callable, Dict, Optional, Tuple, Union
import torch
from vllm.utils import is_pin_memory_available
def find_loaded_library(lib_name) -> Optional[str]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found_line = None
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found_line = line
break
if found_line is None:
# the library is not loaded in the current process
return None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start = found_line.index("/")
path = found_line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(lib_name), \
f"Unexpected filename: {filename} for library {lib_name}"
return path
cumem_available = False
try:
from vllm.cumem_allocator import (init_module, python_create_and_map,
python_unmap_and_release)
from vllm.distributed.device_communicators.cuda_wrapper import (
CudaRTLibrary)
lib_name = find_loaded_library("cumem_allocator")
libcudart = CudaRTLibrary()
cumem_available = True
except ModuleNotFoundError:
# rocm platform does not support cumem allocator
init_module = None
python_create_and_map = None
python_unmap_and_release = None
CudaRTLibrary = None
lib_name = None
libcudart = None
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
HandleType = Tuple[int, int, int, int]
@dataclasses.dataclass
class AllocationData:
handle: HandleType
tag: str
cpu_backup_tensor: Optional[torch.Tensor] = None
def create_and_map(allocation_handle: HandleType) -> None:
python_create_and_map(*allocation_handle)
def unmap_and_release(allocation_handle: HandleType) -> None:
python_unmap_and_release(*allocation_handle)
def get_pluggable_allocator(
python_malloc_fn: Callable[[int],
int], python_free_func: Callable[[int, int],
None]
) -> torch.cuda.memory.CUDAPluggableAllocator:
init_module(python_malloc_fn, python_free_func)
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
lib_name, 'my_malloc', 'my_free')
return new_alloc
@contextmanager
def use_memory_pool_with_allocator(
python_malloc_fn: Callable[[int], int],
python_free_func: Callable[[int, int], None]) -> None:
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
with torch.cuda.memory.use_mem_pool(mem_pool):
yield mem_pool
class CuMemAllocator:
"""
A singleton class that manages a memory pool for CUDA tensors.
The memory in this pool can be offloaded or discarded when the
allocator sleeps.
Inside the `use_memory_pool(tag)` context, all tensors created will
be allocated in the memory pool, and has the same tag as the
tag passed to the context.
When we call `sleep`, all tensors with the specified tag will be
offloaded to CPU memory, and the rest of the tensors will be discarded.
When we call `wake_up`, all tensors that are previously offloaded
will be loaded back to GPU memory, and the rest of the tensors will
have empty memory.
Why it needs to be a singleton?
When allocated tensors are garbage collected, PyTorch will call
the free callback, which will call the `python_free_callback` method.
The C-extension uses a global variable to store the function of an
instance of this class. If we create multiple instances of this class,
the global variable will be overwritten and the free callback will
not work as expected.
"""
instance: "CuMemAllocator" = None
default_tag: str = "default"
@staticmethod
def get_instance() -> "CuMemAllocator":
"""
CuMemAllocator is a singleton class.
We cannot call the constructor directly.
Call this method to get the instance.
"""
assert cumem_available, "cumem allocator is not available"
if CuMemAllocator.instance is None:
CuMemAllocator.instance = CuMemAllocator()
return CuMemAllocator.instance
def __init__(self):
self.pointer_to_data: Dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
Internal method to store the allocation data
when memory is allocated in the memory pool."""
py_d_mem = allocation_handle[2]
self.pointer_to_data[py_d_mem] = AllocationData(
allocation_handle, self.current_tag)
return
def python_free_callback(self, ptr: int) -> HandleType:
"""
Internal method to look up the allocation data
when memory is freed in the memory pool."""
data = self.pointer_to_data.pop(ptr)
if data.cpu_backup_tensor is not None:
data.cpu_backup_tensor = None
return data.handle
def sleep(
self,
offload_tags: Optional[Union[Tuple[str, ...],
str]] = None) -> None:
"""
Put the allocator in sleep mode.
All data in the memory allocation with the specified tag will be
offloaded to CPU memory, and others will be discarded.
:param offload_tags: The tags of the memory allocation that will be
offloaded. The rest of the memory allocation will be discarded.
"""
if offload_tags is None:
# by default, allocated tensors are offloaded
# when the allocator sleeps
offload_tags = (CuMemAllocator.default_tag, )
elif isinstance(offload_tags, str):
offload_tags = (offload_tags, )
assert isinstance(offload_tags, tuple)
for ptr, data in self.pointer_to_data.items():
handle = data.handle
if data.tag in offload_tags:
size_in_bytes = handle[1]
cpu_backup_tensor = torch.empty(
size_in_bytes,
dtype=torch.uint8,
device='cpu',
pin_memory=is_pin_memory_available())
cpu_ptr = cpu_backup_tensor.data_ptr()
libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
data.cpu_backup_tensor = cpu_backup_tensor
unmap_and_release(handle)
def wake_up(self):
"""
Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory."""
for ptr, data in self.pointer_to_data.items():
handle = data.handle
create_and_map(handle)
if data.cpu_backup_tensor is not None:
cpu_backup_tensor = data.cpu_backup_tensor
if cpu_backup_tensor is not None:
size_in_bytes = cpu_backup_tensor.numel(
) * cpu_backup_tensor.element_size()
cpu_ptr = cpu_backup_tensor.data_ptr()
libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
data.cpu_backup_tensor = None
@contextmanager
def use_memory_pool(self, tag: Optional[str] = None):
"""
A context manager to use the memory pool.
All memory allocation created inside the context will be allocated
in the memory pool, and has the specified tag.
:param tag: The tag of the memory allocation. If None, the default tag
will be used.
"""
if tag is None:
tag = CuMemAllocator.default_tag
assert isinstance(tag, str)
old_tag = self.current_tag
self.current_tag = tag
with use_memory_pool_with_allocator(self.python_malloc_callback,
self.python_free_callback):
yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see
# https://github.com/pytorch/pytorch/issues/145168 .
# if we have some memory allocated and then freed,
# the memory will not be released.
# right now it is fine, because we only use this allocator
# during weight loading and kv cache creation, where we only
# allocate memory.
# TODO: we need to find a way to release the memory,
# i.e. calling torch.cuda.empty_cache()
self.current_tag = old_tag
def get_current_usage(self) -> int:
"""
Get the total number of bytes allocated in the memory pool.
"""
sum_bytes: int = 0
for ptr, data in self.pointer_to_data.items():
handle = data.handle
sum_bytes += handle[1]
return sum_bytes

View File

@ -197,6 +197,7 @@ class EngineArgs:
kv_transfer_config: Optional[KVTransferConfig] = None
generation_config: Optional[str] = None
enable_sleep_mode: bool = False
def __post_init__(self):
if not self.tokenizer:
@ -955,6 +956,12 @@ class EngineArgs:
"loaded from model. If set to a folder path, the generation config "
"will be loaded from the specified folder path.")
parser.add_argument("--enable-sleep-mode",
action="store_true",
default=False,
help="Enable sleep mode for the engine. "
"(only cuda platform is supported)")
return parser
@classmethod
@ -999,7 +1006,9 @@ class EngineArgs:
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config)
generation_config=self.generation_config,
enable_sleep_mode=self.enable_sleep_mode,
)
def create_load_config(self) -> LoadConfig:
return LoadConfig(

View File

@ -1818,6 +1818,16 @@ class LLMEngine:
def stop_profile(self) -> None:
self.model_executor.stop_profile()
def sleep(self, level: int = 1) -> None:
assert self.vllm_config.model_config.enable_sleep_mode, (
"Sleep mode is not enabled in the model config")
self.model_executor.sleep(level=level)
def wake_up(self) -> None:
assert self.vllm_config.model_config.enable_sleep_mode, (
"Sleep mode is not enabled in the model config")
self.model_executor.wake_up()
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()

View File

@ -1132,6 +1132,29 @@ class LLM:
def stop_profile(self) -> None:
self.llm_engine.stop_profile()
def sleep(self, level: int = 1):
"""
Put the engine to sleep. The engine should not process any requests.
The caller should guarantee that no requests are being processed
during the sleep period, before `wake_up` is called.
:param level: The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. The content of kv cache is
forgotten. Level 1 sleep is good for sleeping and waking up the
engine to run the same model again. The model weights are backed
up in CPU memory. Please make sure there's enough CPU memory to
store the model weights. Level 2 sleep will discard both the model
weights and the kv cache. The content of both the model weights
and kv cache is forgotten. Level 2 sleep is good for sleeping and
waking up the engine to run a different model or update the model,
where previous model weights are not needed. It reduces CPU memory
pressure.
"""
self.llm_engine.sleep(level=level)
def wake_up(self):
self.llm_engine.wake_up()
# LEGACY
def _convert_v1_inputs(
self,

View File

@ -193,6 +193,17 @@ class ExecutorBase(ABC):
def stop_profile(self) -> None:
self.collective_rpc("stop_profile")
def sleep(self, level: int = 1):
if self.cache_config.enable_prefix_caching:
# TODO: support sleep with prefix caching
# by resetting the prefix cache state,
# after https://github.com/vllm-project/vllm/pull/12284
raise ValueError("Cannot sleep when prefix caching is enabled.")
self.collective_rpc("sleep", kwargs=dict(level=level))
def wake_up(self):
self.collective_rpc("wake_up")
def save_sharded_state(
self,
path: str,

View File

@ -9,12 +9,14 @@ import torch.nn as nn
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
@ -77,6 +79,23 @@ class Worker:
else:
self.profiler = None
def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def wake_up(self) -> None:
allocator = CuMemAllocator.get_instance()
allocator.wake_up()
def init_device(self):
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
@ -110,7 +129,17 @@ class Worker:
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
def load_model(self) -> None:
self.model_runner.load_model()
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.load_model()
@torch.inference_mode()
def determine_available_memory(self) -> int:
@ -167,7 +196,14 @@ class Worker:
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
self.model_runner.initialize_kv_cache(kv_cache_config)
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:

View File

@ -8,6 +8,7 @@ import torch.distributed
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_kv_transfer_initialized,
ensure_model_parallel_initialized,
init_distributed_environment,
@ -120,6 +121,23 @@ class Worker(LocalOrDistributedWorkerBase):
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def wake_up(self) -> None:
allocator = CuMemAllocator.get_instance()
allocator.wake_up()
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
@ -151,7 +169,17 @@ class Worker(LocalOrDistributedWorkerBase):
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.load_model()
def save_sharded_state(
self,
@ -270,7 +298,14 @@ class Worker(LocalOrDistributedWorkerBase):
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._init_cache_engine()
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self._init_cache_engine()
self._warm_up_model()
def _init_cache_engine(self):