[V1][Core] Fix memory issue with logits & sampling (#14508)

Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Varun Sundar Rabindranath <3337719+varun-sundar-rabindranath@users.noreply.github.com>
This commit is contained in:
Roger Wang 2025-03-10 21:03:41 -07:00 committed by GitHub
parent c982ac5722
commit 1fc973c0b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 139 additions and 91 deletions

View File

@ -142,7 +142,16 @@ def test_end_to_end(model: str, use_v1: bool):
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
# now the memory usage is mostly cudagraph memory pool,
# and it should be less than the model weights (1B model, 2GiB weights)
assert used_bytes < 2 * GiB_bytes
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
# is captured but cannot be releasesd from PyTorch due to a known bug,
# therefore high memory usage after `llm.sleep` is called is expected.
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
# in V1.
if use_v1:
assert used_bytes < 7 * GiB_bytes
else:
assert used_bytes < 2 * GiB_bytes
llm.wake_up()
output2 = llm.generate(prompt, sampling_params)

View File

@ -3525,6 +3525,11 @@ class VllmConfig:
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]
max_num_tokens = self.scheduler_config.max_num_batched_tokens
batch_size_capture_list = [
size for size in batch_size_capture_list
if size <= max_num_tokens
]
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)

View File

@ -1202,41 +1202,98 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
num_tokens: int,
) -> torch.Tensor:
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
hidden_states = model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
hidden_states = model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
return hidden_states[logit_indices]
@torch.inference_mode()
def _dummy_sampler_run(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
logits = self.model.compute_logits(hidden_states, None)
num_reqs = logits.size(0)
dummy_tensors = lambda v: torch.full(
(num_reqs, ), v, device=self.device)
dummy_metadata = SamplingMetadata(
temperature=dummy_tensors(0.5),
all_greedy=False,
all_random=False,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=None,
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
sampler_output = self.model.sample(logits=logits,
sampling_metadata=dummy_metadata)
return sampler_output
def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
@ -1332,60 +1389,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# For profile, have maximum num_reqs and that collectively have
# maximum num_tokens.
num_reqs = self.scheduler_config.max_num_seqs
num_tokens = self.max_num_tokens
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
with self.maybe_profile_with_lora(self.lora_config,
num_scheduled_tokens):
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
dummy_tensors = lambda v: torch.full(
(num_reqs, ), v, device=self.device)
dummy_metadata = SamplingMetadata(
temperature=dummy_tensors(0.5),
all_greedy=False,
all_random=False,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=torch.ones_like(logits,
dtype=torch.int64),
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
sampler_output = self.model.sample(
logits=logits, sampling_metadata=dummy_metadata)
else:
logits = None
sampler_output = None
dummy_metadata = None
torch.cuda.synchronize()
del hidden_states, logits, sampler_output, dummy_metadata
self.encoder_cache.clear()
hidden_states = self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank:
sampler_output = self._dummy_sampler_run(hidden_states)
else:
sampler_output = None
torch.cuda.synchronize()
del hidden_states, sampler_output
self.encoder_cache.clear()
gc.collect()
def capture_model(self) -> None:

View File

@ -119,6 +119,8 @@ class Worker(WorkerBase):
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
@ -211,6 +213,27 @@ class Worker(WorkerBase):
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
try:
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
self.model_runner._dummy_sampler_run(
hidden_states=self.model_runner._dummy_run(
num_tokens=max_num_reqs))
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up sampler. "
"Please try lowering `gpu_memory_utilization` when "
"initializing the engine.") from None
else:
raise e
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)

View File

@ -83,8 +83,8 @@ class LoRAModelRunnerMixin:
lora_requests)
@contextmanager
def maybe_profile_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
if lora_config is None:
yield
else:
@ -145,4 +145,4 @@ class LoRAModelRunnerMixin:
def list_loras(self) -> set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()
return self.lora_manager.list_adapters()