[V1][Core] Fix memory issue with logits & sampling (#13776)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
0b7f06b447
commit
8d5aa466fb
@ -142,7 +142,16 @@ def test_end_to_end(model: str, use_v1: bool):
|
||||
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
|
||||
# now the memory usage is mostly cudagraph memory pool,
|
||||
# and it should be less than the model weights (1B model, 2GiB weights)
|
||||
assert used_bytes < 2 * GiB_bytes
|
||||
|
||||
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
|
||||
# is captured but cannot be releasesd from PyTorch due to a known bug,
|
||||
# therefore high memory usage after `llm.sleep` is called is expected.
|
||||
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
|
||||
# in V1.
|
||||
if use_v1:
|
||||
assert used_bytes < 7 * GiB_bytes
|
||||
else:
|
||||
assert used_bytes < 2 * GiB_bytes
|
||||
|
||||
llm.wake_up()
|
||||
output2 = llm.generate(prompt, sampling_params)
|
||||
|
@ -1238,6 +1238,42 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_sampler_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
num_reqs = logits.size(0)
|
||||
|
||||
dummy_tensors = lambda v: torch.full(
|
||||
(num_reqs, ), v, device=self.device)
|
||||
|
||||
dummy_metadata = SamplingMetadata(
|
||||
temperature=dummy_tensors(0.5),
|
||||
all_greedy=False,
|
||||
all_random=False,
|
||||
top_p=dummy_tensors(0.9),
|
||||
top_k=dummy_tensors(logits.size(1) - 1),
|
||||
min_p=None,
|
||||
generators={},
|
||||
max_num_logprobs=None,
|
||||
no_penalties=True,
|
||||
prompt_token_ids=None,
|
||||
frequency_penalties=dummy_tensors(0.1),
|
||||
presence_penalties=dummy_tensors(0.1),
|
||||
repetition_penalties=dummy_tensors(0.1),
|
||||
output_token_ids=[[] for _ in range(num_reqs)],
|
||||
min_tokens={},
|
||||
logit_bias=[None for _ in range(num_reqs)],
|
||||
allowed_token_ids_mask=None,
|
||||
)
|
||||
sampler_output = self.model.sample(logits=logits,
|
||||
sampling_metadata=dummy_metadata)
|
||||
|
||||
return sampler_output
|
||||
|
||||
def profile_run(self) -> None:
|
||||
# Profile with multimodal encoder & encoder cache.
|
||||
# TODO: handle encoder-decoder models once we support them.
|
||||
@ -1353,37 +1389,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
||||
if get_pp_group().is_last_rank:
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
dummy_tensors = lambda v: torch.full(
|
||||
(num_reqs, ), v, device=self.device)
|
||||
dummy_metadata = SamplingMetadata(
|
||||
temperature=dummy_tensors(0.5),
|
||||
all_greedy=False,
|
||||
all_random=False,
|
||||
top_p=dummy_tensors(0.9),
|
||||
top_k=dummy_tensors(logits.size(1) - 1),
|
||||
min_p=None,
|
||||
generators={},
|
||||
max_num_logprobs=None,
|
||||
no_penalties=True,
|
||||
prompt_token_ids=torch.ones_like(logits,
|
||||
dtype=torch.int64),
|
||||
frequency_penalties=dummy_tensors(0.1),
|
||||
presence_penalties=dummy_tensors(0.1),
|
||||
repetition_penalties=dummy_tensors(0.1),
|
||||
output_token_ids=[[] for _ in range(num_reqs)],
|
||||
min_tokens={},
|
||||
logit_bias=[None for _ in range(num_reqs)],
|
||||
allowed_token_ids_mask=None,
|
||||
)
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits, sampling_metadata=dummy_metadata)
|
||||
sampler_output = self._dummy_sampler_run(hidden_states)
|
||||
else:
|
||||
logits = None
|
||||
sampler_output = None
|
||||
dummy_metadata = None
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, logits, sampler_output, dummy_metadata
|
||||
del hidden_states, sampler_output
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
|
@ -119,6 +119,8 @@ class Worker(WorkerBase):
|
||||
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||
self.vllm_config, self.device)
|
||||
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
@ -211,6 +213,25 @@ class Worker(WorkerBase):
|
||||
self.model_runner._dummy_run(size)
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
# fragmentation issue.
|
||||
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||
try:
|
||||
self.model_runner._dummy_sampler_run(
|
||||
hidden_states=self.model_runner._dummy_run(
|
||||
num_tokens=self.scheduler_config.max_num_seqs))
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
raise RuntimeError(
|
||||
"CUDA out of memory occurred when warming up sampler. "
|
||||
"Please try lowering `gpu_memory_utilization` when "
|
||||
"initializing the engine.") from None
|
||||
else:
|
||||
raise e
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
Loading…
x
Reference in New Issue
Block a user