[core][model] yet another cpu offload implementation (#6496)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
youkaichao 2024-07-17 20:54:35 -07:00 committed by GitHub
parent 18fecc3559
commit 1c27d25fb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 128 additions and 4 deletions

View File

@ -140,6 +140,7 @@ steps:
# install tensorizer for tensorize_vllm_model.py
- pip install awscli tensorizer
- python3 offline_inference.py
- python3 cpu_offload.py
- python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py
- python3 llava_example.py

22
examples/cpu_offload.py Normal file
View File

@ -0,0 +1,22 @@
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf", cpu_offload_gb=10)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -433,6 +433,7 @@ class CacheConfig:
num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
cpu_offload_gb: float = 0,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
@ -441,6 +442,7 @@ class CacheConfig:
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()

View File

@ -45,6 +45,7 @@ class EngineArgs:
disable_sliding_window: bool = False
use_v2_block_manager: bool = False
swap_space: int = 4 # GiB
cpu_offload_gb: int = 0 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
@ -303,6 +304,20 @@ class EngineArgs:
type=int,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU.')
parser.add_argument(
'--cpu-offload-gb',
type=float,
default=0,
help='The space in GiB to offload to CPU, per GPU. '
'Default is 0, which means no offloading. Intuitively, '
'this argument can be seen as a virtual way to increase '
'the GPU memory size. For example, if you have one 24 GB '
'GPU and set this to 10, virtually you can think of it as '
'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
'which requires at least 26GB GPU memory. Note that this '
'requires fast CPU-GPU interconnect, as part of the model is'
'loaded from CPU memory to GPU memory on the fly in each '
'model forward pass.')
parser.add_argument(
'--gpu-memory-utilization',
type=float,
@ -633,6 +648,11 @@ class EngineArgs:
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
assert self.cpu_offload_gb >= 0, (
"CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")
multimodal_config = MultiModalConfig()
device_config = DeviceConfig(device=self.device)
@ -666,7 +686,9 @@ class EngineArgs:
cache_dtype=self.kv_cache_dtype,
num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching)
enable_prefix_caching=self.enable_prefix_caching,
cpu_offload_gb=self.cpu_offload_gb,
)
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,

View File

@ -69,6 +69,10 @@ class LLM:
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
the model weights. This virtually increases the GPU memory space
you can use to hold the model weights, at the cost of CPU-GPU data
transfer for every forward pass.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
@ -114,6 +118,7 @@ class LLM:
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
cpu_offload_gb: float = 0,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
@ -141,6 +146,7 @@ class LLM:
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,

View File

@ -1,8 +1,10 @@
from typing import Callable, Dict, List, Tuple
import torch
from torch.func import functional_call
from vllm.multimodal import BatchedTensors
from vllm.utils import is_pin_memory_available
def merge_vision_embeddings(input_ids: torch.Tensor,
@ -52,6 +54,70 @@ class PPMissingLayer(torch.nn.Identity):
super().__init__()
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = max_bytes
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device = next(module.parameters()).device
if device == torch.device("cpu"):
return module
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module
pin_memory = is_pin_memory_available()
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty(size=p.data.size(),
dtype=p.data.dtype,
layout=p.data.layout,
device='cpu',
pin_memory=pin_memory)
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
state_dict: Dict[str, torch.Tensor] = module.state_dict()
original_forward = module.forward
def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in state_dict.items()
}
output = functional_call(module,
device_state,
args=args,
kwargs=kwargs)
module.forward = forward
return output
module.forward = forward
return module
def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
) -> Tuple[int, int, torch.nn.ModuleList]:
@ -64,9 +130,10 @@ def make_layers(
get_pp_group().rank_in_group,
get_pp_group().world_size)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] +
[layer_fn() for _ in range(start_layer, end_layer)] +
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
[PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn())
for _ in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules

View File

@ -39,6 +39,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.prompt_adapter.layers import PromptAdapterMapping
@ -544,6 +545,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.flashinfer_prefill_workspace_buffer = None
self.flashinfer_prefill_wrapper = None
set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(model_config=self.model_config,