[Bugfix] Better FP8 supported defaults

This commit is contained in:
Lucas Wilkinson 2025-02-05 22:22:19 -05:00 committed by GitHub
parent 5b19b93082
commit 76abd0c881
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 12 deletions

View File

@ -15,7 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
_normalize_quant_group_shape, scaled_dequantize) _normalize_quant_group_shape, scaled_dequantize)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear) CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
@ -38,7 +38,7 @@ def apply_w8a8_block_fp8_linear(
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = True, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
) -> torch.Tensor: ) -> torch.Tensor:
assert input_scale is None assert input_scale is None
# View input as 2D matrix for fp8 methods # View input as 2D matrix for fp8 methods
@ -91,6 +91,8 @@ def apply_fp8_linear_generic(
input_group_shape: Tuple[int, int], input_group_shape: Tuple[int, int],
weight_group_shape: Tuple[int, int], weight_group_shape: Tuple[int, int],
input_scale: Optional[torch.Tensor] = None, # static scale if one input_scale: Optional[torch.Tensor] = None, # static scale if one
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
) -> torch.Tensor: ) -> torch.Tensor:
# View input as 2D matrix for fp8 methods # View input as 2D matrix for fp8 methods
input = input.view(-1, input.shape[-1]) input = input.view(-1, input.shape[-1])
@ -105,14 +107,18 @@ def apply_fp8_linear_generic(
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\ if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\ and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
input_group_shape == (1, weight_group_shape[1]): input_group_shape == (1, weight_group_shape[1]):
return apply_w8a8_block_fp8_linear(input, weight, return apply_w8a8_block_fp8_linear(
input,
weight,
list(weight_group_shape), list(weight_group_shape),
weight_scale) weight_scale,
cutlass_block_fp8_supported=cutlass_block_fp8_supported)
else: else:
# Despite having linear in the it doesn't conform to # Despite having linear in the it doesn't conform to
# `torch.nn.functional.linear` which is defined as `input @ weight.T` # `torch.nn.functional.linear` which is defined as `input @ weight.T`
# so we explicitly transpose the weight matrix here # so we explicitly transpose the weight matrix here
return apply_fp8_linear(input, weight.T, weight_scale.T, return apply_fp8_linear(input, weight.T, weight_scale.T,
cutlass_fp8_supported=cutlass_fp8_supported,
use_per_token_if_dynamic=\ use_per_token_if_dynamic=\
(input_group_shape == (1, input.shape[1]))) (input_group_shape == (1, input.shape[1])))

View File

@ -42,6 +42,10 @@ def cutlass_block_fp8_supported() -> bool:
return ops.cutlass_scaled_mm_supports_block_fp8(capability) return ops.cutlass_scaled_mm_supports_block_fp8(capability)
CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported()
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
def per_tensor_dequantize( def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float, tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor: torch.Tensor]) -> torch.Tensor:
@ -109,7 +113,7 @@ def apply_fp8_linear(
input_scale: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
cutlass_fp8_supported: bool = True, cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
use_per_token_if_dynamic: bool = False, use_per_token_if_dynamic: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant. # ops.scaled_fp8_quant supports both dynamic and static quant.