[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)

The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
This commit is contained in:
Cody Yu 2024-05-22 13:28:20 -07:00 committed by GitHub
parent 8674f9880e
commit a3a73ab069
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 284 additions and 158 deletions

View File

@ -153,15 +153,13 @@ if __name__ == '__main__':
action='store_true', action='store_true',
help='enforce eager mode and disable CUDA graph') help='enforce eager mode and disable CUDA graph')
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", '--kv-cache-dtype',
type=str, type=str,
choices=['auto', 'fp8'], choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default='auto', default="auto",
help= help='Data type for kv cache storage. If "auto", will use model '
'Data type for kv cache storage. If "auto", will use model data type. ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=str, type=str,

View File

@ -323,15 +323,13 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="enforce eager execution") help="enforce eager execution")
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", '--kv-cache-dtype',
type=str, type=str,
choices=["auto", "fp8"], choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto", default="auto",
help= help='Data type for kv cache storage. If "auto", will use model '
'Data type for kv cache storage. If "auto", will use model data type. ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=str, type=str,

View File

@ -183,13 +183,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
choices=["auto", "fp8"], choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
default="auto", default="auto",
help= help="Data type for kv cache storage. If 'auto', will use model "
'Data type for kv cache storage. If "auto", will use model data type. ' "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
'FP8_E5M2 (without scaling) is only supported on cuda version greater ' "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)

View File

@ -16,31 +16,55 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
MODELS = [ MODELS = [
"nm-testing/Meta-Llama-3-8B-Instruct-FP8", "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV",
"meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct",
] ]
EXPECTED_STRS_MAP = { EXPECTED_STRS_MAP = {
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": [ "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": {
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', "auto": [
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', 'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both',
'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
], 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no'
"meta-llama/Meta-Llama-3-8B-Instruct": [ ],
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', "fp8": [
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', 'A neural network is a complex system made up of several basic components that work together to enable it to',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
], 'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk'
]
},
"meta-llama/Meta-Llama-3-8B-Instruct": {
"auto": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
],
"fp8": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu'
]
},
} }
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
@ -52,14 +76,14 @@ fp8_not_supported = (capability <
@pytest.mark.skipif(fp8_not_supported, @pytest.mark.skipif(fp8_not_supported,
reason="fp8 is not supported on this GPU type.") reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS) @pytest.mark.parametrize("model_name", MODELS)
def test_models( @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
example_prompts, def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
model_name,
) -> None:
model = LLM(model=model_name, model = LLM(model=model_name,
max_model_len=MAX_MODEL_LEN, max_model_len=MAX_MODEL_LEN,
trust_remote_code=True,
enforce_eager=True, enforce_eager=True,
quantization="fp8") quantization="fp8",
kv_cache_dtype=kv_cache_dtype)
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
formatted_prompts = [ formatted_prompts = [
@ -81,8 +105,8 @@ def test_models(
generations.append(outputs[0].outputs[0].text) generations.append(outputs[0].outputs[0].text)
del model del model
print(generations) print(model_name, kv_cache_dtype, generations)
expected_strs = EXPECTED_STRS_MAP[model_name] expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
generated_str = generations[i] generated_str = generations[i]
expected_str = expected_strs[i] expected_str = expected_strs[i]

View File

@ -7,6 +7,8 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class Attention(nn.Module): class Attention(nn.Module):
@ -30,6 +32,7 @@ class Attention(nn.Module):
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if cache_config is not None: if cache_config is not None:
@ -40,6 +43,27 @@ class Attention(nn.Module):
block_size = 16 block_size = 16
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = num_heads num_kv_heads = num_heads
# The default kv_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized kv_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self._kv_scale = 1.0
quant_method = quant_config.get_quant_method(
self) if quant_config else None
if quant_method is not None:
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# When FP8 quantization is enabled, we make a parameter
# "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back
# to self._kv_scale in a native float32 value after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model # During model initialization, the default dtype is set as the model
# weight and activation dtype. # weight and activation dtype.
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
@ -57,10 +81,9 @@ class Attention(nn.Module):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata, return self.impl.forward(query, key, value, kv_cache, attn_metadata,
kv_scale) self._kv_scale)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore s = f"head_size={self.impl.head_size}" # type: ignore

View File

@ -355,14 +355,12 @@ class CacheConfig:
def _verify_cache_dtype(self) -> None: def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto": if self.cache_dtype == "auto":
pass pass
elif self.cache_dtype == "fp8": elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
logger.info( logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU " "Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. " "memory footprint and boosts the performance. "
"But it may cause slight accuracy drop without scaling " "Meanwhile, it may cause accuracy drop without a proper "
"factors. FP8_E5M2 (without scaling) is only supported on " "scaling factor")
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
"is instead supported for common inference criteria.")
else: else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

View File

@ -191,12 +191,11 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--kv-cache-dtype', '--kv-cache-dtype',
type=str, type=str,
choices=['auto', 'fp8'], choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default=EngineArgs.kv_cache_dtype, default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model ' help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
'supported for common inference criteria.')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=nullable_str, type=nullable_str,

View File

@ -8,8 +8,9 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
@ -58,9 +59,13 @@ class Fp8Config(QuantizationConfig):
activation_scheme=activation_scheme) activation_scheme=activation_scheme)
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]: self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return Fp8LinearMethod(self) return Fp8LinearMethod(self)
if isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
@ -251,6 +256,44 @@ class Fp8LinearMethod(LinearMethodBase):
return torch.narrow(output, 0, 0, x.shape[0]) return torch.narrow(output, 0, 0, x.shape[0])
class Fp8KVCacheMethod(QuantizeMethodBase):
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module):
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scale to 1.0 as the default value.
# If the kv_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
def process_weights_after_loading(self, layer: Module) -> None:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
kv_scale = layer.kv_scale.to("cpu").tolist()
if not isinstance(kv_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")
layer._kv_scale = kv_scale
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint.")
del layer.kv_scale
def all_close_1d(x: torch.Tensor) -> bool: def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1 assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))

