[Bugfix] Better FP8 supported defaults

This commit is contained in:
Lucas Wilkinson 2025-02-05 22:22:19 -05:00 committed by GitHub
parent 5b19b93082
commit 76abd0c881
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 12 deletions

View File

@ -15,7 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
_normalize_quant_group_shape, scaled_dequantize)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -38,7 +38,7 @@ def apply_w8a8_block_fp8_linear(
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = True,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
@ -85,12 +85,14 @@ def apply_w8a8_block_fp8_linear(
# `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally
def apply_fp8_linear_generic(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_group_shape: Tuple[int, int],
weight_group_shape: Tuple[int, int],
input_scale: Optional[torch.Tensor] = None, # static scale if one
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_group_shape: Tuple[int, int],
weight_group_shape: Tuple[int, int],
input_scale: Optional[torch.Tensor] = None, # static scale if one
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input = input.view(-1, input.shape[-1])
@ -105,14 +107,18 @@ def apply_fp8_linear_generic(
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
input_group_shape == (1, weight_group_shape[1]):
return apply_w8a8_block_fp8_linear(input, weight,
list(weight_group_shape),
weight_scale)
return apply_w8a8_block_fp8_linear(
input,
weight,
list(weight_group_shape),
weight_scale,
cutlass_block_fp8_supported=cutlass_block_fp8_supported)
else:
# Despite having linear in the it doesn't conform to
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
# so we explicitly transpose the weight matrix here
return apply_fp8_linear(input, weight.T, weight_scale.T,
cutlass_fp8_supported=cutlass_fp8_supported,
use_per_token_if_dynamic=\
(input_group_shape == (1, input.shape[1])))

View File

@ -42,6 +42,10 @@ def cutlass_block_fp8_supported() -> bool:
return ops.cutlass_scaled_mm_supports_block_fp8(capability)
CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported()
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
@ -109,7 +113,7 @@ def apply_fp8_linear(
input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_fp8_supported: bool = True,
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
use_per_token_if_dynamic: bool = False,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.