[V1][Core] Fix memory issue with logits & sampling (#14508)
Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: Varun Sundar Rabindranath <3337719+varun-sundar-rabindranath@users.noreply.github.com>
This commit is contained in:
parent
c982ac5722
commit
1fc973c0b5
@ -142,6 +142,15 @@ def test_end_to_end(model: str, use_v1: bool):
|
|||||||
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
|
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
|
||||||
# now the memory usage is mostly cudagraph memory pool,
|
# now the memory usage is mostly cudagraph memory pool,
|
||||||
# and it should be less than the model weights (1B model, 2GiB weights)
|
# and it should be less than the model weights (1B model, 2GiB weights)
|
||||||
|
|
||||||
|
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
|
||||||
|
# is captured but cannot be releasesd from PyTorch due to a known bug,
|
||||||
|
# therefore high memory usage after `llm.sleep` is called is expected.
|
||||||
|
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
|
||||||
|
# in V1.
|
||||||
|
if use_v1:
|
||||||
|
assert used_bytes < 7 * GiB_bytes
|
||||||
|
else:
|
||||||
assert used_bytes < 2 * GiB_bytes
|
assert used_bytes < 2 * GiB_bytes
|
||||||
|
|
||||||
llm.wake_up()
|
llm.wake_up()
|
||||||
|
@ -3525,6 +3525,11 @@ class VllmConfig:
|
|||||||
not self.model_config.enforce_eager:
|
not self.model_config.enforce_eager:
|
||||||
batch_size_capture_list = [1, 2, 4
|
batch_size_capture_list = [1, 2, 4
|
||||||
] + [i for i in range(8, 513, 8)]
|
] + [i for i in range(8, 513, 8)]
|
||||||
|
max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
|
batch_size_capture_list = [
|
||||||
|
size for size in batch_size_capture_list
|
||||||
|
if size <= max_num_tokens
|
||||||
|
]
|
||||||
|
|
||||||
self.compilation_config.init_with_cudagraph_sizes(
|
self.compilation_config.init_with_cudagraph_sizes(
|
||||||
batch_size_capture_list)
|
batch_size_capture_list)
|
||||||
|
@ -1202,6 +1202,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||||
|
# for dummy run with LoRA so that the num_reqs collectively
|
||||||
|
# has num_tokens in total.
|
||||||
|
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
||||||
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||||
|
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
|
||||||
|
min_tokens_per_req = num_tokens // num_reqs
|
||||||
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||||
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||||
|
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||||
|
assert len(num_scheduled_tokens_list) == num_reqs
|
||||||
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||||
|
dtype=np.int32)
|
||||||
|
|
||||||
|
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||||
|
num_scheduled_tokens):
|
||||||
model = self.model
|
model = self.model
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
@ -1228,7 +1245,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for k, v in self.intermediate_tensors.items()
|
for k, v in self.intermediate_tensors.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
with set_forward_context(None, self.vllm_config,
|
with set_forward_context(None,
|
||||||
|
self.vllm_config,
|
||||||
num_tokens=num_tokens):
|
num_tokens=num_tokens):
|
||||||
hidden_states = model(
|
hidden_states = model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -1236,7 +1254,46 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
return hidden_states
|
|
||||||
|
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||||
|
return hidden_states[logit_indices]
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _dummy_sampler_run(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
|
num_reqs = logits.size(0)
|
||||||
|
|
||||||
|
dummy_tensors = lambda v: torch.full(
|
||||||
|
(num_reqs, ), v, device=self.device)
|
||||||
|
|
||||||
|
dummy_metadata = SamplingMetadata(
|
||||||
|
temperature=dummy_tensors(0.5),
|
||||||
|
all_greedy=False,
|
||||||
|
all_random=False,
|
||||||
|
top_p=dummy_tensors(0.9),
|
||||||
|
top_k=dummy_tensors(logits.size(1) - 1),
|
||||||
|
min_p=None,
|
||||||
|
generators={},
|
||||||
|
max_num_logprobs=None,
|
||||||
|
no_penalties=True,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
frequency_penalties=dummy_tensors(0.1),
|
||||||
|
presence_penalties=dummy_tensors(0.1),
|
||||||
|
repetition_penalties=dummy_tensors(0.1),
|
||||||
|
output_token_ids=[[] for _ in range(num_reqs)],
|
||||||
|
min_tokens={},
|
||||||
|
logit_bias=[None for _ in range(num_reqs)],
|
||||||
|
allowed_token_ids_mask=None,
|
||||||
|
bad_words_token_ids={},
|
||||||
|
)
|
||||||
|
sampler_output = self.model.sample(logits=logits,
|
||||||
|
sampling_metadata=dummy_metadata)
|
||||||
|
|
||||||
|
return sampler_output
|
||||||
|
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# Profile with multimodal encoder & encoder cache.
|
# Profile with multimodal encoder & encoder cache.
|
||||||
@ -1332,59 +1389,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Cache the dummy encoder outputs.
|
# Cache the dummy encoder outputs.
|
||||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||||
|
|
||||||
# For profile, have maximum num_reqs and that collectively have
|
|
||||||
# maximum num_tokens.
|
|
||||||
num_reqs = self.scheduler_config.max_num_seqs
|
|
||||||
num_tokens = self.max_num_tokens
|
|
||||||
min_tokens_per_req = num_tokens // num_reqs
|
|
||||||
|
|
||||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
|
||||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
|
||||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
|
||||||
assert len(num_scheduled_tokens_list) == num_reqs
|
|
||||||
|
|
||||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
|
||||||
dtype=np.int32)
|
|
||||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
|
||||||
|
|
||||||
with self.maybe_profile_with_lora(self.lora_config,
|
|
||||||
num_scheduled_tokens):
|
|
||||||
# Trigger compilation for general shape.
|
|
||||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
hidden_states = self._dummy_run(self.max_num_tokens)
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
hidden_states = hidden_states[logit_indices]
|
sampler_output = self._dummy_sampler_run(hidden_states)
|
||||||
logits = self.model.compute_logits(hidden_states, None)
|
|
||||||
dummy_tensors = lambda v: torch.full(
|
|
||||||
(num_reqs, ), v, device=self.device)
|
|
||||||
dummy_metadata = SamplingMetadata(
|
|
||||||
temperature=dummy_tensors(0.5),
|
|
||||||
all_greedy=False,
|
|
||||||
all_random=False,
|
|
||||||
top_p=dummy_tensors(0.9),
|
|
||||||
top_k=dummy_tensors(logits.size(1) - 1),
|
|
||||||
min_p=None,
|
|
||||||
generators={},
|
|
||||||
max_num_logprobs=None,
|
|
||||||
no_penalties=True,
|
|
||||||
prompt_token_ids=torch.ones_like(logits,
|
|
||||||
dtype=torch.int64),
|
|
||||||
frequency_penalties=dummy_tensors(0.1),
|
|
||||||
presence_penalties=dummy_tensors(0.1),
|
|
||||||
repetition_penalties=dummy_tensors(0.1),
|
|
||||||
output_token_ids=[[] for _ in range(num_reqs)],
|
|
||||||
min_tokens={},
|
|
||||||
logit_bias=[None for _ in range(num_reqs)],
|
|
||||||
allowed_token_ids_mask=None,
|
|
||||||
bad_words_token_ids={},
|
|
||||||
)
|
|
||||||
sampler_output = self.model.sample(
|
|
||||||
logits=logits, sampling_metadata=dummy_metadata)
|
|
||||||
else:
|
else:
|
||||||
logits = None
|
|
||||||
sampler_output = None
|
sampler_output = None
|
||||||
dummy_metadata = None
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
del hidden_states, logits, sampler_output, dummy_metadata
|
del hidden_states, sampler_output
|
||||||
self.encoder_cache.clear()
|
self.encoder_cache.clear()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
@ -119,6 +119,8 @@ class Worker(WorkerBase):
|
|||||||
self.model_runner: GPUModelRunner = GPUModelRunner(
|
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||||
self.vllm_config, self.device)
|
self.vllm_config, self.device)
|
||||||
|
|
||||||
|
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||||
|
# to hijack tensor allocation.
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
if self.vllm_config.model_config.enable_sleep_mode:
|
if self.vllm_config.model_config.enable_sleep_mode:
|
||||||
allocator = CuMemAllocator.get_instance()
|
allocator = CuMemAllocator.get_instance()
|
||||||
@ -211,6 +213,27 @@ class Worker(WorkerBase):
|
|||||||
self.model_runner._dummy_run(size)
|
self.model_runner._dummy_run(size)
|
||||||
if not self.model_config.enforce_eager:
|
if not self.model_config.enforce_eager:
|
||||||
self.model_runner.capture_model()
|
self.model_runner.capture_model()
|
||||||
|
|
||||||
|
# Warm up sampler and preallocate memory buffer for logits and other
|
||||||
|
# sampling related tensors of max possible shape to avoid memory
|
||||||
|
# fragmentation issue.
|
||||||
|
# NOTE: This is called after `capture_model` on purpose to prevent
|
||||||
|
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
||||||
|
try:
|
||||||
|
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
||||||
|
self.scheduler_config.max_num_batched_tokens)
|
||||||
|
self.model_runner._dummy_sampler_run(
|
||||||
|
hidden_states=self.model_runner._dummy_run(
|
||||||
|
num_tokens=max_num_reqs))
|
||||||
|
except RuntimeError as e:
|
||||||
|
if 'out of memory' in str(e):
|
||||||
|
raise RuntimeError(
|
||||||
|
"CUDA out of memory occurred when warming up sampler. "
|
||||||
|
"Please try lowering `gpu_memory_utilization` when "
|
||||||
|
"initializing the engine.") from None
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
# Reset the seed to ensure that the random state is not affected by
|
# Reset the seed to ensure that the random state is not affected by
|
||||||
# the model initialization and profiling.
|
# the model initialization and profiling.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
@ -83,7 +83,7 @@ class LoRAModelRunnerMixin:
|
|||||||
lora_requests)
|
lora_requests)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def maybe_profile_with_lora(self, lora_config: LoRAConfig,
|
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
||||||
num_scheduled_tokens: np.ndarray):
|
num_scheduled_tokens: np.ndarray):
|
||||||
if lora_config is None:
|
if lora_config is None:
|
||||||
yield
|
yield
|
||||||
|
Loading…
x
Reference in New Issue
Block a user