🐛 fix torch memory profiling (#9516)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
337ed76671
commit
380e18639f
@ -107,8 +107,7 @@ def validate_generated_texts(hf_runner,
|
||||
quantization='bitsandbytes',
|
||||
load_format='bitsandbytes',
|
||||
tensor_parallel_size=vllm_tp_size,
|
||||
enforce_eager=False,
|
||||
gpu_memory_utilization=0.8) as llm:
|
||||
enforce_eager=False) as llm:
|
||||
vllm_outputs = llm.generate_greedy(prompts, 8)
|
||||
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
|
||||
|
||||
|
@ -54,16 +54,17 @@ def test_gpu_memory_profiling():
|
||||
gpu_blocks, _ = worker.determine_num_available_blocks()
|
||||
|
||||
# Peak vram usage by torch should be 0.7077 GiB
|
||||
# Non-torch allocations should be 0.0079 GiB
|
||||
# No memory should be allocated outside of torch
|
||||
# 9.0 GiB should be the utilization target
|
||||
# 8.2843 GiB should be available for the KV cache
|
||||
# 8.2923 GiB should be available for the KV cache
|
||||
block_size = CacheEngine.get_cache_block_size(
|
||||
engine_config.cache_config, engine_config.model_config,
|
||||
engine_config.parallel_config)
|
||||
|
||||
expected_blocks = (8.2843 * 1024**3) // block_size
|
||||
expected_blocks = (8.2923 * 1024**3) // block_size
|
||||
|
||||
# Check within a small tolerance for portability
|
||||
# Hardware, kernel, or dependency changes could all affect memory
|
||||
# utilization
|
||||
assert abs(gpu_blocks - expected_blocks) < 5
|
||||
# utilization.
|
||||
# A 10 block tolerance here should be about 6MB of wiggle room.
|
||||
assert abs(gpu_blocks - expected_blocks) < 10
|
||||
|
@ -232,10 +232,11 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
# gpu outside of `torch`. NCCL operations, for example, can use a few
|
||||
# GB during a forward pass
|
||||
torch.cuda.empty_cache()
|
||||
# After emptying the torch cache, any other increase in gpu ram should
|
||||
# be from non-torch allocations.
|
||||
non_torch_allocations = free_memory_pre_profile - \
|
||||
torch.cuda.mem_get_info()[0]
|
||||
torch_allocated_bytes = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
total_allocated_bytes = torch.cuda.mem_get_info(
|
||||
)[1] - torch.cuda.mem_get_info()[0]
|
||||
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||||
if non_torch_allocations > 0:
|
||||
peak_memory += non_torch_allocations
|
||||
|
||||
@ -259,10 +260,12 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
logger.info(
|
||||
"Memory profiling results: total_gpu_memory=%.2fGiB"
|
||||
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
|
||||
" memory_usage_post_profile=%.2fGib"
|
||||
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
|
||||
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
|
||||
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
|
||||
(peak_memory - non_torch_allocations) / (1024**3),
|
||||
total_allocated_bytes / (1024**3),
|
||||
non_torch_allocations / (1024**3),
|
||||
available_kv_cache_memory / (1024**3),
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
|
Loading…
x
Reference in New Issue
Block a user