[core] further polish memory profiling (#12126)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
c09503ddd6
commit
da02cb4b27
@ -9,10 +9,10 @@ import torch
|
||||
from vllm_test_utils import monitor
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
|
||||
StoreBoolean, bind_kv_cache, deprecate_kwargs,
|
||||
get_open_port, memory_profiling, merge_async_iterators,
|
||||
supports_kw)
|
||||
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
|
||||
PlaceholderModule, StoreBoolean, bind_kv_cache,
|
||||
deprecate_kwargs, get_open_port, memory_profiling,
|
||||
merge_async_iterators, supports_kw)
|
||||
|
||||
from .utils import error_on_warning, fork_new_process_for_each_test
|
||||
|
||||
@ -284,14 +284,13 @@ def test_memory_profiling():
|
||||
# 512 MiB allocation outside of this instance
|
||||
handle1 = lib.cudaMalloc(512 * 1024 * 1024)
|
||||
|
||||
baseline_memory_in_bytes = \
|
||||
torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]
|
||||
baseline_snapshot = MemorySnapshot()
|
||||
|
||||
# load weights
|
||||
|
||||
weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32)
|
||||
|
||||
weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB
|
||||
weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB
|
||||
|
||||
def measure_current_non_torch():
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
@ -300,8 +299,8 @@ def test_memory_profiling():
|
||||
current_non_torch = current_used - current_torch
|
||||
return current_non_torch
|
||||
|
||||
with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
|
||||
weights_memory_in_bytes=weights_memory_in_bytes) as result, \
|
||||
with memory_profiling(baseline_snapshot=baseline_snapshot,
|
||||
weights_memory=weights_memory) as result, \
|
||||
monitor(measure_current_non_torch) as monitored_values:
|
||||
# make a memory spike, 1 GiB
|
||||
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
|
||||
@ -316,13 +315,12 @@ def test_memory_profiling():
|
||||
assert measured_diff == 256 * 1024 * 1024
|
||||
|
||||
# Check that the memory usage is within 5% of the expected values
|
||||
# 5% tolerance is caused by PyTorch caching allocator,
|
||||
# we cannot control PyTorch's behavior of its internal buffers,
|
||||
# 5% tolerance is caused by cuda runtime.
|
||||
# we cannot control cuda runtime in the granularity of bytes,
|
||||
# which causes a small error (<10 MiB in practice)
|
||||
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
|
||||
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
|
||||
non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa
|
||||
assert abs(non_torch_ratio - 1) <= 0.05
|
||||
assert abs(torch_peak_ratio - 1) <= 0.05
|
||||
assert result.torch_peak_increase == 1024 * 1024 * 1024
|
||||
del weights
|
||||
lib.cudaFree(handle1)
|
||||
lib.cudaFree(handle2)
|
||||
|
@ -1923,36 +1923,57 @@ def kill_process_tree(pid: int):
|
||||
@dataclass
|
||||
class MemorySnapshot:
|
||||
"""Memory snapshot."""
|
||||
torch_peak_in_bytes: int = 0
|
||||
torch_memory_in_bytes: int = 0
|
||||
torch_peak: int = 0
|
||||
cuda_memory: int = 0
|
||||
torch_memory: int = 0
|
||||
non_torch_memory: int = 0
|
||||
timestamp: float = 0.0
|
||||
auto_measure: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.auto_measure:
|
||||
self.measure()
|
||||
|
||||
def measure(self):
|
||||
self.torch_peak_in_bytes = torch.cuda.max_memory_reserved()
|
||||
# we measure the torch peak memory usage via allocated_bytes,
|
||||
# rather than `torch.cuda.memory_reserved()` .
|
||||
# After `torch.cuda.reset_peak_memory_stats()`,
|
||||
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
|
||||
# when we call `torch.cuda.empty_cache()` or OOM happens.
|
||||
self.torch_peak = torch.cuda.memory_stats().get(
|
||||
"allocated_bytes.all.peak", 0)
|
||||
|
||||
self.cuda_memory = torch.cuda.mem_get_info(
|
||||
)[1] - torch.cuda.mem_get_info()[0]
|
||||
|
||||
# torch.cuda.memory_reserved() is how many bytes
|
||||
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
|
||||
self.torch_memory_in_bytes = torch.cuda.memory_reserved()
|
||||
# this is used to measure the non-torch memory usage
|
||||
self.torch_memory = torch.cuda.memory_reserved()
|
||||
|
||||
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
||||
self.timestamp = time.time()
|
||||
|
||||
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
|
||||
"""support a - b"""
|
||||
return MemorySnapshot(
|
||||
torch_peak_in_bytes=self.torch_peak_in_bytes -
|
||||
other.torch_peak_in_bytes,
|
||||
torch_memory_in_bytes=self.torch_memory_in_bytes -
|
||||
other.torch_memory_in_bytes,
|
||||
timestamp=self.timestamp - other.timestamp)
|
||||
torch_peak=self.torch_peak - other.torch_peak,
|
||||
cuda_memory=self.cuda_memory - other.cuda_memory,
|
||||
torch_memory=self.torch_memory - other.torch_memory,
|
||||
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
|
||||
timestamp=self.timestamp - other.timestamp,
|
||||
auto_measure=False,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryProfilingResult:
|
||||
"""Memory profiling result.
|
||||
""" # noqa
|
||||
baseline_memory_in_bytes: int = 0
|
||||
non_kv_cache_memory_in_bytes: int = 0
|
||||
torch_peak_increase_in_bytes: int = 0
|
||||
non_torch_increase_in_bytes: int = 0
|
||||
weights_memory_in_bytes: float = 0
|
||||
"""Memory profiling result. All numbers are in bytes.
|
||||
"""
|
||||
non_kv_cache_memory: int = 0
|
||||
torch_peak_increase: int = 0
|
||||
non_torch_increase: int = 0
|
||||
weights_memory: float = 0
|
||||
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
profile_time: float = 0.0
|
||||
@ -1960,18 +1981,14 @@ class MemoryProfilingResult:
|
||||
|
||||
@contextlib.contextmanager
|
||||
def memory_profiling(
|
||||
baseline_memory_in_bytes: int, weights_memory_in_bytes: int
|
||||
) -> Generator[MemoryProfilingResult, None, None]:
|
||||
baseline_snapshot: MemorySnapshot,
|
||||
weights_memory: int) -> Generator[MemoryProfilingResult, None, None]:
|
||||
"""Memory profiling context manager.
|
||||
baseline_memory_in_bytes: memory used by all the components other than
|
||||
the current vLLM instance. It contains: memory used by other processes, memory
|
||||
used by another vLLM instance in the same process, etc. It is usually measured
|
||||
before the current vLLM instance initialize the device. And we assume it is
|
||||
constant during the profiling of the current vLLM instance.
|
||||
weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
|
||||
baseline_snapshot: the memory snapshot before the current vLLM instance.
|
||||
weights_memory: memory used by PyTorch when loading the model weights.
|
||||
Note that, before loading the model weights, we also initialize the device
|
||||
and distributed environment, which may consume some memory. This part is not
|
||||
included in the weights_memory_in_bytes because PyTorch does not control it.
|
||||
included in the weights_memory because PyTorch does not control it.
|
||||
|
||||
The memory in one GPU can be classified into 3 categories:
|
||||
1. memory used by anything other than the current vLLM instance.
|
||||
@ -2006,20 +2023,21 @@ def memory_profiling(
|
||||
b. 2 GiB reserved for the peak activation tensors (category 2)
|
||||
c. 1 GiB used by non-torch components (category 3)
|
||||
|
||||
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
|
||||
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
|
||||
|
||||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
|
||||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
|
||||
|
||||
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
|
||||
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
|
||||
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
|
||||
""" # noqa
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
result = MemoryProfilingResult()
|
||||
|
||||
result.baseline_memory_in_bytes = baseline_memory_in_bytes
|
||||
result.before_create = baseline_snapshot
|
||||
# the part of memory used for holding the model weights
|
||||
result.weights_memory_in_bytes = weights_memory_in_bytes
|
||||
result.weights_memory = weights_memory
|
||||
|
||||
result.before_profile.measure()
|
||||
|
||||
@ -2030,13 +2048,12 @@ def memory_profiling(
|
||||
|
||||
result.after_profile.measure()
|
||||
|
||||
diff = result.after_profile - result.before_profile
|
||||
result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes
|
||||
current_cuda_memory_bytes = torch.cuda.mem_get_info(
|
||||
)[1] - torch.cuda.mem_get_info()[0]
|
||||
result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa
|
||||
result.profile_time = diff.timestamp
|
||||
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
|
||||
diff_profile = result.after_profile - result.before_profile
|
||||
diff_from_create = result.after_profile - result.before_create
|
||||
result.torch_peak_increase = diff_profile.torch_peak
|
||||
result.non_torch_increase = diff_from_create.non_torch_memory
|
||||
result.profile_time = diff_profile.timestamp
|
||||
result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
|
||||
|
@ -21,7 +21,8 @@ from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||||
from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling
|
||||
from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
|
||||
memory_profiling)
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
@ -137,7 +138,8 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.baseline_snapshot = MemorySnapshot()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
@ -192,10 +194,9 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with memory_profiling(baseline_memory_in_bytes=total_gpu_memory -
|
||||
self.init_gpu_memory,
|
||||
weights_memory_in_bytes=self.model_runner.
|
||||
model_memory_usage) as result:
|
||||
with memory_profiling(
|
||||
self.baseline_snapshot,
|
||||
weights_memory=self.model_runner.model_memory_usage) as result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
self._assert_memory_footprint_increased_during_profiling()
|
||||
@ -203,7 +204,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
memory_for_current_instance = total_gpu_memory * \
|
||||
self.cache_config.gpu_memory_utilization
|
||||
available_kv_cache_memory = (memory_for_current_instance -
|
||||
result.non_kv_cache_memory_in_bytes)
|
||||
result.non_kv_cache_memory)
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
@ -226,11 +227,11 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
f"({self.cache_config.gpu_memory_utilization:.2f})"
|
||||
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
|
||||
"model weights take "
|
||||
f"{(result.weights_memory_in_bytes / GiB_bytes):.2f}GiB;"
|
||||
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
|
||||
" non_torch_memory takes "
|
||||
f"{(result.non_torch_increase_in_bytes / GiB_bytes):.2f}GiB;"
|
||||
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
|
||||
" PyTorch activation peak memory takes "
|
||||
f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB;"
|
||||
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
|
||||
" the rest of the memory reserved for KV Cache is "
|
||||
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
|
||||
|
||||
@ -246,11 +247,13 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
def _assert_memory_footprint_increased_during_profiling(self):
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
assert self.init_gpu_memory - free_gpu_memory > 0, (
|
||||
free_gpu_memory, total = torch.cuda.mem_get_info()
|
||||
cuda_memory = total - free_gpu_memory
|
||||
assert self.baseline_snapshot.cuda_memory < cuda_memory, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||
f"Initial used memory {self.baseline_snapshot.cuda_memory}, "
|
||||
f"currently used memory {cuda_memory}. "
|
||||
f"This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
|
Loading…
x
Reference in New Issue
Block a user