View File

@ -268,7 +268,8 @@ class ArcticAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -154,7 +154,8 @@ class BaiChuanAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes,
quant_config=quant_config)
else: else:
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
@ -166,7 +167,8 @@ class BaiChuanAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -111,7 +111,8 @@ class BloomAttention(nn.Module):
self.head_dim, self.head_dim,
scaling, scaling,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -86,13 +86,12 @@ class GLMAttention(nn.Module):
base=10000 * rope_ratio, base=10000 * rope_ratio,
is_neox_style=False, is_neox_style=False,
) )
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads, self.head_dim,
self.head_dim, self.scaling,
self.scaling, num_kv_heads=self.num_kv_heads,
num_kv_heads=self.num_kv_heads, cache_config=cache_config,
cache_config=cache_config, quant_config=quant_config)
)
def forward( def forward(
self, self,

View File

@ -177,13 +177,12 @@ class CohereAttention(nn.Module):
rope_scaling=self.rope_scaling, rope_scaling=self.rope_scaling,
is_neox_style=False, is_neox_style=False,
) )
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads, self.head_dim,
self.head_dim, self.scaling,
self.scaling, num_kv_heads=self.num_kv_heads,
num_kv_heads=self.num_kv_heads, cache_config=cache_config,
cache_config=cache_config, quant_config=quant_config)
)
if self.use_qk_norm: if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads, self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim), self.head_dim),

View File

