[Bugfix] Support cpu offloading with fp8 quantization (#6960)
This commit is contained in:
parent
bd70013407
commit
460c1884e3
@ -1,4 +1,6 @@
|
||||
from vllm.utils import is_hip
|
||||
import pytest
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
from ..utils import compare_two_settings
|
||||
|
||||
@ -6,8 +8,37 @@ from ..utils import compare_two_settings
|
||||
def test_cpu_offload():
|
||||
compare_two_settings("meta-llama/Llama-2-7b-hf", [],
|
||||
["--cpu-offload-gb", "4"])
|
||||
if not is_hip():
|
||||
# compressed-tensors quantization is currently not supported in ROCm.
|
||||
compare_two_settings(
|
||||
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [],
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="fp8 is not supported on this GPU type.")
|
||||
def test_cpu_offload_fp8():
|
||||
# Test quantization of an unquantized checkpoint
|
||||
compare_two_settings("meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
["--quantization", "fp8"],
|
||||
["--quantization", "fp8", "--cpu-offload-gb", "2"])
|
||||
# Test loading a quantized checkpoint
|
||||
compare_two_settings("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", [],
|
||||
["--cpu-offload-gb", "2"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("awq"),
|
||||
reason="awq is not supported on this GPU type.")
|
||||
def test_cpu_offload_awq():
|
||||
compare_two_settings("casperhansen/llama-3-8b-instruct-awq", [],
|
||||
["--cpu-offload-gb", "2"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="gptq_marlin is not supported on this GPU type.")
|
||||
def test_cpu_offload_compressed_tensors():
|
||||
# Test wNa16
|
||||
compare_two_settings("nm-testing/tinyllama-oneshot-w4a16-channel-v2", [],
|
||||
["--cpu-offload-gb", "1"])
|
||||
# Test w4a16_marlin24
|
||||
compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t",
|
||||
[], ["--cpu-offload-gb", "1"])
|
||||
# Test w8a8
|
||||
compare_two_settings(
|
||||
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", [],
|
||||
["--cpu-offload-gb", "1"])
|
||||
|
@ -7,6 +7,7 @@ import json
|
||||
import math
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
|
||||
|
||||
import huggingface_hub
|
||||
@ -37,7 +38,49 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
|
||||
supports_vision)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_tpu
|
||||
from vllm.utils import is_pin_memory_available, is_tpu
|
||||
|
||||
|
||||
@contextmanager
|
||||
def device_loading_context(module: torch.nn.Module,
|
||||
target_device: torch.device):
|
||||
if target_device.type == "cpu":
|
||||
# If target is CPU, no need to move anything
|
||||
yield module
|
||||
return
|
||||
|
||||
original_device_states: Dict[str, torch.device] = {}
|
||||
|
||||
# Store original device states and move parameters to GPU if they're on CPU
|
||||
for name, p in module.named_parameters():
|
||||
if p.device.type == "cpu":
|
||||
original_device_states[name] = p.device
|
||||
p.data = p.data.to(target_device)
|
||||
# Parameters already on target device are not touched
|
||||
|
||||
try:
|
||||
yield module
|
||||
|
||||
finally:
|
||||
# Restore parameters to their original devices, ignoring new parameters
|
||||
pin_memory = is_pin_memory_available()
|
||||
for name, p in module.named_parameters():
|
||||
if name in original_device_states:
|
||||
original_device: torch.device = original_device_states[name]
|
||||
if original_device.type == "cpu":
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty_strided(size=p.data.size(),
|
||||
stride=p.data.stride(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
else:
|
||||
p.data = p.data.to(original_device)
|
||||
# New parameters or parameters already on target device are untouched
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -275,8 +318,9 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
with target_device:
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, multimodal_config,
|
||||
cache_config, scheduler_config)
|
||||
@ -291,6 +335,12 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
return model.eval()
|
||||
|
||||
|
@ -87,6 +87,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
|
||||
# offload parameters to CPU
|
||||
# use pin_memory if possible, which helps cudagraph capture speed
|
||||
offloaded_parameters = False
|
||||
for p in module.parameters():
|
||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||
# we use per-parameter offloading
|
||||
@ -94,7 +95,8 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
break
|
||||
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty(size=p.data.size(),
|
||||
cpu_data = torch.empty_strided(size=p.data.size(),
|
||||
stride=p.data.stride(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device='cpu',
|
||||
@ -102,9 +104,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
||||
offloaded_parameters = True
|
||||
|
||||
state_dict: Dict[str, torch.Tensor] = module.state_dict()
|
||||
|
||||
if offloaded_parameters:
|
||||
original_forward = module.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
@ -113,7 +115,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
# here we blindly call `to(device)`
|
||||
# if the parameter is already on the device, it will be a no-op
|
||||
k: v.to(device, non_blocking=True)
|
||||
for k, v in state_dict.items()
|
||||
for k, v in module.state_dict().items()
|
||||
}
|
||||
output = functional_call(module,
|
||||
device_state,
|
||||
|
Loading…
x
Reference in New Issue
Block a user