[Misc] Pass cutlass_fp8_supported correctly in fbgemm_fp8 (#6871)

This commit is contained in:
Elsa Granger 2024-07-28 23:13:49 +08:00 committed by GitHub
parent b1366a9534
commit 3eeb148f46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.fp8 import cutlass_fp8_supported
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
@ -72,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config): def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights( def create_weights(
self, self,
@ -139,11 +141,12 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
bias=bias) bias=bias)
return apply_fp8_linear(input=x, return apply_fp8_linear(
weight=layer.weight, input=x,
weight_scale=layer.weight_scale, weight=layer.weight,
input_scale=None, weight_scale=layer.weight_scale,
input_scale_ub=layer.input_scale_ub, input_scale=None,
bias=bias, input_scale_ub=layer.input_scale_ub,
cutlass_fp8_supported=True, bias=bias,
use_per_token_if_dynamic=True) cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)