[V1][Core] Fix memory issue with logits & sampling (#13776)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang 2025-03-08 06:11:04 -08:00 committed by GitHub
parent 0b7f06b447
commit 8d5aa466fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 69 additions and 29 deletions

View File

@ -142,6 +142,15 @@ def test_end_to_end(model: str, use_v1: bool):
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
# now the memory usage is mostly cudagraph memory pool,
# and it should be less than the model weights (1B model, 2GiB weights)
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
# is captured but cannot be releasesd from PyTorch due to a known bug,
# therefore high memory usage after `llm.sleep` is called is expected.
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
# in V1.
if use_v1:
assert used_bytes < 7 * GiB_bytes
else:
assert used_bytes < 2 * GiB_bytes
llm.wake_up()

View File

@ -1238,6 +1238,42 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return hidden_states
@torch.inference_mode()
def _dummy_sampler_run(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
logits = self.model.compute_logits(hidden_states, None)
num_reqs = logits.size(0)
dummy_tensors = lambda v: torch.full(
(num_reqs, ), v, device=self.device)
dummy_metadata = SamplingMetadata(
temperature=dummy_tensors(0.5),
all_greedy=False,
all_random=False,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=None,
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
)
sampler_output = self.model.sample(logits=logits,
sampling_metadata=dummy_metadata)
return sampler_output
def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them.
@ -1353,37 +1389,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states = self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
dummy_tensors = lambda v: torch.full(
(num_reqs, ), v, device=self.device)
dummy_metadata = SamplingMetadata(
temperature=dummy_tensors(0.5),
all_greedy=False,
all_random=False,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=torch.ones_like(logits,
dtype=torch.int64),
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
)
sampler_output = self.model.sample(
logits=logits, sampling_metadata=dummy_metadata)
sampler_output = self._dummy_sampler_run(hidden_states)
else:
logits = None
sampler_output = None
dummy_metadata = None
torch.cuda.synchronize()
del hidden_states, logits, sampler_output, dummy_metadata
del hidden_states, sampler_output
self.encoder_cache.clear()
gc.collect()

View File

@ -119,6 +119,8 @@ class Worker(WorkerBase):
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
@ -211,6 +213,25 @@ class Worker(WorkerBase):
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
try:
self.model_runner._dummy_sampler_run(
hidden_states=self.model_runner._dummy_run(
num_tokens=self.scheduler_config.max_num_seqs))
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up sampler. "
"Please try lowering `gpu_memory_utilization` when "
"initializing the engine.") from None
else:
raise e
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)