@ -218,13 +218,12 @@ class DbrxAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads, self.head_dim,
self.head_dim, self.scaling,
self.scaling, num_kv_heads=self.num_kv_heads,
num_kv_heads=self.num_kv_heads, cache_config=cache_config,
cache_config=cache_config, quant_config=quant_config)
)
def forward( def forward(
self, self,

View File

@ -232,7 +232,8 @@ class DeepseekAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -153,7 +153,8 @@ class FalconAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.inv_norm_factor, self.inv_norm_factor,
num_kv_heads=self.num_kv_heads) num_kv_heads=self.num_kv_heads,
quant_config=quant_config)
elif self.use_alibi: elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads head_start = tp_rank * self.num_heads
@ -165,13 +166,15 @@ class FalconAttention(nn.Module):
self.head_dim, self.head_dim,
self.inv_norm_factor, self.inv_norm_factor,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes,
quant_config=quant_config)
else: else:
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.inv_norm_factor, scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -157,7 +157,8 @@ class GemmaAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -75,7 +75,8 @@ class GPT2Attention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scale, scale=self.scale,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -88,7 +88,8 @@ class GPTBigCodeAttention(nn.Module):
self.head_dim, self.head_dim,
scale=self.scale, scale=self.scale,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -88,7 +88,8 @@ class GPTJAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_size, self.head_size,
scaling, scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -89,7 +89,8 @@ class GPTNeoXAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_size, self.head_size,
scaling, scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -117,7 +117,8 @@ class InternLM2Attention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -105,13 +105,12 @@ class JAISAttention(nn.Module):
head_end = (tp_rank + 1) * self.num_heads head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = _get_alibi_slopes(total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end] alibi_slopes = alibi_slopes[head_start:head_end]
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads, self.head_dim,
self.head_dim, scale=self.scale,
scale=self.scale, alibi_slopes=alibi_slopes,
alibi_slopes=alibi_slopes, cache_config=cache_config,
cache_config=cache_config, quant_config=quant_config)
)
def forward( def forward(
self, self,

View File

@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader) default_weight_loader, kv_cache_scales_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.utils import is_hip from vllm.utils import is_hip, print_warning_once
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
@ -119,15 +119,6 @@ class LlamaAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
# This will be overwritten by model initialization if we are using it.
# N.B. currently we only support per tensor scalar scaling factors
# & only applicable to ROCm (AMD GPU).
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
self.kv_scale = 1.0
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
self.head_dim, self.head_dim,
@ -155,7 +146,8 @@ class LlamaAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window, sliding_window=sliding_window,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,
@ -167,8 +159,7 @@ class LlamaAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
self.kv_scale)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -421,6 +412,19 @@ class LlamaForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
f"Found kv scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
"not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
@ -445,7 +449,7 @@ class LlamaForCausalLM(nn.Module):
# scaling_factor = tensor_amax / FPtype_max # scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2 scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"): if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.kv_scale = scaling_factor layer_self_attn.attn._kv_scale = scaling_factor
else: else:
raise RuntimeError("Self attention has no KV cache scaling " raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!") "factor attribute!")

View File

@ -236,7 +236,8 @@ class MiniCPMAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -308,14 +308,13 @@ class MixtralAttention(nn.Module):
base=int(self.rope_theta), base=int(self.rope_theta),
is_neox_style=True, is_neox_style=True,
) )
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads, self.head_dim,
self.head_dim, self.scaling,
self.scaling, num_kv_heads=self.num_kv_heads,
num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window,
sliding_window=self.sliding_window, cache_config=cache_config,
cache_config=cache_config, quant_config=quant_config)
)
def forward( def forward(
self, self,
@ -581,6 +580,20 @@ class MixtralForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -213,14 +213,13 @@ class MixtralAttention(nn.Module):
base=int(self.rope_theta), base=int(self.rope_theta),
is_neox_style=True, is_neox_style=True,
) )
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads, self.head_dim,
self.head_dim, self.scaling,
self.scaling, num_kv_heads=self.num_kv_heads,
num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window,
sliding_window=self.sliding_window, cache_config=cache_config,
cache_config=cache_config, quant_config=quant_config)
)
def forward( def forward(
self, self,

View File

@ -110,7 +110,8 @@ class MPTAttention(nn.Module):
scaling, scaling,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -96,7 +96,8 @@ class OlmoAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scaling, scale=self.scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
# Attention output projection. # Attention output projection.
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(

View File

@ -91,7 +91,8 @@ class OPTAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
scale=self.scaling, scale=self.scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -121,7 +121,8 @@ class OrionAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -110,7 +110,8 @@ class PhiAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_size, self.head_size,
scaling, scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -106,7 +106,8 @@ class QWenAttention(nn.Module):
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -141,7 +141,8 @@ class Qwen2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -241,7 +241,8 @@ class Qwen2MoeAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -127,7 +127,8 @@ class StablelmAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
num_kv_heads=self.num_key_value_heads, num_kv_heads=self.num_key_value_heads,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -97,14 +97,13 @@ class Starcoder2Attention(nn.Module):
base=int(self.rope_theta), base=int(self.rope_theta),
is_neox_style=True, is_neox_style=True,
) )
self.attn = Attention( self.attn = Attention(self.num_heads,
self.num_heads, self.head_dim,
self.head_dim, self.scaling,
self.scaling, num_kv_heads=self.num_kv_heads,
num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window,
sliding_window=self.sliding_window, cache_config=cache_config,
cache_config=cache_config, quant_config=quant_config)
)
def forward( def forward(
self, self,

View File

@ -135,7 +135,8 @@ class XverseAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window, sliding_window=sliding_window,
cache_config=cache_config) cache_config=cache_config,
quant_config=quant_config)
def forward( def forward(
self, self,

View File

@ -31,6 +31,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
"float": torch.float, "float": torch.float,
"fp8": torch.uint8, "fp8": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
} }

View File

@ -1,4 +1,5 @@
import time import time
import warnings
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
import numpy as np import numpy as np
@ -168,11 +169,21 @@ class ModelRunner:
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
if self.kv_cache_dtype == "fp8" and is_hip(): if self.kv_cache_dtype == "fp8" and is_hip():
# Currently scaled KV cache is only enabled on ROCm # Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.
if self.model_config.quantization_param_path is not None: if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)): if callable(getattr(self.model, "load_kv_cache_scales", None)):
warnings.warn(
"Loading kv cache scaling factor from JSON is "
"deprecated and will be removed. Please include "
"kv cache scaling factors in the model checkpoint.",
FutureWarning,
stacklevel=2)
self.model.load_kv_cache_scales( self.model.load_kv_cache_scales(
self.model_config.quantization_param_path) self.model_config.quantization_param_path)
logger.info("Loaded KV cache scaling factors from %s",
self.model_config.quantization_param_path)
else: else:
raise RuntimeError( raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but " "Using FP8 KV cache and scaling factors provided but "
@ -183,10 +194,6 @@ class ModelRunner:
"Using FP8 KV cache but no scaling factors " "Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. " "provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!") "This may lead to less accurate results!")
elif self.model_config.quantization_param_path is not None:
logger.warning("KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")
def save_sharded_state( def save_sharded_state(
self, self,