[Core] Support fully transparent sleep mode (#11743)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
4004f144f3
commit
68ad4e3a8d
@ -76,7 +76,9 @@ steps:
|
||||
- tests/basic_correctness/test_basic_correctness
|
||||
- tests/basic_correctness/test_cpu_offload
|
||||
- tests/basic_correctness/test_preemption
|
||||
- tests/basic_correctness/test_cumem.py
|
||||
commands:
|
||||
- pytest -v -s basic_correctness/test_cumem.py
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
|
@ -181,6 +181,31 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
|
||||
# Define other extension targets
|
||||
#
|
||||
|
||||
#
|
||||
# cumem_allocator extension
|
||||
#
|
||||
|
||||
set(VLLM_CUMEM_EXT_SRC
|
||||
"csrc/cumem_allocator.cpp")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_CUMEM_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Enabling cumem allocator extension.")
|
||||
# link against cuda driver library
|
||||
list(APPEND CUMEM_LIBS cuda)
|
||||
define_gpu_extension_target(
|
||||
cumem_allocator
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_CUMEM_EXT_SRC}
|
||||
LIBRARIES ${CUMEM_LIBS}
|
||||
USE_SABI 3.8
|
||||
WITH_SOABI)
|
||||
endif()
|
||||
|
||||
#
|
||||
# _C extension
|
||||
#
|
||||
|
310
csrc/cumem_allocator.cpp
Normal file
310
csrc/cumem_allocator.cpp
Normal file
@ -0,0 +1,310 @@
|
||||
// A CUDAPluggableAllocator based on cumem* APIs.
|
||||
// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle*
|
||||
// need to be unsigned long long
|
||||
#include <iostream>
|
||||
|
||||
extern "C" {
|
||||
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda.h>
|
||||
|
||||
#define CUDA_CHECK(condition) \
|
||||
do { \
|
||||
CUresult error = condition; \
|
||||
if (error != 0) { \
|
||||
char* error_string; \
|
||||
cuGetErrorString(error, (const char**)&error_string); \
|
||||
std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \
|
||||
<< __LINE__ << std::endl; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Global references to Python callables
|
||||
// NOTE: this is borrowed reference, so we don't need to DECREF them.
|
||||
// This brings the limitation that the allocator needs to be singleton.
|
||||
static PyObject* g_python_malloc_callback = nullptr;
|
||||
static PyObject* g_python_free_callback = nullptr;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper functions:
|
||||
|
||||
void ensure_context(unsigned long long device) {
|
||||
CUcontext pctx;
|
||||
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
// Ensure device context.
|
||||
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK(cuCtxSetCurrent(pctx));
|
||||
}
|
||||
}
|
||||
|
||||
void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
|
||||
CUmemGenericAllocationHandle* p_memHandle) {
|
||||
ensure_context(device);
|
||||
// Define memory allocation properties
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = device;
|
||||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
// Allocate memory using cuMemCreate
|
||||
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
|
||||
CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0));
|
||||
|
||||
CUmemAccessDesc accessDesc = {};
|
||||
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
accessDesc.location.id = device;
|
||||
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
|
||||
CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1));
|
||||
// std::cout << "create_and_map: device=" << device << ", size=" << size << ",
|
||||
// d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
|
||||
}
|
||||
|
||||
void unmap_and_release(unsigned long long device, ssize_t size,
|
||||
CUdeviceptr d_mem,
|
||||
CUmemGenericAllocationHandle* p_memHandle) {
|
||||
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
|
||||
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
|
||||
ensure_context(device);
|
||||
CUDA_CHECK(cuMemUnmap(d_mem, size));
|
||||
CUDA_CHECK(cuMemRelease(*p_memHandle));
|
||||
}
|
||||
|
||||
PyObject* create_tuple_from_c_integers(unsigned long long a,
|
||||
unsigned long long b,
|
||||
unsigned long long c,
|
||||
unsigned long long d) {
|
||||
// Create a new tuple of size 4
|
||||
PyObject* tuple = PyTuple_New(4);
|
||||
if (!tuple) {
|
||||
return NULL; // Return NULL on failure
|
||||
}
|
||||
|
||||
// Convert integers to Python objects and set them in the tuple
|
||||
PyTuple_SetItem(
|
||||
tuple, 0,
|
||||
PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong
|
||||
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
|
||||
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
|
||||
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d));
|
||||
|
||||
// Note: PyTuple_SetItem "steals" a reference to each object,
|
||||
// so we do not need to Py_DECREF the PyLong objects explicitly.
|
||||
|
||||
return tuple; // Return the created tuple
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Our exported C functions that call Python:
|
||||
|
||||
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h
|
||||
void* my_malloc(ssize_t size, int device, CUstream stream) {
|
||||
ensure_context(device);
|
||||
|
||||
// first allocation, align the size, and reserve an address, and also allocate
|
||||
// a CUmemGenericAllocationHandle
|
||||
|
||||
// Define memory allocation properties
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = device;
|
||||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
// Check if the allocation is supported
|
||||
size_t granularity;
|
||||
CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop,
|
||||
CU_MEM_ALLOC_GRANULARITY_MINIMUM));
|
||||
|
||||
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
|
||||
|
||||
CUdeviceptr d_mem;
|
||||
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0));
|
||||
|
||||
// allocate the CUmemGenericAllocationHandle
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)malloc(
|
||||
sizeof(CUmemGenericAllocationHandle));
|
||||
|
||||
if (!g_python_malloc_callback) {
|
||||
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Acquire GIL (not in stable ABI officially, but often works)
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
PyObject* arg_tuple = create_tuple_from_c_integers(
|
||||
(unsigned long long)device, (unsigned long long)alignedSize,
|
||||
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
|
||||
|
||||
// Call g_python_malloc_callback
|
||||
PyObject* py_result =
|
||||
PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL);
|
||||
Py_DECREF(arg_tuple);
|
||||
|
||||
if (!py_result) {
|
||||
PyErr_Print();
|
||||
PyGILState_Release(gstate);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// do the final mapping
|
||||
create_and_map(device, alignedSize, d_mem, p_memHandle);
|
||||
|
||||
return (void*)d_mem;
|
||||
}
|
||||
|
||||
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h
|
||||
void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
|
||||
// get memory handle from the pointer
|
||||
if (!g_python_free_callback) {
|
||||
std::cerr << "ERROR: g_python_free_callback not set.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// Acquire GIL (not in stable ABI officially, but often works)
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
PyObject* py_ptr =
|
||||
PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr));
|
||||
|
||||
PyObject* py_result =
|
||||
PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL);
|
||||
|
||||
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size,
|
||||
&recv_d_mem, &recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return;
|
||||
}
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// recv_size == size
|
||||
// recv_device == device
|
||||
|
||||
// Free memory
|
||||
|
||||
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
unmap_and_release(device, size, d_mem, p_memHandle);
|
||||
|
||||
// free address and the handle
|
||||
CUDA_CHECK(cuMemAddressFree(d_mem, size));
|
||||
free(p_memHandle);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Python extension boilerplate:
|
||||
|
||||
// Python-exposed function: init_module(python_malloc, python_free)
|
||||
static PyObject* py_init_module(PyObject* self, PyObject* args) {
|
||||
PyObject* malloc_callback = nullptr;
|
||||
PyObject* free_callback = nullptr;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Save the Python callables
|
||||
// This module does not handle GC of these objects, so they must be kept alive
|
||||
// outside of this module.
|
||||
g_python_malloc_callback = malloc_callback;
|
||||
g_python_free_callback = free_callback;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
|
||||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
|
||||
&recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
|
||||
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
|
||||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
|
||||
&recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
|
||||
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"init_module", (PyCFunction)py_init_module, METH_VARARGS,
|
||||
"Initialize module with python_malloc and python_free callables."},
|
||||
{"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS,
|
||||
"Create and map memory on the device."},
|
||||
{"python_unmap_and_release", (PyCFunction)python_unmap_and_release,
|
||||
METH_VARARGS, "Unmap and release memory on the device."},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef cumem_allocator_module = {
|
||||
PyModuleDef_HEAD_INIT, "cumem_allocator",
|
||||
"cumem-based allocator for CUDAPluggableAllocator", -1, module_methods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_cumem_allocator(void) {
|
||||
// Initialize the module
|
||||
PyObject* module = PyModule_Create(&cumem_allocator_module);
|
||||
if (!module) {
|
||||
return NULL;
|
||||
}
|
||||
return module;
|
||||
}
|
||||
} // extern "C"
|
2
setup.py
2
setup.py
@ -301,6 +301,7 @@ class repackage_wheel(build_ext):
|
||||
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so",
|
||||
"vllm/vllm_flash_attn/flash_attn_interface.py",
|
||||
"vllm/vllm_flash_attn/__init__.py",
|
||||
"vllm/cumem_allocator.abi3.so",
|
||||
# "vllm/_version.py", # not available in nightly wheels yet
|
||||
]
|
||||
file_members = filter(lambda x: x.filename in files_to_copy,
|
||||
@ -594,6 +595,7 @@ if _is_hip():
|
||||
if _is_cuda():
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
|
||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
112
tests/basic_correctness/test_cumem.py
Normal file
112
tests/basic_correctness/test_cumem.py
Normal file
@ -0,0 +1,112 @@
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.utils import GiB_bytes
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_basic_cumem():
|
||||
# some tensors from default memory pool
|
||||
shape = (1024, 1024)
|
||||
x = torch.empty(shape, device='cuda')
|
||||
x.zero_()
|
||||
|
||||
# some tensors from custom memory pool
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
with allocator.use_memory_pool():
|
||||
# custom memory pool
|
||||
y = torch.empty(shape, device='cuda')
|
||||
y.zero_()
|
||||
y += 1
|
||||
z = torch.empty(shape, device='cuda')
|
||||
z.zero_()
|
||||
z += 2
|
||||
|
||||
# they can be used together
|
||||
output = x + y + z
|
||||
assert torch.allclose(output, torch.ones_like(output) * 3)
|
||||
|
||||
free_bytes = torch.cuda.mem_get_info()[0]
|
||||
allocator.sleep()
|
||||
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
|
||||
assert free_bytes_after_sleep > free_bytes
|
||||
allocator.wake_up()
|
||||
|
||||
# they can be used together
|
||||
output = x + y + z
|
||||
assert torch.allclose(output, torch.ones_like(output) * 3)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_cumem_with_cudagraph():
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
with allocator.use_memory_pool():
|
||||
weight = torch.eye(1024, device='cuda')
|
||||
with allocator.use_memory_pool(tag="discard"):
|
||||
cache = torch.empty(1024, 1024, device='cuda')
|
||||
|
||||
def model(x):
|
||||
out = x @ weight
|
||||
cache[:out.size(0)].copy_(out)
|
||||
return out + 1
|
||||
|
||||
x = torch.empty(128, 1024, device='cuda')
|
||||
|
||||
# warmup
|
||||
model(x)
|
||||
|
||||
# capture cudagraph
|
||||
model_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(model_graph):
|
||||
y = model(x)
|
||||
|
||||
free_bytes = torch.cuda.mem_get_info()[0]
|
||||
allocator.sleep()
|
||||
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
|
||||
assert free_bytes_after_sleep > free_bytes
|
||||
allocator.wake_up()
|
||||
|
||||
# after waking up, the content in the weight tensor
|
||||
# should be restored, but the content in the cache tensor
|
||||
# should be discarded
|
||||
|
||||
# this operation is also compatible with cudagraph
|
||||
|
||||
x.random_()
|
||||
model_graph.replay()
|
||||
|
||||
# cache content is as expected
|
||||
assert torch.allclose(x, cache[:x.size(0)])
|
||||
|
||||
# output content is as expected
|
||||
assert torch.allclose(y, x + 1)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_end_to_end():
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
used_bytes_baseline = total - free # in case other process is running
|
||||
llm = LLM("meta-llama/Llama-3.2-1B", enable_sleep_mode=True)
|
||||
prompt = "How are you?"
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
output = llm.generate(prompt, sampling_params)
|
||||
|
||||
# the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
|
||||
# which is difficult to measure in the test. therefore, we only
|
||||
# test sleep level 1 here.
|
||||
llm.sleep(level=1)
|
||||
|
||||
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
|
||||
# now the memory usage is mostly cudagraph memory pool,
|
||||
# and it should be less than the model weights (1B model, 2GiB weights)
|
||||
assert used_bytes < 2 * GiB_bytes
|
||||
|
||||
llm.wake_up()
|
||||
output2 = llm.generate(prompt, sampling_params)
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
@ -195,40 +195,43 @@ class ModelConfig:
|
||||
factors.append(self.rope_theta)
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
task: Union[TaskOption, Literal["draft"]],
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
allowed_local_media_path: str = "",
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 20,
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
disable_mm_preprocessor_cache: bool = False,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
override_pooler_config: Optional["PoolerConfig"] = None,
|
||||
logits_processor_pattern: Optional[str] = None,
|
||||
generation_config: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
task: Union[TaskOption, Literal["draft"]],
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
allowed_local_media_path: str = "",
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 20,
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
disable_mm_preprocessor_cache: bool = False,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
override_pooler_config: Optional["PoolerConfig"] = None,
|
||||
logits_processor_pattern: Optional[str] = None,
|
||||
generation_config: Optional[str] = None,
|
||||
enable_sleep_mode: bool = False,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
@ -277,6 +280,12 @@ class ModelConfig:
|
||||
self.max_logprobs = max_logprobs
|
||||
self.disable_sliding_window = disable_sliding_window
|
||||
self.skip_tokenizer_init = skip_tokenizer_init
|
||||
self.enable_sleep_mode = enable_sleep_mode
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if self.enable_sleep_mode and not current_platform.is_cuda():
|
||||
raise ValueError("Sleep mode is only supported on CUDA devices.")
|
||||
|
||||
hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision, config_format)
|
||||
@ -348,7 +357,6 @@ class ModelConfig:
|
||||
self.is_hybrid = self._init_is_hybrid()
|
||||
self.has_inner_state = self._init_has_inner_state()
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_neuron():
|
||||
self.override_neuron_config = override_neuron_config
|
||||
else:
|
||||
|
0
vllm/device_allocator/__init__.py
Normal file
0
vllm/device_allocator/__init__.py
Normal file
254
vllm/device_allocator/cumem.py
Normal file
254
vllm/device_allocator/cumem.py
Normal file
@ -0,0 +1,254 @@
|
||||
# cumem-based pytorch pluggable allocator to implement sleep mode.
|
||||
# other approaches tried but failed:
|
||||
# - cuda-python package binding
|
||||
# - custom libcuda driver ctypes wrapper
|
||||
# both of them failed because of cuda context mismatch.
|
||||
# not sure why, they are created from a different context.
|
||||
# the only successful approach is to call cuda driver API in C.
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
"""
|
||||
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
||||
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
||||
shared libraries loaded by the process. We can use this file to find the path of the
|
||||
a loaded library.
|
||||
""" # noqa
|
||||
found_line = None
|
||||
with open("/proc/self/maps") as f:
|
||||
for line in f:
|
||||
if lib_name in line:
|
||||
found_line = line
|
||||
break
|
||||
if found_line is None:
|
||||
# the library is not loaded in the current process
|
||||
return None
|
||||
# if lib_name is libcudart, we need to match a line with:
|
||||
# address /path/to/libcudart-hash.so.11.0
|
||||
start = found_line.index("/")
|
||||
path = found_line[start:].strip()
|
||||
filename = path.split("/")[-1]
|
||||
assert filename.rpartition(".so")[0].startswith(lib_name), \
|
||||
f"Unexpected filename: {filename} for library {lib_name}"
|
||||
return path
|
||||
|
||||
|
||||
cumem_available = False
|
||||
try:
|
||||
from vllm.cumem_allocator import (init_module, python_create_and_map,
|
||||
python_unmap_and_release)
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import (
|
||||
CudaRTLibrary)
|
||||
lib_name = find_loaded_library("cumem_allocator")
|
||||
libcudart = CudaRTLibrary()
|
||||
cumem_available = True
|
||||
except ModuleNotFoundError:
|
||||
# rocm platform does not support cumem allocator
|
||||
init_module = None
|
||||
python_create_and_map = None
|
||||
python_unmap_and_release = None
|
||||
CudaRTLibrary = None
|
||||
lib_name = None
|
||||
libcudart = None
|
||||
|
||||
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
|
||||
HandleType = Tuple[int, int, int, int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AllocationData:
|
||||
handle: HandleType
|
||||
tag: str
|
||||
cpu_backup_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def create_and_map(allocation_handle: HandleType) -> None:
|
||||
python_create_and_map(*allocation_handle)
|
||||
|
||||
|
||||
def unmap_and_release(allocation_handle: HandleType) -> None:
|
||||
python_unmap_and_release(*allocation_handle)
|
||||
|
||||
|
||||
def get_pluggable_allocator(
|
||||
python_malloc_fn: Callable[[int],
|
||||
int], python_free_func: Callable[[int, int],
|
||||
None]
|
||||
) -> torch.cuda.memory.CUDAPluggableAllocator:
|
||||
init_module(python_malloc_fn, python_free_func)
|
||||
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
||||
lib_name, 'my_malloc', 'my_free')
|
||||
return new_alloc
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_memory_pool_with_allocator(
|
||||
python_malloc_fn: Callable[[int], int],
|
||||
python_free_func: Callable[[int, int], None]) -> None:
|
||||
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
|
||||
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
|
||||
with torch.cuda.memory.use_mem_pool(mem_pool):
|
||||
yield mem_pool
|
||||
|
||||
|
||||
class CuMemAllocator:
|
||||
"""
|
||||
A singleton class that manages a memory pool for CUDA tensors.
|
||||
The memory in this pool can be offloaded or discarded when the
|
||||
allocator sleeps.
|
||||
|
||||
Inside the `use_memory_pool(tag)` context, all tensors created will
|
||||
be allocated in the memory pool, and has the same tag as the
|
||||
tag passed to the context.
|
||||
|
||||
When we call `sleep`, all tensors with the specified tag will be
|
||||
offloaded to CPU memory, and the rest of the tensors will be discarded.
|
||||
When we call `wake_up`, all tensors that are previously offloaded
|
||||
will be loaded back to GPU memory, and the rest of the tensors will
|
||||
have empty memory.
|
||||
|
||||
Why it needs to be a singleton?
|
||||
When allocated tensors are garbage collected, PyTorch will call
|
||||
the free callback, which will call the `python_free_callback` method.
|
||||
The C-extension uses a global variable to store the function of an
|
||||
instance of this class. If we create multiple instances of this class,
|
||||
the global variable will be overwritten and the free callback will
|
||||
not work as expected.
|
||||
"""
|
||||
instance: "CuMemAllocator" = None
|
||||
default_tag: str = "default"
|
||||
|
||||
@staticmethod
|
||||
def get_instance() -> "CuMemAllocator":
|
||||
"""
|
||||
CuMemAllocator is a singleton class.
|
||||
We cannot call the constructor directly.
|
||||
Call this method to get the instance.
|
||||
"""
|
||||
assert cumem_available, "cumem allocator is not available"
|
||||
if CuMemAllocator.instance is None:
|
||||
CuMemAllocator.instance = CuMemAllocator()
|
||||
return CuMemAllocator.instance
|
||||
|
||||
def __init__(self):
|
||||
self.pointer_to_data: Dict[int, AllocationData] = {}
|
||||
self.current_tag: str = CuMemAllocator.default_tag
|
||||
|
||||
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
|
||||
"""
|
||||
Internal method to store the allocation data
|
||||
when memory is allocated in the memory pool."""
|
||||
py_d_mem = allocation_handle[2]
|
||||
self.pointer_to_data[py_d_mem] = AllocationData(
|
||||
allocation_handle, self.current_tag)
|
||||
return
|
||||
|
||||
def python_free_callback(self, ptr: int) -> HandleType:
|
||||
"""
|
||||
Internal method to look up the allocation data
|
||||
when memory is freed in the memory pool."""
|
||||
data = self.pointer_to_data.pop(ptr)
|
||||
if data.cpu_backup_tensor is not None:
|
||||
data.cpu_backup_tensor = None
|
||||
return data.handle
|
||||
|
||||
def sleep(
|
||||
self,
|
||||
offload_tags: Optional[Union[Tuple[str, ...],
|
||||
str]] = None) -> None:
|
||||
"""
|
||||
Put the allocator in sleep mode.
|
||||
All data in the memory allocation with the specified tag will be
|
||||
offloaded to CPU memory, and others will be discarded.
|
||||
|
||||
:param offload_tags: The tags of the memory allocation that will be
|
||||
offloaded. The rest of the memory allocation will be discarded.
|
||||
"""
|
||||
if offload_tags is None:
|
||||
# by default, allocated tensors are offloaded
|
||||
# when the allocator sleeps
|
||||
offload_tags = (CuMemAllocator.default_tag, )
|
||||
elif isinstance(offload_tags, str):
|
||||
offload_tags = (offload_tags, )
|
||||
|
||||
assert isinstance(offload_tags, tuple)
|
||||
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
handle = data.handle
|
||||
if data.tag in offload_tags:
|
||||
size_in_bytes = handle[1]
|
||||
cpu_backup_tensor = torch.empty(
|
||||
size_in_bytes,
|
||||
dtype=torch.uint8,
|
||||
device='cpu',
|
||||
pin_memory=is_pin_memory_available())
|
||||
cpu_ptr = cpu_backup_tensor.data_ptr()
|
||||
libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
|
||||
data.cpu_backup_tensor = cpu_backup_tensor
|
||||
unmap_and_release(handle)
|
||||
|
||||
def wake_up(self):
|
||||
"""
|
||||
Wake up the allocator from sleep mode.
|
||||
All data that is previously offloaded will be loaded back to GPU
|
||||
memory, and the rest of the data will have empty memory."""
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
handle = data.handle
|
||||
create_and_map(handle)
|
||||
if data.cpu_backup_tensor is not None:
|
||||
cpu_backup_tensor = data.cpu_backup_tensor
|
||||
if cpu_backup_tensor is not None:
|
||||
size_in_bytes = cpu_backup_tensor.numel(
|
||||
) * cpu_backup_tensor.element_size()
|
||||
cpu_ptr = cpu_backup_tensor.data_ptr()
|
||||
libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
|
||||
data.cpu_backup_tensor = None
|
||||
|
||||
@contextmanager
|
||||
def use_memory_pool(self, tag: Optional[str] = None):
|
||||
"""
|
||||
A context manager to use the memory pool.
|
||||
All memory allocation created inside the context will be allocated
|
||||
in the memory pool, and has the specified tag.
|
||||
|
||||
:param tag: The tag of the memory allocation. If None, the default tag
|
||||
will be used.
|
||||
"""
|
||||
if tag is None:
|
||||
tag = CuMemAllocator.default_tag
|
||||
|
||||
assert isinstance(tag, str)
|
||||
|
||||
old_tag = self.current_tag
|
||||
self.current_tag = tag
|
||||
with use_memory_pool_with_allocator(self.python_malloc_callback,
|
||||
self.python_free_callback):
|
||||
yield
|
||||
# PyTorch's bug, calling torch.cuda.empty_cache() will error
|
||||
# when using pluggable allocator, see
|
||||
# https://github.com/pytorch/pytorch/issues/145168 .
|
||||
# if we have some memory allocated and then freed,
|
||||
# the memory will not be released.
|
||||
# right now it is fine, because we only use this allocator
|
||||
# during weight loading and kv cache creation, where we only
|
||||
# allocate memory.
|
||||
# TODO: we need to find a way to release the memory,
|
||||
# i.e. calling torch.cuda.empty_cache()
|
||||
self.current_tag = old_tag
|
||||
|
||||
def get_current_usage(self) -> int:
|
||||
"""
|
||||
Get the total number of bytes allocated in the memory pool.
|
||||
"""
|
||||
sum_bytes: int = 0
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
handle = data.handle
|
||||
sum_bytes += handle[1]
|
||||
return sum_bytes
|
@ -197,6 +197,7 @@ class EngineArgs:
|
||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||
|
||||
generation_config: Optional[str] = None
|
||||
enable_sleep_mode: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.tokenizer:
|
||||
@ -955,6 +956,12 @@ class EngineArgs:
|
||||
"loaded from model. If set to a folder path, the generation config "
|
||||
"will be loaded from the specified folder path.")
|
||||
|
||||
parser.add_argument("--enable-sleep-mode",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable sleep mode for the engine. "
|
||||
"(only cuda platform is supported)")
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -999,7 +1006,9 @@ class EngineArgs:
|
||||
override_neuron_config=self.override_neuron_config,
|
||||
override_pooler_config=self.override_pooler_config,
|
||||
logits_processor_pattern=self.logits_processor_pattern,
|
||||
generation_config=self.generation_config)
|
||||
generation_config=self.generation_config,
|
||||
enable_sleep_mode=self.enable_sleep_mode,
|
||||
)
|
||||
|
||||
def create_load_config(self) -> LoadConfig:
|
||||
return LoadConfig(
|
||||
|
@ -1818,6 +1818,16 @@ class LLMEngine:
|
||||
def stop_profile(self) -> None:
|
||||
self.model_executor.stop_profile()
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
assert self.vllm_config.model_config.enable_sleep_mode, (
|
||||
"Sleep mode is not enabled in the model config")
|
||||
self.model_executor.sleep(level=level)
|
||||
|
||||
def wake_up(self) -> None:
|
||||
assert self.vllm_config.model_config.enable_sleep_mode, (
|
||||
"Sleep mode is not enabled in the model config")
|
||||
self.model_executor.wake_up()
|
||||
|
||||
def check_health(self) -> None:
|
||||
if self.tokenizer:
|
||||
self.tokenizer.check_health()
|
||||
|
@ -1132,6 +1132,29 @@ class LLM:
|
||||
def stop_profile(self) -> None:
|
||||
self.llm_engine.stop_profile()
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
"""
|
||||
Put the engine to sleep. The engine should not process any requests.
|
||||
The caller should guarantee that no requests are being processed
|
||||
during the sleep period, before `wake_up` is called.
|
||||
|
||||
:param level: The sleep level. Level 1 sleep will offload the model
|
||||
weights and discard the kv cache. The content of kv cache is
|
||||
forgotten. Level 1 sleep is good for sleeping and waking up the
|
||||
engine to run the same model again. The model weights are backed
|
||||
up in CPU memory. Please make sure there's enough CPU memory to
|
||||
store the model weights. Level 2 sleep will discard both the model
|
||||
weights and the kv cache. The content of both the model weights
|
||||
and kv cache is forgotten. Level 2 sleep is good for sleeping and
|
||||
waking up the engine to run a different model or update the model,
|
||||
where previous model weights are not needed. It reduces CPU memory
|
||||
pressure.
|
||||
"""
|
||||
self.llm_engine.sleep(level=level)
|
||||
|
||||
def wake_up(self):
|
||||
self.llm_engine.wake_up()
|
||||
|
||||
# LEGACY
|
||||
def _convert_v1_inputs(
|
||||
self,
|
||||
|
@ -193,6 +193,17 @@ class ExecutorBase(ABC):
|
||||
def stop_profile(self) -> None:
|
||||
self.collective_rpc("stop_profile")
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
# TODO: support sleep with prefix caching
|
||||
# by resetting the prefix cache state,
|
||||
# after https://github.com/vllm-project/vllm/pull/12284
|
||||
raise ValueError("Cannot sleep when prefix caching is enabled.")
|
||||
self.collective_rpc("sleep", kwargs=dict(level=level))
|
||||
|
||||
def wake_up(self):
|
||||
self.collective_rpc("wake_up")
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
|
@ -9,12 +9,14 @@ import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
@ -77,6 +79,23 @@ class Worker:
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
||||
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||
used_bytes = total - free_bytes_after_sleep
|
||||
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
||||
logger.info(
|
||||
"Sleep mode freed %.2f GiB memory, "
|
||||
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
||||
used_bytes / GiB_bytes)
|
||||
|
||||
def wake_up(self) -> None:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.wake_up()
|
||||
|
||||
def init_device(self):
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
@ -110,7 +129,17 @@ class Worker:
|
||||
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model_runner.load_model()
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be "
|
||||
"used for one instance per process.")
|
||||
context = allocator.use_memory_pool(tag="weights")
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.load_model()
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory(self) -> int:
|
||||
@ -167,7 +196,14 @@ class Worker:
|
||||
|
||||
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
|
@ -8,6 +8,7 @@ import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.distributed import (ensure_kv_transfer_initialized,
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
@ -120,6 +121,23 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
||||
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||
used_bytes = total - free_bytes_after_sleep
|
||||
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
||||
logger.info(
|
||||
"Sleep mode freed %.2f GiB memory, "
|
||||
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
||||
used_bytes / GiB_bytes)
|
||||
|
||||
def wake_up(self) -> None:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
allocator.wake_up()
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
@ -151,7 +169,17 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be "
|
||||
"used for one instance per process.")
|
||||
context = allocator.use_memory_pool(tag="weights")
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self.model_runner.load_model()
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
@ -270,7 +298,14 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self._init_cache_engine()
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
else:
|
||||
from contextlib import nullcontext
|
||||
context = nullcontext()
|
||||
with context:
|
||||
self._init_cache_engine()
|
||||
self._warm_up_model()
|
||||
|
||||
def _init_cache_engine(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user