[Bugfix] Better FP8 supported defaults
This commit is contained in:
parent
5b19b93082
commit
76abd0c881
@ -15,7 +15,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
_normalize_quant_group_shape, scaled_dequantize)
|
_normalize_quant_group_shape, scaled_dequantize)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear)
|
CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -38,7 +38,7 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
input_scale: Optional[torch.Tensor] = None,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
cutlass_block_fp8_supported: bool = True,
|
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert input_scale is None
|
assert input_scale is None
|
||||||
# View input as 2D matrix for fp8 methods
|
# View input as 2D matrix for fp8 methods
|
||||||
@ -91,6 +91,8 @@ def apply_fp8_linear_generic(
|
|||||||
input_group_shape: Tuple[int, int],
|
input_group_shape: Tuple[int, int],
|
||||||
weight_group_shape: Tuple[int, int],
|
weight_group_shape: Tuple[int, int],
|
||||||
input_scale: Optional[torch.Tensor] = None, # static scale if one
|
input_scale: Optional[torch.Tensor] = None, # static scale if one
|
||||||
|
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
|
||||||
|
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# View input as 2D matrix for fp8 methods
|
# View input as 2D matrix for fp8 methods
|
||||||
input = input.view(-1, input.shape[-1])
|
input = input.view(-1, input.shape[-1])
|
||||||
@ -105,14 +107,18 @@ def apply_fp8_linear_generic(
|
|||||||
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
|
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
|
||||||
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
|
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
|
||||||
input_group_shape == (1, weight_group_shape[1]):
|
input_group_shape == (1, weight_group_shape[1]):
|
||||||
return apply_w8a8_block_fp8_linear(input, weight,
|
return apply_w8a8_block_fp8_linear(
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
list(weight_group_shape),
|
list(weight_group_shape),
|
||||||
weight_scale)
|
weight_scale,
|
||||||
|
cutlass_block_fp8_supported=cutlass_block_fp8_supported)
|
||||||
else:
|
else:
|
||||||
# Despite having linear in the it doesn't conform to
|
# Despite having linear in the it doesn't conform to
|
||||||
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
|
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
|
||||||
# so we explicitly transpose the weight matrix here
|
# so we explicitly transpose the weight matrix here
|
||||||
return apply_fp8_linear(input, weight.T, weight_scale.T,
|
return apply_fp8_linear(input, weight.T, weight_scale.T,
|
||||||
|
cutlass_fp8_supported=cutlass_fp8_supported,
|
||||||
use_per_token_if_dynamic=\
|
use_per_token_if_dynamic=\
|
||||||
(input_group_shape == (1, input.shape[1])))
|
(input_group_shape == (1, input.shape[1])))
|
||||||
|
|
||||||
|
@ -42,6 +42,10 @@ def cutlass_block_fp8_supported() -> bool:
|
|||||||
return ops.cutlass_scaled_mm_supports_block_fp8(capability)
|
return ops.cutlass_scaled_mm_supports_block_fp8(capability)
|
||||||
|
|
||||||
|
|
||||||
|
CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported()
|
||||||
|
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
|
||||||
|
|
||||||
|
|
||||||
def per_tensor_dequantize(
|
def per_tensor_dequantize(
|
||||||
tensor: torch.Tensor, inv_scale: Union[float,
|
tensor: torch.Tensor, inv_scale: Union[float,
|
||||||
torch.Tensor]) -> torch.Tensor:
|
torch.Tensor]) -> torch.Tensor:
|
||||||
@ -109,7 +113,7 @@ def apply_fp8_linear(
|
|||||||
input_scale: Optional[torch.Tensor] = None,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
input_scale_ub: Optional[torch.Tensor] = None,
|
input_scale_ub: Optional[torch.Tensor] = None,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
cutlass_fp8_supported: bool = True,
|
cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED,
|
||||||
use_per_token_if_dynamic: bool = False,
|
use_per_token_if_dynamic: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user