[ROCm] Apply FP8 weights padding to values not divisible by 512 bytes on ROCm (#13231)

This commit is contained in:
Gregory Shtrasberg 2025-02-22 08:54:38 -05:00 committed by GitHub
parent 558db8083c
commit c904fdddf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 1 deletions

View File

@ -74,6 +74,7 @@ if TYPE_CHECKING:
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
@ -507,6 +508,9 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
"K_SCALE_CONSTANT":
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),

View File

@ -3,6 +3,7 @@
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
@ -251,6 +252,17 @@ class Fp8LinearMethod(LinearMethodBase):
else:
layer.register_parameter("input_scale", None)
def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
and weight.stride(-1) == 1
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
return weight
def process_weights_after_loading(self, layer: Module) -> None:
# TODO(rob): refactor block quant into separate class.
if self.block_quant:
@ -264,6 +276,8 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data
weight = self.add_padding_to_weight(weight)
# Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale_inv,
@ -327,6 +341,7 @@ class Fp8LinearMethod(LinearMethodBase):
logical_widths=layer.logical_widths,
)
weight = self.add_padding_to_weight(weight)
# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)

View File

@ -494,7 +494,7 @@ def w8a8_block_fp8_matmul(
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert B.ndim == 2 and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]