[Bugfix] Support cpu offloading with fp8 quantization (#6960)

This commit is contained in:
Michael Goin 2024-07-31 15:47:46 -04:00 committed by GitHub
parent bd70013407
commit 460c1884e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 114 additions and 31 deletions

View File

@ -1,4 +1,6 @@
from vllm.utils import is_hip
import pytest
from tests.quantization.utils import is_quant_method_supported
from ..utils import compare_two_settings
@ -6,8 +8,37 @@ from ..utils import compare_two_settings
def test_cpu_offload():
compare_two_settings("meta-llama/Llama-2-7b-hf", [],
["--cpu-offload-gb", "4"])
if not is_hip():
# compressed-tensors quantization is currently not supported in ROCm.
compare_two_settings(
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [],
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="fp8 is not supported on this GPU type.")
def test_cpu_offload_fp8():
# Test quantization of an unquantized checkpoint
compare_two_settings("meta-llama/Meta-Llama-3-8B-Instruct",
["--quantization", "fp8"],
["--quantization", "fp8", "--cpu-offload-gb", "2"])
# Test loading a quantized checkpoint
compare_two_settings("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", [],
["--cpu-offload-gb", "2"])
@pytest.mark.skipif(not is_quant_method_supported("awq"),
reason="awq is not supported on this GPU type.")
def test_cpu_offload_awq():
compare_two_settings("casperhansen/llama-3-8b-instruct-awq", [],
["--cpu-offload-gb", "2"])
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="gptq_marlin is not supported on this GPU type.")
def test_cpu_offload_compressed_tensors():
# Test wNa16
compare_two_settings("nm-testing/tinyllama-oneshot-w4a16-channel-v2", [],
["--cpu-offload-gb", "1"])
# Test w4a16_marlin24
compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t",
[], ["--cpu-offload-gb", "1"])
# Test w8a8
compare_two_settings(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", [],
["--cpu-offload-gb", "1"])

View File

@ -7,6 +7,7 @@ import json
import math
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
import huggingface_hub
@ -37,7 +38,49 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_tpu
from vllm.utils import is_pin_memory_available, is_tpu
@contextmanager
def device_loading_context(module: torch.nn.Module,
target_device: torch.device):
if target_device.type == "cpu":
# If target is CPU, no need to move anything
yield module
return
original_device_states: Dict[str, torch.device] = {}
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
# Parameters already on target device are not touched
try:
yield module
finally:
# Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available()
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory)
cpu_data.copy_(p.data)
p.data = cpu_data
else:
p.data = p.data.to(original_device)
# New parameters or parameters already on target device are untouched
logger = init_logger(__name__)
@ -275,8 +318,9 @@ class DefaultModelLoader(BaseModelLoader):
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config, scheduler_config)
@ -291,6 +335,12 @@ class DefaultModelLoader(BaseModelLoader):
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model.eval()

View File

@ -87,6 +87,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters = False
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
@ -94,7 +95,8 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
break
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty(size=p.data.size(),
cpu_data = torch.empty_strided(size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device='cpu',
@ -102,9 +104,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True
state_dict: Dict[str, torch.Tensor] = module.state_dict()
if offloaded_parameters:
original_forward = module.forward
def forward(*args, **kwargs):
@ -113,7 +115,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in state_dict.items()
for k, v in module.state_dict().items()
}
output = functional_call(module,
device_state,