diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 445d74d6..00fa86b4 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -140,6 +140,7 @@ steps: # install tensorizer for tensorize_vllm_model.py - pip install awscli tensorizer - python3 offline_inference.py + - python3 cpu_offload.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 llava_example.py diff --git a/examples/cpu_offload.py b/examples/cpu_offload.py new file mode 100644 index 00000000..b152e5bc --- /dev/null +++ b/examples/cpu_offload.py @@ -0,0 +1,22 @@ +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="meta-llama/Llama-2-13b-chat-hf", cpu_offload_gb=10) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/config.py b/vllm/config.py index c87974d0..41911837 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -433,6 +433,7 @@ class CacheConfig: num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, + cpu_offload_gb: float = 0, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -441,6 +442,7 @@ class CacheConfig: self.cache_dtype = cache_dtype self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching + self.cpu_offload_gb = cpu_offload_gb self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b972573c..28ae3448 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -45,6 +45,7 @@ class EngineArgs: disable_sliding_window: bool = False use_v2_block_manager: bool = False swap_space: int = 4 # GiB + cpu_offload_gb: int = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 @@ -303,6 +304,20 @@ class EngineArgs: type=int, default=EngineArgs.swap_space, help='CPU swap space size (GiB) per GPU.') + parser.add_argument( + '--cpu-offload-gb', + type=float, + default=0, + help='The space in GiB to offload to CPU, per GPU. ' + 'Default is 0, which means no offloading. Intuitively, ' + 'this argument can be seen as a virtual way to increase ' + 'the GPU memory size. For example, if you have one 24 GB ' + 'GPU and set this to 10, virtually you can think of it as ' + 'a 34 GB GPU. Then you can load a 13B model with BF16 weight,' + 'which requires at least 26GB GPU memory. Note that this ' + 'requires fast CPU-GPU interconnect, as part of the model is' + 'loaded from CPU memory to GPU memory on the fly in each ' + 'model forward pass.') parser.add_argument( '--gpu-memory-utilization', type=float, @@ -633,6 +648,11 @@ class EngineArgs: raise ValueError( "BitsAndBytes load format and QLoRA adapter only support " f"'bitsandbytes' quantization, but got {self.quantization}") + + assert self.cpu_offload_gb >= 0, ( + "CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + multimodal_config = MultiModalConfig() device_config = DeviceConfig(device=self.device) @@ -666,7 +686,9 @@ class EngineArgs: cache_dtype=self.kv_cache_dtype, num_gpu_blocks_override=self.num_gpu_blocks_override, sliding_window=model_config.get_sliding_window(), - enable_prefix_caching=self.enable_prefix_caching) + enable_prefix_caching=self.enable_prefix_caching, + cpu_offload_gb=self.cpu_offload_gb, + ) parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 57e81a63..cadaffa0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -69,6 +69,10 @@ class LLM: when their `best_of` sampling parameters are larger than 1. If all requests will have `best_of=1`, you can safely set this to 0. Otherwise, too small values may cause out-of-memory (OOM) errors. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading + the model weights. This virtually increases the GPU memory space + you can use to hold the model weights, at the cost of CPU-GPU data + transfer for every forward pass. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. @@ -114,6 +118,7 @@ class LLM: seed: int = 0, gpu_memory_utilization: float = 0.9, swap_space: int = 4, + cpu_offload_gb: float = 0, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: int = 8192, @@ -141,6 +146,7 @@ class LLM: seed=seed, gpu_memory_utilization=gpu_memory_utilization, swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index c135b203..b505d32d 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,8 +1,10 @@ from typing import Callable, Dict, List, Tuple import torch +from torch.func import functional_call from vllm.multimodal import BatchedTensors +from vllm.utils import is_pin_memory_available def merge_vision_embeddings(input_ids: torch.Tensor, @@ -52,6 +54,70 @@ class PPMissingLayer(torch.nn.Identity): super().__init__() +_CPU_OFFLOAD_BYTES = 0 +_CPU_OFFLOAD_MAX_BYTES = 0 + + +def set_cpu_offload_max_bytes(max_bytes: int) -> None: + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES + _CPU_OFFLOAD_BYTES = 0 + _CPU_OFFLOAD_MAX_BYTES = max_bytes + + +def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: + device = next(module.parameters()).device + + if device == torch.device("cpu"): + return module + + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: + return module + + pin_memory = is_pin_memory_available() + + # offload parameters to CPU + # use pin_memory if possible, which helps cudagraph capture speed + for p in module.parameters(): + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: + # we use per-parameter offloading + # one module might have some parameters offloaded and some not + break + + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty(size=p.data.size(), + dtype=p.data.dtype, + layout=p.data.layout, + device='cpu', + pin_memory=pin_memory) + cpu_data.copy_(p.data) + p.data = cpu_data + _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() + + state_dict: Dict[str, torch.Tensor] = module.state_dict() + + original_forward = module.forward + + def forward(*args, **kwargs): + module.forward = original_forward + device_state = { + # here we blindly call `to(device)` + # if the parameter is already on the device, it will be a no-op + k: v.to(device, non_blocking=True) + for k, v in state_dict.items() + } + output = functional_call(module, + device_state, + args=args, + kwargs=kwargs) + module.forward = forward + return output + + module.forward = forward + + return module + + def make_layers( num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module] ) -> Tuple[int, int, torch.nn.ModuleList]: @@ -64,9 +130,10 @@ def make_layers( get_pp_group().rank_in_group, get_pp_group().world_size) modules = torch.nn.ModuleList( - [PPMissingLayer() for _ in range(start_layer)] + - [layer_fn() for _ in range(start_layer, end_layer)] + - [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) + [PPMissingLayer() for _ in range(start_layer)] + [ + maybe_offload_to_cpu(layer_fn()) + for _ in range(start_layer, end_layer) + ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) return start_layer, end_layer, modules diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 75a2607d..d8104436 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -39,6 +39,7 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models.interfaces import (supports_lora, supports_vision) +from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, MultiModalInputs) from vllm.prompt_adapter.layers import PromptAdapterMapping @@ -544,6 +545,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): self.flashinfer_prefill_workspace_buffer = None self.flashinfer_prefill_wrapper = None + set_cpu_offload_max_bytes( + int(self.cache_config.cpu_offload_gb * 1024**3)) + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model(model_config=self.model_config,