[AMD][FP8] Using MI300 FP8 format on ROCm for block_quant (#12134)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
54cacf008f
commit
b5b57e301e
@ -247,6 +247,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
# Block quant doesn't need to process weights after loading
|
# Block quant doesn't need to process weights after loading
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
weight, weight_scale, _ = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=layer.weight,
|
||||||
|
weight_scale=layer.weight_scale_inv,
|
||||||
|
input_scale=layer.input_scale)
|
||||||
|
layer.weight = Parameter(weight, requires_grad=False)
|
||||||
|
layer.weight_scale_inv = Parameter(weight_scale,
|
||||||
|
requires_grad=False)
|
||||||
return
|
return
|
||||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -495,6 +504,30 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
# Block quant doesn't need to process weights after loading
|
# Block quant doesn't need to process weights after loading
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
w13_weight, w13_weight_scale_inv, w13_input_scale = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
layer.w13_weight, layer.w13_weight_scale_inv,
|
||||||
|
layer.w13_input_scale)
|
||||||
|
w2_weight, w2_weight_scale_inv, w2_input_scale = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
layer.w2_weight, layer.w2_weight_scale_inv,
|
||||||
|
layer.w2_input_scale)
|
||||||
|
# Reset the parameter
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
||||||
|
w13_weight_scale_inv, requires_grad=False)
|
||||||
|
if w13_input_scale is not None:
|
||||||
|
layer.w13_input_scale = torch.nn.Parameter(
|
||||||
|
w13_input_scale, requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
||||||
|
w2_weight_scale_inv, requires_grad=False)
|
||||||
|
if w2_input_scale is not None:
|
||||||
|
layer.w2_input_scale = torch.nn.Parameter(
|
||||||
|
w2_input_scale, requires_grad=False)
|
||||||
return
|
return
|
||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16, quantize in place.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
@ -5,6 +5,8 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
def apply_w8a8_block_fp8_linear(
|
def apply_w8a8_block_fp8_linear(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
@ -33,11 +35,14 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
|
|
||||||
|
|
||||||
def input_to_float8(
|
def input_to_float8(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
dtype: torch.dtype = torch.float8_e4m3fn
|
dtype: Optional[torch.dtype] = None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""This function quantizes input values to float8 values "
|
"""This function quantizes input values to float8 values "
|
||||||
"with tensor-wise quantization."""
|
"with tensor-wise quantization."""
|
||||||
|
if dtype is None:
|
||||||
|
dtype = (torch.float8_e4m3fnuz
|
||||||
|
if current_platform.is_rocm() else torch.float8_e4m3fn)
|
||||||
finfo = torch.finfo(dtype)
|
finfo = torch.finfo(dtype)
|
||||||
min_val, max_val = x.aminmax()
|
min_val, max_val = x.aminmax()
|
||||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
@ -125,7 +130,7 @@ def per_token_group_quant_fp8(
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
eps: float = 1e-10,
|
eps: float = 1e-10,
|
||||||
dtype: torch.dtype = torch.float8_e4m3fn,
|
dtype: Optional[torch.dtype] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||||
It converts the tensor values into signed float8 values and returns the
|
It converts the tensor values into signed float8 values and returns the
|
||||||
@ -140,6 +145,9 @@ def per_token_group_quant_fp8(
|
|||||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||||
scaling factor for quantization.
|
scaling factor for quantization.
|
||||||
"""
|
"""
|
||||||
|
if dtype is None:
|
||||||
|
dtype = (torch.float8_e4m3fnuz
|
||||||
|
if current_platform.is_rocm() else torch.float8_e4m3fn)
|
||||||
assert (x.shape[-1] % group_size == 0), (
|
assert (x.shape[-1] % group_size == 0), (
|
||||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||||
f"by `group_size` {group_size}")
|
f"by `group_size` {group_size}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user