[core] Add tags parameter to wake_up() (#15500)

Signed-off-by: Eric <erictang000@gmail.com>
This commit is contained in:
Eric Tang 2025-04-02 01:59:27 -07:00 committed by GitHub
parent 90969fb39a
commit ddb94c2605
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 144 additions and 71 deletions

View File

@ -155,6 +155,24 @@ def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
llm.wake_up()
output2 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text
llm.sleep(level=1)
llm.wake_up(tags=["weights"])
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
# should just reallocate memory for weights (1B model, ~2GiB weights)
if use_v1:
assert used_bytes < 10 * GiB_bytes
else:
assert used_bytes < 6 * GiB_bytes
# now allocate kv cache memory
llm.wake_up(tags=["kv_cache"])
output3 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output3[0].outputs[0].text

View File

@ -25,16 +25,37 @@ def test_sleep_mode():
"VLLM_SERVER_DEV_MODE": "1",
"CUDA_VISIBLE_DEVICES": "0"
}) as remote_server:
response = requests.post(remote_server.url_for("/sleep"),
response = requests.post(remote_server.url_for("sleep"),
params={"level": "1"})
assert response.status_code == 200
response = requests.get(remote_server.url_for("/is_sleeping"))
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is True
response = requests.post(remote_server.url_for("/wake_up"))
response = requests.post(remote_server.url_for("wake_up"))
assert response.status_code == 200
response = requests.get(remote_server.url_for("/is_sleeping"))
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is False
# test wake up with tags
response = requests.post(remote_server.url_for("sleep"),
params={"level": "1"})
assert response.status_code == 200
response = requests.post(remote_server.url_for("wake_up"),
params={"tags": ["weights"]})
assert response.status_code == 200
# is sleeping should be false after waking up any part of the engine
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is True
response = requests.post(remote_server.url_for("wake_up"),
params={"tags": ["kv_cache"]})
assert response.status_code == 200
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is False

View File

@ -208,22 +208,28 @@ class CuMemAllocator:
gc.collect()
torch.cuda.empty_cache()
def wake_up(self):
def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""
Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory."""
memory, and the rest of the data will have empty memory.
:param tags: The tags of the memory allocation that will be loaded
back to GPU memory. If None, all memory allocation will be loaded
back to GPU memory.
"""
for ptr, data in self.pointer_to_data.items():
handle = data.handle
create_and_map(handle)
if data.cpu_backup_tensor is not None:
cpu_backup_tensor = data.cpu_backup_tensor
if cpu_backup_tensor is not None:
size_in_bytes = cpu_backup_tensor.numel(
) * cpu_backup_tensor.element_size()
cpu_ptr = cpu_backup_tensor.data_ptr()
libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
data.cpu_backup_tensor = None
if tags is None or data.tag in tags:
handle = data.handle
create_and_map(handle)
if data.cpu_backup_tensor is not None:
cpu_backup_tensor = data.cpu_backup_tensor
if cpu_backup_tensor is not None:
size_in_bytes = cpu_backup_tensor.numel(
) * cpu_backup_tensor.element_size()
cpu_ptr = cpu_backup_tensor.data_ptr()
libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
data.cpu_backup_tensor = None
@contextmanager
def use_memory_pool(self, tag: Optional[str] = None):

View File

@ -1225,8 +1225,8 @@ class AsyncLLMEngine(EngineClient):
async def sleep(self, level: int = 1) -> None:
self.engine.sleep(level)
async def wake_up(self) -> None:
self.engine.wake_up()
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine.wake_up(tags)
async def is_sleeping(self) -> bool:
return self.engine.is_sleeping()

View File

@ -1938,10 +1938,10 @@ class LLMEngine:
"Sleep mode is not enabled in the model config")
self.model_executor.sleep(level=level)
def wake_up(self) -> None:
def wake_up(self, tags: Optional[list[str]] = None) -> None:
assert self.vllm_config.model_config.enable_sleep_mode, (
"Sleep mode is not enabled in the model config")
self.model_executor.wake_up()
self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping

View File

@ -133,8 +133,9 @@ class RPCSleepRequest(Enum):
SLEEP_LEVEL_2 = 2
class RPCWakeUpRequest(Enum):
WAKE_UP = 1
@dataclass
class RPCWakeUpRequest:
tags: Optional[list[str]] = None
@dataclass

