vllm/tests/quantization/test_quark.py
chaow-amd 9e90c9f73f
[Bugfix] Fix bugs of running Quark quantized models (#16236)
Signed-off-by: chaow <chaow@amd.com>
2025-04-11 10:18:32 -04:00

66 lines
2.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
"""Test model set-up and weight loading for quark-quantized models.
Run `pytest tests/quantization/test_quark.py`.
"""
import pytest
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.platforms import current_platform
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
@pytest.mark.parametrize('tp', [1])
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
with vllm_runner(model_path,
kv_cache_dtype=kv_cache_dtype,
tensor_parallel_size=tp) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
assert len(qkv_proj.weight_scale.shape) == 0
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
@pytest.mark.parametrize('tp', [1])
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
with vllm_runner(model_path, tensor_parallel_size=tp) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output