[V1] Fully Transparent Implementation of CPU Offloading (#15354)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
e7ae3bf3d6
commit
555aa21905
@ -234,6 +234,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
"csrc/layernorm_quant_kernels.cu"
|
||||
"csrc/cuda_view.cu"
|
||||
"csrc/quantization/gptq/q_gemm.cu"
|
||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
|
39
csrc/cuda_view.cu
Normal file
39
csrc/cuda_view.cu
Normal file
@ -0,0 +1,39 @@
|
||||
#include <torch/all.h>
|
||||
#include <torch/cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
|
||||
// memory, and that UVA (Unified Virtual Addressing) is enabled.
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
|
||||
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
|
||||
|
||||
// Get raw host pointer from CPU tensor
|
||||
void* host_ptr = cpu_tensor.data_ptr();
|
||||
|
||||
// Get a device pointer corresponding to the pinned host memory
|
||||
void* device_ptr = nullptr;
|
||||
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||
TORCH_CHECK(err == cudaSuccess,
|
||||
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||
|
||||
// We'll use the same sizes, strides, and dtype as the CPU tensor.
|
||||
// TODO: check if layout is respected.
|
||||
auto sizes = cpu_tensor.sizes();
|
||||
auto strides = cpu_tensor.strides();
|
||||
auto options = cpu_tensor.options().device(torch::kCUDA);
|
||||
|
||||
// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
|
||||
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
|
||||
// memory, so we don't free it here.
|
||||
auto deleter = [](void*) {
|
||||
// no-op, since the memory is owned by the original CPU tensor
|
||||
};
|
||||
|
||||
torch::Tensor cuda_tensor =
|
||||
torch::from_blob(device_ptr, sizes, strides, deleter, options);
|
||||
|
||||
TORCH_CHECK(cuda_tensor.device().is_cuda(),
|
||||
"Resulting tensor is not on CUDA device");
|
||||
|
||||
return cuda_tensor;
|
||||
}
|
@ -119,6 +119,8 @@ void advance_step_flashinfer(
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
|
@ -31,6 +31,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||
|
||||
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
|
||||
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
|
||||
&get_cuda_view_from_cpu_tensor);
|
||||
|
||||
// Attention ops
|
||||
// Compute the attention between an input query and the cached
|
||||
// keys/values using PagedAttention.
|
||||
|
@ -1,15 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from ..utils import compare_two_settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
def test_cpu_offload():
|
||||
compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", [],
|
||||
["--cpu-offload-gb", "1"])
|
||||
|
61
tests/kernels/test_uva.py
Normal file
61
tests/kernels/test_uva.py
Normal file
@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_cpu_write(device):
|
||||
torch.set_default_device(device)
|
||||
cpu_tensor = torch.zeros(10,
|
||||
10,
|
||||
device="cpu",
|
||||
pin_memory=True,
|
||||
dtype=torch.int32)
|
||||
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
assert cuda_view.device.type == "cuda"
|
||||
|
||||
assert cuda_view[0, 0] == 0
|
||||
assert cuda_view[2, 3] == 0
|
||||
assert cuda_view[4, 5] == 0
|
||||
|
||||
cpu_tensor[0, 0] = 1
|
||||
cpu_tensor[2, 3] = 2
|
||||
cpu_tensor[4, 5] = -1
|
||||
|
||||
cuda_view.mul_(2)
|
||||
assert cuda_view[0, 0] == 2
|
||||
assert cuda_view[2, 3] == 4
|
||||
assert cuda_view[4, 5] == -2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_gpu_write(device):
|
||||
torch.set_default_device(device)
|
||||
cpu_tensor = torch.zeros(10,
|
||||
10,
|
||||
device="cpu",
|
||||
pin_memory=True,
|
||||
dtype=torch.int32)
|
||||
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
assert cuda_view.device.type == "cuda"
|
||||
|
||||
assert cuda_view[0, 0] == 0
|
||||
assert cuda_view[2, 3] == 0
|
||||
assert cuda_view[4, 5] == 0
|
||||
|
||||
cuda_view[0, 0] = 1
|
||||
cuda_view[2, 3] = 2
|
||||
cuda_view[4, 5] = -1
|
||||
cuda_view.mul_(2)
|
||||
|
||||
assert cpu_tensor[0, 0] == 2
|
||||
assert cpu_tensor[2, 3] == 4
|
||||
assert cpu_tensor[4, 5] == -2
|
@ -10,13 +10,6 @@ from tests.quantization.utils import is_quant_method_supported
|
||||
from ..utils import compare_two_settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
# Fall back to V0 if cpu offloading is enabled.
|
||||
# Fixture is required to that baseline uses V0.
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="fp8 is not supported on this GPU type.")
|
||||
def test_cpu_offload_fp8():
|
||||
|
@ -3562,9 +3562,10 @@ class VllmConfig:
|
||||
|
||||
if self.cache_config is not None and \
|
||||
self.cache_config.cpu_offload_gb > 0 and \
|
||||
self.compilation_config.level != CompilationLevel.NO_COMPILATION:
|
||||
self.compilation_config.level != CompilationLevel.NO_COMPILATION \
|
||||
and not envs.VLLM_USE_V1:
|
||||
logger.warning(
|
||||
"CPU offload is not supported with `torch.compile` yet."
|
||||
"CPU offload is not supported with `torch.compile` in v0 yet."
|
||||
" Disabling `torch.compile`.")
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
|
@ -1595,12 +1595,6 @@ class EngineArgs:
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No CPU offloading yet.
|
||||
if self.cpu_offload_gb != EngineArgs.cpu_offload_gb:
|
||||
_raise_or_fallback(feature_name="--cpu-offload-gb",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# Only Fp16 and Bf16 dtypes since we only support FA.
|
||||
V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
|
||||
if model_config.dtype not in V1_SUPPORTED_DTYPES:
|
||||
|
@ -10,12 +10,14 @@ import torch.nn as nn
|
||||
from torch.func import functional_call
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
|
||||
is_uva_available)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -505,6 +507,14 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
return module
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
uva_available = is_uva_available()
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
assert uva_available, ("V1 CPU offloading requires"
|
||||
" uva (pin memory) support")
|
||||
uva_offloading = True
|
||||
else:
|
||||
uva_offloading = False
|
||||
|
||||
# offload parameters to CPU
|
||||
# use pin_memory if possible, which helps cudagraph capture speed
|
||||
@ -523,11 +533,16 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
device='cpu',
|
||||
pin_memory=pin_memory)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
if not uva_offloading:
|
||||
p.data = cpu_data
|
||||
else:
|
||||
# keep the cpu data alive
|
||||
p._vllm_offloaded_cpu_data = cpu_data
|
||||
p.data = get_cuda_view_from_cpu_tensor(cpu_data)
|
||||
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
||||
offloaded_parameters = True
|
||||
|
||||
if offloaded_parameters:
|
||||
if offloaded_parameters and not uva_offloading:
|
||||
original_forward = module.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
|
@ -795,6 +795,14 @@ def is_pin_memory_available() -> bool:
|
||||
return current_platform.is_pin_memory_available()
|
||||
|
||||
|
||||
@cache
|
||||
def is_uva_available() -> bool:
|
||||
"""Check if Unified Virtual Addressing (UVA) is available."""
|
||||
# UVA requires pinned memory.
|
||||
# TODO: Add more requirements for UVA if needed.
|
||||
return is_pin_memory_available()
|
||||
|
||||
|
||||
class DeviceMemoryProfiler:
|
||||
|
||||
def __init__(self, device: Optional[torch.types.Device] = None):
|
||||
@ -1645,6 +1653,14 @@ def weak_ref_tensors(
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
|
||||
def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
|
||||
"""
|
||||
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
|
||||
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
|
||||
|
||||
def is_in_doc_build() -> bool:
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
|
@ -69,6 +69,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||
set_cpu_offload_max_bytes(
|
||||
int(self.cache_config.cpu_offload_gb * 1024**3))
|
||||
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
scheduler_config = self.scheduler_config
|
||||
|
Loading…
x
Reference in New Issue
Block a user