[core][model] yet another cpu offload implementation (#6496)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
18fecc3559
commit
1c27d25fb5
@ -140,6 +140,7 @@ steps:
|
|||||||
# install tensorizer for tensorize_vllm_model.py
|
# install tensorizer for tensorize_vllm_model.py
|
||||||
- pip install awscli tensorizer
|
- pip install awscli tensorizer
|
||||||
- python3 offline_inference.py
|
- python3 offline_inference.py
|
||||||
|
- python3 cpu_offload.py
|
||||||
- python3 offline_inference_with_prefix.py
|
- python3 offline_inference_with_prefix.py
|
||||||
- python3 llm_engine_example.py
|
- python3 llm_engine_example.py
|
||||||
- python3 llava_example.py
|
- python3 llava_example.py
|
||||||
|
22
examples/cpu_offload.py
Normal file
22
examples/cpu_offload.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf", cpu_offload_gb=10)
|
||||||
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
|
# that contain the prompt, generated text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
@ -433,6 +433,7 @@ class CacheConfig:
|
|||||||
num_gpu_blocks_override: Optional[int] = None,
|
num_gpu_blocks_override: Optional[int] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
enable_prefix_caching: bool = False,
|
enable_prefix_caching: bool = False,
|
||||||
|
cpu_offload_gb: float = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.gpu_memory_utilization = gpu_memory_utilization
|
self.gpu_memory_utilization = gpu_memory_utilization
|
||||||
@ -441,6 +442,7 @@ class CacheConfig:
|
|||||||
self.cache_dtype = cache_dtype
|
self.cache_dtype = cache_dtype
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.enable_prefix_caching = enable_prefix_caching
|
self.enable_prefix_caching = enable_prefix_caching
|
||||||
|
self.cpu_offload_gb = cpu_offload_gb
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
self._verify_cache_dtype()
|
self._verify_cache_dtype()
|
||||||
self._verify_prefix_caching()
|
self._verify_prefix_caching()
|
||||||
|
@ -45,6 +45,7 @@ class EngineArgs:
|
|||||||
disable_sliding_window: bool = False
|
disable_sliding_window: bool = False
|
||||||
use_v2_block_manager: bool = False
|
use_v2_block_manager: bool = False
|
||||||
swap_space: int = 4 # GiB
|
swap_space: int = 4 # GiB
|
||||||
|
cpu_offload_gb: int = 0 # GiB
|
||||||
gpu_memory_utilization: float = 0.90
|
gpu_memory_utilization: float = 0.90
|
||||||
max_num_batched_tokens: Optional[int] = None
|
max_num_batched_tokens: Optional[int] = None
|
||||||
max_num_seqs: int = 256
|
max_num_seqs: int = 256
|
||||||
@ -303,6 +304,20 @@ class EngineArgs:
|
|||||||
type=int,
|
type=int,
|
||||||
default=EngineArgs.swap_space,
|
default=EngineArgs.swap_space,
|
||||||
help='CPU swap space size (GiB) per GPU.')
|
help='CPU swap space size (GiB) per GPU.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--cpu-offload-gb',
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help='The space in GiB to offload to CPU, per GPU. '
|
||||||
|
'Default is 0, which means no offloading. Intuitively, '
|
||||||
|
'this argument can be seen as a virtual way to increase '
|
||||||
|
'the GPU memory size. For example, if you have one 24 GB '
|
||||||
|
'GPU and set this to 10, virtually you can think of it as '
|
||||||
|
'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
|
||||||
|
'which requires at least 26GB GPU memory. Note that this '
|
||||||
|
'requires fast CPU-GPU interconnect, as part of the model is'
|
||||||
|
'loaded from CPU memory to GPU memory on the fly in each '
|
||||||
|
'model forward pass.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--gpu-memory-utilization',
|
'--gpu-memory-utilization',
|
||||||
type=float,
|
type=float,
|
||||||
@ -633,6 +648,11 @@ class EngineArgs:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"BitsAndBytes load format and QLoRA adapter only support "
|
"BitsAndBytes load format and QLoRA adapter only support "
|
||||||
f"'bitsandbytes' quantization, but got {self.quantization}")
|
f"'bitsandbytes' quantization, but got {self.quantization}")
|
||||||
|
|
||||||
|
assert self.cpu_offload_gb >= 0, (
|
||||||
|
"CPU offload space must be non-negative"
|
||||||
|
f", but got {self.cpu_offload_gb}")
|
||||||
|
|
||||||
multimodal_config = MultiModalConfig()
|
multimodal_config = MultiModalConfig()
|
||||||
|
|
||||||
device_config = DeviceConfig(device=self.device)
|
device_config = DeviceConfig(device=self.device)
|
||||||
@ -666,7 +686,9 @@ class EngineArgs:
|
|||||||
cache_dtype=self.kv_cache_dtype,
|
cache_dtype=self.kv_cache_dtype,
|
||||||
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
||||||
sliding_window=model_config.get_sliding_window(),
|
sliding_window=model_config.get_sliding_window(),
|
||||||
enable_prefix_caching=self.enable_prefix_caching)
|
enable_prefix_caching=self.enable_prefix_caching,
|
||||||
|
cpu_offload_gb=self.cpu_offload_gb,
|
||||||
|
)
|
||||||
parallel_config = ParallelConfig(
|
parallel_config = ParallelConfig(
|
||||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||||
tensor_parallel_size=self.tensor_parallel_size,
|
tensor_parallel_size=self.tensor_parallel_size,
|
||||||
|
@ -69,6 +69,10 @@ class LLM:
|
|||||||
when their `best_of` sampling parameters are larger than 1. If all
|
when their `best_of` sampling parameters are larger than 1. If all
|
||||||
requests will have `best_of=1`, you can safely set this to 0.
|
requests will have `best_of=1`, you can safely set this to 0.
|
||||||
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||||
|
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
|
||||||
|
the model weights. This virtually increases the GPU memory space
|
||||||
|
you can use to hold the model weights, at the cost of CPU-GPU data
|
||||||
|
transfer for every forward pass.
|
||||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||||
disable CUDA graph and always execute the model in eager mode.
|
disable CUDA graph and always execute the model in eager mode.
|
||||||
If False, we will use CUDA graph and eager execution in hybrid.
|
If False, we will use CUDA graph and eager execution in hybrid.
|
||||||
@ -114,6 +118,7 @@ class LLM:
|
|||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
gpu_memory_utilization: float = 0.9,
|
gpu_memory_utilization: float = 0.9,
|
||||||
swap_space: int = 4,
|
swap_space: int = 4,
|
||||||
|
cpu_offload_gb: float = 0,
|
||||||
enforce_eager: bool = False,
|
enforce_eager: bool = False,
|
||||||
max_context_len_to_capture: Optional[int] = None,
|
max_context_len_to_capture: Optional[int] = None,
|
||||||
max_seq_len_to_capture: int = 8192,
|
max_seq_len_to_capture: int = 8192,
|
||||||
@ -141,6 +146,7 @@ class LLM:
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
swap_space=swap_space,
|
swap_space=swap_space,
|
||||||
|
cpu_offload_gb=cpu_offload_gb,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
max_context_len_to_capture=max_context_len_to_capture,
|
max_context_len_to_capture=max_context_len_to_capture,
|
||||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
from typing import Callable, Dict, List, Tuple
|
from typing import Callable, Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.func import functional_call
|
||||||
|
|
||||||
from vllm.multimodal import BatchedTensors
|
from vllm.multimodal import BatchedTensors
|
||||||
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||||
@ -52,6 +54,70 @@ class PPMissingLayer(torch.nn.Identity):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
_CPU_OFFLOAD_BYTES = 0
|
||||||
|
_CPU_OFFLOAD_MAX_BYTES = 0
|
||||||
|
|
||||||
|
|
||||||
|
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
|
||||||
|
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
||||||
|
_CPU_OFFLOAD_BYTES = 0
|
||||||
|
_CPU_OFFLOAD_MAX_BYTES = max_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||||
|
device = next(module.parameters()).device
|
||||||
|
|
||||||
|
if device == torch.device("cpu"):
|
||||||
|
return module
|
||||||
|
|
||||||
|
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
||||||
|
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||||
|
return module
|
||||||
|
|
||||||
|
pin_memory = is_pin_memory_available()
|
||||||
|
|
||||||
|
# offload parameters to CPU
|
||||||
|
# use pin_memory if possible, which helps cudagraph capture speed
|
||||||
|
for p in module.parameters():
|
||||||
|
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||||
|
# we use per-parameter offloading
|
||||||
|
# one module might have some parameters offloaded and some not
|
||||||
|
break
|
||||||
|
|
||||||
|
# `torch.empty_like` does not support `pin_memory` argument
|
||||||
|
cpu_data = torch.empty(size=p.data.size(),
|
||||||
|
dtype=p.data.dtype,
|
||||||
|
layout=p.data.layout,
|
||||||
|
device='cpu',
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
cpu_data.copy_(p.data)
|
||||||
|
p.data = cpu_data
|
||||||
|
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
||||||
|
|
||||||
|
state_dict: Dict[str, torch.Tensor] = module.state_dict()
|
||||||
|
|
||||||
|
original_forward = module.forward
|
||||||
|
|
||||||
|
def forward(*args, **kwargs):
|
||||||
|
module.forward = original_forward
|
||||||
|
device_state = {
|
||||||
|
# here we blindly call `to(device)`
|
||||||
|
# if the parameter is already on the device, it will be a no-op
|
||||||
|
k: v.to(device, non_blocking=True)
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
}
|
||||||
|
output = functional_call(module,
|
||||||
|
device_state,
|
||||||
|
args=args,
|
||||||
|
kwargs=kwargs)
|
||||||
|
module.forward = forward
|
||||||
|
return output
|
||||||
|
|
||||||
|
module.forward = forward
|
||||||
|
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
def make_layers(
|
def make_layers(
|
||||||
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
|
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
|
||||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||||
@ -64,9 +130,10 @@ def make_layers(
|
|||||||
get_pp_group().rank_in_group,
|
get_pp_group().rank_in_group,
|
||||||
get_pp_group().world_size)
|
get_pp_group().world_size)
|
||||||
modules = torch.nn.ModuleList(
|
modules = torch.nn.ModuleList(
|
||||||
[PPMissingLayer() for _ in range(start_layer)] +
|
[PPMissingLayer() for _ in range(start_layer)] + [
|
||||||
[layer_fn() for _ in range(start_layer, end_layer)] +
|
maybe_offload_to_cpu(layer_fn())
|
||||||
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
for _ in range(start_layer, end_layer)
|
||||||
|
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
||||||
return start_layer, end_layer, modules
|
return start_layer, end_layer, modules
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,6 +39,7 @@ from vllm.model_executor.model_loader import get_model
|
|||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
from vllm.model_executor.models.interfaces import (supports_lora,
|
from vllm.model_executor.models.interfaces import (supports_lora,
|
||||||
supports_vision)
|
supports_vision)
|
||||||
|
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
||||||
MultiModalInputs)
|
MultiModalInputs)
|
||||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||||
@ -544,6 +545,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
self.flashinfer_prefill_workspace_buffer = None
|
self.flashinfer_prefill_workspace_buffer = None
|
||||||
self.flashinfer_prefill_wrapper = None
|
self.flashinfer_prefill_wrapper = None
|
||||||
|
|
||||||
|
set_cpu_offload_max_bytes(
|
||||||
|
int(self.cache_config.cpu_offload_gb * 1024**3))
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
with CudaMemoryProfiler() as m:
|
with CudaMemoryProfiler() as m:
|
||||||
self.model = get_model(model_config=self.model_config,
|
self.model = get_model(model_config=self.model_config,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user