[ROCm] [Feature] [Doc] [Dockerfile] [BugFix] Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing (#12501)
This commit is contained in:
parent
0630d4537a
commit
eaa92d4437
@ -6,7 +6,7 @@ ARG RCCL_BRANCH="648a58d"
|
||||
ARG RCCL_REPO="https://github.com/ROCm/rccl"
|
||||
ARG TRITON_BRANCH="e5be006"
|
||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||
ARG PYTORCH_BRANCH="8d4926e"
|
||||
ARG PYTORCH_BRANCH="3a585126"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Installation
|
||||
|
||||
vLLM supports AMD GPUs with ROCm 6.2.
|
||||
vLLM supports AMD GPUs with ROCm 6.3.
|
||||
|
||||
:::{attention}
|
||||
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
|
||||
@ -9,7 +9,7 @@ There are no pre-built wheels for this device, so you must either use the pre-bu
|
||||
## Requirements
|
||||
|
||||
- GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
|
||||
- ROCm 6.2
|
||||
- ROCm 6.3
|
||||
|
||||
## Set up using Python
|
||||
|
||||
@ -24,9 +24,15 @@ Currently, there are no pre-built ROCm wheels.
|
||||
- [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html)
|
||||
- [PyTorch](https://pytorch.org/)
|
||||
|
||||
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`.
|
||||
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.3_ubuntu24.04_py3.12_pytorch_release_2.4.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3.
|
||||
|
||||
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/)
|
||||
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/). Example:
|
||||
|
||||
```console
|
||||
# Install PyTorch
|
||||
$ pip uninstall torch -y
|
||||
$ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.3
|
||||
```
|
||||
|
||||
1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton)
|
||||
|
||||
@ -37,7 +43,7 @@ Currently, there are no pre-built ROCm wheels.
|
||||
pip uninstall -y triton
|
||||
git clone https://github.com/OpenAI/triton.git
|
||||
cd triton
|
||||
git checkout e192dba
|
||||
git checkout e5be006
|
||||
cd python
|
||||
pip3 install .
|
||||
cd ../..
|
||||
@ -49,15 +55,15 @@ Currently, there are no pre-built ROCm wheels.
|
||||
|
||||
2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention/tree/ck_tile)
|
||||
|
||||
Install ROCm's flash attention (v2.5.9.post1) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support)
|
||||
Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support)
|
||||
Alternatively, wheels intended for vLLM use can be accessed under the releases.
|
||||
|
||||
For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`.
|
||||
For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`.
|
||||
|
||||
```console
|
||||
git clone https://github.com/ROCm/flash-attention.git
|
||||
cd flash-attention
|
||||
git checkout 3cea2fb
|
||||
git checkout b7d29fb
|
||||
git submodule update --init
|
||||
GPU_ARCHS="gfx90a" python3 setup.py install
|
||||
cd ..
|
||||
@ -67,20 +73,16 @@ Currently, there are no pre-built ROCm wheels.
|
||||
You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||
:::
|
||||
|
||||
3. Build vLLM. For example, vLLM on ROCM 6.2 can be built with the following steps:
|
||||
3. Build vLLM. For example, vLLM on ROCM 6.3 can be built with the following steps:
|
||||
|
||||
```bash
|
||||
$ pip install --upgrade pip
|
||||
|
||||
# Install PyTorch
|
||||
$ pip uninstall torch -y
|
||||
$ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.2
|
||||
|
||||
# Build & install AMD SMI
|
||||
$ pip install /opt/rocm/share/amd_smi
|
||||
|
||||
# Install dependencies
|
||||
$ pip install --upgrade numba scipy huggingface-hub[cli]
|
||||
$ pip install --upgrade numba scipy huggingface-hub[cli,hf_transfer] setuptools_scm
|
||||
$ pip install "numpy<2"
|
||||
$ pip install -r requirements-rocm.txt
|
||||
|
||||
@ -104,7 +106,7 @@ Currently, there are no pre-built ROCm wheels.
|
||||
For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization).
|
||||
:::
|
||||
|
||||
## Set up using Docker
|
||||
## Set up using Docker (Recommended)
|
||||
|
||||
### Pre-built images
|
||||
|
||||
@ -120,7 +122,12 @@ for instructions on how to use this prebuilt docker image.
|
||||
|
||||
Building the Docker image from source is the recommended way to use vLLM with ROCm.
|
||||
|
||||
First, build a docker image from <gh-file:Dockerfile.rocm> and launch a docker container from the image.
|
||||
#### (Optional) Build an image with ROCm software stack
|
||||
|
||||
Build a docker image from <gh-file:Dockerfile.rocm_base> which setup ROCm software stack needed by the vLLM.
|
||||
**This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.**
|
||||
If you choose to build this rocm_base image yourself, the steps are as follows.
|
||||
|
||||
It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
|
||||
|
||||
```console
|
||||
@ -131,7 +138,26 @@ It is important that the user kicks off the docker build using buildkit. Either
|
||||
}
|
||||
```
|
||||
|
||||
<gh-file:Dockerfile.rocm> uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches.
|
||||
To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default:
|
||||
|
||||
```console
|
||||
DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm_base -t rocm/vllm-dev:base .
|
||||
```
|
||||
|
||||
#### Build an image with vLLM
|
||||
|
||||
First, build a docker image from <gh-file:Dockerfile.rocm> and launch a docker container from the image.
|
||||
It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
|
||||
|
||||
```console
|
||||
{
|
||||
"features": {
|
||||
"buildkit": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
<gh-file:Dockerfile.rocm> uses ROCm 6.3 by default, but also supports ROCm 5.7, 6.0, 6.1, and 6.2, in older vLLM branches.
|
||||
It provides flexibility to customize the build of docker image using the following arguments:
|
||||
|
||||
- `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using <gh-file:Dockerfile.rocm_base>
|
||||
@ -141,13 +167,13 @@ It provides flexibility to customize the build of docker image using the followi
|
||||
|
||||
Their values can be passed in when running `docker build` with `--build-arg` options.
|
||||
|
||||
To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default:
|
||||
To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default:
|
||||
|
||||
```console
|
||||
DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||
```
|
||||
|
||||
To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should pick the alternative base image:
|
||||
To build vllm on ROCm 6.3 for Radeon RX7900 series (gfx1100), you should pick the alternative base image:
|
||||
|
||||
```console
|
||||
DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" -f Dockerfile.rocm -t vllm-rocm .
|
||||
|
@ -55,10 +55,21 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
|
||||
|
||||
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
||||
|
||||
# NOTE: it is valid for scales to be 1.0 (default value), but
|
||||
# we know these checkpoints have scales < 1.0
|
||||
assert 0.0 < attn._k_scale < 1.0
|
||||
assert 0.0 < attn._v_scale < 1.0
|
||||
if not current_platform.is_rocm():
|
||||
# NOTE: This code path requires validation on Non-CUDA platform
|
||||
# NOTE: it is valid for scales to be 1.0 (default value), but
|
||||
# we know these checkpoints have scales < 1.0
|
||||
assert 0.0 < attn._k_scale < 1.0
|
||||
assert 0.0 < attn._v_scale < 1.0
|
||||
else:
|
||||
# NOTE: This code path is for ROCm platform
|
||||
# NOTE: it is valid for scales to be 1.0 (default value), but
|
||||
# we know these checkpoints have scales < 1.0
|
||||
# However on ROCm platform, the _k_scale and _v_scale will be
|
||||
# scaled by a factor of 2 as described in
|
||||
# vllm/model_executor/layers/quantization/kv_cache.py
|
||||
assert 0.0 < attn._k_scale < (1.0 * 2.0)
|
||||
assert 0.0 < attn._v_scale < (1.0 * 2.0)
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
@ -91,13 +102,29 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
||||
assert attn._k_scale == 1.0
|
||||
assert attn._v_scale == 1.0
|
||||
|
||||
if current_platform.has_device_capability(89) and not force_marlin:
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||
else:
|
||||
# For GPUs without hardware support, we pack the fp8 weights
|
||||
# for weight-only quantization using Marlin kernels
|
||||
assert fc1.weight.dtype == torch.int32
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.has_device_capability(
|
||||
89) and not force_marlin:
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||
else:
|
||||
# For GPUs without hardware support, we pack the fp8 weights
|
||||
# for weight-only quantization using Marlin kernels
|
||||
assert fc1.weight.dtype == torch.int32
|
||||
elif current_platform.is_rocm():
|
||||
# Only MI300 and above support quantization='fp8'
|
||||
if current_platform.has_device_capability(
|
||||
94) and not force_marlin:
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fnuz
|
||||
else: # unsupported ROCm platform
|
||||
pytest.skip(
|
||||
"Skip `test_load_fp16_model`. "
|
||||
"It only runs on ROCm platform with FP8 compute."
|
||||
" e.g. MI300X and above.")
|
||||
else: # unsupported platform
|
||||
pytest.skip("Skip `test_load_fp16_model`. "
|
||||
"It only runs on CUDA and ROCm platform.")
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
|
55
tests/quantization/test_ptpc_fp8.py
Normal file
55
tests/quantization/test_ptpc_fp8.py
Normal file
@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests whether PTPC w8a8 FP8 computation is enabled correctly.
|
||||
|
||||
Run `pytest tests/quantization/test_ptpc_fp8.py --forked`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.ptpc_fp8 import (
|
||||
PTPCFp8LinearMethod)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
|
||||
reason="PTPC FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="This test is for ROCm GPU.")
|
||||
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
|
||||
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
|
||||
|
||||
try:
|
||||
with vllm_runner("facebook/opt-125m",
|
||||
dtype=dtype,
|
||||
quantization="ptpc_fp8",
|
||||
kv_cache_dtype=kv_cache_dtype) as llm:
|
||||
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
fc1 = model.model.decoder.layers[0].fc1
|
||||
assert isinstance(fc1.quant_method, PTPCFp8LinearMethod)
|
||||
if kv_cache_dtype == "ptpc_fp8":
|
||||
attn = model.model.decoder.layers[0].self_attn.attn
|
||||
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
||||
assert attn._k_scale == 1.0
|
||||
assert attn._v_scale == 1.0
|
||||
|
||||
if current_platform.has_device_capability(94):
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fnuz
|
||||
else:
|
||||
pytest.skip()
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
except AssertionError as e:
|
||||
if str(
|
||||
e
|
||||
) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501
|
||||
# If the error message matches, the test passes
|
||||
pass
|
||||
else:
|
||||
# If the error message does not match, re-raise the exception
|
||||
raise
|
@ -11,6 +11,7 @@ QUANTIZATION_METHODS: List[str] = [
|
||||
"deepspeedfp",
|
||||
"tpu_int8",
|
||||
"fp8",
|
||||
"ptpc_fp8",
|
||||
"fbgemm_fp8",
|
||||
"modelopt",
|
||||
# The order of gptq methods is important for config.py iteration over
|
||||
@ -99,6 +100,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
from .modelopt import ModelOptFp8Config
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .neuron_quant import NeuronQuantConfig
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .qqq import QQQConfig
|
||||
from .tpu_int8 import Int8TpuConfig
|
||||
|
||||
@ -120,6 +122,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
"gptq": GPTQConfig,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"ptpc_fp8": PTPCFp8Config,
|
||||
"qqq": QQQConfig,
|
||||
"hqq": HQQMarlinConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
|
125
vllm/model_executor/layers/quantization/ptpc_fp8.py
Normal file
125
vllm/model_executor/layers/quantization/ptpc_fp8.py
Normal file
@ -0,0 +1,125 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
||||
Fp8KVCacheMethod,
|
||||
Fp8LinearMethod)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PTPCFp8Config(Fp8Config):
|
||||
"""Config class for Per-Token-Per-Channel Dynamic Quantization Fp8."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
if not current_platform.is_rocm():
|
||||
raise ValueError(
|
||||
"ptpc_fp8 quantization is supported only on ROCm.")
|
||||
|
||||
if not current_platform.has_device_capability(94):
|
||||
raise ValueError(
|
||||
"ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501
|
||||
)
|
||||
if activation_scheme == "static":
|
||||
raise ValueError(
|
||||
"ptpc_fp8 as of now only support dynamic quantization.")
|
||||
|
||||
super().__init__(is_checkpoint_fp8_serialized=False,
|
||||
activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "ptpc_fp8"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config":
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
return cls(activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
return UnquantizedLinearMethod()
|
||||
return PTPCFp8LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
"""Linear method for Per-Token and Per-Channel FP8 Quantization.
|
||||
Only supports loading quantized BF16 model checkpoints with dynamic
|
||||
activation scaling. To load FP16 model checkpoints, user must specify
|
||||
to convert the FP16 model weight loading into BF16.
|
||||
The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
|
||||
Limitations:
|
||||
1. Only support float8_e4m3fnuz data type due to the limitation of
|
||||
torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041)
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PTPCFp8Config):
|
||||
super().__init__(quant_config=quant_config)
|
||||
# Force weight quantization
|
||||
self.quant_config.is_checkpoint_fp8_serialized = False
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
assert layer.weight.data.dtype == torch.bfloat16, \
|
||||
f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501
|
||||
# Quantize the weights.
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(
|
||||
layer.weight, scale=None, use_per_token_if_dynamic=True)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(
|
||||
qweight.t(), requires_grad=False) # Pretranspose the weight
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.input_scale = None
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return apply_fp8_linear(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None,
|
||||
input_scale_ub=None,
|
||||
bias=bias,
|
||||
cutlass_fp8_supported=False,
|
||||
use_per_token_if_dynamic=True)
|
@ -11,6 +11,13 @@ from vllm.platforms import current_platform
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
# The condition to determine if it is on a platform that supports
|
||||
# torch._scaled_mm rowwise feature.
|
||||
# The condition is determined once as the operations
|
||||
# are time consuming.
|
||||
USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm()
|
||||
and current_platform.has_device_capability(94))
|
||||
|
||||
|
||||
def sparse_cutlass_supported() -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
@ -172,6 +179,26 @@ def apply_fp8_linear(
|
||||
return torch.narrow(output, 0, 0,
|
||||
input_2d.shape[0]).view(*output_shape)
|
||||
|
||||
elif (use_per_token_if_dynamic and not per_tensor_weights
|
||||
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
|
||||
# For now validated on ROCm platform
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using
|
||||
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
# For CUDA platform please validate if the
|
||||
# torch._scaled_mm support rowwise scaled GEMM
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.t(),
|
||||
bias=bias)
|
||||
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
output = output.view(*output_shape)
|
||||
return output
|
||||
|
||||
else:
|
||||
# Fallback for channelwise case, where we use unfused DQ
|
||||
# due to limitations with scaled_mm
|
||||
|
@ -72,7 +72,7 @@ class RocmPlatform(Platform):
|
||||
|
||||
supported_quantization: list[str] = [
|
||||
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
||||
"fbgemm_fp8", "gguf", "quark"
|
||||
"fbgemm_fp8", "gguf", "quark", "ptpc_fp8"
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user