[Platform] move current_memory_usage() into platform (#11369)
Signed-off-by: Shanshan Shen <467638484@qq.com>
This commit is contained in:
parent
1a51b9f872
commit
9ddac56311
@ -143,6 +143,13 @@ class CudaPlatformBase(Platform):
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return torch.cuda.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1) -> str:
|
||||
|
@ -277,6 +277,15 @@ class Platform:
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
"""
|
||||
Return the memory usage in bytes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
"""
|
||||
|
@ -157,3 +157,10 @@ class RocmPlatform(Platform):
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return torch.cuda.max_memory_allocated(device)
|
||||
|
@ -94,3 +94,10 @@ class XPUPlatform(Platform):
|
||||
def is_pin_memory_available(cls):
|
||||
logger.warning("Pin memory is not supported on XPU.")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
torch.xpu.reset_peak_memory_stats(device)
|
||||
return torch.xpu.max_memory_allocated(device)
|
||||
|
@ -710,13 +710,7 @@ class DeviceMemoryProfiler:
|
||||
def current_memory_usage(self) -> float:
|
||||
# Return the memory usage in bytes.
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.reset_peak_memory_stats(self.device)
|
||||
mem = torch.cuda.max_memory_allocated(self.device)
|
||||
elif current_platform.is_xpu():
|
||||
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
|
||||
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
|
||||
return mem
|
||||
return current_platform.get_current_memory_usage(self.device)
|
||||
|
||||
def __enter__(self):
|
||||
self.initial_memory = self.current_memory_usage()
|
||||
|
Loading…
x
Reference in New Issue
Block a user