[Bugfix] Fix bugs of running Quark quantized models (#16236)

Signed-off-by: chaow <chaow@amd.com>
This commit is contained in:
chaow-amd 2025-04-11 22:18:32 +08:00 committed by GitHub
parent e9528f6dc6
commit 9e90c9f73f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 22 deletions

View File

@ -4,17 +4,28 @@
Run `pytest tests/quantization/test_quark.py`.
"""
import torch
import pytest
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW8A8Fp8)
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.platforms import current_platform
def test_quark_fp8(vllm_runner, monkeypatch):
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
@pytest.mark.parametrize('tp', [1])
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
with vllm_runner(model_path) as llm:
with vllm_runner(model_path,
kv_cache_dtype=kv_cache_dtype,
tensor_parallel_size=tp) as llm:
def check_model(model):
layer = model.model.layers[0]
@ -26,11 +37,29 @@ def test_quark_fp8(vllm_runner, monkeypatch):
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
assert len(qkv_proj.weight_scale.shape) == 0
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
@pytest.mark.parametrize('tp', [1])
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
with vllm_runner(model_path, tensor_parallel_size=tp) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output

View File

@ -21,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme):
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
self.out_dtype = torch.get_default_dtype()
@classmethod
@ -41,10 +41,11 @@ class QuarkW8A8Fp8(QuarkScheme):
)
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
input_scale=layer.input_scale)
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
@ -57,11 +58,12 @@ class QuarkW8A8Fp8(QuarkScheme):
weight = layer.weight
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale)
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
@ -105,7 +107,7 @@ class QuarkW8A8Fp8(QuarkScheme):
# the newly added parameters
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)

View File

@ -35,7 +35,7 @@ class QuarkW8A8Int8(QuarkScheme):
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.logical_widths = output_partition_sizes
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.qscheme == "per_channel"),
@ -63,16 +63,28 @@ class QuarkW8A8Int8(QuarkScheme):
# WEIGHT SCALE
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
ChannelQuantZPParameter = ChannelQuantScaleParameter
weight_zero_point = ChannelQuantZPParameter(
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.int8),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
PerTensorZPParameter = PerTensorScaleParameter
weight_zero_point = PerTensorZPParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_zero_point", weight_zero_point)
# INPUT SCALE
if self.is_static_input_scheme:
@ -81,14 +93,10 @@ class QuarkW8A8Int8(QuarkScheme):
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric:
# Note: quark stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
input_zero_point = BasevLLMParameter(data=torch.empty(
1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
@ -100,6 +108,12 @@ class QuarkW8A8Int8(QuarkScheme):
# Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.register_parameter("weight_zero_point", None)
delattr(layer, 'weight_zero_point')
if self.input_symmetric:
layer.register_parameter("input_zero_point", None)
delattr(layer, 'input_zero_point')
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,