[core] Add tags parameter to wake_up() (#15500)

Signed-off-by: Eric <erictang000@gmail.com>
This commit is contained in:
Eric Tang 2025-04-02 01:59:27 -07:00 committed by GitHub
parent 90969fb39a
commit ddb94c2605
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 144 additions and 71 deletions

View File

@ -155,6 +155,24 @@ def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
llm.wake_up() llm.wake_up()
output2 = llm.generate(prompt, sampling_params) output2 = llm.generate(prompt, sampling_params)
# cmp output # cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text assert output[0].outputs[0].text == output2[0].outputs[0].text
llm.sleep(level=1)
llm.wake_up(tags=["weights"])
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
# should just reallocate memory for weights (1B model, ~2GiB weights)
if use_v1:
assert used_bytes < 10 * GiB_bytes
else:
assert used_bytes < 6 * GiB_bytes
# now allocate kv cache memory
llm.wake_up(tags=["kv_cache"])
output3 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output3[0].outputs[0].text

View File

@ -25,16 +25,37 @@ def test_sleep_mode():
"VLLM_SERVER_DEV_MODE": "1", "VLLM_SERVER_DEV_MODE": "1",
"CUDA_VISIBLE_DEVICES": "0" "CUDA_VISIBLE_DEVICES": "0"
}) as remote_server: }) as remote_server:
response = requests.post(remote_server.url_for("sleep"),
response = requests.post(remote_server.url_for("/sleep"),
params={"level": "1"}) params={"level": "1"})
assert response.status_code == 200 assert response.status_code == 200
response = requests.get(remote_server.url_for("/is_sleeping")) response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200 assert response.status_code == 200
assert response.json().get("is_sleeping") is True assert response.json().get("is_sleeping") is True
response = requests.post(remote_server.url_for("/wake_up")) response = requests.post(remote_server.url_for("wake_up"))
assert response.status_code == 200 assert response.status_code == 200
response = requests.get(remote_server.url_for("/is_sleeping")) response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is False
# test wake up with tags
response = requests.post(remote_server.url_for("sleep"),
params={"level": "1"})
assert response.status_code == 200
response = requests.post(remote_server.url_for("wake_up"),
params={"tags": ["weights"]})
assert response.status_code == 200
# is sleeping should be false after waking up any part of the engine
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is True
response = requests.post(remote_server.url_for("wake_up"),
params={"tags": ["kv_cache"]})
assert response.status_code == 200
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200 assert response.status_code == 200
assert response.json().get("is_sleeping") is False assert response.json().get("is_sleeping") is False

View File

@ -208,22 +208,28 @@ class CuMemAllocator:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None) -> None:
""" """
Wake up the allocator from sleep mode. Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory.""" memory, and the rest of the data will have empty memory.
:param tags: The tags of the memory allocation that will be loaded
back to GPU memory. If None, all memory allocation will be loaded
back to GPU memory.
"""
for ptr, data in self.pointer_to_data.items(): for ptr, data in self.pointer_to_data.items():
handle = data.handle if tags is None or data.tag in tags:
create_and_map(handle) handle = data.handle
if data.cpu_backup_tensor is not None: create_and_map(handle)
cpu_backup_tensor = data.cpu_backup_tensor if data.cpu_backup_tensor is not None:
if cpu_backup_tensor is not None: cpu_backup_tensor = data.cpu_backup_tensor
size_in_bytes = cpu_backup_tensor.numel( if cpu_backup_tensor is not None:
) * cpu_backup_tensor.element_size() size_in_bytes = cpu_backup_tensor.numel(
cpu_ptr = cpu_backup_tensor.data_ptr() ) * cpu_backup_tensor.element_size()
libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) cpu_ptr = cpu_backup_tensor.data_ptr()
data.cpu_backup_tensor = None libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
data.cpu_backup_tensor = None
@contextmanager @contextmanager
def use_memory_pool(self, tag: Optional[str] = None): def use_memory_pool(self, tag: Optional[str] = None):

View File

@ -1225,8 +1225,8 @@ class AsyncLLMEngine(EngineClient):
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
self.engine.sleep(level) self.engine.sleep(level)
async def wake_up(self) -> None: async def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine.wake_up() self.engine.wake_up(tags)
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
return self.engine.is_sleeping() return self.engine.is_sleeping()

View File

