From ddb94c26058cd7658e2435699ce353b3406235ac Mon Sep 17 00:00:00 2001 From: Eric Tang <46737979+erictang000@users.noreply.github.com> Date: Wed, 2 Apr 2025 01:59:27 -0700 Subject: [PATCH] [core] Add tags parameter to wake_up() (#15500) Signed-off-by: Eric --- tests/basic_correctness/test_cumem.py | 20 ++++++++++++- tests/entrypoints/openai/test_sleep.py | 31 +++++++++++++++++---- vllm/device_allocator/cumem.py | 32 ++++++++++++--------- vllm/engine/async_llm_engine.py | 4 +-- vllm/engine/llm_engine.py | 4 +-- vllm/engine/multiprocessing/__init__.py | 5 ++-- vllm/engine/multiprocessing/client.py | 4 +-- vllm/engine/multiprocessing/engine.py | 6 ++-- vllm/engine/protocol.py | 2 +- vllm/entrypoints/llm.py | 37 +++++++++++++++---------- vllm/entrypoints/openai/api_server.py | 9 ++++-- vllm/executor/executor_base.py | 25 +++++++++++++---- vllm/v1/engine/async_llm.py | 4 +-- vllm/v1/engine/core.py | 4 +-- vllm/v1/engine/core_client.py | 16 +++++------ vllm/v1/engine/llm_engine.py | 4 +-- vllm/v1/worker/gpu_worker.py | 4 +-- vllm/worker/worker.py | 4 +-- 18 files changed, 144 insertions(+), 71 deletions(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 31aa8982..76b266aa 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -155,6 +155,24 @@ def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool): llm.wake_up() output2 = llm.generate(prompt, sampling_params) - # cmp output assert output[0].outputs[0].text == output2[0].outputs[0].text + + llm.sleep(level=1) + llm.wake_up(tags=["weights"]) + + free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info() + used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline + + # should just reallocate memory for weights (1B model, ~2GiB weights) + if use_v1: + assert used_bytes < 10 * GiB_bytes + else: + assert used_bytes < 6 * GiB_bytes + + # now allocate kv cache memory + llm.wake_up(tags=["kv_cache"]) + output3 = llm.generate(prompt, sampling_params) + + # cmp output + assert output[0].outputs[0].text == output3[0].outputs[0].text diff --git a/tests/entrypoints/openai/test_sleep.py b/tests/entrypoints/openai/test_sleep.py index 66d8d929..3ca8a9a4 100644 --- a/tests/entrypoints/openai/test_sleep.py +++ b/tests/entrypoints/openai/test_sleep.py @@ -25,16 +25,37 @@ def test_sleep_mode(): "VLLM_SERVER_DEV_MODE": "1", "CUDA_VISIBLE_DEVICES": "0" }) as remote_server: - - response = requests.post(remote_server.url_for("/sleep"), + response = requests.post(remote_server.url_for("sleep"), params={"level": "1"}) assert response.status_code == 200 - response = requests.get(remote_server.url_for("/is_sleeping")) + response = requests.get(remote_server.url_for("is_sleeping")) assert response.status_code == 200 assert response.json().get("is_sleeping") is True - response = requests.post(remote_server.url_for("/wake_up")) + response = requests.post(remote_server.url_for("wake_up")) assert response.status_code == 200 - response = requests.get(remote_server.url_for("/is_sleeping")) + response = requests.get(remote_server.url_for("is_sleeping")) + assert response.status_code == 200 + assert response.json().get("is_sleeping") is False + + # test wake up with tags + response = requests.post(remote_server.url_for("sleep"), + params={"level": "1"}) + assert response.status_code == 200 + + response = requests.post(remote_server.url_for("wake_up"), + params={"tags": ["weights"]}) + assert response.status_code == 200 + + # is sleeping should be false after waking up any part of the engine + response = requests.get(remote_server.url_for("is_sleeping")) + assert response.status_code == 200 + assert response.json().get("is_sleeping") is True + + response = requests.post(remote_server.url_for("wake_up"), + params={"tags": ["kv_cache"]}) + assert response.status_code == 200 + + response = requests.get(remote_server.url_for("is_sleeping")) assert response.status_code == 200 assert response.json().get("is_sleeping") is False diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index f666c18c..9ff77f14 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -208,22 +208,28 @@ class CuMemAllocator: gc.collect() torch.cuda.empty_cache() - def wake_up(self): + def wake_up(self, tags: Optional[list[str]] = None) -> None: """ Wake up the allocator from sleep mode. - All data that is previously offloaded will be loaded back to GPU - memory, and the rest of the data will have empty memory.""" + All data that is previously offloaded will be loaded back to GPU + memory, and the rest of the data will have empty memory. + + :param tags: The tags of the memory allocation that will be loaded + back to GPU memory. If None, all memory allocation will be loaded + back to GPU memory. + """ for ptr, data in self.pointer_to_data.items(): - handle = data.handle - create_and_map(handle) - if data.cpu_backup_tensor is not None: - cpu_backup_tensor = data.cpu_backup_tensor - if cpu_backup_tensor is not None: - size_in_bytes = cpu_backup_tensor.numel( - ) * cpu_backup_tensor.element_size() - cpu_ptr = cpu_backup_tensor.data_ptr() - libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) - data.cpu_backup_tensor = None + if tags is None or data.tag in tags: + handle = data.handle + create_and_map(handle) + if data.cpu_backup_tensor is not None: + cpu_backup_tensor = data.cpu_backup_tensor + if cpu_backup_tensor is not None: + size_in_bytes = cpu_backup_tensor.numel( + ) * cpu_backup_tensor.element_size() + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) + data.cpu_backup_tensor = None @contextmanager def use_memory_pool(self, tag: Optional[str] = None): diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3e337731..7f9f85e1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1225,8 +1225,8 @@ class AsyncLLMEngine(EngineClient): async def sleep(self, level: int = 1) -> None: self.engine.sleep(level) - async def wake_up(self) -> None: - self.engine.wake_up() + async def wake_up(self, tags: Optional[list[str]] = None) -> None: + self.engine.wake_up(tags) async def is_sleeping(self) -> bool: return self.engine.is_sleeping() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 10677878..f842581b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1938,10 +1938,10 @@ class LLMEngine: "Sleep mode is not enabled in the model config") self.model_executor.sleep(level=level) - def wake_up(self) -> None: + def wake_up(self, tags: Optional[list[str]] = None) -> None: assert self.vllm_config.model_config.enable_sleep_mode, ( "Sleep mode is not enabled in the model config") - self.model_executor.wake_up() + self.model_executor.wake_up(tags) def is_sleeping(self) -> bool: return self.model_executor.is_sleeping diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index fdad5358..cafd8150 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -133,8 +133,9 @@ class RPCSleepRequest(Enum): SLEEP_LEVEL_2 = 2 -class RPCWakeUpRequest(Enum): - WAKE_UP = 1 +@dataclass +class RPCWakeUpRequest: + tags: Optional[list[str]] = None @dataclass diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index db91c5d3..f058b132 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -697,10 +697,10 @@ class MQLLMEngineClient(EngineClient): return await self._send_one_way_rpc_request( request=RPCSleepRequest(level), socket=self.input_socket) - async def wake_up(self) -> None: + async def wake_up(self, tags: Optional[list[str]] = None) -> None: """Wake up the engine""" return await self._send_one_way_rpc_request( - request=RPCWakeUpRequest.WAKE_UP, socket=self.input_socket) + request=RPCWakeUpRequest(tags), socket=self.input_socket) async def is_sleeping(self) -> bool: """Check whether the engine is sleeping""" diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 739cbedc..6ed5ae0a 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -274,7 +274,7 @@ class MQLLMEngine: elif isinstance(request, RPCSleepRequest): self.sleep(request.value) elif isinstance(request, RPCWakeUpRequest): - self.wake_up() + self.wake_up(request.tags) elif isinstance(request, RPCIsSleepingRequest): self._handle_is_sleeping_request(request) else: @@ -415,8 +415,8 @@ class MQLLMEngine: def sleep(self, level: int = 1) -> None: self.engine.sleep(level) - def wake_up(self) -> None: - self.engine.wake_up() + def wake_up(self, tags: Optional[list[str]] = None) -> None: + self.engine.wake_up(tags) def is_sleeping(self) -> bool: return self.engine.is_sleeping() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index d2f2c226..e2974b02 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -282,7 +282,7 @@ class EngineClient(ABC): ... @abstractmethod - async def wake_up(self) -> None: + async def wake_up(self, tags: Optional[list[str]] = None) -> None: """Wake up the engine""" ... diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 7c354be2..f39b011c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1200,26 +1200,35 @@ class LLM: The caller should guarantee that no requests are being processed during the sleep period, before `wake_up` is called. - :param level: The sleep level. Level 1 sleep will offload the model - weights and discard the kv cache. The content of kv cache is - forgotten. Level 1 sleep is good for sleeping and waking up the - engine to run the same model again. The model weights are backed - up in CPU memory. Please make sure there's enough CPU memory to - store the model weights. Level 2 sleep will discard both the model - weights and the kv cache. The content of both the model weights - and kv cache is forgotten. Level 2 sleep is good for sleeping and - waking up the engine to run a different model or update the model, - where previous model weights are not needed. It reduces CPU memory - pressure. + Args: + level: The sleep level. Level 1 sleep will offload the model + weights and discard the kv cache. The content of kv cache + is forgotten. Level 1 sleep is good for sleeping and waking + up the engine to run the same model again. The model weights + are backed up in CPU memory. Please make sure there's enough + CPU memory to store the model weights. Level 2 sleep will + discard both the model weights and the kv cache. The content + of both the model weights and kv cache is forgotten. Level 2 + sleep is good for sleeping and waking up the engine to run a + different model or update the model, where previous model + weights are not needed. It reduces CPU memory pressure. """ self.reset_prefix_cache() self.llm_engine.sleep(level=level) - def wake_up(self): + def wake_up(self, tags: Optional[list[str]] = None): """ Wake up the engine from sleep mode. See the :meth:`sleep` method - for more details.""" - self.llm_engine.wake_up() + for more details. + + Args: + tags: An optional list of tags to reallocate the engine memory + for specific memory allocations. Values must be in + ("weights", "kv_cache",). If None, all memory is reallocated. + wake_up should be called with all tags (or None) before the + engine is used again. + """ + self.llm_engine.wake_up(tags) # LEGACY def _convert_v1_inputs( diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1e7d9eb8..6a8bdd06 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -705,7 +705,6 @@ if envs.VLLM_SERVER_DEV_MODE: async def sleep(raw_request: Request): # get POST params level = raw_request.query_params.get("level", "1") - logger.info("sleep the engine with level %s", level) await engine_client(raw_request).sleep(int(level)) # FIXME: in v0 with frontend multiprocessing, the sleep command # is sent but does not finish yet when we return a response. @@ -713,8 +712,12 @@ if envs.VLLM_SERVER_DEV_MODE: @router.post("/wake_up") async def wake_up(raw_request: Request): - logger.info("wake up the engine") - await engine_client(raw_request).wake_up() + tags = raw_request.query_params.getlist("tags") + if tags == []: + # set to None to wake up all tags if no tags are provided + tags = None + logger.info("wake up the engine with tags: %s", tags) + await engine_client(raw_request).wake_up(tags) # FIXME: in v0 with frontend multiprocessing, the wake-up command # is sent but does not finish yet when we return a response. return Response(status_code=200) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 6f5adb4f..58796e5d 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -51,6 +51,7 @@ class ExecutorBase(ABC): self.observability_config = vllm_config.observability_config self._init_executor() self.is_sleeping = False + self.sleeping_tags: set[str] = set() @abstractmethod def _init_executor(self) -> None: @@ -204,20 +205,34 @@ class ExecutorBase(ABC): time_before_sleep = time.perf_counter() self.collective_rpc("sleep", kwargs=dict(level=level)) time_after_sleep = time.perf_counter() + self.sleeping_tags = {"weights", "kv_cache"} self.is_sleeping = True logger.info("It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep) - def wake_up(self): + def wake_up(self, tags: Optional[list[str]] = None): if not self.is_sleeping: logger.warning("Executor is not sleeping.") return + if tags: + for tag in tags: + if tag not in self.sleeping_tags: + logger.warning("Tag %s is not in sleeping tags %s", tag, + self.sleeping_tags) + return time_before_wakeup = time.perf_counter() - self.collective_rpc("wake_up") + self.collective_rpc("wake_up", kwargs=dict(tags=tags)) time_after_wakeup = time.perf_counter() - self.is_sleeping = False - logger.info("It took %.6f seconds to wake up.", - time_after_wakeup - time_before_wakeup) + logger.info("It took %.6f seconds to wake up tags %s.", + time_after_wakeup - time_before_wakeup, + tags if tags is not None else self.sleeping_tags) + if tags: + for tag in tags: + self.sleeping_tags.remove(tag) + else: + self.sleeping_tags.clear() + if not self.sleeping_tags: + self.is_sleeping = False def save_sharded_state( self, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a8d86e70..b77a6824 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -424,8 +424,8 @@ class AsyncLLM(EngineClient): async def sleep(self, level: int = 1) -> None: await self.engine_core.sleep_async(level) - async def wake_up(self) -> None: - await self.engine_core.wake_up_async() + async def wake_up(self, tags: Optional[list[str]] = None) -> None: + await self.engine_core.wake_up_async(tags) async def is_sleeping(self) -> bool: return await self.engine_core.is_sleeping_async() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d915d474..19c7799b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -264,8 +264,8 @@ class EngineCore: def sleep(self, level: int = 1): self.model_executor.sleep(level) - def wake_up(self): - self.model_executor.wake_up() + def wake_up(self, tags: Optional[list[str]] = None): + self.model_executor.wake_up(tags) def is_sleeping(self) -> bool: return self.model_executor.is_sleeping diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 3dc33a12..99774ff4 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -92,7 +92,7 @@ class EngineCoreClient(ABC): def sleep(self, level: int = 1) -> None: raise NotImplementedError - def wake_up(self) -> None: + def wake_up(self, tags: Optional[list[str]] = None) -> None: raise NotImplementedError def is_sleeping(self) -> bool: @@ -141,7 +141,7 @@ class EngineCoreClient(ABC): async def sleep_async(self, level: int = 1) -> None: raise NotImplementedError - async def wake_up_async(self) -> None: + async def wake_up_async(self, tags: Optional[list[str]] = None) -> None: raise NotImplementedError async def is_sleeping_async(self) -> bool: @@ -206,8 +206,8 @@ class InprocClient(EngineCoreClient): def sleep(self, level: int = 1) -> None: self.engine_core.sleep(level) - def wake_up(self) -> None: - self.engine_core.wake_up() + def wake_up(self, tags: Optional[list[str]] = None) -> None: + self.engine_core.wake_up(tags) def is_sleeping(self) -> bool: return self.engine_core.is_sleeping() @@ -520,8 +520,8 @@ class SyncMPClient(MPClient): def sleep(self, level: int = 1) -> None: self.call_utility("sleep", level) - def wake_up(self) -> None: - self.call_utility("wake_up") + def wake_up(self, tags: Optional[list[str]] = None) -> None: + self.call_utility("wake_up", tags) def is_sleeping(self) -> bool: return self.call_utility("is_sleeping") @@ -647,8 +647,8 @@ class AsyncMPClient(MPClient): async def sleep_async(self, level: int = 1) -> None: await self.call_utility_async("sleep", level) - async def wake_up_async(self) -> None: - await self.call_utility_async("wake_up") + async def wake_up_async(self, tags: Optional[list[str]] = None) -> None: + await self.call_utility_async("wake_up", tags) async def is_sleeping_async(self) -> bool: return await self.call_utility_async("is_sleeping") diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 764c643b..4c67186f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -245,8 +245,8 @@ class LLMEngine: def sleep(self, level: int = 1): self.engine_core.sleep(level) - def wake_up(self): - self.engine_core.wake_up() + def wake_up(self, tags: Optional[list[str]] = None): + self.engine_core.wake_up(tags) def is_sleeping(self) -> bool: return self.engine_core.is_sleeping() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 51b9f567..19144368 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -83,9 +83,9 @@ class Worker(WorkerBase): "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, used_bytes / GiB_bytes) - def wake_up(self) -> None: + def wake_up(self, tags: Optional[list[str]] = None) -> None: allocator = CuMemAllocator.get_instance() - allocator.wake_up() + allocator.wake_up(tags) def init_device(self): if self.device_config.device.type == "cuda": diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ad94a6a4..d59f20f4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -135,9 +135,9 @@ class Worker(LocalOrDistributedWorkerBase): "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, used_bytes / GiB_bytes) - def wake_up(self) -> None: + def wake_up(self, tags: Optional[list[str]] = None) -> None: allocator = CuMemAllocator.get_instance() - allocator.wake_up() + allocator.wake_up(tags=tags) def init_device(self) -> None: if self.device_config.device.type == "cuda":