[Misc] Pass cutlass_fp8_supported correctly in fbgemm_fp8 (#6871)
This commit is contained in:
parent
b1366a9534
commit
3eeb148f46
@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.fp8 import cutlass_fp8_supported
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@ -72,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: FBGEMMFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -139,11 +141,12 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
return apply_fp8_linear(input=x,
|
||||
return apply_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None,
|
||||
input_scale_ub=layer.input_scale_ub,
|
||||
bias=bias,
|
||||
cutlass_fp8_supported=True,
|
||||
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||
use_per_token_if_dynamic=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user