[Bugfix] Fix bugs of running Quark quantized models (#16236)
Signed-off-by: chaow <chaow@amd.com>
This commit is contained in:
parent
e9528f6dc6
commit
9e90c9f73f
@ -4,17 +4,28 @@
|
|||||||
Run `pytest tests/quantization/test_quark.py`.
|
Run `pytest tests/quantization/test_quark.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import pytest
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
||||||
QuarkLinearMethod, QuarkW8A8Fp8)
|
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
def test_quark_fp8(vllm_runner, monkeypatch):
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
# vllm_runner.apply_model() relies on V0 internals.
|
def use_v0_only(monkeypatch):
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
"""
|
||||||
|
This module relies on V0 internals, so set VLLM_USE_V1=0.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
|
||||||
|
@pytest.mark.parametrize('tp', [1])
|
||||||
|
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
|
||||||
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
|
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
|
||||||
with vllm_runner(model_path) as llm:
|
with vllm_runner(model_path,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
tensor_parallel_size=tp) as llm:
|
||||||
|
|
||||||
def check_model(model):
|
def check_model(model):
|
||||||
layer = model.model.layers[0]
|
layer = model.model.layers[0]
|
||||||
@ -26,11 +37,29 @@ def test_quark_fp8(vllm_runner, monkeypatch):
|
|||||||
|
|
||||||
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
|
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
|
||||||
assert len(qkv_proj.input_scale.shape) == 0
|
assert len(qkv_proj.input_scale.shape) == 0
|
||||||
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
|
assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
|
||||||
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
|
|
||||||
assert len(qkv_proj.weight_scale.shape) == 0
|
assert len(qkv_proj.weight_scale.shape) == 0
|
||||||
|
|
||||||
llm.apply_model(check_model)
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
assert output
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('tp', [1])
|
||||||
|
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
|
||||||
|
model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
|
||||||
|
with vllm_runner(model_path, tensor_parallel_size=tp) as llm:
|
||||||
|
|
||||||
|
def check_model(model):
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
|
||||||
|
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
||||||
|
assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
|
assert output
|
||||||
|
@ -21,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
|
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
|
||||||
self.qscheme = qscheme
|
self.qscheme = qscheme
|
||||||
self.is_static_input_scheme = is_static_input_scheme
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
|
||||||
self.out_dtype = torch.get_default_dtype()
|
self.out_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -41,10 +41,11 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if current_platform.is_fp8_fnuz():
|
if current_platform.is_fp8_fnuz():
|
||||||
|
input_scale = getattr(layer, 'input_scale', None)
|
||||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=weight,
|
weight=weight,
|
||||||
weight_scale=max_w_scale,
|
weight_scale=max_w_scale,
|
||||||
input_scale=layer.input_scale)
|
input_scale=input_scale)
|
||||||
if input_scale is not None:
|
if input_scale is not None:
|
||||||
layer.input_scale = Parameter(input_scale,
|
layer.input_scale = Parameter(input_scale,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -57,11 +58,12 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
|
|
||||||
if current_platform.is_fp8_fnuz():
|
if current_platform.is_fp8_fnuz():
|
||||||
|
input_scale = getattr(layer, 'input_scale', None)
|
||||||
weight, weight_scale, input_scale = \
|
weight, weight_scale, input_scale = \
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=weight,
|
weight=weight,
|
||||||
weight_scale=layer.weight_scale,
|
weight_scale=layer.weight_scale,
|
||||||
input_scale=layer.input_scale)
|
input_scale=input_scale)
|
||||||
if input_scale is not None:
|
if input_scale is not None:
|
||||||
layer.input_scale = Parameter(input_scale,
|
layer.input_scale = Parameter(input_scale,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -105,7 +107,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
# the newly added parameters
|
# the newly added parameters
|
||||||
if self.qscheme == "per_channel":
|
if self.qscheme == "per_channel":
|
||||||
weight_scale = ChannelQuantScaleParameter(
|
weight_scale = ChannelQuantScaleParameter(
|
||||||
data=torch.empty((sum(output_partition_sizes), 1),
|
data=torch.empty((sum(output_partition_sizes)),
|
||||||
dtype=torch.float32),
|
dtype=torch.float32),
|
||||||
output_dim=0,
|
output_dim=0,
|
||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
|
@ -35,7 +35,7 @@ class QuarkW8A8Int8(QuarkScheme):
|
|||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, weight_loader: Callable,
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||||
is_channelwise=(self.qscheme == "per_channel"),
|
is_channelwise=(self.qscheme == "per_channel"),
|
||||||
@ -63,16 +63,28 @@ class QuarkW8A8Int8(QuarkScheme):
|
|||||||
# WEIGHT SCALE
|
# WEIGHT SCALE
|
||||||
if self.qscheme == "per_channel":
|
if self.qscheme == "per_channel":
|
||||||
weight_scale = ChannelQuantScaleParameter(
|
weight_scale = ChannelQuantScaleParameter(
|
||||||
data=torch.empty((sum(output_partition_sizes), 1),
|
data=torch.empty((sum(output_partition_sizes)),
|
||||||
dtype=torch.float32),
|
dtype=torch.float32),
|
||||||
output_dim=0,
|
output_dim=0,
|
||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
|
ChannelQuantZPParameter = ChannelQuantScaleParameter
|
||||||
|
weight_zero_point = ChannelQuantZPParameter(
|
||||||
|
data=torch.empty((sum(output_partition_sizes)),
|
||||||
|
dtype=torch.int8),
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
else:
|
else:
|
||||||
assert self.qscheme == "per_tensor"
|
assert self.qscheme == "per_tensor"
|
||||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||||
len(output_partition_sizes), dtype=torch.float32),
|
len(output_partition_sizes), dtype=torch.float32),
|
||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
|
PerTensorZPParameter = PerTensorScaleParameter
|
||||||
|
weight_zero_point = PerTensorZPParameter(
|
||||||
|
data=torch.empty(len(output_partition_sizes),
|
||||||
|
dtype=torch.int8),
|
||||||
|
weight_loader=weight_loader)
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||||
|
|
||||||
# INPUT SCALE
|
# INPUT SCALE
|
||||||
if self.is_static_input_scheme:
|
if self.is_static_input_scheme:
|
||||||
@ -81,14 +93,10 @@ class QuarkW8A8Int8(QuarkScheme):
|
|||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
layer.register_parameter("input_scale", input_scale)
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
|
||||||
if not self.input_symmetric:
|
input_zero_point = BasevLLMParameter(data=torch.empty(
|
||||||
# Note: quark stores the zp using the same dtype
|
1, dtype=torch.int8),
|
||||||
# as the weights
|
weight_loader=weight_loader)
|
||||||
# AZP loaded as int8 but used as int32
|
layer.register_parameter("input_zero_point", input_zero_point)
|
||||||
input_zero_point = BasevLLMParameter(
|
|
||||||
data=torch.empty(1, dtype=torch.int8),
|
|
||||||
weight_loader=weight_loader)
|
|
||||||
layer.register_parameter("input_zero_point", input_zero_point)
|
|
||||||
|
|
||||||
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
|
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
|
||||||
w_q_param_name="weight",
|
w_q_param_name="weight",
|
||||||
@ -100,6 +108,12 @@ class QuarkW8A8Int8(QuarkScheme):
|
|||||||
# Checkpoints are serialized in quark format, which is
|
# Checkpoints are serialized in quark format, which is
|
||||||
# different from the format the kernel may want. Handle repacking here.
|
# different from the format the kernel may want. Handle repacking here.
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
layer.register_parameter("weight_zero_point", None)
|
||||||
|
delattr(layer, 'weight_zero_point')
|
||||||
|
if self.input_symmetric:
|
||||||
|
layer.register_parameter("input_zero_point", None)
|
||||||
|
delattr(layer, 'input_zero_point')
|
||||||
|
|
||||||
self.kernel.process_weights_after_loading(layer)
|
self.kernel.process_weights_after_loading(layer)
|
||||||
|
|
||||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user