diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index 85dc695b..ce918a32 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -4,17 +4,28 @@ Run `pytest tests/quantization/test_quark.py`. """ -import torch +import pytest from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 - QuarkLinearMethod, QuarkW8A8Fp8) + QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8) +from vllm.platforms import current_platform -def test_quark_fp8(vllm_runner, monkeypatch): - # vllm_runner.apply_model() relies on V0 internals. - monkeypatch.setenv("VLLM_USE_V1", "0") +@pytest.fixture(scope="function", autouse=True) +def use_v0_only(monkeypatch): + """ + This module relies on V0 internals, so set VLLM_USE_V1=0. + """ + monkeypatch.setenv('VLLM_USE_V1', '0') + + +@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8']) +@pytest.mark.parametrize('tp', [1]) +def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp): model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test" - with vllm_runner(model_path) as llm: + with vllm_runner(model_path, + kv_cache_dtype=kv_cache_dtype, + tensor_parallel_size=tp) as llm: def check_model(model): layer = model.model.layers[0] @@ -26,11 +37,29 @@ def test_quark_fp8(vllm_runner, monkeypatch): if isinstance(qkv_proj.scheme, QuarkW8A8Fp8): assert len(qkv_proj.input_scale.shape) == 0 - assert qkv_proj.weight.dtype is torch.float8_e4m3fn - #assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz + assert qkv_proj.weight.dtype is current_platform.fp8_dtype() assert len(qkv_proj.weight_scale.shape) == 0 llm.apply_model(check_model) output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output + + +@pytest.mark.parametrize('tp', [1]) +def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp): + model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test" + with vllm_runner(model_path, tensor_parallel_size=tp) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + + assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) + assert isinstance(qkv_proj.scheme, QuarkW8A8Int8) + + llm.apply_model(check_model) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index c161849c..afd4bb72 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -21,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme): def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): self.qscheme = qscheme self.is_static_input_scheme = is_static_input_scheme - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) self.out_dtype = torch.get_default_dtype() @classmethod @@ -41,10 +41,11 @@ class QuarkW8A8Fp8(QuarkScheme): ) if current_platform.is_fp8_fnuz(): + input_scale = getattr(layer, 'input_scale', None) weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=max_w_scale, - input_scale=layer.input_scale) + input_scale=input_scale) if input_scale is not None: layer.input_scale = Parameter(input_scale, requires_grad=False) @@ -57,11 +58,12 @@ class QuarkW8A8Fp8(QuarkScheme): weight = layer.weight if current_platform.is_fp8_fnuz(): + input_scale = getattr(layer, 'input_scale', None) weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=layer.weight_scale, - input_scale=layer.input_scale) + input_scale=input_scale) if input_scale is not None: layer.input_scale = Parameter(input_scale, requires_grad=False) @@ -105,7 +107,7 @@ class QuarkW8A8Fp8(QuarkScheme): # the newly added parameters if self.qscheme == "per_channel": weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), + data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32), output_dim=0, weight_loader=weight_loader) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index 1bf34b09..da8ed8c0 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -35,7 +35,7 @@ class QuarkW8A8Int8(QuarkScheme): input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - self.logical_widths = output_partition_sizes + layer.logical_widths = output_partition_sizes scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( is_channelwise=(self.qscheme == "per_channel"), @@ -63,16 +63,28 @@ class QuarkW8A8Int8(QuarkScheme): # WEIGHT SCALE if self.qscheme == "per_channel": weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), + data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32), output_dim=0, weight_loader=weight_loader) + ChannelQuantZPParameter = ChannelQuantScaleParameter + weight_zero_point = ChannelQuantZPParameter( + data=torch.empty((sum(output_partition_sizes)), + dtype=torch.int8), + output_dim=0, + weight_loader=weight_loader) else: assert self.qscheme == "per_tensor" weight_scale = PerTensorScaleParameter(data=torch.empty( len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader) + PerTensorZPParameter = PerTensorScaleParameter + weight_zero_point = PerTensorZPParameter( + data=torch.empty(len(output_partition_sizes), + dtype=torch.int8), + weight_loader=weight_loader) layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_zero_point", weight_zero_point) # INPUT SCALE if self.is_static_input_scheme: @@ -81,14 +93,10 @@ class QuarkW8A8Int8(QuarkScheme): weight_loader=weight_loader) layer.register_parameter("input_scale", input_scale) - if not self.input_symmetric: - # Note: quark stores the zp using the same dtype - # as the weights - # AZP loaded as int8 but used as int32 - input_zero_point = BasevLLMParameter( - data=torch.empty(1, dtype=torch.int8), - weight_loader=weight_loader) - layer.register_parameter("input_zero_point", input_zero_point) + input_zero_point = BasevLLMParameter(data=torch.empty( + 1, dtype=torch.int8), + weight_loader=weight_loader) + layer.register_parameter("input_zero_point", input_zero_point) self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, w_q_param_name="weight", @@ -100,6 +108,12 @@ class QuarkW8A8Int8(QuarkScheme): # Checkpoints are serialized in quark format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.register_parameter("weight_zero_point", None) + delattr(layer, 'weight_zero_point') + if self.input_symmetric: + layer.register_parameter("input_zero_point", None) + delattr(layer, 'input_zero_point') + self.kernel.process_weights_after_loading(layer) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,