View File

@ -697,10 +697,10 @@ class MQLLMEngineClient(EngineClient):
return await self._send_one_way_rpc_request(
request=RPCSleepRequest(level), socket=self.input_socket)
async def wake_up(self) -> None:
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine"""
return await self._send_one_way_rpc_request(
request=RPCWakeUpRequest.WAKE_UP, socket=self.input_socket)
request=RPCWakeUpRequest(tags), socket=self.input_socket)
async def is_sleeping(self) -> bool:
"""Check whether the engine is sleeping"""

View File

@ -274,7 +274,7 @@ class MQLLMEngine:
elif isinstance(request, RPCSleepRequest):
self.sleep(request.value)
elif isinstance(request, RPCWakeUpRequest):
self.wake_up()
self.wake_up(request.tags)
elif isinstance(request, RPCIsSleepingRequest):
self._handle_is_sleeping_request(request)
else:
@ -415,8 +415,8 @@ class MQLLMEngine:
def sleep(self, level: int = 1) -> None:
self.engine.sleep(level)
def wake_up(self) -> None:
self.engine.wake_up()
def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine.is_sleeping()

View File

@ -282,7 +282,7 @@ class EngineClient(ABC):
...
@abstractmethod
async def wake_up(self) -> None:
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine"""
...

View File

