[ROCm] [Feature] [Doc] [Dockerfile] [BugFix] Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing (#12501)
This commit is contained in:
parent
0630d4537a
commit
eaa92d4437
@ -6,7 +6,7 @@ ARG RCCL_BRANCH="648a58d"
|
|||||||
ARG RCCL_REPO="https://github.com/ROCm/rccl"
|
ARG RCCL_REPO="https://github.com/ROCm/rccl"
|
||||||
ARG TRITON_BRANCH="e5be006"
|
ARG TRITON_BRANCH="e5be006"
|
||||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||||
ARG PYTORCH_BRANCH="8d4926e"
|
ARG PYTORCH_BRANCH="3a585126"
|
||||||
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||||
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Installation
|
# Installation
|
||||||
|
|
||||||
vLLM supports AMD GPUs with ROCm 6.2.
|
vLLM supports AMD GPUs with ROCm 6.3.
|
||||||
|
|
||||||
:::{attention}
|
:::{attention}
|
||||||
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
|
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
|
||||||
@ -9,7 +9,7 @@ There are no pre-built wheels for this device, so you must either use the pre-bu
|
|||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
|
- GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
|
||||||
- ROCm 6.2
|
- ROCm 6.3
|
||||||
|
|
||||||
## Set up using Python
|
## Set up using Python
|
||||||
|
|
||||||
@ -24,9 +24,15 @@ Currently, there are no pre-built ROCm wheels.
|
|||||||
- [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html)
|
- [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html)
|
||||||
- [PyTorch](https://pytorch.org/)
|
- [PyTorch](https://pytorch.org/)
|
||||||
|
|
||||||
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`.
|
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.3_ubuntu24.04_py3.12_pytorch_release_2.4.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3.
|
||||||
|
|
||||||
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/)
|
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/). Example:
|
||||||
|
|
||||||
|
```console
|
||||||
|
# Install PyTorch
|
||||||
|
$ pip uninstall torch -y
|
||||||
|
$ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.3
|
||||||
|
```
|
||||||
|
|
||||||
1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton)
|
1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton)
|
||||||
|
|
||||||
@ -37,7 +43,7 @@ Currently, there are no pre-built ROCm wheels.
|
|||||||
pip uninstall -y triton
|
pip uninstall -y triton
|
||||||
git clone https://github.com/OpenAI/triton.git
|
git clone https://github.com/OpenAI/triton.git
|
||||||
cd triton
|
cd triton
|
||||||
git checkout e192dba
|
git checkout e5be006
|
||||||
cd python
|
cd python
|
||||||
pip3 install .
|
pip3 install .
|
||||||
cd ../..
|
cd ../..
|
||||||
@ -49,15 +55,15 @@ Currently, there are no pre-built ROCm wheels.
|
|||||||
|
|
||||||
2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention/tree/ck_tile)
|
2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention/tree/ck_tile)
|
||||||
|
|
||||||
Install ROCm's flash attention (v2.5.9.post1) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support)
|
Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support)
|
||||||
Alternatively, wheels intended for vLLM use can be accessed under the releases.
|
Alternatively, wheels intended for vLLM use can be accessed under the releases.
|
||||||
|
|
||||||
For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`.
|
For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`.
|
||||||
|
|
||||||
```console
|
```console
|
||||||
git clone https://github.com/ROCm/flash-attention.git
|
git clone https://github.com/ROCm/flash-attention.git
|
||||||
cd flash-attention
|
cd flash-attention
|
||||||
git checkout 3cea2fb
|
git checkout b7d29fb
|
||||||
git submodule update --init
|
git submodule update --init
|
||||||
GPU_ARCHS="gfx90a" python3 setup.py install
|
GPU_ARCHS="gfx90a" python3 setup.py install
|
||||||
cd ..
|
cd ..
|
||||||
@ -67,20 +73,16 @@ Currently, there are no pre-built ROCm wheels.
|
|||||||
You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||||
:::
|
:::
|
||||||
|
|
||||||
3. Build vLLM. For example, vLLM on ROCM 6.2 can be built with the following steps:
|
3. Build vLLM. For example, vLLM on ROCM 6.3 can be built with the following steps:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ pip install --upgrade pip
|
$ pip install --upgrade pip
|
||||||
|
|
||||||
# Install PyTorch
|
|
||||||
$ pip uninstall torch -y
|
|
||||||
$ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.2
|
|
||||||
|
|
||||||
# Build & install AMD SMI
|
# Build & install AMD SMI
|
||||||
$ pip install /opt/rocm/share/amd_smi
|
$ pip install /opt/rocm/share/amd_smi
|
||||||
|
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
$ pip install --upgrade numba scipy huggingface-hub[cli]
|
$ pip install --upgrade numba scipy huggingface-hub[cli,hf_transfer] setuptools_scm
|
||||||
$ pip install "numpy<2"
|
$ pip install "numpy<2"
|
||||||
$ pip install -r requirements-rocm.txt
|
$ pip install -r requirements-rocm.txt
|
||||||
|
|
||||||
@ -104,7 +106,7 @@ Currently, there are no pre-built ROCm wheels.
|
|||||||
For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization).
|
For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization).
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Set up using Docker
|
## Set up using Docker (Recommended)
|
||||||
|
|
||||||
### Pre-built images
|
### Pre-built images
|
||||||
|
|
||||||
@ -120,7 +122,12 @@ for instructions on how to use this prebuilt docker image.
|
|||||||
|
|
||||||
Building the Docker image from source is the recommended way to use vLLM with ROCm.
|
Building the Docker image from source is the recommended way to use vLLM with ROCm.
|
||||||
|
|
||||||
First, build a docker image from <gh-file:Dockerfile.rocm> and launch a docker container from the image.
|
#### (Optional) Build an image with ROCm software stack
|
||||||
|
|
||||||
|
Build a docker image from <gh-file:Dockerfile.rocm_base> which setup ROCm software stack needed by the vLLM.
|
||||||
|
**This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.**
|
||||||
|
If you choose to build this rocm_base image yourself, the steps are as follows.
|
||||||
|
|
||||||
It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
|
It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
|
||||||
|
|
||||||
```console
|
```console
|
||||||
@ -131,7 +138,26 @@ It is important that the user kicks off the docker build using buildkit. Either
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
<gh-file:Dockerfile.rocm> uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches.
|
To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default:
|
||||||
|
|
||||||
|
```console
|
||||||
|
DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm_base -t rocm/vllm-dev:base .
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Build an image with vLLM
|
||||||
|
|
||||||
|
First, build a docker image from <gh-file:Dockerfile.rocm> and launch a docker container from the image.
|
||||||
|
It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
|
||||||
|
|
||||||
|
```console
|
||||||
|
{
|
||||||
|
"features": {
|
||||||
|
"buildkit": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
<gh-file:Dockerfile.rocm> uses ROCm 6.3 by default, but also supports ROCm 5.7, 6.0, 6.1, and 6.2, in older vLLM branches.
|
||||||
It provides flexibility to customize the build of docker image using the following arguments:
|
It provides flexibility to customize the build of docker image using the following arguments:
|
||||||
|
|
||||||
- `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using <gh-file:Dockerfile.rocm_base>
|
- `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using <gh-file:Dockerfile.rocm_base>
|
||||||
@ -141,13 +167,13 @@ It provides flexibility to customize the build of docker image using the followi
|
|||||||
|
|
||||||
Their values can be passed in when running `docker build` with `--build-arg` options.
|
Their values can be passed in when running `docker build` with `--build-arg` options.
|
||||||
|
|
||||||
To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default:
|
To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default:
|
||||||
|
|
||||||
```console
|
```console
|
||||||
DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm .
|
DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||||
```
|
```
|
||||||
|
|
||||||
To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should pick the alternative base image:
|
To build vllm on ROCm 6.3 for Radeon RX7900 series (gfx1100), you should pick the alternative base image:
|
||||||
|
|
||||||
```console
|
```console
|
||||||
DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" -f Dockerfile.rocm -t vllm-rocm .
|
DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" -f Dockerfile.rocm -t vllm-rocm .
|
||||||
|
@ -55,10 +55,21 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
|
|||||||
|
|
||||||
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
||||||
|
|
||||||
|
if not current_platform.is_rocm():
|
||||||
|
# NOTE: This code path requires validation on Non-CUDA platform
|
||||||
# NOTE: it is valid for scales to be 1.0 (default value), but
|
# NOTE: it is valid for scales to be 1.0 (default value), but
|
||||||
# we know these checkpoints have scales < 1.0
|
# we know these checkpoints have scales < 1.0
|
||||||
assert 0.0 < attn._k_scale < 1.0
|
assert 0.0 < attn._k_scale < 1.0
|
||||||
assert 0.0 < attn._v_scale < 1.0
|
assert 0.0 < attn._v_scale < 1.0
|
||||||
|
else:
|
||||||
|
# NOTE: This code path is for ROCm platform
|
||||||
|
# NOTE: it is valid for scales to be 1.0 (default value), but
|
||||||
|
# we know these checkpoints have scales < 1.0
|
||||||
|
# However on ROCm platform, the _k_scale and _v_scale will be
|
||||||
|
# scaled by a factor of 2 as described in
|
||||||
|
# vllm/model_executor/layers/quantization/kv_cache.py
|
||||||
|
assert 0.0 < attn._k_scale < (1.0 * 2.0)
|
||||||
|
assert 0.0 < attn._v_scale < (1.0 * 2.0)
|
||||||
|
|
||||||
llm.apply_model(check_model)
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
@ -91,13 +102,29 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
|||||||
assert attn._k_scale == 1.0
|
assert attn._k_scale == 1.0
|
||||||
assert attn._v_scale == 1.0
|
assert attn._v_scale == 1.0
|
||||||
|
|
||||||
if current_platform.has_device_capability(89) and not force_marlin:
|
if current_platform.is_cuda():
|
||||||
|
if current_platform.has_device_capability(
|
||||||
|
89) and not force_marlin:
|
||||||
# For GPUs with hardware support, we keep weights in fp8
|
# For GPUs with hardware support, we keep weights in fp8
|
||||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||||
else:
|
else:
|
||||||
# For GPUs without hardware support, we pack the fp8 weights
|
# For GPUs without hardware support, we pack the fp8 weights
|
||||||
# for weight-only quantization using Marlin kernels
|
# for weight-only quantization using Marlin kernels
|
||||||
assert fc1.weight.dtype == torch.int32
|
assert fc1.weight.dtype == torch.int32
|
||||||
|
elif current_platform.is_rocm():
|
||||||
|
# Only MI300 and above support quantization='fp8'
|
||||||
|
if current_platform.has_device_capability(
|
||||||
|
94) and not force_marlin:
|
||||||
|
# For GPUs with hardware support, we keep weights in fp8
|
||||||
|
assert fc1.weight.dtype == torch.float8_e4m3fnuz
|
||||||
|
else: # unsupported ROCm platform
|
||||||
|
pytest.skip(
|
||||||
|
"Skip `test_load_fp16_model`. "
|
||||||
|
"It only runs on ROCm platform with FP8 compute."
|
||||||
|
" e.g. MI300X and above.")
|
||||||
|
else: # unsupported platform
|
||||||
|
pytest.skip("Skip `test_load_fp16_model`. "
|
||||||
|
"It only runs on CUDA and ROCm platform.")
|
||||||
|
|
||||||
llm.apply_model(check_model)
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
|
55
tests/quantization/test_ptpc_fp8.py
Normal file
55
tests/quantization/test_ptpc_fp8.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Tests whether PTPC w8a8 FP8 computation is enabled correctly.
|
||||||
|
|
||||||
|
Run `pytest tests/quantization/test_ptpc_fp8.py --forked`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
|
||||||
|
from vllm.model_executor.layers.quantization.ptpc_fp8 import (
|
||||||
|
PTPCFp8LinearMethod)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
|
||||||
|
reason="PTPC FP8 is not supported on this GPU type.")
|
||||||
|
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||||
|
reason="This test is for ROCm GPU.")
|
||||||
|
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
|
||||||
|
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
|
||||||
|
|
||||||
|
try:
|
||||||
|
with vllm_runner("facebook/opt-125m",
|
||||||
|
dtype=dtype,
|
||||||
|
quantization="ptpc_fp8",
|
||||||
|
kv_cache_dtype=kv_cache_dtype) as llm:
|
||||||
|
|
||||||
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
|
fc1 = model.model.decoder.layers[0].fc1
|
||||||
|
assert isinstance(fc1.quant_method, PTPCFp8LinearMethod)
|
||||||
|
if kv_cache_dtype == "ptpc_fp8":
|
||||||
|
attn = model.model.decoder.layers[0].self_attn.attn
|
||||||
|
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
||||||
|
assert attn._k_scale == 1.0
|
||||||
|
assert attn._v_scale == 1.0
|
||||||
|
|
||||||
|
if current_platform.has_device_capability(94):
|
||||||
|
# For GPUs with hardware support, we keep weights in fp8
|
||||||
|
assert fc1.weight.dtype == torch.float8_e4m3fnuz
|
||||||
|
else:
|
||||||
|
pytest.skip()
|
||||||
|
|
||||||
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
|
assert output
|
||||||
|
except AssertionError as e:
|
||||||
|
if str(
|
||||||
|
e
|
||||||
|
) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501
|
||||||
|
# If the error message matches, the test passes
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# If the error message does not match, re-raise the exception
|
||||||
|
raise
|
@ -11,6 +11,7 @@ QUANTIZATION_METHODS: List[str] = [
|
|||||||
"deepspeedfp",
|
"deepspeedfp",
|
||||||
"tpu_int8",
|
"tpu_int8",
|
||||||
"fp8",
|
"fp8",
|
||||||
|
"ptpc_fp8",
|
||||||
"fbgemm_fp8",
|
"fbgemm_fp8",
|
||||||
"modelopt",
|
"modelopt",
|
||||||
# The order of gptq methods is important for config.py iteration over
|
# The order of gptq methods is important for config.py iteration over
|
||||||
@ -99,6 +100,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
from .modelopt import ModelOptFp8Config
|
from .modelopt import ModelOptFp8Config
|
||||||
from .moe_wna16 import MoeWNA16Config
|
from .moe_wna16 import MoeWNA16Config
|
||||||
from .neuron_quant import NeuronQuantConfig
|
from .neuron_quant import NeuronQuantConfig
|
||||||
|
from .ptpc_fp8 import PTPCFp8Config
|
||||||
from .qqq import QQQConfig
|
from .qqq import QQQConfig
|
||||||
from .tpu_int8 import Int8TpuConfig
|
from .tpu_int8 import Int8TpuConfig
|
||||||
|
|
||||||
@ -120,6 +122,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
"gptq": GPTQConfig,
|
"gptq": GPTQConfig,
|
||||||
"compressed-tensors": CompressedTensorsConfig,
|
"compressed-tensors": CompressedTensorsConfig,
|
||||||
"bitsandbytes": BitsAndBytesConfig,
|
"bitsandbytes": BitsAndBytesConfig,
|
||||||
|
"ptpc_fp8": PTPCFp8Config,
|
||||||
"qqq": QQQConfig,
|
"qqq": QQQConfig,
|
||||||
"hqq": HQQMarlinConfig,
|
"hqq": HQQMarlinConfig,
|
||||||
"experts_int8": ExpertsInt8Config,
|
"experts_int8": ExpertsInt8Config,
|
||||||
|
125
vllm/model_executor/layers/quantization/ptpc_fp8.py
Normal file
125
vllm/model_executor/layers/quantization/ptpc_fp8.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizeMethodBase)
|
||||||
|
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
||||||
|
Fp8KVCacheMethod,
|
||||||
|
Fp8LinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
is_layer_skipped)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
apply_fp8_linear)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PTPCFp8Config(Fp8Config):
|
||||||
|
"""Config class for Per-Token-Per-Channel Dynamic Quantization Fp8."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation_scheme: str = "dynamic",
|
||||||
|
ignored_layers: Optional[List[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
if not current_platform.is_rocm():
|
||||||
|
raise ValueError(
|
||||||
|
"ptpc_fp8 quantization is supported only on ROCm.")
|
||||||
|
|
||||||
|
if not current_platform.has_device_capability(94):
|
||||||
|
raise ValueError(
|
||||||
|
"ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501
|
||||||
|
)
|
||||||
|
if activation_scheme == "static":
|
||||||
|
raise ValueError(
|
||||||
|
"ptpc_fp8 as of now only support dynamic quantization.")
|
||||||
|
|
||||||
|
super().__init__(is_checkpoint_fp8_serialized=False,
|
||||||
|
activation_scheme=activation_scheme,
|
||||||
|
ignored_layers=ignored_layers)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "ptpc_fp8"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config":
|
||||||
|
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||||
|
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||||
|
return cls(activation_scheme=activation_scheme,
|
||||||
|
ignored_layers=ignored_layers)
|
||||||
|
|
||||||
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
|
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
if is_layer_skipped(prefix, self.ignored_layers):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
return PTPCFp8LinearMethod(self)
|
||||||
|
elif isinstance(layer, Attention):
|
||||||
|
return Fp8KVCacheMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||||
|
"""Linear method for Per-Token and Per-Channel FP8 Quantization.
|
||||||
|
Only supports loading quantized BF16 model checkpoints with dynamic
|
||||||
|
activation scaling. To load FP16 model checkpoints, user must specify
|
||||||
|
to convert the FP16 model weight loading into BF16.
|
||||||
|
The weight scaling factor will be initialized after
|
||||||
|
the model weights are loaded.
|
||||||
|
|
||||||
|
Limitations:
|
||||||
|
1. Only support float8_e4m3fnuz data type due to the limitation of
|
||||||
|
torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: PTPCFp8Config):
|
||||||
|
super().__init__(quant_config=quant_config)
|
||||||
|
# Force weight quantization
|
||||||
|
self.quant_config.is_checkpoint_fp8_serialized = False
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
assert layer.weight.data.dtype == torch.bfloat16, \
|
||||||
|
f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501
|
||||||
|
# Quantize the weights.
|
||||||
|
qweight, weight_scale = ops.scaled_fp8_quant(
|
||||||
|
layer.weight, scale=None, use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
|
# Update the layer with the new values.
|
||||||
|
layer.weight = Parameter(
|
||||||
|
qweight.t(), requires_grad=False) # Pretranspose the weight
|
||||||
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
|
layer.input_scale = None
|
||||||
|
|
||||||
|
def apply(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
|
return apply_fp8_linear(input=x,
|
||||||
|
weight=layer.weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
input_scale=None,
|
||||||
|
input_scale_ub=None,
|
||||||
|
bias=bias,
|
||||||
|
cutlass_fp8_supported=False,
|
||||||
|
use_per_token_if_dynamic=True)
|
@ -11,6 +11,13 @@ from vllm.platforms import current_platform
|
|||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
# The condition to determine if it is on a platform that supports
|
||||||
|
# torch._scaled_mm rowwise feature.
|
||||||
|
# The condition is determined once as the operations
|
||||||
|
# are time consuming.
|
||||||
|
USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm()
|
||||||
|
and current_platform.has_device_capability(94))
|
||||||
|
|
||||||
|
|
||||||
def sparse_cutlass_supported() -> bool:
|
def sparse_cutlass_supported() -> bool:
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
@ -172,6 +179,26 @@ def apply_fp8_linear(
|
|||||||
return torch.narrow(output, 0, 0,
|
return torch.narrow(output, 0, 0,
|
||||||
input_2d.shape[0]).view(*output_shape)
|
input_2d.shape[0]).view(*output_shape)
|
||||||
|
|
||||||
|
elif (use_per_token_if_dynamic and not per_tensor_weights
|
||||||
|
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
|
||||||
|
# For now validated on ROCm platform
|
||||||
|
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||||
|
# https://github.com/pytorch/pytorch/pull/144432 using
|
||||||
|
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||||
|
# For CUDA platform please validate if the
|
||||||
|
# torch._scaled_mm support rowwise scaled GEMM
|
||||||
|
# Fused GEMM_DQ Rowwise GEMM
|
||||||
|
output = torch._scaled_mm(qinput,
|
||||||
|
weight,
|
||||||
|
out_dtype=input.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=weight_scale.t(),
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||||
|
output = output.view(*output_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Fallback for channelwise case, where we use unfused DQ
|
# Fallback for channelwise case, where we use unfused DQ
|
||||||
# due to limitations with scaled_mm
|
# due to limitations with scaled_mm
|
||||||
|
@ -72,7 +72,7 @@ class RocmPlatform(Platform):
|
|||||||
|
|
||||||
supported_quantization: list[str] = [
|
supported_quantization: list[str] = [
|
||||||
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
||||||
"fbgemm_fp8", "gguf", "quark"
|
"fbgemm_fp8", "gguf", "quark", "ptpc_fp8"
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
x
Reference in New Issue
Block a user