[Bugfix] Better FP8 supported defaults
This commit is contained in:
parent
5b19b93082
commit
76abd0c881
@ -15,7 +15,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
_normalize_quant_group_shape, scaled_dequantize)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear)
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -38,7 +38,7 @@ def apply_w8a8_block_fp8_linear(
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = True,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
@ -85,12 +85,14 @@ def apply_w8a8_block_fp8_linear(
|
||||
# `apply_fp8_linear`
|
||||
# NOTE(lucas): this is quite messy, we should think through this more formally
|
||||
def apply_fp8_linear_generic(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_group_shape: Tuple[int, int],
|
||||
weight_group_shape: Tuple[int, int],
|
||||
input_scale: Optional[torch.Tensor] = None, # static scale if one
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_group_shape: Tuple[int, int],
|
||||
weight_group_shape: Tuple[int, int],
|
||||
input_scale: Optional[torch.Tensor] = None, # static scale if one
|
||||
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
) -> torch.Tensor:
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input = input.view(-1, input.shape[-1])
|
||||
@ -105,14 +107,18 @@ def apply_fp8_linear_generic(
|
||||
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
|
||||
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
|
||||
input_group_shape == (1, weight_group_shape[1]):
|
||||
return apply_w8a8_block_fp8_linear(input, weight,
|
||||
list(weight_group_shape),
|
||||
weight_scale)
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
input,
|
||||
weight,
|
||||
list(weight_group_shape),
|
||||
weight_scale,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported)
|
||||
else:
|
||||
# Despite having linear in the it doesn't conform to
|
||||
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
|
||||
# so we explicitly transpose the weight matrix here
|
||||
return apply_fp8_linear(input, weight.T, weight_scale.T,
|
||||
cutlass_fp8_supported=cutlass_fp8_supported,
|
||||
use_per_token_if_dynamic=\
|
||||
(input_group_shape == (1, input.shape[1])))
|
||||
|
||||
|
@ -42,6 +42,10 @@ def cutlass_block_fp8_supported() -> bool:
|
||||
return ops.cutlass_scaled_mm_supports_block_fp8(capability)
|
||||
|
||||
|
||||
CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported()
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
|
||||
|
||||
|
||||
def per_tensor_dequantize(
|
||||
tensor: torch.Tensor, inv_scale: Union[float,
|
||||
torch.Tensor]) -> torch.Tensor:
|
||||
@ -109,7 +113,7 @@ def apply_fp8_linear(
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
input_scale_ub: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_fp8_supported: bool = True,
|
||||
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
|
Loading…
x
Reference in New Issue
Block a user