2024-04-19 21:28:57 -07:00
|
|
|
"""Tests whether FP8 computation is enabled correctly.
|
|
|
|
|
|
|
|
Run `pytest tests/quantization/test_fp8.py --forked`.
|
|
|
|
"""
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
2024-06-13 11:18:08 -04:00
|
|
|
from tests.quantization.utils import is_quant_method_supported
|
2024-07-03 13:38:00 -04:00
|
|
|
from vllm import _custom_ops as ops
|
2024-07-16 18:31:32 -04:00
|
|
|
from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
|
|
|
|
Fp8LinearMethod)
|
2024-08-07 14:23:12 -04:00
|
|
|
from vllm.platforms import current_platform
|
2024-04-19 21:28:57 -07:00
|
|
|
|
2024-06-30 19:06:27 -04:00
|
|
|
MODELS = [
|
2024-07-16 18:31:32 -04:00
|
|
|
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
|
2024-06-30 19:06:27 -04:00
|
|
|
"nm-testing/Phi-3-mini-128k-instruct-FP8",
|
2024-07-23 18:45:12 -04:00
|
|
|
"nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV",
|
2024-06-30 19:06:27 -04:00
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
|
|
|
reason="FP8 is not supported on this GPU type.")
|
2024-07-16 18:31:32 -04:00
|
|
|
@pytest.mark.parametrize("model_id", MODELS)
|
2024-08-07 14:23:12 -04:00
|
|
|
@pytest.mark.parametrize("force_marlin", [False, True])
|
|
|
|
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
|
|
|
|
monkeypatch) -> None:
|
|
|
|
if force_marlin:
|
|
|
|
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
|
|
|
|
|
2024-07-16 18:31:32 -04:00
|
|
|
with vllm_runner(model_id) as llm:
|
|
|
|
# note: this does not test accuracy, just that we can run through
|
|
|
|
# see lm-eval tests for accuracy
|
|
|
|
outputs = llm.generate_greedy(prompts=["Hello my name is"],
|
|
|
|
max_tokens=10)
|
|
|
|
print(outputs[0][1])
|
|
|
|
|
|
|
|
|
|
|
|
KV_CACHE_MODELS = [
|
|
|
|
# Deprecated AutoFP8 format using .kv_scale
|
|
|
|
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
|
|
|
|
# AutoFP8 format using separate .k_scale and .v_scale
|
|
|
|
"nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
|
|
|
reason="FP8 is not supported on this GPU type.")
|
|
|
|
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
|
|
|
|
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
|
|
|
|
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
|
|
|
|
|
|
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
|
|
|
attn = model.model.layers[0].self_attn.attn
|
|
|
|
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
|
|
|
# NOTE: it is valid for scales to be 1.0 (default value), but we know
|
|
|
|
# these checkpoints have scales < 1.0
|
|
|
|
assert 0.0 < attn._k_scale < 1.0
|
|
|
|
assert 0.0 < attn._v_scale < 1.0
|
|
|
|
|
2024-06-30 19:06:27 -04:00
|
|
|
# note: this does not test accuracy, just that we can run through
|
|
|
|
# see lm-eval tests for accuracy
|
|
|
|
outputs = llm.generate_greedy(prompts=["Hello my name is"],
|
|
|
|
max_tokens=10)
|
|
|
|
print(outputs[0][1])
|
|
|
|
|
2024-04-19 21:28:57 -07:00
|
|
|
|
2024-06-13 11:18:08 -04:00
|
|
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
|
|
|
reason="FP8 is not supported on this GPU type.")
|
2024-07-25 12:46:15 -04:00
|
|
|
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
2024-08-07 14:23:12 -04:00
|
|
|
@pytest.mark.parametrize("force_marlin", [False, True])
|
|
|
|
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
|
|
|
monkeypatch) -> None:
|
|
|
|
if force_marlin:
|
|
|
|
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
|
|
|
|
|
2024-07-25 12:46:15 -04:00
|
|
|
with vllm_runner("facebook/opt-125m",
|
|
|
|
quantization="fp8",
|
|
|
|
kv_cache_dtype=kv_cache_dtype) as llm:
|
2024-04-19 21:28:57 -07:00
|
|
|
|
2024-06-08 01:59:20 -07:00
|
|
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
|
|
|
fc1 = model.model.decoder.layers[0].fc1
|
|
|
|
assert isinstance(fc1.quant_method, Fp8LinearMethod)
|
2024-07-25 12:46:15 -04:00
|
|
|
if kv_cache_dtype == "fp8":
|
|
|
|
attn = model.model.decoder.layers[0].self_attn.attn
|
|
|
|
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
|
|
|
assert attn._k_scale == 1.0
|
|
|
|
assert attn._v_scale == 1.0
|
2024-07-03 13:38:00 -04:00
|
|
|
|
2024-08-07 14:23:12 -04:00
|
|
|
capability = current_platform.get_device_capability()
|
2024-07-03 13:38:00 -04:00
|
|
|
capability = capability[0] * 10 + capability[1]
|
2024-08-07 14:23:12 -04:00
|
|
|
if capability >= 89 and not force_marlin:
|
2024-07-03 13:38:00 -04:00
|
|
|
# For GPUs with hardware support, we keep weights in fp8
|
|
|
|
assert fc1.weight.dtype == torch.float8_e4m3fn
|
|
|
|
else:
|
|
|
|
# For GPUs without hardware support, we pack the fp8 weights
|
|
|
|
# for weight-only quantization using Marlin kernels
|
|
|
|
assert fc1.weight.dtype == torch.int32
|
2024-06-12 14:07:26 -07:00
|
|
|
|
|
|
|
|
2024-06-13 11:18:08 -04:00
|
|
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
|
|
|
reason="FP8 is not supported on this GPU type.")
|
2024-06-12 14:07:26 -07:00
|
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
|
|
def test_scaled_fp8_quant(dtype) -> None:
|
|
|
|
|
|
|
|
def quantize_ref(tensor, inv_scale):
|
|
|
|
# The reference implementation that fully aligns to
|
|
|
|
# the kernel being tested.
|
|
|
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
|
|
scale = inv_scale.reciprocal()
|
|
|
|
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min,
|
|
|
|
max=finfo.max)
|
|
|
|
qweight = qweight.to(torch.float8_e4m3fn)
|
|
|
|
return qweight
|
|
|
|
|
|
|
|
def per_tensor_dequantize(tensor, inv_scale, dtype):
|
|
|
|
fake_qweight = tensor.to(dtype)
|
|
|
|
dq_weight = fake_qweight * inv_scale
|
|
|
|
return dq_weight
|
|
|
|
|
|
|
|
# Note that we use a shape % 4 != 0 to cover edge cases,
|
|
|
|
# because scaled_fp8_quant is vectorized by 4.
|
|
|
|
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
|
|
|
|
|
|
|
|
# Dynamic quantization
|
2024-07-03 13:38:00 -04:00
|
|
|
ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
|
2024-06-12 14:07:26 -07:00
|
|
|
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
|
|
|
|
|
|
|
|
# Reference dynamic quantizaton
|
|
|
|
y = quantize_ref(x, inv_scale)
|
2024-08-15 21:24:04 -07:00
|
|
|
torch.testing.assert_close(ref_y,
|
|
|
|
per_tensor_dequantize(y, inv_scale, dtype))
|
2024-06-12 14:07:26 -07:00
|
|
|
|
|
|
|
# Static quantization
|
2024-07-03 13:38:00 -04:00
|
|
|
y, _ = ops.scaled_fp8_quant(x, inv_scale)
|
2024-08-15 21:24:04 -07:00
|
|
|
torch.testing.assert_close(ref_y,
|
|
|
|
per_tensor_dequantize(y, inv_scale, dtype))
|
2024-06-12 14:07:26 -07:00
|
|
|
|
|
|
|
# Padding
|
2024-07-30 16:37:01 -04:00
|
|
|
y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
|
2024-06-12 14:07:26 -07:00
|
|
|
assert y.shape[0] == 17
|
2024-08-15 21:24:04 -07:00
|
|
|
torch.testing.assert_close(
|
2024-06-12 14:07:26 -07:00
|
|
|
ref_y,
|
|
|
|
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
|
|
|
|
dtype))
|