@ -1200,26 +1200,35 @@ class LLM:
The caller should guarantee that no requests are being processed
during the sleep period, before `wake_up` is called.
:param level: The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. The content of kv cache is
forgotten. Level 1 sleep is good for sleeping and waking up the
engine to run the same model again. The model weights are backed
up in CPU memory. Please make sure there's enough CPU memory to
store the model weights. Level 2 sleep will discard both the model
weights and the kv cache. The content of both the model weights
and kv cache is forgotten. Level 2 sleep is good for sleeping and
waking up the engine to run a different model or update the model,
where previous model weights are not needed. It reduces CPU memory
pressure.
Args:
level: The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. The content of kv cache
is forgotten. Level 1 sleep is good for sleeping and waking
up the engine to run the same model again. The model weights
are backed up in CPU memory. Please make sure there's enough
CPU memory to store the model weights. Level 2 sleep will
discard both the model weights and the kv cache. The content
of both the model weights and kv cache is forgotten. Level 2
sleep is good for sleeping and waking up the engine to run a
different model or update the model, where previous model
weights are not needed. It reduces CPU memory pressure.
"""
self.reset_prefix_cache()
self.llm_engine.sleep(level=level)
def wake_up(self):
def wake_up(self, tags: Optional[list[str]] = None):
"""
Wake up the engine from sleep mode. See the :meth:`sleep` method
for more details."""
self.llm_engine.wake_up()
for more details.
Args:
tags: An optional list of tags to reallocate the engine memory
for specific memory allocations. Values must be in
("weights", "kv_cache",). If None, all memory is reallocated.
wake_up should be called with all tags (or None) before the
engine is used again.
"""
self.llm_engine.wake_up(tags)
# LEGACY
def _convert_v1_inputs(

View File

@ -705,7 +705,6 @@ if envs.VLLM_SERVER_DEV_MODE:
async def sleep(raw_request: Request):
# get POST params
level = raw_request.query_params.get("level", "1")
logger.info("sleep the engine with level %s", level)
await engine_client(raw_request).sleep(int(level))
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
@ -713,8 +712,12 @@ if envs.VLLM_SERVER_DEV_MODE:
@router.post("/wake_up")
async def wake_up(raw_request: Request):
logger.info("wake up the engine")
await engine_client(raw_request).wake_up()
tags = raw_request.query_params.getlist("tags")
if tags == []:
# set to None to wake up all tags if no tags are provided
tags = None
logger.info("wake up the engine with tags: %s", tags)
await engine_client(raw_request).wake_up(tags)
# FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)

View File

@ -51,6 +51,7 @@ class ExecutorBase(ABC):
self.observability_config = vllm_config.observability_config
self._init_executor()
self.is_sleeping = False
self.sleeping_tags: set[str] = set()
@abstractmethod
def _init_executor(self) -> None:
@ -204,20 +205,34 @@ class ExecutorBase(ABC):
time_before_sleep = time.perf_counter()
self.collective_rpc("sleep", kwargs=dict(level=level))
time_after_sleep = time.perf_counter()
self.sleeping_tags = {"weights", "kv_cache"}
self.is_sleeping = True
logger.info("It took %.6f seconds to fall asleep.",
time_after_sleep - time_before_sleep)
def wake_up(self):
def wake_up(self, tags: Optional[list[str]] = None):
if not self.is_sleeping:
logger.warning("Executor is not sleeping.")
return
if tags:
for tag in tags:
if tag not in self.sleeping_tags:
logger.warning("Tag %s is not in sleeping tags %s", tag,
self.sleeping_tags)
return
time_before_wakeup = time.perf_counter()
self.collective_rpc("wake_up")
self.collective_rpc("wake_up", kwargs=dict(tags=tags))
time_after_wakeup = time.perf_counter()
self.is_sleeping = False
logger.info("It took %.6f seconds to wake up.",
time_after_wakeup - time_before_wakeup)
logger.info("It took %.6f seconds to wake up tags %s.",
time_after_wakeup - time_before_wakeup,
tags if tags is not None else self.sleeping_tags)
if tags:
for tag in tags:
self.sleeping_tags.remove(tag)
else:
self.sleeping_tags.clear()
if not self.sleeping_tags:
self.is_sleeping = False
def save_sharded_state(
self,

View File

@ -424,8 +424,8 @@ class AsyncLLM(EngineClient):
async def sleep(self, level: int = 1) -> None:
await self.engine_core.sleep_async(level)
async def wake_up(self) -> None:
await self.engine_core.wake_up_async()
async def wake_up(self, tags: Optional[list[str]] = None) -> None:
await self.engine_core.wake_up_async(tags)
async def is_sleeping(self) -> bool:
return await self.engine_core.is_sleeping_async()

View File

@ -264,8 +264,8 @@ class EngineCore:
def sleep(self, level: int = 1):
self.model_executor.sleep(level)
def wake_up(self):
self.model_executor.wake_up()
def wake_up(self, tags: Optional[list[str]] = None):
self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping

View File

@ -92,7 +92,7 @@ class EngineCoreClient(ABC):
def sleep(self, level: int = 1) -> None:
raise NotImplementedError
def wake_up(self) -> None:
def wake_up(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError
def is_sleeping(self) -> bool:
@ -141,7 +141,7 @@ class EngineCoreClient(ABC):
async def sleep_async(self, level: int = 1) -> None:
raise NotImplementedError
async def wake_up_async(self) -> None:
async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError
async def is_sleeping_async(self) -> bool:
@ -206,8 +206,8 @@ class InprocClient(EngineCoreClient):
def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level)
def wake_up(self) -> None:
self.engine_core.wake_up()
def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()
@ -520,8 +520,8 @@ class SyncMPClient(MPClient):
def sleep(self, level: int = 1) -> None:
self.call_utility("sleep", level)
def wake_up(self) -> None:
self.call_utility("wake_up")
def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.call_utility("wake_up", tags)
def is_sleeping(self) -> bool:
return self.call_utility("is_sleeping")
@ -647,8 +647,8 @@ class AsyncMPClient(MPClient):
async def sleep_async(self, level: int = 1) -> None:
await self.call_utility_async("sleep", level)
async def wake_up_async(self) -> None:
await self.call_utility_async("wake_up")
async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
await self.call_utility_async("wake_up", tags)
async def is_sleeping_async(self) -> bool:
return await self.call_utility_async("is_sleeping")

View File

@ -245,8 +245,8 @@ class LLMEngine:
def sleep(self, level: int = 1):
self.engine_core.sleep(level)
def wake_up(self):
self.engine_core.wake_up()
def wake_up(self, tags: Optional[list[str]] = None):
self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()

View File

@ -83,9 +83,9 @@ class Worker(WorkerBase):
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def wake_up(self) -> None:
def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CuMemAllocator.get_instance()
allocator.wake_up()
allocator.wake_up(tags)
def init_device(self):
if self.device_config.device.type == "cuda":

View File

@ -135,9 +135,9 @@ class Worker(LocalOrDistributedWorkerBase):
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def wake_up(self) -> None:
def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CuMemAllocator.get_instance()
allocator.wake_up()
allocator.wake_up(tags=tags)
def init_device(self) -> None:
if self.device_config.device.type == "cuda":