[core][model] yet another cpu offload implementation (#6496)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
18fecc3559
commit
1c27d25fb5
@ -140,6 +140,7 @@ steps:
|
||||
# install tensorizer for tensorize_vllm_model.py
|
||||
- pip install awscli tensorizer
|
||||
- python3 offline_inference.py
|
||||
- python3 cpu_offload.py
|
||||
- python3 offline_inference_with_prefix.py
|
||||
- python3 llm_engine_example.py
|
||||
- python3 llava_example.py
|
||||
|
22
examples/cpu_offload.py
Normal file
22
examples/cpu_offload.py
Normal file
@ -0,0 +1,22 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf", cpu_offload_gb=10)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
@ -433,6 +433,7 @@ class CacheConfig:
|
||||
num_gpu_blocks_override: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_prefix_caching: bool = False,
|
||||
cpu_offload_gb: float = 0,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
@ -441,6 +442,7 @@ class CacheConfig:
|
||||
self.cache_dtype = cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.cpu_offload_gb = cpu_offload_gb
|
||||
self._verify_args()
|
||||
self._verify_cache_dtype()
|
||||
self._verify_prefix_caching()
|
||||
|
@ -45,6 +45,7 @@ class EngineArgs:
|
||||
disable_sliding_window: bool = False
|
||||
use_v2_block_manager: bool = False
|
||||
swap_space: int = 4 # GiB
|
||||
cpu_offload_gb: int = 0 # GiB
|
||||
gpu_memory_utilization: float = 0.90
|
||||
max_num_batched_tokens: Optional[int] = None
|
||||
max_num_seqs: int = 256
|
||||
@ -303,6 +304,20 @@ class EngineArgs:
|
||||
type=int,
|
||||
default=EngineArgs.swap_space,
|
||||
help='CPU swap space size (GiB) per GPU.')
|
||||
parser.add_argument(
|
||||
'--cpu-offload-gb',
|
||||
type=float,
|
||||
default=0,
|
||||
help='The space in GiB to offload to CPU, per GPU. '
|
||||
'Default is 0, which means no offloading. Intuitively, '
|
||||
'this argument can be seen as a virtual way to increase '
|
||||
'the GPU memory size. For example, if you have one 24 GB '
|
||||
'GPU and set this to 10, virtually you can think of it as '
|
||||
'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
|
||||
'which requires at least 26GB GPU memory. Note that this '
|
||||
'requires fast CPU-GPU interconnect, as part of the model is'
|
||||
'loaded from CPU memory to GPU memory on the fly in each '
|
||||
'model forward pass.')
|
||||
parser.add_argument(
|
||||
'--gpu-memory-utilization',
|
||||
type=float,
|
||||
@ -633,6 +648,11 @@ class EngineArgs:
|
||||
raise ValueError(
|
||||
"BitsAndBytes load format and QLoRA adapter only support "
|
||||
f"'bitsandbytes' quantization, but got {self.quantization}")
|
||||
|
||||
assert self.cpu_offload_gb >= 0, (
|
||||
"CPU offload space must be non-negative"
|
||||
f", but got {self.cpu_offload_gb}")
|
||||
|
||||
multimodal_config = MultiModalConfig()
|
||||
|
||||
device_config = DeviceConfig(device=self.device)
|
||||
@ -666,7 +686,9 @@ class EngineArgs:
|
||||
cache_dtype=self.kv_cache_dtype,
|
||||
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
||||
sliding_window=model_config.get_sliding_window(),
|
||||
enable_prefix_caching=self.enable_prefix_caching)
|
||||
enable_prefix_caching=self.enable_prefix_caching,
|
||||
cpu_offload_gb=self.cpu_offload_gb,
|
||||
)
|
||||
parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
|
@ -69,6 +69,10 @@ class LLM:
|
||||
when their `best_of` sampling parameters are larger than 1. If all
|
||||
requests will have `best_of=1`, you can safely set this to 0.
|
||||
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
|
||||
the model weights. This virtually increases the GPU memory space
|
||||
you can use to hold the model weights, at the cost of CPU-GPU data
|
||||
transfer for every forward pass.
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
@ -114,6 +118,7 @@ class LLM:
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
@ -141,6 +146,7 @@ class LLM:
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
cpu_offload_gb=cpu_offload_gb,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
|
@ -1,8 +1,10 @@
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.func import functional_call
|
||||
|
||||
from vllm.multimodal import BatchedTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
@ -52,6 +54,70 @@ class PPMissingLayer(torch.nn.Identity):
|
||||
super().__init__()
|
||||
|
||||
|
||||
_CPU_OFFLOAD_BYTES = 0
|
||||
_CPU_OFFLOAD_MAX_BYTES = 0
|
||||
|
||||
|
||||
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
|
||||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
||||
_CPU_OFFLOAD_BYTES = 0
|
||||
_CPU_OFFLOAD_MAX_BYTES = max_bytes
|
||||
|
||||
|
||||
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
device = next(module.parameters()).device
|
||||
|
||||
if device == torch.device("cpu"):
|
||||
return module
|
||||
|
||||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||
return module
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
|
||||
# offload parameters to CPU
|
||||
# use pin_memory if possible, which helps cudagraph capture speed
|
||||
for p in module.parameters():
|
||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||
# we use per-parameter offloading
|
||||
# one module might have some parameters offloaded and some not
|
||||
break
|
||||
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty(size=p.data.size(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device='cpu',
|
||||
pin_memory=pin_memory)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
||||
|
||||
state_dict: Dict[str, torch.Tensor] = module.state_dict()
|
||||
|
||||
original_forward = module.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
module.forward = original_forward
|
||||
device_state = {
|
||||
# here we blindly call `to(device)`
|
||||
# if the parameter is already on the device, it will be a no-op
|
||||
k: v.to(device, non_blocking=True)
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
output = functional_call(module,
|
||||
device_state,
|
||||
args=args,
|
||||
kwargs=kwargs)
|
||||
module.forward = forward
|
||||
return output
|
||||
|
||||
module.forward = forward
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def make_layers(
|
||||
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
|
||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||
@ -64,9 +130,10 @@ def make_layers(
|
||||
get_pp_group().rank_in_group,
|
||||
get_pp_group().world_size)
|
||||
modules = torch.nn.ModuleList(
|
||||
[PPMissingLayer() for _ in range(start_layer)] +
|
||||
[layer_fn() for _ in range(start_layer, end_layer)] +
|
||||
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
||||
[PPMissingLayer() for _ in range(start_layer)] + [
|
||||
maybe_offload_to_cpu(layer_fn())
|
||||
for _ in range(start_layer, end_layer)
|
||||
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
||||
return start_layer, end_layer, modules
|
||||
|
||||
|
||||
|
@ -39,6 +39,7 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models.interfaces import (supports_lora,
|
||||
supports_vision)
|
||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
||||
MultiModalInputs)
|
||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||
@ -544,6 +545,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.flashinfer_prefill_workspace_buffer = None
|
||||
self.flashinfer_prefill_wrapper = None
|
||||
|
||||
set_cpu_offload_max_bytes(
|
||||
int(self.cache_config.cpu_offload_gb * 1024**3))
|
||||
|
||||
def load_model(self) -> None:
|
||||
with CudaMemoryProfiler() as m:
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
|
Loading…
x
Reference in New Issue
Block a user