[Bugfix] Support cpu offloading with fp8 quantization (#6960)

This commit is contained in:
Michael Goin 2024-07-31 15:47:46 -04:00 committed by GitHub
parent bd70013407
commit 460c1884e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 114 additions and 31 deletions

View File

@ -1,4 +1,6 @@
from vllm.utils import is_hip import pytest
from tests.quantization.utils import is_quant_method_supported
from ..utils import compare_two_settings from ..utils import compare_two_settings
@ -6,8 +8,37 @@ from ..utils import compare_two_settings
def test_cpu_offload(): def test_cpu_offload():
compare_two_settings("meta-llama/Llama-2-7b-hf", [], compare_two_settings("meta-llama/Llama-2-7b-hf", [],
["--cpu-offload-gb", "4"]) ["--cpu-offload-gb", "4"])
if not is_hip():
# compressed-tensors quantization is currently not supported in ROCm.
compare_two_settings( @pytest.mark.skipif(not is_quant_method_supported("fp8"),
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [], reason="fp8 is not supported on this GPU type.")
["--cpu-offload-gb", "1"]) def test_cpu_offload_fp8():
# Test quantization of an unquantized checkpoint
compare_two_settings("meta-llama/Meta-Llama-3-8B-Instruct",
["--quantization", "fp8"],
["--quantization", "fp8", "--cpu-offload-gb", "2"])
# Test loading a quantized checkpoint
compare_two_settings("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", [],
["--cpu-offload-gb", "2"])
@pytest.mark.skipif(not is_quant_method_supported("awq"),
reason="awq is not supported on this GPU type.")
def test_cpu_offload_awq():
compare_two_settings("casperhansen/llama-3-8b-instruct-awq", [],
["--cpu-offload-gb", "2"])
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="gptq_marlin is not supported on this GPU type.")
def test_cpu_offload_compressed_tensors():
# Test wNa16
compare_two_settings("nm-testing/tinyllama-oneshot-w4a16-channel-v2", [],
["--cpu-offload-gb", "1"])
# Test w4a16_marlin24
compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t",
[], ["--cpu-offload-gb", "1"])
# Test w8a8
compare_two_settings(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", [],
["--cpu-offload-gb", "1"])

View File

@ -7,6 +7,7 @@ import json
import math import math
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple, Type from typing import Any, Dict, Generator, List, Optional, Tuple, Type
import huggingface_hub import huggingface_hub
@ -37,7 +38,49 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
supports_vision) supports_vision)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_tpu from vllm.utils import is_pin_memory_available, is_tpu
@contextmanager
def device_loading_context(module: torch.nn.Module,
target_device: torch.device):
if target_device.type == "cpu":
# If target is CPU, no need to move anything
yield module
return
original_device_states: Dict[str, torch.device] = {}
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
# Parameters already on target device are not touched
try:
yield module
finally:
# Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available()
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory)
cpu_data.copy_(p.data)
p.data = cpu_data
else:
p.data = p.data.to(original_device)
# New parameters or parameters already on target device are untouched
logger = init_logger(__name__) logger = init_logger(__name__)
@ -275,8 +318,9 @@ class DefaultModelLoader(BaseModelLoader):
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with target_device:
model = _initialize_model(model_config, self.load_config, model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config, lora_config, multimodal_config,
cache_config, scheduler_config) cache_config, scheduler_config)
@ -291,7 +335,13 @@ class DefaultModelLoader(BaseModelLoader):
for _, module in model.named_modules(): for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
if quant_method is not None: if quant_method is not None:
quant_method.process_weights_after_loading(module) # When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model.eval() return model.eval()

View File

@ -87,6 +87,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
# offload parameters to CPU # offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed # use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters = False
for p in module.parameters(): for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading # we use per-parameter offloading
@ -94,35 +95,36 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
break break
# `torch.empty_like` does not support `pin_memory` argument # `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty(size=p.data.size(), cpu_data = torch.empty_strided(size=p.data.size(),
dtype=p.data.dtype, stride=p.data.stride(),
layout=p.data.layout, dtype=p.data.dtype,
device='cpu', layout=p.data.layout,
pin_memory=pin_memory) device='cpu',
pin_memory=pin_memory)
cpu_data.copy_(p.data) cpu_data.copy_(p.data)
p.data = cpu_data p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True
state_dict: Dict[str, torch.Tensor] = module.state_dict() if offloaded_parameters:
original_forward = module.forward
original_forward = module.forward def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in module.state_dict().items()
}
output = functional_call(module,
device_state,
args=args,
kwargs=kwargs)
module.forward = forward
return output
def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in state_dict.items()
}
output = functional_call(module,
device_state,
args=args,
kwargs=kwargs)
module.forward = forward module.forward = forward
return output
module.forward = forward
return module return module