[Bugfix] Fix bugs of running Quark quantized models (#16236)

Signed-off-by: chaow <chaow@amd.com>
This commit is contained in:
chaow-amd 2025-04-11 22:18:32 +08:00 committed by GitHub
parent e9528f6dc6
commit 9e90c9f73f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 22 deletions

View File

@ -4,17 +4,28 @@
Run `pytest tests/quantization/test_quark.py`. Run `pytest tests/quantization/test_quark.py`.
""" """
import torch import pytest
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW8A8Fp8) QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.platforms import current_platform
def test_quark_fp8(vllm_runner, monkeypatch): @pytest.fixture(scope="function", autouse=True)
# vllm_runner.apply_model() relies on V0 internals. def use_v0_only(monkeypatch):
monkeypatch.setenv("VLLM_USE_V1", "0") """
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
@pytest.mark.parametrize('tp', [1])
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test" model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
with vllm_runner(model_path) as llm: with vllm_runner(model_path,
kv_cache_dtype=kv_cache_dtype,
tensor_parallel_size=tp) as llm:
def check_model(model): def check_model(model):
layer = model.model.layers[0] layer = model.model.layers[0]
@ -26,11 +37,29 @@ def test_quark_fp8(vllm_runner, monkeypatch):
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8): if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0 assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
assert len(qkv_proj.weight_scale.shape) == 0 assert len(qkv_proj.weight_scale.shape) == 0
llm.apply_model(check_model) llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20) output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output assert output
@pytest.mark.parametrize('tp', [1])
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
with vllm_runner(model_path, tensor_parallel_size=tp) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output

View File

@ -21,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme):
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
@classmethod @classmethod
@ -41,10 +41,11 @@ class QuarkW8A8Fp8(QuarkScheme):
) )
if current_platform.is_fp8_fnuz(): if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
weight_scale=max_w_scale, weight_scale=max_w_scale,
input_scale=layer.input_scale) input_scale=input_scale)
if input_scale is not None: if input_scale is not None:
layer.input_scale = Parameter(input_scale, layer.input_scale = Parameter(input_scale,
requires_grad=False) requires_grad=False)
@ -57,11 +58,12 @@ class QuarkW8A8Fp8(QuarkScheme):
weight = layer.weight weight = layer.weight
if current_platform.is_fp8_fnuz(): if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None)
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale,
input_scale=layer.input_scale) input_scale=input_scale)
if input_scale is not None: if input_scale is not None:
layer.input_scale = Parameter(input_scale, layer.input_scale = Parameter(input_scale,
requires_grad=False) requires_grad=False)
@ -105,7 +107,7 @@ class QuarkW8A8Fp8(QuarkScheme):
# the newly added parameters # the newly added parameters
if self.qscheme == "per_channel": if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter( weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32), dtype=torch.float32),
output_dim=0, output_dim=0,
weight_loader=weight_loader) weight_loader=weight_loader)

View File

@ -35,7 +35,7 @@ class QuarkW8A8Int8(QuarkScheme):
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
self.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.qscheme == "per_channel"), is_channelwise=(self.qscheme == "per_channel"),
@ -63,16 +63,28 @@ class QuarkW8A8Int8(QuarkScheme):
# WEIGHT SCALE # WEIGHT SCALE
if self.qscheme == "per_channel": if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter( weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32), dtype=torch.float32),
output_dim=0, output_dim=0,
weight_loader=weight_loader) weight_loader=weight_loader)
ChannelQuantZPParameter = ChannelQuantScaleParameter
weight_zero_point = ChannelQuantZPParameter(
data=torch.empty((sum(output_partition_sizes)),
dtype=torch.int8),
output_dim=0,
weight_loader=weight_loader)
else: else:
assert self.qscheme == "per_tensor" assert self.qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty( weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32), len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader) weight_loader=weight_loader)
PerTensorZPParameter = PerTensorScaleParameter
weight_zero_point = PerTensorZPParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_zero_point", weight_zero_point)
# INPUT SCALE # INPUT SCALE
if self.is_static_input_scheme: if self.is_static_input_scheme:
@ -81,14 +93,10 @@ class QuarkW8A8Int8(QuarkScheme):
weight_loader=weight_loader) weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric: input_zero_point = BasevLLMParameter(data=torch.empty(
# Note: quark stores the zp using the same dtype 1, dtype=torch.int8),
# as the weights weight_loader=weight_loader)
# AZP loaded as int8 but used as int32 layer.register_parameter("input_zero_point", input_zero_point)
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
w_q_param_name="weight", w_q_param_name="weight",
@ -100,6 +108,12 @@ class QuarkW8A8Int8(QuarkScheme):
# Checkpoints are serialized in quark format, which is # Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here. # different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.register_parameter("weight_zero_point", None)
delattr(layer, 'weight_zero_point')
if self.input_symmetric:
layer.register_parameter("input_zero_point", None)
delattr(layer, 'input_zero_point')
self.kernel.process_weights_after_loading(layer) self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,