diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 80cefcb4..2587e3a1 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -143,6 +143,13 @@ class CudaPlatformBase(Platform): if cache_config and cache_config.block_size is None: cache_config.block_size = 16 + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + torch.cuda.reset_peak_memory_stats(device) + return torch.cuda.max_memory_allocated(device) + @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1) -> str: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 85fde767..f2ecec32 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -277,6 +277,15 @@ class Platform: return False return True + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + """ + Return the memory usage in bytes. + """ + raise NotImplementedError + @classmethod def get_punica_wrapper(cls) -> str: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 43105d78..67a9e816 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -157,3 +157,10 @@ class RocmPlatform(Platform): @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" + + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + torch.cuda.reset_peak_memory_stats(device) + return torch.cuda.max_memory_allocated(device) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index f34376b4..031abdc0 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -94,3 +94,10 @@ class XPUPlatform(Platform): def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on XPU.") return False + + @classmethod + def get_current_memory_usage(cls, + device: Optional[torch.types.Device] = None + ) -> float: + torch.xpu.reset_peak_memory_stats(device) + return torch.xpu.max_memory_allocated(device) diff --git a/vllm/utils.py b/vllm/utils.py index 9a509da3..7477e702 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -710,13 +710,7 @@ class DeviceMemoryProfiler: def current_memory_usage(self) -> float: # Return the memory usage in bytes. from vllm.platforms import current_platform - if current_platform.is_cuda_alike(): - torch.cuda.reset_peak_memory_stats(self.device) - mem = torch.cuda.max_memory_allocated(self.device) - elif current_platform.is_xpu(): - torch.xpu.reset_peak_memory_stats(self.device) # type: ignore - mem = torch.xpu.max_memory_allocated(self.device) # type: ignore - return mem + return current_platform.get_current_memory_usage(self.device) def __enter__(self): self.initial_memory = self.current_memory_usage()