[core] further polish memory profiling (#12126)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-18 12:25:08 +08:00 committed by GitHub
parent c09503ddd6
commit da02cb4b27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 85 additions and 67 deletions

View File

@ -9,10 +9,10 @@ import torch
from vllm_test_utils import monitor from vllm_test_utils import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule, from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
StoreBoolean, bind_kv_cache, deprecate_kwargs, PlaceholderModule, StoreBoolean, bind_kv_cache,
get_open_port, memory_profiling, merge_async_iterators, deprecate_kwargs, get_open_port, memory_profiling,
supports_kw) merge_async_iterators, supports_kw)
from .utils import error_on_warning, fork_new_process_for_each_test from .utils import error_on_warning, fork_new_process_for_each_test
@ -284,14 +284,13 @@ def test_memory_profiling():
# 512 MiB allocation outside of this instance # 512 MiB allocation outside of this instance
handle1 = lib.cudaMalloc(512 * 1024 * 1024) handle1 = lib.cudaMalloc(512 * 1024 * 1024)
baseline_memory_in_bytes = \ baseline_snapshot = MemorySnapshot()
torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]
# load weights # load weights
weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32) weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32)
weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB
def measure_current_non_torch(): def measure_current_non_torch():
free, total = torch.cuda.mem_get_info() free, total = torch.cuda.mem_get_info()
@ -300,8 +299,8 @@ def test_memory_profiling():
current_non_torch = current_used - current_torch current_non_torch = current_used - current_torch
return current_non_torch return current_non_torch
with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes, with memory_profiling(baseline_snapshot=baseline_snapshot,
weights_memory_in_bytes=weights_memory_in_bytes) as result, \ weights_memory=weights_memory) as result, \
monitor(measure_current_non_torch) as monitored_values: monitor(measure_current_non_torch) as monitored_values:
# make a memory spike, 1 GiB # make a memory spike, 1 GiB
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32) spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
@ -316,13 +315,12 @@ def test_memory_profiling():
assert measured_diff == 256 * 1024 * 1024 assert measured_diff == 256 * 1024 * 1024
# Check that the memory usage is within 5% of the expected values # Check that the memory usage is within 5% of the expected values
# 5% tolerance is caused by PyTorch caching allocator, # 5% tolerance is caused by cuda runtime.
# we cannot control PyTorch's behavior of its internal buffers, # we cannot control cuda runtime in the granularity of bytes,
# which causes a small error (<10 MiB in practice) # which causes a small error (<10 MiB in practice)
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
assert abs(non_torch_ratio - 1) <= 0.05 assert abs(non_torch_ratio - 1) <= 0.05
assert abs(torch_peak_ratio - 1) <= 0.05 assert result.torch_peak_increase == 1024 * 1024 * 1024
del weights del weights
lib.cudaFree(handle1) lib.cudaFree(handle1)
lib.cudaFree(handle2) lib.cudaFree(handle2)

View File

