From 1fc973c0b53a151bba062576cc3ff489cc1ceeec Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 10 Mar 2025 21:03:41 -0700 Subject: [PATCH] [V1][Core] Fix memory issue with logits & sampling (#14508) Signed-off-by: Roger Wang Co-authored-by: Varun Sundar Rabindranath <3337719+varun-sundar-rabindranath@users.noreply.github.com> --- tests/basic_correctness/test_cumem.py | 11 +- vllm/config.py | 5 + vllm/v1/worker/gpu_model_runner.py | 185 ++++++++++++---------- vllm/v1/worker/gpu_worker.py | 23 +++ vllm/v1/worker/lora_model_runner_mixin.py | 6 +- 5 files changed, 139 insertions(+), 91 deletions(-) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 61c79a7b..ba81f2bb 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -142,7 +142,16 @@ def test_end_to_end(model: str, use_v1: bool): used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline # now the memory usage is mostly cudagraph memory pool, # and it should be less than the model weights (1B model, 2GiB weights) - assert used_bytes < 2 * GiB_bytes + + # NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size) + # is captured but cannot be releasesd from PyTorch due to a known bug, + # therefore high memory usage after `llm.sleep` is called is expected. + # FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode + # in V1. + if use_v1: + assert used_bytes < 7 * GiB_bytes + else: + assert used_bytes < 2 * GiB_bytes llm.wake_up() output2 = llm.generate(prompt, sampling_params) diff --git a/vllm/config.py b/vllm/config.py index a6ac9f43..26c02563 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3525,6 +3525,11 @@ class VllmConfig: not self.model_config.enforce_eager: batch_size_capture_list = [1, 2, 4 ] + [i for i in range(8, 513, 8)] + max_num_tokens = self.scheduler_config.max_num_batched_tokens + batch_size_capture_list = [ + size for size in batch_size_capture_list + if size <= max_num_tokens + ] self.compilation_config.init_with_cudagraph_sizes( batch_size_capture_list) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 57b05908..73279288 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1202,41 +1202,98 @@ class GPUModelRunner(LoRAModelRunnerMixin): self, num_tokens: int, ) -> torch.Tensor: - model = self.model - if self.is_multimodal_model: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None - if self.uses_mrope: - positions = self.mrope_positions[:, :num_tokens] - else: - positions = self.positions[:num_tokens] - if get_pp_group().is_first_rank: - intermediate_tensors = None - else: - if self.intermediate_tensors is None: - self.intermediate_tensors = ( - self.model.make_empty_intermediate_tensors( - batch_size=self.max_num_tokens, - dtype=self.model_config.dtype, - device=self.device)) - intermediate_tensors = IntermediateTensors({ - k: v[:num_tokens] - for k, v in self.intermediate_tensors.items() - }) + # Set num_scheduled_tokens based on num_tokens and max_num_seqs + # for dummy run with LoRA so that the num_reqs collectively + # has num_tokens in total. + assert num_tokens <= self.scheduler_config.max_num_batched_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens = np.array(num_scheduled_tokens_list, + dtype=np.int32) - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): - hidden_states = model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - return hidden_states + with self.maybe_dummy_run_with_lora(self.lora_config, + num_scheduled_tokens): + model = self.model + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + intermediate_tensors = IntermediateTensors({ + k: v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + + with set_forward_context(None, + self.vllm_config, + num_tokens=num_tokens): + hidden_states = model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + return hidden_states[logit_indices] + + @torch.inference_mode() + def _dummy_sampler_run( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + logits = self.model.compute_logits(hidden_states, None) + num_reqs = logits.size(0) + + dummy_tensors = lambda v: torch.full( + (num_reqs, ), v, device=self.device) + + dummy_metadata = SamplingMetadata( + temperature=dummy_tensors(0.5), + all_greedy=False, + all_random=False, + top_p=dummy_tensors(0.9), + top_k=dummy_tensors(logits.size(1) - 1), + min_p=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=dummy_tensors(0.1), + presence_penalties=dummy_tensors(0.1), + repetition_penalties=dummy_tensors(0.1), + output_token_ids=[[] for _ in range(num_reqs)], + min_tokens={}, + logit_bias=[None for _ in range(num_reqs)], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + ) + sampler_output = self.model.sample(logits=logits, + sampling_metadata=dummy_metadata) + + return sampler_output def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. @@ -1332,60 +1389,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - # For profile, have maximum num_reqs and that collectively have - # maximum num_tokens. - num_reqs = self.scheduler_config.max_num_seqs - num_tokens = self.max_num_tokens - min_tokens_per_req = num_tokens // num_reqs - - num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs - num_scheduled_tokens_list[-1] += num_tokens % num_reqs - assert sum(num_scheduled_tokens_list) == num_tokens - assert len(num_scheduled_tokens_list) == num_reqs - - num_scheduled_tokens = np.array(num_scheduled_tokens_list, - dtype=np.int32) - logit_indices = np.cumsum(num_scheduled_tokens) - 1 - - with self.maybe_profile_with_lora(self.lora_config, - num_scheduled_tokens): - # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.max_num_tokens) - if get_pp_group().is_last_rank: - hidden_states = hidden_states[logit_indices] - logits = self.model.compute_logits(hidden_states, None) - dummy_tensors = lambda v: torch.full( - (num_reqs, ), v, device=self.device) - dummy_metadata = SamplingMetadata( - temperature=dummy_tensors(0.5), - all_greedy=False, - all_random=False, - top_p=dummy_tensors(0.9), - top_k=dummy_tensors(logits.size(1) - 1), - min_p=None, - generators={}, - max_num_logprobs=None, - no_penalties=True, - prompt_token_ids=torch.ones_like(logits, - dtype=torch.int64), - frequency_penalties=dummy_tensors(0.1), - presence_penalties=dummy_tensors(0.1), - repetition_penalties=dummy_tensors(0.1), - output_token_ids=[[] for _ in range(num_reqs)], - min_tokens={}, - logit_bias=[None for _ in range(num_reqs)], - allowed_token_ids_mask=None, - bad_words_token_ids={}, - ) - sampler_output = self.model.sample( - logits=logits, sampling_metadata=dummy_metadata) - else: - logits = None - sampler_output = None - dummy_metadata = None - torch.cuda.synchronize() - del hidden_states, logits, sampler_output, dummy_metadata - self.encoder_cache.clear() + hidden_states = self._dummy_run(self.max_num_tokens) + if get_pp_group().is_last_rank: + sampler_output = self._dummy_sampler_run(hidden_states) + else: + sampler_output = None + torch.cuda.synchronize() + del hidden_states, sampler_output + self.encoder_cache.clear() gc.collect() def capture_model(self) -> None: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index cc6268d6..040a27de 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -119,6 +119,8 @@ class Worker(WorkerBase): self.model_runner: GPUModelRunner = GPUModelRunner( self.vllm_config, self.device) + # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool + # to hijack tensor allocation. def load_model(self) -> None: if self.vllm_config.model_config.enable_sleep_mode: allocator = CuMemAllocator.get_instance() @@ -211,6 +213,27 @@ class Worker(WorkerBase): self.model_runner._dummy_run(size) if not self.model_config.enforce_eager: self.model_runner.capture_model() + + # Warm up sampler and preallocate memory buffer for logits and other + # sampling related tensors of max possible shape to avoid memory + # fragmentation issue. + # NOTE: This is called after `capture_model` on purpose to prevent + # memory buffers from being cleared by `torch.cuda.empty_cache`. + try: + max_num_reqs = min(self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens) + self.model_runner._dummy_sampler_run( + hidden_states=self.model_runner._dummy_run( + num_tokens=max_num_reqs)) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up sampler. " + "Please try lowering `gpu_memory_utilization` when " + "initializing the engine.") from None + else: + raise e + # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 0b30a467..2814f0fd 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -83,8 +83,8 @@ class LoRAModelRunnerMixin: lora_requests) @contextmanager - def maybe_profile_with_lora(self, lora_config: LoRAConfig, - num_scheduled_tokens: np.ndarray): + def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray): if lora_config is None: yield else: @@ -145,4 +145,4 @@ class LoRAModelRunnerMixin: def list_loras(self) -> set[int]: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_adapters() \ No newline at end of file + return self.lora_manager.list_adapters()