@ -1938,10 +1938,10 @@ class LLMEngine:
"Sleep mode is not enabled in the model config") "Sleep mode is not enabled in the model config")
self.model_executor.sleep(level=level) self.model_executor.sleep(level=level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
assert self.vllm_config.model_config.enable_sleep_mode, ( assert self.vllm_config.model_config.enable_sleep_mode, (
"Sleep mode is not enabled in the model config") "Sleep mode is not enabled in the model config")
self.model_executor.wake_up() self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping return self.model_executor.is_sleeping

View File

@ -133,8 +133,9 @@ class RPCSleepRequest(Enum):
SLEEP_LEVEL_2 = 2 SLEEP_LEVEL_2 = 2
class RPCWakeUpRequest(Enum): @dataclass
WAKE_UP = 1 class RPCWakeUpRequest:
tags: Optional[list[str]] = None
@dataclass @dataclass

View File

@ -697,10 +697,10 @@ class MQLLMEngineClient(EngineClient):
return await self._send_one_way_rpc_request( return await self._send_one_way_rpc_request(
request=RPCSleepRequest(level), socket=self.input_socket) request=RPCSleepRequest(level), socket=self.input_socket)
async def wake_up(self) -> None: async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine""" """Wake up the engine"""
return await self._send_one_way_rpc_request( return await self._send_one_way_rpc_request(
request=RPCWakeUpRequest.WAKE_UP, socket=self.input_socket) request=RPCWakeUpRequest(tags), socket=self.input_socket)
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
"""Check whether the engine is sleeping""" """Check whether the engine is sleeping"""

View File

@ -274,7 +274,7 @@ class MQLLMEngine:
elif isinstance(request, RPCSleepRequest): elif isinstance(request, RPCSleepRequest):
self.sleep(request.value) self.sleep(request.value)
elif isinstance(request, RPCWakeUpRequest): elif isinstance(request, RPCWakeUpRequest):
self.wake_up() self.wake_up(request.tags)
elif isinstance(request, RPCIsSleepingRequest): elif isinstance(request, RPCIsSleepingRequest):
self._handle_is_sleeping_request(request) self._handle_is_sleeping_request(request)
else: else:
@ -415,8 +415,8 @@ class MQLLMEngine:
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self.engine.sleep(level) self.engine.sleep(level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine.wake_up() self.engine.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine.is_sleeping() return self.engine.is_sleeping()

View File

@ -282,7 +282,7 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def wake_up(self) -> None: async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine""" """Wake up the engine"""
... ...

View File

@ -1200,26 +1200,35 @@ class LLM:
The caller should guarantee that no requests are being processed The caller should guarantee that no requests are being processed
during the sleep period, before `wake_up` is called. during the sleep period, before `wake_up` is called.
:param level: The sleep level. Level 1 sleep will offload the model Args:
weights and discard the kv cache. The content of kv cache is level: The sleep level. Level 1 sleep will offload the model
forgotten. Level 1 sleep is good for sleeping and waking up the weights and discard the kv cache. The content of kv cache
engine to run the same model again. The model weights are backed is forgotten. Level 1 sleep is good for sleeping and waking
up in CPU memory. Please make sure there's enough CPU memory to up the engine to run the same model again. The model weights
store the model weights. Level 2 sleep will discard both the model are backed up in CPU memory. Please make sure there's enough
weights and the kv cache. The content of both the model weights CPU memory to store the model weights. Level 2 sleep will
and kv cache is forgotten. Level 2 sleep is good for sleeping and discard both the model weights and the kv cache. The content
waking up the engine to run a different model or update the model, of both the model weights and kv cache is forgotten. Level 2
where previous model weights are not needed. It reduces CPU memory sleep is good for sleeping and waking up the engine to run a
pressure. different model or update the model, where previous model
weights are not needed. It reduces CPU memory pressure.
""" """
self.reset_prefix_cache() self.reset_prefix_cache()
self.llm_engine.sleep(level=level) self.llm_engine.sleep(level=level)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
""" """
Wake up the engine from sleep mode. See the :meth:`sleep` method Wake up the engine from sleep mode. See the :meth:`sleep` method
for more details.""" for more details.
self.llm_engine.wake_up()
Args:
tags: An optional list of tags to reallocate the engine memory
for specific memory allocations. Values must be in
("weights", "kv_cache",). If None, all memory is reallocated.
wake_up should be called with all tags (or None) before the
engine is used again.
"""
self.llm_engine.wake_up(tags)
# LEGACY # LEGACY
def _convert_v1_inputs( def _convert_v1_inputs(

View File

@ -705,7 +705,6 @@ if envs.VLLM_SERVER_DEV_MODE:
async def sleep(raw_request: Request): async def sleep(raw_request: Request):
# get POST params # get POST params
level = raw_request.query_params.get("level", "1") level = raw_request.query_params.get("level", "1")
logger.info("sleep the engine with level %s", level)
await engine_client(raw_request).sleep(int(level)) await engine_client(raw_request).sleep(int(level))
# FIXME: in v0 with frontend multiprocessing, the sleep command # FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response. # is sent but does not finish yet when we return a response.
@ -713,8 +712,12 @@ if envs.VLLM_SERVER_DEV_MODE:
@router.post("/wake_up") @router.post("/wake_up")
async def wake_up(raw_request: Request): async def wake_up(raw_request: Request):
logger.info("wake up the engine") tags = raw_request.query_params.getlist("tags")
await engine_client(raw_request).wake_up() if tags == []:
# set to None to wake up all tags if no tags are provided
tags = None
logger.info("wake up the engine with tags: %s", tags)
await engine_client(raw_request).wake_up(tags)
# FIXME: in v0 with frontend multiprocessing, the wake-up command # FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response. # is sent but does not finish yet when we return a response.
return Response(status_code=200) return Response(status_code=200)

View File

@ -51,6 +51,7 @@ class ExecutorBase(ABC):
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self._init_executor() self._init_executor()
self.is_sleeping = False self.is_sleeping = False
self.sleeping_tags: set[str] = set()
@abstractmethod @abstractmethod
def _init_executor(self) -> None: def _init_executor(self) -> None:
@ -204,20 +205,34 @@ class ExecutorBase(ABC):
time_before_sleep = time.perf_counter() time_before_sleep = time.perf_counter()
self.collective_rpc("sleep", kwargs=dict(level=level)) self.collective_rpc("sleep", kwargs=dict(level=level))
time_after_sleep = time.perf_counter() time_after_sleep = time.perf_counter()
self.sleeping_tags = {"weights", "kv_cache"}
self.is_sleeping = True self.is_sleeping = True
logger.info("It took %.6f seconds to fall asleep.", logger.info("It took %.6f seconds to fall asleep.",
time_after_sleep - time_before_sleep) time_after_sleep - time_before_sleep)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
if not self.is_sleeping: if not self.is_sleeping:
logger.warning("Executor is not sleeping.") logger.warning("Executor is not sleeping.")
return return
if tags:
for tag in tags:
if tag not in self.sleeping_tags:
logger.warning("Tag %s is not in sleeping tags %s", tag,
self.sleeping_tags)
return
time_before_wakeup = time.perf_counter() time_before_wakeup = time.perf_counter()
self.collective_rpc("wake_up") self.collective_rpc("wake_up", kwargs=dict(tags=tags))
time_after_wakeup = time.perf_counter() time_after_wakeup = time.perf_counter()
self.is_sleeping = False logger.info("It took %.6f seconds to wake up tags %s.",
logger.info("It took %.6f seconds to wake up.", time_after_wakeup - time_before_wakeup,
time_after_wakeup - time_before_wakeup) tags if tags is not None else self.sleeping_tags)
if tags:
for tag in tags:
self.sleeping_tags.remove(tag)
else:
self.sleeping_tags.clear()
if not self.sleeping_tags:
self.is_sleeping = False
def save_sharded_state( def save_sharded_state(
self, self,

View File

@ -424,8 +424,8 @@ class AsyncLLM(EngineClient):
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
await self.engine_core.sleep_async(level) await self.engine_core.sleep_async(level)
async def wake_up(self) -> None: async def wake_up(self, tags: Optional[list[str]] = None) -> None:
await self.engine_core.wake_up_async() await self.engine_core.wake_up_async(tags)
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
return await self.engine_core.is_sleeping_async() return await self.engine_core.is_sleeping_async()

View File

@ -264,8 +264,8 @@ class EngineCore:
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
self.model_executor.sleep(level) self.model_executor.sleep(level)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
self.model_executor.wake_up() self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping return self.model_executor.is_sleeping

View File

@ -92,7 +92,7 @@ class EngineCoreClient(ABC):
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
raise NotImplementedError raise NotImplementedError
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError raise NotImplementedError
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
@ -141,7 +141,7 @@ class EngineCoreClient(ABC):
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1) -> None:
raise NotImplementedError raise NotImplementedError
async def wake_up_async(self) -> None: async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError raise NotImplementedError
async def is_sleeping_async(self) -> bool: async def is_sleeping_async(self) -> bool:
@ -206,8 +206,8 @@ class InprocClient(EngineCoreClient):
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level) self.engine_core.sleep(level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine_core.wake_up() self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping() return self.engine_core.is_sleeping()
@ -520,8 +520,8 @@ class SyncMPClient(MPClient):
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self.call_utility("sleep", level) self.call_utility("sleep", level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.call_utility("wake_up") self.call_utility("wake_up", tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.call_utility("is_sleeping") return self.call_utility("is_sleeping")
@ -647,8 +647,8 @@ class AsyncMPClient(MPClient):
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1) -> None:
await self.call_utility_async("sleep", level) await self.call_utility_async("sleep", level)
async def wake_up_async(self) -> None: async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
await self.call_utility_async("wake_up") await self.call_utility_async("wake_up", tags)
async def is_sleeping_async(self) -> bool: async def is_sleeping_async(self) -> bool:
return await self.call_utility_async("is_sleeping") return await self.call_utility_async("is_sleeping")

View File

@ -245,8 +245,8 @@ class LLMEngine:
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
self.engine_core.sleep(level) self.engine_core.sleep(level)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
self.engine_core.wake_up() self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping() return self.engine_core.is_sleeping()

View File

@ -83,9 +83,9 @@ class Worker(WorkerBase):
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes) used_bytes / GiB_bytes)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
allocator.wake_up() allocator.wake_up(tags)
def init_device(self): def init_device(self):
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":

View File

@ -135,9 +135,9 @@ class Worker(LocalOrDistributedWorkerBase):
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes) used_bytes / GiB_bytes)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
allocator.wake_up() allocator.wake_up(tags=tags)
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":