[FP8] Refactor apply_fp8_linear and apply_fp8_linear_generic into an object (#14390)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
parent
dae6896977
commit
e1744502c2
@ -13,7 +13,7 @@ from vllm.compilation.noop_elimination import NoOpEliminationPass
|
|||||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity)
|
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
@ -34,26 +34,20 @@ class TestModel(torch.nn.Module):
|
|||||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
for _ in range(2)
|
for _ in range(2)
|
||||||
]
|
]
|
||||||
|
self.fp8_linear = Fp8LinearOp(
|
||||||
|
cutlass_fp8_supported=cutlass_fp8_enabled,
|
||||||
|
use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
resid = torch.sqrt(x)
|
resid = torch.sqrt(x)
|
||||||
y = self.norm[0](x)
|
y = self.norm[0](x)
|
||||||
|
|
||||||
x2 = apply_fp8_linear(y,
|
x2 = self.fp8_linear.apply(y, self.w[0], self.wscale[0], self.scale[0])
|
||||||
self.w[0],
|
|
||||||
self.wscale[0],
|
|
||||||
self.scale[0],
|
|
||||||
use_per_token_if_dynamic=True,
|
|
||||||
cutlass_fp8_supported=self.cutlass_fp8_enabled)
|
|
||||||
# make sure resid is used for replacement to work
|
# make sure resid is used for replacement to work
|
||||||
y2, resid = self.norm[1](x2, resid)
|
y2, resid = self.norm[1](x2, resid)
|
||||||
|
|
||||||
x3 = apply_fp8_linear(y2,
|
x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1],
|
||||||
self.w[1],
|
self.scale[1])
|
||||||
self.wscale[1],
|
|
||||||
self.scale[1],
|
|
||||||
use_per_token_if_dynamic=True,
|
|
||||||
cutlass_fp8_supported=self.cutlass_fp8_enabled)
|
|
||||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||||
return y3
|
return y3
|
||||||
|
|
||||||
|
@ -226,7 +226,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
CompressedTensorsW8A8Fp8)
|
CompressedTensorsW8A8Fp8)
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
scaled_quantize)
|
scaled_quantize)
|
||||||
from vllm.model_executor.layers.rotary_embedding import (
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
@ -1057,6 +1057,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
self.kv_b_proj = kv_b_proj
|
self.kv_b_proj = kv_b_proj
|
||||||
self.o_proj = o_proj
|
self.o_proj = o_proj
|
||||||
self.triton_fa_func = triton_attention
|
self.triton_fa_func = triton_attention
|
||||||
|
self.fp8_linear_generic = Fp8LinearGenericOp()
|
||||||
|
|
||||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||||
@ -1071,7 +1072,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
def _v_up_proj_and_o_proj(self, x):
|
def _v_up_proj_and_o_proj(self, x):
|
||||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||||
if is_fp8(self.W_UV_O):
|
if is_fp8(self.W_UV_O):
|
||||||
output_parallel = apply_fp8_linear_generic(
|
output_parallel = self.fp8_linear_generic.apply(
|
||||||
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
|
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
|
||||||
self.reqaunt_input_group_shape,
|
self.reqaunt_input_group_shape,
|
||||||
self.reqaunt_weight_group_shape)
|
self.reqaunt_weight_group_shape)
|
||||||
@ -1091,7 +1092,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
def _q_proj_and_k_up_proj(self, x):
|
def _q_proj_and_k_up_proj(self, x):
|
||||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||||
if is_fp8(self.W_Q_UK):
|
if is_fp8(self.W_Q_UK):
|
||||||
return apply_fp8_linear_generic(
|
return self.fp8_linear_generic.apply(
|
||||||
x, self.W_Q_UK, self.W_Q_UK_scales,
|
x, self.W_Q_UK, self.W_Q_UK_scales,
|
||||||
self.reqaunt_input_group_shape,
|
self.reqaunt_input_group_shape,
|
||||||
self.reqaunt_weight_group_shape).view(
|
self.reqaunt_weight_group_shape).view(
|
||||||
|
@ -9,8 +9,8 @@ from torch.nn import Parameter
|
|||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear, cutlass_fp8_supported, maybe_create_device_identity,
|
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
||||||
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
|
requantize_with_max_scale)
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
@ -24,7 +24,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
self.is_static_input_scheme = is_static_input_scheme
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -140,11 +140,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
return apply_fp8_linear(
|
return self.fp8_linear.apply(input=x,
|
||||||
input=x,
|
weight=layer.weight,
|
||||||
weight=layer.weight,
|
weight_scale=layer.weight_scale,
|
||||||
weight_scale=layer.weight_scale,
|
input_scale=layer.input_scale,
|
||||||
input_scale=layer.input_scale,
|
bias=bias)
|
||||||
bias=bias,
|
|
||||||
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
|
||||||
use_per_token_if_dynamic=True)
|
|
||||||
|
@ -11,14 +11,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.fp8 import cutlass_fp8_supported
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
is_layer_skipped)
|
is_layer_skipped)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear, maybe_create_device_identity,
|
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz)
|
||||||
normalize_e4m3fn_to_e4m3fnuz)
|
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
ModelWeightParameter)
|
ModelWeightParameter)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -37,6 +35,7 @@ class FBGEMMFp8Config(QuantizationConfig):
|
|||||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
# kernel for fast weight-only FP8 quantization
|
# kernel for fast weight-only FP8 quantization
|
||||||
self.use_marlin = not current_platform.has_device_capability(89)
|
self.use_marlin = not current_platform.has_device_capability(89)
|
||||||
|
self.fp8_linear = Fp8LinearOp()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> str:
|
||||||
@ -73,7 +72,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
def __init__(self, quant_config: FBGEMMFp8Config):
|
def __init__(self, quant_config: FBGEMMFp8Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@ -159,12 +158,9 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
size_k=layer.input_size_per_partition,
|
size_k=layer.input_size_per_partition,
|
||||||
bias=bias)
|
bias=bias)
|
||||||
|
|
||||||
return apply_fp8_linear(
|
return self.fp8_linear.apply(input=x,
|
||||||
input=x,
|
weight=layer.weight,
|
||||||
weight=layer.weight,
|
weight_scale=layer.weight_scale,
|
||||||
weight_scale=layer.weight_scale,
|
input_scale=None,
|
||||||
input_scale=None,
|
input_scale_ub=layer.input_scale_ub,
|
||||||
input_scale_ub=layer.input_scale_ub,
|
bias=bias)
|
||||||
bias=bias,
|
|
||||||
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
|
||||||
use_per_token_if_dynamic=True)
|
|
||||||
|
@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
is_layer_skipped)
|
is_layer_skipped)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
Fp8LinearOp, all_close_1d, convert_to_channelwise,
|
||||||
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||||
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
||||||
per_tensor_dequantize, requantize_with_max_scale)
|
per_tensor_dequantize, requantize_with_max_scale)
|
||||||
@ -137,7 +137,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
def __init__(self, quant_config: Fp8Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
|
||||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||||
|
|
||||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
@ -153,6 +152,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# Marlin doesn't support block-wise fp8
|
# Marlin doesn't support block-wise fp8
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
|
|
||||||
|
self.fp8_linear = Fp8LinearOp(
|
||||||
|
# Default to using per_token quantization if cutlass is supported
|
||||||
|
use_per_token_if_dynamic=cutlass_fp8_supported())
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -381,15 +384,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||||
)
|
)
|
||||||
|
|
||||||
return apply_fp8_linear(
|
return self.fp8_linear.apply(input=x,
|
||||||
input=x,
|
weight=layer.weight,
|
||||||
weight=layer.weight,
|
weight_scale=layer.weight_scale,
|
||||||
weight_scale=layer.weight_scale,
|
input_scale=layer.input_scale,
|
||||||
input_scale=layer.input_scale,
|
bias=bias)
|
||||||
bias=bias,
|
|
||||||
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
|
||||||
# Default to using per_token quantization if cutlass is supported
|
|
||||||
use_per_token_if_dynamic=self.cutlass_fp8_supported)
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||||
|
@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
|
Fp8LinearOp, requantize_with_max_scale)
|
||||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
|
|
||||||
@ -95,7 +95,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
def __init__(self, quant_config: ModelOptFp8Config):
|
def __init__(self, quant_config: ModelOptFp8Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.fp8_linear = Fp8LinearOp()
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@ -157,10 +157,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return apply_fp8_linear(
|
return self.fp8_linear.apply(input=x,
|
||||||
input=x,
|
weight=layer.weight,
|
||||||
weight=layer.weight,
|
weight_scale=layer.weight_scale,
|
||||||
weight_scale=layer.weight_scale,
|
input_scale=layer.input_scale,
|
||||||
input_scale=layer.input_scale,
|
bias=bias)
|
||||||
bias=bias,
|
|
||||||
cutlass_fp8_supported=self.cutlass_fp8_supported)
|
|
||||||
|
@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
is_layer_skipped)
|
is_layer_skipped)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear)
|
Fp8LinearOp)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
@ -93,6 +93,8 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
|||||||
super().__init__(quant_config=quant_config)
|
super().__init__(quant_config=quant_config)
|
||||||
# Force weight quantization
|
# Force weight quantization
|
||||||
self.quant_config.is_checkpoint_fp8_serialized = False
|
self.quant_config.is_checkpoint_fp8_serialized = False
|
||||||
|
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False,
|
||||||
|
use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||||
@ -115,11 +117,9 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
return apply_fp8_linear(input=x,
|
return self.fp8_linear.apply(input=x,
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
weight_scale=layer.weight_scale,
|
weight_scale=layer.weight_scale,
|
||||||
input_scale=None,
|
input_scale=None,
|
||||||
input_scale_ub=None,
|
input_scale_ub=None,
|
||||||
bias=bias,
|
bias=bias)
|
||||||
cutlass_fp8_supported=False,
|
|
||||||
use_per_token_if_dynamic=True)
|
|
||||||
|
@ -7,8 +7,7 @@ from torch.nn import Parameter
|
|||||||
|
|
||||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
|
Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
|
||||||
requantize_with_max_scale)
|
|
||||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
@ -22,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
|
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
|
||||||
self.qscheme = qscheme
|
self.qscheme = qscheme
|
||||||
self.is_static_input_scheme = is_static_input_scheme
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -132,11 +131,8 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
return apply_fp8_linear(
|
return self.fp8_linear.apply(input=x,
|
||||||
input=x,
|
weight=layer.weight,
|
||||||
weight=layer.weight,
|
weight_scale=layer.weight_scale,
|
||||||
weight_scale=layer.weight_scale,
|
input_scale=layer.input_scale,
|
||||||
input_scale=layer.input_scale,
|
bias=bias)
|
||||||
bias=bias,
|
|
||||||
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
|
||||||
use_per_token_if_dynamic=True)
|
|
||||||
|
@ -15,7 +15,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
_normalize_quant_group_shape, scaled_dequantize)
|
_normalize_quant_group_shape, scaled_dequantize)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
|
CUTLASS_BLOCK_FP8_SUPPORTED, Fp8LinearOp, cutlass_block_fp8_supported,
|
||||||
|
cutlass_fp8_supported)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -32,6 +33,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
|||||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||||
|
|
||||||
|
|
||||||
|
# TODO fix ROCm->Triton custom path:
|
||||||
|
# https://github.com/vllm-project/vllm/issues/14397
|
||||||
def apply_w8a8_block_fp8_linear(
|
def apply_w8a8_block_fp8_linear(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
@ -49,6 +52,7 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
|
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
|
||||||
and weight.shape[1] % 128 == 0)
|
and weight.shape[1] % 128 == 0)
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
|
# TODO this is never used, as cutlass_block_fp8_supported is False
|
||||||
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
|
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
|
||||||
input_2d.shape[:-1])[::-1]
|
input_2d.shape[:-1])[::-1]
|
||||||
scale_b_shape = (weight_scale.view(-1, 1)
|
scale_b_shape = (weight_scale.view(-1, 1)
|
||||||
@ -104,43 +108,55 @@ direct_register_custom_op(
|
|||||||
# Unify the interface between `apply_w8a8_block_fp8_linear` and
|
# Unify the interface between `apply_w8a8_block_fp8_linear` and
|
||||||
# `apply_fp8_linear`
|
# `apply_fp8_linear`
|
||||||
# NOTE(lucas): this is quite messy, we should think through this more formally
|
# NOTE(lucas): this is quite messy, we should think through this more formally
|
||||||
def apply_fp8_linear_generic(
|
# TODO(luka): unify this better
|
||||||
input: torch.Tensor,
|
# https://github.com/vllm-project/vllm/issues/14397
|
||||||
weight: torch.Tensor,
|
class Fp8LinearGenericOp:
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
input_group_shape: Tuple[int, int],
|
|
||||||
weight_group_shape: Tuple[int, int],
|
|
||||||
input_scale: Optional[torch.Tensor] = None, # static scale if one
|
|
||||||
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
|
|
||||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# View input as 2D matrix for fp8 methods
|
|
||||||
input = input.view(-1, input.shape[-1])
|
|
||||||
|
|
||||||
weight_group_shape = _normalize_quant_group_shape(\
|
def __init__(
|
||||||
weight, weight_group_shape)
|
self,
|
||||||
input_group_shape = _normalize_quant_group_shape(input, input_group_shape)
|
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
||||||
|
cutlass_block_fp8_supported: bool = cutlass_block_fp8_supported(),
|
||||||
|
):
|
||||||
|
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported
|
||||||
|
self.fp8_linear = Fp8LinearOp(
|
||||||
|
cutlass_fp8_supported=cutlass_fp8_supported)
|
||||||
|
|
||||||
def is_dim_blocked(dim, shape, group_shape):
|
def apply(
|
||||||
return group_shape < shape[dim] and group_shape > 1
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_group_shape: Tuple[int, int],
|
||||||
|
weight_group_shape: Tuple[int, int],
|
||||||
|
input_scale: Optional[torch.Tensor] = None, # static scale if one
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
input = input.view(-1, input.shape[-1])
|
||||||
|
|
||||||
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
|
weight_group_shape = _normalize_quant_group_shape( \
|
||||||
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
|
weight, weight_group_shape)
|
||||||
input_group_shape == (1, weight_group_shape[1]):
|
input_group_shape = _normalize_quant_group_shape(
|
||||||
return apply_w8a8_block_fp8_linear(
|
input, input_group_shape)
|
||||||
input,
|
|
||||||
weight,
|
def is_dim_blocked(dim, shape, group_shape):
|
||||||
list(weight_group_shape),
|
return group_shape < shape[dim] and group_shape > 1
|
||||||
weight_scale,
|
|
||||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported)
|
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
|
||||||
else:
|
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
|
||||||
# Despite having linear in the it doesn't conform to
|
input_group_shape == (1, weight_group_shape[1]):
|
||||||
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
|
return apply_w8a8_block_fp8_linear(
|
||||||
# so we explicitly transpose the weight matrix here
|
input,
|
||||||
return apply_fp8_linear(input, weight.T, weight_scale.T,
|
weight,
|
||||||
cutlass_fp8_supported=cutlass_fp8_supported,
|
list(weight_group_shape),
|
||||||
use_per_token_if_dynamic=\
|
weight_scale,
|
||||||
(input_group_shape == (1, input.shape[1])))
|
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported)
|
||||||
|
else:
|
||||||
|
# Despite having linear in the name it doesn't conform to
|
||||||
|
# `torch.nn.functional.linear` which is defined as
|
||||||
|
# `input @ weight.T` so we explicitly transpose the weight matrix
|
||||||
|
return self.fp8_linear.apply(input, weight.T, weight_scale.T,
|
||||||
|
use_per_token_if_dynamic=\
|
||||||
|
(input_group_shape == (1, input.shape[1])))
|
||||||
|
|
||||||
|
|
||||||
def input_to_float8(
|
def input_to_float8(
|
||||||
|
@ -121,134 +121,162 @@ def maybe_create_device_identity():
|
|||||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
def apply_fp8_linear(
|
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
||||||
input: torch.Tensor,
|
# https://github.com/vllm-project/vllm/issues/14397
|
||||||
weight: torch.Tensor,
|
class Fp8LinearOp:
|
||||||
weight_scale: torch.Tensor,
|
"""
|
||||||
input_scale: Optional[torch.Tensor] = None,
|
This class executes a FP8 linear layer using cutlass if supported and
|
||||||
input_scale_ub: Optional[torch.Tensor] = None,
|
torch.scaled_mm otherwise.
|
||||||
bias: Optional[torch.Tensor] = None,
|
It needs to be a class instead of a method so that config can be read
|
||||||
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
|
in the __init__ method, as reading config is not allowed inside forward.
|
||||||
use_per_token_if_dynamic: bool = False,
|
"""
|
||||||
) -> torch.Tensor:
|
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
|
||||||
|
|
||||||
# View input as 2D matrix for fp8 methods
|
def __init__(self,
|
||||||
input_2d = input.view(-1, input.shape[-1])
|
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
||||||
output_shape = [*input.shape[:-1], weight.shape[1]]
|
use_per_token_if_dynamic: bool = False,
|
||||||
|
pad_output: Optional[bool] = None):
|
||||||
|
self.cutlass_fp8_supported = cutlass_fp8_supported
|
||||||
|
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
||||||
|
|
||||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
|
||||||
if cutlass_fp8_supported:
|
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(
|
|
||||||
input_2d,
|
|
||||||
input_scale,
|
|
||||||
scale_ub=input_scale_ub,
|
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
|
||||||
|
|
||||||
# Fused GEMM_DQ
|
|
||||||
output = ops.cutlass_scaled_mm(qinput,
|
|
||||||
weight,
|
|
||||||
out_dtype=input.dtype,
|
|
||||||
scale_a=x_scale,
|
|
||||||
scale_b=weight_scale,
|
|
||||||
bias=bias)
|
|
||||||
return output.view(*output_shape)
|
|
||||||
|
|
||||||
# torch.scaled_mm supports per tensor weights + activations only
|
|
||||||
# so fallback to naive if per channel or per token
|
|
||||||
else:
|
|
||||||
# Note: we pad the input because torch._scaled_mm is more performant
|
# Note: we pad the input because torch._scaled_mm is more performant
|
||||||
# for matrices with batch dimension > 16.
|
# for matrices with batch dimension > 16.
|
||||||
# This could change in the future.
|
# This could change in the future.
|
||||||
# We also don't pad when using torch.compile,
|
# We also don't pad when using torch.compile,
|
||||||
# as it breaks with dynamic shapes.
|
# as it breaks with dynamic shapes.
|
||||||
config = get_current_vllm_config().compilation_config
|
if pad_output is None:
|
||||||
do_pad = config.level < CompilationLevel.PIECEWISE
|
config = get_current_vllm_config().compilation_config
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(
|
pad_output = config.level < CompilationLevel.PIECEWISE
|
||||||
input_2d,
|
self.output_padding = 17 if pad_output else None
|
||||||
input_scale,
|
|
||||||
num_token_padding=17 if do_pad else None,
|
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
|
||||||
|
|
||||||
per_tensor_weights = (weight_scale.numel() == 1)
|
def apply(
|
||||||
per_tensor_activations = (x_scale.numel() == 1)
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
input_scale_ub: Optional[torch.Tensor] = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
# TODO(luka) remove this parameter in favor of __init__
|
||||||
|
use_per_token_if_dynamic: Optional[bool] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
|
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
input_2d = input.view(-1, input.shape[-1])
|
||||||
|
output_shape = [*input.shape[:-1], weight.shape[1]]
|
||||||
|
|
||||||
|
# TODO(luka) this is here because currently MLA only decides this
|
||||||
|
# during the forward method instead of in __init__.
|
||||||
|
if use_per_token_if_dynamic is None:
|
||||||
|
use_per_token_if_dynamic = self.use_per_token_if_dynamic
|
||||||
|
|
||||||
|
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||||
|
if self.cutlass_fp8_supported:
|
||||||
|
qinput, x_scale = ops.scaled_fp8_quant(
|
||||||
|
input_2d,
|
||||||
|
input_scale,
|
||||||
|
scale_ub=input_scale_ub,
|
||||||
|
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||||
|
|
||||||
if per_tensor_weights and per_tensor_activations:
|
|
||||||
# Fused GEMM_DQ
|
# Fused GEMM_DQ
|
||||||
output = torch._scaled_mm(qinput,
|
output = ops.cutlass_scaled_mm(qinput,
|
||||||
weight,
|
weight,
|
||||||
out_dtype=input.dtype,
|
out_dtype=input.dtype,
|
||||||
scale_a=x_scale,
|
scale_a=x_scale,
|
||||||
scale_b=weight_scale,
|
scale_b=weight_scale,
|
||||||
bias=bias)
|
bias=bias)
|
||||||
# A fix for discrepancy in scaled_mm which returns tuple
|
return output.view(*output_shape)
|
||||||
# for torch < 2.5 and a single value in torch >= 2.5
|
|
||||||
if type(output) is tuple and len(output) == 2:
|
|
||||||
output = output[0]
|
|
||||||
|
|
||||||
return torch.narrow(output, 0, 0,
|
|
||||||
input_2d.shape[0]).view(*output_shape)
|
|
||||||
|
|
||||||
elif (use_per_token_if_dynamic and not per_tensor_weights
|
|
||||||
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
|
|
||||||
# For now validated on ROCm platform
|
|
||||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
|
||||||
# https://github.com/pytorch/pytorch/pull/144432 using
|
|
||||||
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
|
|
||||||
# For CUDA platform please validate if the
|
|
||||||
# torch._scaled_mm support rowwise scaled GEMM
|
|
||||||
# Fused GEMM_DQ Rowwise GEMM
|
|
||||||
output = torch._scaled_mm(qinput,
|
|
||||||
weight,
|
|
||||||
out_dtype=input.dtype,
|
|
||||||
scale_a=x_scale,
|
|
||||||
scale_b=weight_scale.t(),
|
|
||||||
bias=bias)
|
|
||||||
|
|
||||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
|
||||||
output = output.view(*output_shape)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
# torch.scaled_mm supports per tensor weights + activations only
|
||||||
|
# so fallback to naive if per channel or per token
|
||||||
else:
|
else:
|
||||||
# Fallback for channelwise case, where we use unfused DQ
|
# Maybe apply padding to output, see comment in __init__
|
||||||
# due to limitations with scaled_mm
|
qinput, x_scale = ops.scaled_fp8_quant(
|
||||||
|
input_2d,
|
||||||
|
input_scale,
|
||||||
|
num_token_padding=self.output_padding,
|
||||||
|
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||||
|
|
||||||
# Symmetric quantized GEMM by definition computes the following:
|
per_tensor_weights = (weight_scale.numel() == 1)
|
||||||
# C = (s_x * X) (s_w * W) + bias
|
per_tensor_activations = (x_scale.numel() == 1)
|
||||||
# This is equivalent to dequantizing the weights and activations
|
|
||||||
# before applying a GEMM.
|
|
||||||
#
|
|
||||||
# In order to compute quantized operands, a quantized kernel
|
|
||||||
# will rewrite the above like so:
|
|
||||||
# C = s_w * s_x * (X * W) + bias
|
|
||||||
#
|
|
||||||
# For the scaled_mm fallback case, we break this down, since it
|
|
||||||
# does not support s_w being a vector.
|
|
||||||
|
|
||||||
# GEMM
|
if per_tensor_weights and per_tensor_activations:
|
||||||
# This computes C = (X * W).
|
# Fused GEMM_DQ
|
||||||
# Output in fp32 to allow subsequent ops to happen in-place
|
output = torch._scaled_mm(qinput,
|
||||||
output = torch._scaled_mm(qinput,
|
weight,
|
||||||
weight,
|
out_dtype=input.dtype,
|
||||||
scale_a=TORCH_DEVICE_IDENTITY,
|
scale_a=x_scale,
|
||||||
scale_b=TORCH_DEVICE_IDENTITY,
|
scale_b=weight_scale,
|
||||||
out_dtype=torch.float32)
|
bias=bias)
|
||||||
# A fix for discrepancy in scaled_mm which returns tuple
|
# A fix for discrepancy in scaled_mm which returns tuple
|
||||||
# for torch < 2.5 and a single value in torch >= 2.5
|
# for torch < 2.5 and a single value in torch >= 2.5
|
||||||
if type(output) is tuple and len(output) == 2:
|
if type(output) is tuple and len(output) == 2:
|
||||||
output = output[0]
|
output = output[0]
|
||||||
# Unpad (undo num_token_padding)
|
|
||||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
|
||||||
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
|
||||||
|
|
||||||
# DQ
|
return torch.narrow(output, 0, 0,
|
||||||
# C = sw * sx * (X * W) + bias
|
input_2d.shape[0]).view(*output_shape)
|
||||||
output = output * x_scale * weight_scale.t()
|
|
||||||
if bias is not None:
|
elif (use_per_token_if_dynamic and not per_tensor_weights
|
||||||
output = output + bias
|
and not per_tensor_activations
|
||||||
return output.to(dtype=input.dtype).view(*output_shape)
|
and USE_ROWWISE_TORCH_SCALED_MM):
|
||||||
|
# For now validated on ROCm platform
|
||||||
|
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||||
|
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
|
||||||
|
# and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||||
|
# For CUDA platform please validate if the
|
||||||
|
# torch._scaled_mm support rowwise scaled GEMM
|
||||||
|
# Fused GEMM_DQ Rowwise GEMM
|
||||||
|
output = torch._scaled_mm(qinput,
|
||||||
|
weight,
|
||||||
|
out_dtype=input.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=weight_scale.t(),
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||||
|
output = output.view(*output_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Fallback for channelwise case, where we use unfused DQ
|
||||||
|
# due to limitations with scaled_mm
|
||||||
|
|
||||||
|
# Symmetric quantized GEMM by definition computes the following:
|
||||||
|
# C = (s_x * X) (s_w * W) + bias
|
||||||
|
# This is equivalent to dequantizing the weights and activations
|
||||||
|
# before applying a GEMM.
|
||||||
|
#
|
||||||
|
# In order to compute quantized operands, a quantized kernel
|
||||||
|
# will rewrite the above like so:
|
||||||
|
# C = s_w * s_x * (X * W) + bias
|
||||||
|
#
|
||||||
|
# For the scaled_mm fallback case, we break this down, since it
|
||||||
|
# does not support s_w being a vector.
|
||||||
|
|
||||||
|
# GEMM
|
||||||
|
# This computes C = (X * W).
|
||||||
|
# Output in fp32 to allow subsequent ops to happen in-place
|
||||||
|
output = torch._scaled_mm(qinput,
|
||||||
|
weight,
|
||||||
|
scale_a=TORCH_DEVICE_IDENTITY,
|
||||||
|
scale_b=TORCH_DEVICE_IDENTITY,
|
||||||
|
out_dtype=torch.float32)
|
||||||
|
# A fix for discrepancy in scaled_mm which returns tuple
|
||||||
|
# for torch < 2.5 and a single value in torch >= 2.5
|
||||||
|
if type(output) is tuple and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
# Unpad (undo num_token_padding)
|
||||||
|
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||||
|
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
||||||
|
|
||||||
|
# DQ
|
||||||
|
# C = sw * sx * (X * W) + bias
|
||||||
|
output = output * x_scale * weight_scale.t()
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
return output.to(dtype=input.dtype).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
def normalize_e4m3fn_to_e4m3fnuz(
|
def normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
@ -219,7 +219,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
CompressedTensorsW8A8Fp8)
|
CompressedTensorsW8A8Fp8)
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
scaled_quantize)
|
scaled_quantize)
|
||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
@ -640,6 +640,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
self.kv_b_proj = kv_b_proj
|
self.kv_b_proj = kv_b_proj
|
||||||
self.o_proj = o_proj
|
self.o_proj = o_proj
|
||||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||||
|
self.fp8_linear_generic = Fp8LinearGenericOp()
|
||||||
|
|
||||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||||
@ -653,7 +654,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
def _v_up_proj_and_o_proj(self, x):
|
def _v_up_proj_and_o_proj(self, x):
|
||||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||||
if is_fp8(self.W_UV_O):
|
if is_fp8(self.W_UV_O):
|
||||||
output_parallel = apply_fp8_linear_generic(
|
output_parallel = self.fp8_linear_generic.apply(
|
||||||
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
|
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
|
||||||
self.reqaunt_input_group_shape,
|
self.reqaunt_input_group_shape,
|
||||||
self.reqaunt_weight_group_shape)
|
self.reqaunt_weight_group_shape)
|
||||||
@ -673,7 +674,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
def _q_proj_and_k_up_proj(self, x):
|
def _q_proj_and_k_up_proj(self, x):
|
||||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||||
if is_fp8(self.W_Q_UK):
|
if is_fp8(self.W_Q_UK):
|
||||||
return apply_fp8_linear_generic(
|
return self.fp8_linear_generic.apply(
|
||||||
x, self.W_Q_UK, self.W_Q_UK_scales,
|
x, self.W_Q_UK, self.W_Q_UK_scales,
|
||||||
self.reqaunt_input_group_shape,
|
self.reqaunt_input_group_shape,
|
||||||
self.reqaunt_weight_group_shape).view(
|
self.reqaunt_weight_group_shape).view(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user