From 8d5aa466fbff258159ecb0ac85134f5356b5344c Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sat, 8 Mar 2025 06:11:04 -0800 Subject: [PATCH] [V1][Core] Fix memory issue with logits & sampling (#13776) Signed-off-by: Roger Wang --- tests/basic_correctness/test_cumem.py | 11 ++++- vllm/v1/worker/gpu_model_runner.py | 66 +++++++++++++++------------ vllm/v1/worker/gpu_worker.py | 21 +++++++++ 3 files changed, 69 insertions(+), 29 deletions(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 61c79a7b..ba81f2bb 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -142,7 +142,16 @@ def test_end_to_end(model: str, use_v1: bool): used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline # now the memory usage is mostly cudagraph memory pool, # and it should be less than the model weights (1B model, 2GiB weights) - assert used_bytes < 2 * GiB_bytes + + # NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size) + # is captured but cannot be releasesd from PyTorch due to a known bug, + # therefore high memory usage after `llm.sleep` is called is expected. + # FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode + # in V1. + if use_v1: + assert used_bytes < 7 * GiB_bytes + else: + assert used_bytes < 2 * GiB_bytes llm.wake_up() output2 = llm.generate(prompt, sampling_params) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0cdf8f1a..81dec429 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1238,6 +1238,42 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) return hidden_states + @torch.inference_mode() + def _dummy_sampler_run( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + logits = self.model.compute_logits(hidden_states, None) + num_reqs = logits.size(0) + + dummy_tensors = lambda v: torch.full( + (num_reqs, ), v, device=self.device) + + dummy_metadata = SamplingMetadata( + temperature=dummy_tensors(0.5), + all_greedy=False, + all_random=False, + top_p=dummy_tensors(0.9), + top_k=dummy_tensors(logits.size(1) - 1), + min_p=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=dummy_tensors(0.1), + presence_penalties=dummy_tensors(0.1), + repetition_penalties=dummy_tensors(0.1), + output_token_ids=[[] for _ in range(num_reqs)], + min_tokens={}, + logit_bias=[None for _ in range(num_reqs)], + allowed_token_ids_mask=None, + ) + sampler_output = self.model.sample(logits=logits, + sampling_metadata=dummy_metadata) + + return sampler_output + def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. # TODO: handle encoder-decoder models once we support them. @@ -1353,37 +1389,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): hidden_states = self._dummy_run(self.max_num_tokens) if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] - logits = self.model.compute_logits(hidden_states, None) - dummy_tensors = lambda v: torch.full( - (num_reqs, ), v, device=self.device) - dummy_metadata = SamplingMetadata( - temperature=dummy_tensors(0.5), - all_greedy=False, - all_random=False, - top_p=dummy_tensors(0.9), - top_k=dummy_tensors(logits.size(1) - 1), - min_p=None, - generators={}, - max_num_logprobs=None, - no_penalties=True, - prompt_token_ids=torch.ones_like(logits, - dtype=torch.int64), - frequency_penalties=dummy_tensors(0.1), - presence_penalties=dummy_tensors(0.1), - repetition_penalties=dummy_tensors(0.1), - output_token_ids=[[] for _ in range(num_reqs)], - min_tokens={}, - logit_bias=[None for _ in range(num_reqs)], - allowed_token_ids_mask=None, - ) - sampler_output = self.model.sample( - logits=logits, sampling_metadata=dummy_metadata) + sampler_output = self._dummy_sampler_run(hidden_states) else: - logits = None sampler_output = None - dummy_metadata = None torch.cuda.synchronize() - del hidden_states, logits, sampler_output, dummy_metadata + del hidden_states, sampler_output self.encoder_cache.clear() gc.collect() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index cc6268d6..01025732 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -119,6 +119,8 @@ class Worker(WorkerBase): self.model_runner: GPUModelRunner = GPUModelRunner( self.vllm_config, self.device) + # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool + # to hijack tensor allocation. def load_model(self) -> None: if self.vllm_config.model_config.enable_sleep_mode: allocator = CuMemAllocator.get_instance() @@ -211,6 +213,25 @@ class Worker(WorkerBase): self.model_runner._dummy_run(size) if not self.model_config.enforce_eager: self.model_runner.capture_model() + + # Warm up sampler and preallocate memory buffer for logits and other + # sampling related tensors of max possible shape to avoid memory + # fragmentation issue. + # NOTE: This is called after `capture_model` on purpose to prevent + # memory buffers from being cleared by `torch.cuda.empty_cache`. + try: + self.model_runner._dummy_sampler_run( + hidden_states=self.model_runner._dummy_run( + num_tokens=self.scheduler_config.max_num_seqs)) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up sampler. " + "Please try lowering `gpu_memory_utilization` when " + "initializing the engine.") from None + else: + raise e + # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed)