From 555aa21905a1d725b44a29ea8cbebf218ff14558 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 31 Mar 2025 20:22:34 +0800 Subject: [PATCH] [V1] Fully Transparent Implementation of CPU Offloading (#15354) Signed-off-by: youkaichao --- CMakeLists.txt | 1 + csrc/cuda_view.cu | 39 +++++++++++++ csrc/ops.h | 2 + csrc/torch_bindings.cpp | 4 ++ tests/basic_correctness/test_cpu_offload.py | 7 --- tests/kernels/test_uva.py | 61 +++++++++++++++++++++ tests/quantization/test_cpu_offload.py | 7 --- vllm/config.py | 5 +- vllm/engine/arg_utils.py | 6 -- vllm/model_executor/models/utils.py | 21 ++++++- vllm/utils.py | 16 ++++++ vllm/v1/worker/gpu_model_runner.py | 4 ++ 12 files changed, 148 insertions(+), 25 deletions(-) create mode 100644 csrc/cuda_view.cu create mode 100644 tests/kernels/test_uva.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d15b77b..ab6185e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,6 +234,7 @@ set(VLLM_EXT_SRC "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" "csrc/layernorm_quant_kernels.cu" + "csrc/cuda_view.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" diff --git a/csrc/cuda_view.cu b/csrc/cuda_view.cu new file mode 100644 index 00000000..938bd4ab --- /dev/null +++ b/csrc/cuda_view.cu @@ -0,0 +1,39 @@ +#include +#include +#include + +// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned +// memory, and that UVA (Unified Virtual Addressing) is enabled. +torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) { + TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU"); + + // Get raw host pointer from CPU tensor + void* host_ptr = cpu_tensor.data_ptr(); + + // Get a device pointer corresponding to the pinned host memory + void* device_ptr = nullptr; + cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0); + TORCH_CHECK(err == cudaSuccess, + "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err)); + + // We'll use the same sizes, strides, and dtype as the CPU tensor. + // TODO: check if layout is respected. + auto sizes = cpu_tensor.sizes(); + auto strides = cpu_tensor.strides(); + auto options = cpu_tensor.options().device(torch::kCUDA); + + // from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter, + // const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the + // memory, so we don't free it here. + auto deleter = [](void*) { + // no-op, since the memory is owned by the original CPU tensor + }; + + torch::Tensor cuda_tensor = + torch::from_blob(device_ptr, sizes, strides, deleter, options); + + TORCH_CHECK(cuda_tensor.device().is_cuda(), + "Resulting tensor is not on CUDA device"); + + return cuda_tensor; +} diff --git a/csrc/ops.h b/csrc/ops.h index 1ea9f465..77d1ab76 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -119,6 +119,8 @@ void advance_step_flashinfer( torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); +torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); + #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 60ad6430..b0a23a36 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -31,6 +31,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); + ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor"); + ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU, + &get_cuda_view_from_cpu_tensor); + // Attention ops // Compute the attention between an input query and the cached // keys/values using PagedAttention. diff --git a/tests/basic_correctness/test_cpu_offload.py b/tests/basic_correctness/test_cpu_offload.py index 436e4363..be3ad123 100644 --- a/tests/basic_correctness/test_cpu_offload.py +++ b/tests/basic_correctness/test_cpu_offload.py @@ -1,15 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -import pytest - from ..utils import compare_two_settings -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - monkeypatch.setenv('VLLM_USE_V1', '0') - - def test_cpu_offload(): compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", [], ["--cpu-offload-gb", "1"]) diff --git a/tests/kernels/test_uva.py b/tests/kernels/test_uva.py new file mode 100644 index 00000000..f641ae7b --- /dev/null +++ b/tests/kernels/test_uva.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.") +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_cpu_write(device): + torch.set_default_device(device) + cpu_tensor = torch.zeros(10, + 10, + device="cpu", + pin_memory=True, + dtype=torch.int32) + cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor) + assert cuda_view.device.type == "cuda" + + assert cuda_view[0, 0] == 0 + assert cuda_view[2, 3] == 0 + assert cuda_view[4, 5] == 0 + + cpu_tensor[0, 0] = 1 + cpu_tensor[2, 3] = 2 + cpu_tensor[4, 5] = -1 + + cuda_view.mul_(2) + assert cuda_view[0, 0] == 2 + assert cuda_view[2, 3] == 4 + assert cuda_view[4, 5] == -2 + + +@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.") +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_gpu_write(device): + torch.set_default_device(device) + cpu_tensor = torch.zeros(10, + 10, + device="cpu", + pin_memory=True, + dtype=torch.int32) + cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor) + assert cuda_view.device.type == "cuda" + + assert cuda_view[0, 0] == 0 + assert cuda_view[2, 3] == 0 + assert cuda_view[4, 5] == 0 + + cuda_view[0, 0] = 1 + cuda_view[2, 3] = 2 + cuda_view[4, 5] = -1 + cuda_view.mul_(2) + + assert cpu_tensor[0, 0] == 2 + assert cpu_tensor[2, 3] == 4 + assert cpu_tensor[4, 5] == -2 \ No newline at end of file diff --git a/tests/quantization/test_cpu_offload.py b/tests/quantization/test_cpu_offload.py index a7d65185..a05eb494 100644 --- a/tests/quantization/test_cpu_offload.py +++ b/tests/quantization/test_cpu_offload.py @@ -10,13 +10,6 @@ from tests.quantization.utils import is_quant_method_supported from ..utils import compare_two_settings -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - # Fall back to V0 if cpu offloading is enabled. - # Fixture is required to that baseline uses V0. - monkeypatch.setenv('VLLM_USE_V1', '0') - - @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="fp8 is not supported on this GPU type.") def test_cpu_offload_fp8(): diff --git a/vllm/config.py b/vllm/config.py index 6a15109c..a02e4f71 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3562,9 +3562,10 @@ class VllmConfig: if self.cache_config is not None and \ self.cache_config.cpu_offload_gb > 0 and \ - self.compilation_config.level != CompilationLevel.NO_COMPILATION: + self.compilation_config.level != CompilationLevel.NO_COMPILATION \ + and not envs.VLLM_USE_V1: logger.warning( - "CPU offload is not supported with `torch.compile` yet." + "CPU offload is not supported with `torch.compile` in v0 yet." " Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ca511c74..1da021d7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1595,12 +1595,6 @@ class EngineArgs: recommend_to_remove=False) return False - # No CPU offloading yet. - if self.cpu_offload_gb != EngineArgs.cpu_offload_gb: - _raise_or_fallback(feature_name="--cpu-offload-gb", - recommend_to_remove=False) - return False - # Only Fp16 and Bf16 dtypes since we only support FA. V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16] if model_config.dtype not in V1_SUPPORTED_DTYPES: diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 1e3d78c7..d8c8b5b3 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -10,12 +10,14 @@ import torch.nn as nn from torch.func import functional_call from transformers import PretrainedConfig +import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.sequence import IntermediateTensors -from vllm.utils import is_pin_memory_available +from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available, + is_uva_available) logger = init_logger(__name__) @@ -505,6 +507,14 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: return module pin_memory = is_pin_memory_available() + uva_available = is_uva_available() + + if envs.VLLM_USE_V1: + assert uva_available, ("V1 CPU offloading requires" + " uva (pin memory) support") + uva_offloading = True + else: + uva_offloading = False # offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed @@ -523,11 +533,16 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: device='cpu', pin_memory=pin_memory) cpu_data.copy_(p.data) - p.data = cpu_data + if not uva_offloading: + p.data = cpu_data + else: + # keep the cpu data alive + p._vllm_offloaded_cpu_data = cpu_data + p.data = get_cuda_view_from_cpu_tensor(cpu_data) _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() offloaded_parameters = True - if offloaded_parameters: + if offloaded_parameters and not uva_offloading: original_forward = module.forward def forward(*args, **kwargs): diff --git a/vllm/utils.py b/vllm/utils.py index bf83b38a..f13f4d78 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -795,6 +795,14 @@ def is_pin_memory_available() -> bool: return current_platform.is_pin_memory_available() +@cache +def is_uva_available() -> bool: + """Check if Unified Virtual Addressing (UVA) is available.""" + # UVA requires pinned memory. + # TODO: Add more requirements for UVA if needed. + return is_pin_memory_available() + + class DeviceMemoryProfiler: def __init__(self, device: Optional[torch.types.Device] = None): @@ -1645,6 +1653,14 @@ def weak_ref_tensors( raise ValueError("Invalid type for tensors") +def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: + """ + Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA). + """ + assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" + return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) + + def is_in_doc_build() -> bool: try: from sphinx.ext.autodoc.mock import _MockModule diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e3df2a62..74f3124e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -69,6 +69,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + from vllm.model_executor.models.utils import set_cpu_offload_max_bytes + set_cpu_offload_max_bytes( + int(self.cache_config.cpu_offload_gb * 1024**3)) + model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config