[V1] Fully Transparent Implementation of CPU Offloading (#15354)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-03-31 20:22:34 +08:00 committed by GitHub
parent e7ae3bf3d6
commit 555aa21905
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 148 additions and 25 deletions

View File

@ -234,6 +234,7 @@ set(VLLM_EXT_SRC
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"

39
csrc/cuda_view.cu Normal file
View File

@ -0,0 +1,39 @@
#include <torch/all.h>
#include <torch/cuda.h>
#include <cuda_runtime.h>
// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
// memory, and that UVA (Unified Virtual Addressing) is enabled.
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
// Get raw host pointer from CPU tensor
void* host_ptr = cpu_tensor.data_ptr();
// Get a device pointer corresponding to the pinned host memory
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
TORCH_CHECK(err == cudaSuccess,
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
// We'll use the same sizes, strides, and dtype as the CPU tensor.
// TODO: check if layout is respected.
auto sizes = cpu_tensor.sizes();
auto strides = cpu_tensor.strides();
auto options = cpu_tensor.options().device(torch::kCUDA);
// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
// memory, so we don't free it here.
auto deleter = [](void*) {
// no-op, since the memory is owned by the original CPU tensor
};
torch::Tensor cuda_tensor =
torch::from_blob(device_ptr, sizes, strides, deleter, options);
TORCH_CHECK(cuda_tensor.device().is_cuda(),
"Resulting tensor is not on CUDA device");
return cuda_tensor;
}

View File

@ -119,6 +119,8 @@ void advance_step_flashinfer(
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,

View File

@ -31,6 +31,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
&get_cuda_view_from_cpu_tensor);
// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.

View File

@ -1,15 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from ..utils import compare_two_settings
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
monkeypatch.setenv('VLLM_USE_V1', '0')
def test_cpu_offload():
compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", [],
["--cpu-offload-gb", "1"])

61
tests/kernels/test_uva.py Normal file
View File

@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cpu_write(device):
torch.set_default_device(device)
cpu_tensor = torch.zeros(10,
10,
device="cpu",
pin_memory=True,
dtype=torch.int32)
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
assert cuda_view.device.type == "cuda"
assert cuda_view[0, 0] == 0
assert cuda_view[2, 3] == 0
assert cuda_view[4, 5] == 0
cpu_tensor[0, 0] = 1
cpu_tensor[2, 3] = 2
cpu_tensor[4, 5] = -1
cuda_view.mul_(2)
assert cuda_view[0, 0] == 2
assert cuda_view[2, 3] == 4
assert cuda_view[4, 5] == -2
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_gpu_write(device):
torch.set_default_device(device)
cpu_tensor = torch.zeros(10,
10,
device="cpu",
pin_memory=True,
dtype=torch.int32)
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
assert cuda_view.device.type == "cuda"
assert cuda_view[0, 0] == 0
assert cuda_view[2, 3] == 0
assert cuda_view[4, 5] == 0
cuda_view[0, 0] = 1
cuda_view[2, 3] = 2
cuda_view[4, 5] = -1
cuda_view.mul_(2)
assert cpu_tensor[0, 0] == 2
assert cpu_tensor[2, 3] == 4
assert cpu_tensor[4, 5] == -2

View File

@ -10,13 +10,6 @@ from tests.quantization.utils import is_quant_method_supported
from ..utils import compare_two_settings
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
# Fall back to V0 if cpu offloading is enabled.
# Fixture is required to that baseline uses V0.
monkeypatch.setenv('VLLM_USE_V1', '0')
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="fp8 is not supported on this GPU type.")
def test_cpu_offload_fp8():

View File

@ -3562,9 +3562,10 @@ class VllmConfig:
if self.cache_config is not None and \
self.cache_config.cpu_offload_gb > 0 and \
self.compilation_config.level != CompilationLevel.NO_COMPILATION:
self.compilation_config.level != CompilationLevel.NO_COMPILATION \
and not envs.VLLM_USE_V1:
logger.warning(
"CPU offload is not supported with `torch.compile` yet."
"CPU offload is not supported with `torch.compile` in v0 yet."
" Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION

View File

@ -1595,12 +1595,6 @@ class EngineArgs:
recommend_to_remove=False)
return False
# No CPU offloading yet.
if self.cpu_offload_gb != EngineArgs.cpu_offload_gb:
_raise_or_fallback(feature_name="--cpu-offload-gb",
recommend_to_remove=False)
return False
# Only Fp16 and Bf16 dtypes since we only support FA.
V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
if model_config.dtype not in V1_SUPPORTED_DTYPES:

View File

@ -10,12 +10,14 @@ import torch.nn as nn
from torch.func import functional_call
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available)
logger = init_logger(__name__)
@ -505,6 +507,14 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
return module
pin_memory = is_pin_memory_available()
uva_available = is_uva_available()
if envs.VLLM_USE_V1:
assert uva_available, ("V1 CPU offloading requires"
" uva (pin memory) support")
uva_offloading = True
else:
uva_offloading = False
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
@ -523,11 +533,16 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device='cpu',
pin_memory=pin_memory)
cpu_data.copy_(p.data)
if not uva_offloading:
p.data = cpu_data
else:
# keep the cpu data alive
p._vllm_offloaded_cpu_data = cpu_data
p.data = get_cuda_view_from_cpu_tensor(cpu_data)
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True
if offloaded_parameters:
if offloaded_parameters and not uva_offloading:
original_forward = module.forward
def forward(*args, **kwargs):

View File

@ -795,6 +795,14 @@ def is_pin_memory_available() -> bool:
return current_platform.is_pin_memory_available()
@cache
def is_uva_available() -> bool:
"""Check if Unified Virtual Addressing (UVA) is available."""
# UVA requires pinned memory.
# TODO: Add more requirements for UVA if needed.
return is_pin_memory_available()
class DeviceMemoryProfiler:
def __init__(self, device: Optional[torch.types.Device] = None):
@ -1645,6 +1653,14 @@ def weak_ref_tensors(
raise ValueError("Invalid type for tensors")
def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
"""
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
"""
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
def is_in_doc_build() -> bool:
try:
from sphinx.ext.autodoc.mock import _MockModule

View File

@ -69,6 +69,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config