2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2025-03-03 01:34:51 +00:00
|
|
|
from typing import Optional, Union
|
2024-07-17 21:38:35 -04:00
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2024-10-28 12:07:00 +08:00
|
|
|
from vllm.platforms import current_platform
|
2024-08-16 12:06:30 -05:00
|
|
|
|
|
|
|
# Using the default value (240.0) from pytorch will cause accuracy
|
|
|
|
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
|
|
|
ROCM_FP8_MAX = 224.0
|
2025-03-11 07:54:56 -07:00
|
|
|
FP8_DTYPE = current_platform.fp8_dtype()
|
2024-08-16 12:06:30 -05:00
|
|
|
|
2024-07-17 21:38:35 -04:00
|
|
|
|
|
|
|
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
|
|
|
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
|
|
|
|
|
|
|
|
def ref_dynamic_per_token_quant(x: torch.tensor,
|
2024-07-19 21:15:26 -04:00
|
|
|
quant_dtype: torch.dtype,
|
|
|
|
scale_ub: Optional[torch.tensor] = None) \
|
2025-03-03 01:34:51 +00:00
|
|
|
-> tuple[torch.tensor, torch.tensor]:
|
2024-07-17 21:38:35 -04:00
|
|
|
|
2024-08-16 12:06:30 -05:00
|
|
|
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
2024-07-19 21:15:26 -04:00
|
|
|
if scale_ub is not None:
|
2024-08-16 12:06:30 -05:00
|
|
|
assert quant_dtype == FP8_DTYPE
|
2024-07-19 21:15:26 -04:00
|
|
|
|
2024-07-17 21:38:35 -04:00
|
|
|
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
|
|
|
|
else torch.finfo(quant_dtype)
|
2024-10-28 12:07:00 +08:00
|
|
|
qtype_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
|
|
|
|
else qtype_traits.max
|
|
|
|
qtype_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
|
|
|
|
else qtype_traits.min
|
2024-08-16 12:06:30 -05:00
|
|
|
qtype_max = as_float32_tensor(qtype_traits_max)
|
2024-07-19 21:15:26 -04:00
|
|
|
s_1 = as_float32_tensor(1.0)
|
|
|
|
s_512 = as_float32_tensor(512.0)
|
2024-07-17 21:38:35 -04:00
|
|
|
|
|
|
|
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
|
|
|
# the same operations as in the corresponding fp8 kernel to prevent
|
|
|
|
# rounding errors.
|
|
|
|
|
|
|
|
# Compute scales
|
|
|
|
x_token_max, _ = x.abs().max(dim=-1)
|
|
|
|
x_token_max = as_float32_tensor(x_token_max)
|
2024-07-19 21:15:26 -04:00
|
|
|
if scale_ub is not None:
|
|
|
|
x_token_max = x_token_max.clamp(max=scale_ub)
|
2024-07-17 21:38:35 -04:00
|
|
|
scales = (x_token_max / qtype_max)[:, None]
|
|
|
|
|
|
|
|
# Quant
|
2024-07-19 21:15:26 -04:00
|
|
|
if quant_dtype == torch.int8:
|
|
|
|
iscales = as_float32_tensor(s_1 / scales)
|
|
|
|
torch_out = as_float32_tensor(x) * iscales
|
|
|
|
torch_out = torch_out.round()
|
2024-08-16 12:06:30 -05:00
|
|
|
torch_out = torch_out.clamp(qtype_traits_min,
|
|
|
|
qtype_traits_max).to(quant_dtype)
|
2024-07-19 21:15:26 -04:00
|
|
|
else:
|
2024-08-16 12:06:30 -05:00
|
|
|
assert quant_dtype == FP8_DTYPE
|
2024-07-19 21:15:26 -04:00
|
|
|
min_scaling_factor = s_1 / (qtype_max * s_512)
|
|
|
|
scales = scales.clamp(min=min_scaling_factor)
|
|
|
|
torch_out = as_float32_tensor(x) / scales
|
2024-08-16 12:06:30 -05:00
|
|
|
torch_out = torch_out.clamp(qtype_traits_min,
|
|
|
|
qtype_traits_max).to(quant_dtype)
|
2024-07-17 21:38:35 -04:00
|
|
|
|
|
|
|
return torch_out, scales
|
|
|
|
|
|
|
|
|
|
|
|
# The int8 version is very similar. Incorporate the int8 version, like in
|
|
|
|
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
|
|
|
|
# kernel
|
|
|
|
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
2025-03-03 01:34:51 +00:00
|
|
|
-> tuple[torch.tensor, torch.tensor]:
|
2024-07-17 21:38:35 -04:00
|
|
|
|
2024-08-16 12:06:30 -05:00
|
|
|
fp8_traits = torch.finfo(FP8_DTYPE)
|
2024-10-28 12:07:00 +08:00
|
|
|
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
|
|
|
|
else fp8_traits.max
|
|
|
|
fp8_traits_min = -ROCM_FP8_MAX if current_platform.is_rocm() \
|
|
|
|
else fp8_traits.min
|
2024-08-16 12:06:30 -05:00
|
|
|
fp8_max = as_float32_tensor(fp8_traits_max)
|
2024-07-17 21:38:35 -04:00
|
|
|
one = as_float32_tensor(1.0)
|
|
|
|
|
|
|
|
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
|
|
|
# the same operations as in the corresponding fp8 kernel to prevent
|
|
|
|
# rounding errors.
|
|
|
|
|
|
|
|
x_max = as_float32_tensor(x.abs().max())
|
|
|
|
ref_scale = x_max / fp8_max
|
|
|
|
ref_iscale = one / ref_scale
|
|
|
|
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
|
2024-08-16 12:06:30 -05:00
|
|
|
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
|
2024-08-15 21:24:04 -07:00
|
|
|
return ref_out, ref_scale.view((1, ))
|