@ -1923,36 +1923,57 @@ def kill_process_tree(pid: int):
@dataclass @dataclass
class MemorySnapshot: class MemorySnapshot:
"""Memory snapshot.""" """Memory snapshot."""
torch_peak_in_bytes: int = 0 torch_peak: int = 0
torch_memory_in_bytes: int = 0 cuda_memory: int = 0
torch_memory: int = 0
non_torch_memory: int = 0
timestamp: float = 0.0 timestamp: float = 0.0
auto_measure: bool = True
def __post_init__(self):
if self.auto_measure:
self.measure()
def measure(self): def measure(self):
self.torch_peak_in_bytes = torch.cuda.max_memory_reserved() # we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get(
"allocated_bytes.all.peak", 0)
self.cuda_memory = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]
# torch.cuda.memory_reserved() is how many bytes # torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.) # PyTorch gets from cuda (by calling cudaMalloc, etc.)
self.torch_memory_in_bytes = torch.cuda.memory_reserved() # this is used to measure the non-torch memory usage
self.torch_memory = torch.cuda.memory_reserved()
self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time() self.timestamp = time.time()
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
"""support a - b"""
return MemorySnapshot( return MemorySnapshot(
torch_peak_in_bytes=self.torch_peak_in_bytes - torch_peak=self.torch_peak - other.torch_peak,
other.torch_peak_in_bytes, cuda_memory=self.cuda_memory - other.cuda_memory,
torch_memory_in_bytes=self.torch_memory_in_bytes - torch_memory=self.torch_memory - other.torch_memory,
other.torch_memory_in_bytes, non_torch_memory=self.non_torch_memory - other.non_torch_memory,
timestamp=self.timestamp - other.timestamp) timestamp=self.timestamp - other.timestamp,
auto_measure=False,
)
@dataclass @dataclass
class MemoryProfilingResult: class MemoryProfilingResult:
"""Memory profiling result. """Memory profiling result. All numbers are in bytes.
""" # noqa """
baseline_memory_in_bytes: int = 0 non_kv_cache_memory: int = 0
non_kv_cache_memory_in_bytes: int = 0 torch_peak_increase: int = 0
torch_peak_increase_in_bytes: int = 0 non_torch_increase: int = 0
non_torch_increase_in_bytes: int = 0 weights_memory: float = 0
weights_memory_in_bytes: float = 0 before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot) before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot) after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
profile_time: float = 0.0 profile_time: float = 0.0
@ -1960,18 +1981,14 @@ class MemoryProfilingResult:
@contextlib.contextmanager @contextlib.contextmanager
def memory_profiling( def memory_profiling(
baseline_memory_in_bytes: int, weights_memory_in_bytes: int baseline_snapshot: MemorySnapshot,
) -> Generator[MemoryProfilingResult, None, None]: weights_memory: int) -> Generator[MemoryProfilingResult, None, None]:
"""Memory profiling context manager. """Memory profiling context manager.
baseline_memory_in_bytes: memory used by all the components other than baseline_snapshot: the memory snapshot before the current vLLM instance.
the current vLLM instance. It contains: memory used by other processes, memory weights_memory: memory used by PyTorch when loading the model weights.
used by another vLLM instance in the same process, etc. It is usually measured
before the current vLLM instance initialize the device. And we assume it is
constant during the profiling of the current vLLM instance.
weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device Note that, before loading the model weights, we also initialize the device
and distributed environment, which may consume some memory. This part is not and distributed environment, which may consume some memory. This part is not
included in the weights_memory_in_bytes because PyTorch does not control it. included in the weights_memory because PyTorch does not control it.
The memory in one GPU can be classified into 3 categories: The memory in one GPU can be classified into 3 categories:
1. memory used by anything other than the current vLLM instance. 1. memory used by anything other than the current vLLM instance.
@ -2006,20 +2023,21 @@ def memory_profiling(
b. 2 GiB reserved for the peak activation tensors (category 2) b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3) c. 1 GiB used by non-torch components (category 3)
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`. The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.). The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`), The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
""" # noqa """ # noqa
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
result = MemoryProfilingResult() result = MemoryProfilingResult()
result.baseline_memory_in_bytes = baseline_memory_in_bytes result.before_create = baseline_snapshot
# the part of memory used for holding the model weights # the part of memory used for holding the model weights
result.weights_memory_in_bytes = weights_memory_in_bytes result.weights_memory = weights_memory
result.before_profile.measure() result.before_profile.measure()
@ -2030,13 +2048,12 @@ def memory_profiling(
result.after_profile.measure() result.after_profile.measure()
diff = result.after_profile - result.before_profile diff_profile = result.after_profile - result.before_profile
result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes diff_from_create = result.after_profile - result.before_create
current_cuda_memory_bytes = torch.cuda.mem_get_info( result.torch_peak_increase = diff_profile.torch_peak
)[1] - torch.cuda.mem_get_info()[0] result.non_torch_increase = diff_from_create.non_torch_memory
result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa result.profile_time = diff_profile.timestamp
result.profile_time = diff.timestamp result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501

View File

@ -21,7 +21,8 @@ from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta) SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
memory_profiling)
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
@ -137,7 +138,8 @@ class Worker(LocalOrDistributedWorkerBase):
_check_if_gpu_supports_dtype(self.model_config.dtype) _check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0] torch.cuda.reset_peak_memory_stats()
self.baseline_snapshot = MemorySnapshot()
else: else:
raise RuntimeError( raise RuntimeError(
f"Not support device type: {self.device_config.device}") f"Not support device type: {self.device_config.device}")
@ -192,10 +194,9 @@ class Worker(LocalOrDistributedWorkerBase):
# Execute a forward pass with dummy inputs to profile the memory usage # Execute a forward pass with dummy inputs to profile the memory usage
# of the model. # of the model.
with memory_profiling(baseline_memory_in_bytes=total_gpu_memory - with memory_profiling(
self.init_gpu_memory, self.baseline_snapshot,
weights_memory_in_bytes=self.model_runner. weights_memory=self.model_runner.model_memory_usage) as result:
model_memory_usage) as result:
self.model_runner.profile_run() self.model_runner.profile_run()
self._assert_memory_footprint_increased_during_profiling() self._assert_memory_footprint_increased_during_profiling()
@ -203,7 +204,7 @@ class Worker(LocalOrDistributedWorkerBase):
memory_for_current_instance = total_gpu_memory * \ memory_for_current_instance = total_gpu_memory * \
self.cache_config.gpu_memory_utilization self.cache_config.gpu_memory_utilization
available_kv_cache_memory = (memory_for_current_instance - available_kv_cache_memory = (memory_for_current_instance -
result.non_kv_cache_memory_in_bytes) result.non_kv_cache_memory)
# Calculate the number of blocks that can be allocated with the # Calculate the number of blocks that can be allocated with the
# profiled peak memory. # profiled peak memory.
@ -226,11 +227,11 @@ class Worker(LocalOrDistributedWorkerBase):
f"({self.cache_config.gpu_memory_utilization:.2f})" f"({self.cache_config.gpu_memory_utilization:.2f})"
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n" f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
"model weights take " "model weights take "
f"{(result.weights_memory_in_bytes / GiB_bytes):.2f}GiB;" f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
" non_torch_memory takes " " non_torch_memory takes "
f"{(result.non_torch_increase_in_bytes / GiB_bytes):.2f}GiB;" f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
" PyTorch activation peak memory takes " " PyTorch activation peak memory takes "
f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB;" f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
" the rest of the memory reserved for KV Cache is " " the rest of the memory reserved for KV Cache is "
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.") f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
@ -246,11 +247,13 @@ class Worker(LocalOrDistributedWorkerBase):
def _assert_memory_footprint_increased_during_profiling(self): def _assert_memory_footprint_increased_during_profiling(self):
# NOTE(woosuk): Here we assume that the other processes using the same # NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling. # GPU did not change their memory usage during the profiling.
free_gpu_memory, _ = torch.cuda.mem_get_info() free_gpu_memory, total = torch.cuda.mem_get_info()
assert self.init_gpu_memory - free_gpu_memory > 0, ( cuda_memory = total - free_gpu_memory
assert self.baseline_snapshot.cuda_memory < cuda_memory, (
"Error in memory profiling. " "Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory" f"Initial used memory {self.baseline_snapshot.cuda_memory}, "
f" {free_gpu_memory}. This happens when the GPU memory was " f"currently used memory {cuda_memory}. "
f"This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.") "not properly cleaned up before initializing the vLLM instance.")
def initialize_cache(self, num_gpu_blocks: int, def initialize_cache(self, num_gpu_blocks: int,