[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)
The 2nd PR for #4532. This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
This commit is contained in:
parent
8674f9880e
commit
a3a73ab069
@ -153,15 +153,13 @@ if __name__ == '__main__':
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='enforce eager mode and disable CUDA graph')
|
help='enforce eager mode and disable CUDA graph')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
'--kv-cache-dtype',
|
||||||
type=str,
|
type=str,
|
||||||
choices=['auto', 'fp8'],
|
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||||||
default='auto',
|
default="auto",
|
||||||
help=
|
help='Data type for kv cache storage. If "auto", will use model '
|
||||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
|
||||||
'instead supported for common inference criteria.')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--quantization-param-path',
|
'--quantization-param-path',
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -323,15 +323,13 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="enforce eager execution")
|
help="enforce eager execution")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
'--kv-cache-dtype',
|
||||||
type=str,
|
type=str,
|
||||||
choices=["auto", "fp8"],
|
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||||||
default="auto",
|
default="auto",
|
||||||
help=
|
help='Data type for kv cache storage. If "auto", will use model '
|
||||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
|
||||||
'common inference criteria.')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--quantization-param-path',
|
'--quantization-param-path',
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -183,13 +183,11 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
"--kv-cache-dtype",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["auto", "fp8"],
|
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
|
||||||
default="auto",
|
default="auto",
|
||||||
help=
|
help="Data type for kv cache storage. If 'auto', will use model "
|
||||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
||||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
|
||||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
|
||||||
'common inference criteria.')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
|
@ -16,22 +16,35 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|||||||
MAX_MODEL_LEN = 1024
|
MAX_MODEL_LEN = 1024
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
|
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV",
|
||||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
]
|
]
|
||||||
|
|
||||||
EXPECTED_STRS_MAP = {
|
EXPECTED_STRS_MAP = {
|
||||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": [
|
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": {
|
||||||
|
"auto": [
|
||||||
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
|
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
|
||||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both',
|
||||||
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
|
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
||||||
'Zeta-5, a highly advanced robot designed for menial labor, whirred to a',
|
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
|
||||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||||
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
|
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no'
|
||||||
],
|
],
|
||||||
"meta-llama/Meta-Llama-3-8B-Instruct": [
|
"fp8": [
|
||||||
|
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||||
|
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||||
|
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||||
|
'A neural network is a complex system made up of several basic components that work together to enable it to',
|
||||||
|
'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like',
|
||||||
|
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
|
||||||
|
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||||
|
'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk'
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"meta-llama/Meta-Llama-3-8B-Instruct": {
|
||||||
|
"auto": [
|
||||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||||
@ -41,6 +54,17 @@ EXPECTED_STRS_MAP = {
|
|||||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||||
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
|
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
|
||||||
],
|
],
|
||||||
|
"fp8": [
|
||||||
|
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||||
|
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||||
|
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||||
|
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
|
||||||
|
'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest',
|
||||||
|
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||||
|
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||||
|
'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu'
|
||||||
|
]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
capability = torch.cuda.get_device_capability()
|
capability = torch.cuda.get_device_capability()
|
||||||
@ -52,14 +76,14 @@ fp8_not_supported = (capability <
|
|||||||
@pytest.mark.skipif(fp8_not_supported,
|
@pytest.mark.skipif(fp8_not_supported,
|
||||||
reason="fp8 is not supported on this GPU type.")
|
reason="fp8 is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("model_name", MODELS)
|
@pytest.mark.parametrize("model_name", MODELS)
|
||||||
def test_models(
|
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||||
example_prompts,
|
def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
|
||||||
model_name,
|
|
||||||
) -> None:
|
|
||||||
model = LLM(model=model_name,
|
model = LLM(model=model_name,
|
||||||
max_model_len=MAX_MODEL_LEN,
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
trust_remote_code=True,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
quantization="fp8")
|
quantization="fp8",
|
||||||
|
kv_cache_dtype=kv_cache_dtype)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
formatted_prompts = [
|
formatted_prompts = [
|
||||||
@ -81,8 +105,8 @@ def test_models(
|
|||||||
generations.append(outputs[0].outputs[0].text)
|
generations.append(outputs[0].outputs[0].text)
|
||||||
del model
|
del model
|
||||||
|
|
||||||
print(generations)
|
print(model_name, kv_cache_dtype, generations)
|
||||||
expected_strs = EXPECTED_STRS_MAP[model_name]
|
expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
|
||||||
for i in range(len(example_prompts)):
|
for i in range(len(example_prompts)):
|
||||||
generated_str = generations[i]
|
generated_str = generations[i]
|
||||||
expected_str = expected_strs[i]
|
expected_str = expected_strs[i]
|
||||||
|
@ -7,6 +7,8 @@ import torch.nn as nn
|
|||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@ -30,6 +32,7 @@ class Attention(nn.Module):
|
|||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if cache_config is not None:
|
if cache_config is not None:
|
||||||
@ -40,6 +43,27 @@ class Attention(nn.Module):
|
|||||||
block_size = 16
|
block_size = 16
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = num_heads
|
num_kv_heads = num_heads
|
||||||
|
|
||||||
|
# The default kv_scale is set to 1.0. This is ignored
|
||||||
|
# when kv-cache is not fp8, and should be used with
|
||||||
|
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||||
|
# expect the pre-quantized kv_scale to be loaded along
|
||||||
|
# with the model weights.
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
self._kv_scale = 1.0
|
||||||
|
quant_method = quant_config.get_quant_method(
|
||||||
|
self) if quant_config else None
|
||||||
|
if quant_method is not None:
|
||||||
|
if self.kv_cache_dtype == "fp8_e5m2":
|
||||||
|
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||||
|
"fp8 checkpoints.")
|
||||||
|
# When FP8 quantization is enabled, we make a parameter
|
||||||
|
# "kv_scale" so that it can be loaded from FP8 checkpoint.
|
||||||
|
# The kv_scale will then be converted back
|
||||||
|
# to self._kv_scale in a native float32 value after weight loading.
|
||||||
|
self.quant_method = quant_method
|
||||||
|
self.quant_method.create_weights(self)
|
||||||
|
|
||||||
# During model initialization, the default dtype is set as the model
|
# During model initialization, the default dtype is set as the model
|
||||||
# weight and activation dtype.
|
# weight and activation dtype.
|
||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
@ -57,10 +81,9 @@ class Attention(nn.Module):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
kv_scale: float = 1.0,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
||||||
kv_scale)
|
self._kv_scale)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||||
|
@ -355,14 +355,12 @@ class CacheConfig:
|
|||||||
def _verify_cache_dtype(self) -> None:
|
def _verify_cache_dtype(self) -> None:
|
||||||
if self.cache_dtype == "auto":
|
if self.cache_dtype == "auto":
|
||||||
pass
|
pass
|
||||||
elif self.cache_dtype == "fp8":
|
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||||
"memory footprint and boosts the performance. "
|
"memory footprint and boosts the performance. "
|
||||||
"But it may cause slight accuracy drop without scaling "
|
"Meanwhile, it may cause accuracy drop without a proper "
|
||||||
"factors. FP8_E5M2 (without scaling) is only supported on "
|
"scaling factor")
|
||||||
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
|
|
||||||
"is instead supported for common inference criteria.")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
||||||
|
|
||||||
|
@ -191,12 +191,11 @@ class EngineArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--kv-cache-dtype',
|
'--kv-cache-dtype',
|
||||||
type=str,
|
type=str,
|
||||||
choices=['auto', 'fp8'],
|
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||||||
default=EngineArgs.kv_cache_dtype,
|
default=EngineArgs.kv_cache_dtype,
|
||||||
help='Data type for kv cache storage. If "auto", will use model '
|
help='Data type for kv cache storage. If "auto", will use model '
|
||||||
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
|
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||||
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
|
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||||
'supported for common inference criteria.')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--quantization-param-path',
|
'--quantization-param-path',
|
||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
|
@ -8,8 +8,9 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
@ -58,9 +59,13 @@ class Fp8Config(QuantizationConfig):
|
|||||||
activation_scheme=activation_scheme)
|
activation_scheme=activation_scheme)
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
|
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
|
||||||
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return Fp8LinearMethod(self)
|
return Fp8LinearMethod(self)
|
||||||
|
if isinstance(layer, Attention):
|
||||||
|
return Fp8KVCacheMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
@ -251,6 +256,44 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
return torch.narrow(output, 0, 0, x.shape[0])
|
return torch.narrow(output, 0, 0, x.shape[0])
|
||||||
|
|
||||||
|
|
||||||
|
class Fp8KVCacheMethod(QuantizeMethodBase):
|
||||||
|
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: Fp8Config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module):
|
||||||
|
"""Create "weight" (aka kv_scale) for an attention layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer: The layer that is using the QuantizeMethodBase factory.
|
||||||
|
"""
|
||||||
|
# Initialize the KV cache scale to 1.0 as the default value.
|
||||||
|
# If the kv_scale appears in the checkpoint, it will be
|
||||||
|
# overwritten when loading weights.
|
||||||
|
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
|
||||||
|
|
||||||
|
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||||
|
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
|
||||||
|
# regardless whether the kv-scale is available in the checkpoint.
|
||||||
|
if layer.kv_cache_dtype != "auto":
|
||||||
|
kv_scale = layer.kv_scale.to("cpu").tolist()
|
||||||
|
if not isinstance(kv_scale, float):
|
||||||
|
raise ValueError("Only support per-tensor scaling factor "
|
||||||
|
"for fp8 KV cache")
|
||||||
|
layer._kv_scale = kv_scale
|
||||||
|
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
||||||
|
print_warning_once(
|
||||||
|
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
|
||||||
|
"cause accuracy issues. Please make sure kv-cache scaling "
|
||||||
|
"factor is available in the fp8 checkpoint.")
|
||||||
|
del layer.kv_scale
|
||||||
|
|
||||||
|
|
||||||
def all_close_1d(x: torch.Tensor) -> bool:
|
def all_close_1d(x: torch.Tensor) -> bool:
|
||||||
assert len(x.shape) == 1
|
assert len(x.shape) == 1
|
||||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||||
|
@ -268,7 +268,8 @@ class ArcticAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -154,7 +154,8 @@ class BaiChuanAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scaling,
|
scaling,
|
||||||
alibi_slopes=alibi_slopes)
|
alibi_slopes=alibi_slopes,
|
||||||
|
quant_config=quant_config)
|
||||||
else:
|
else:
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -166,7 +167,8 @@ class BaiChuanAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -111,7 +111,8 @@ class BloomAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
scaling,
|
scaling,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -86,13 +86,12 @@ class GLMAttention(nn.Module):
|
|||||||
base=10000 * rope_ratio,
|
base=10000 * rope_ratio,
|
||||||
is_neox_style=False,
|
is_neox_style=False,
|
||||||
)
|
)
|
||||||
self.attn = Attention(
|
self.attn = Attention(self.num_heads,
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
)
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -177,13 +177,12 @@ class CohereAttention(nn.Module):
|
|||||||
rope_scaling=self.rope_scaling,
|
rope_scaling=self.rope_scaling,
|
||||||
is_neox_style=False,
|
is_neox_style=False,
|
||||||
)
|
)
|
||||||
self.attn = Attention(
|
self.attn = Attention(self.num_heads,
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
)
|
quant_config=quant_config)
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
||||||
self.head_dim),
|
self.head_dim),
|
||||||
|
@ -218,13 +218,12 @@ class DbrxAttention(nn.Module):
|
|||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.attn = Attention(
|
self.attn = Attention(self.num_heads,
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
)
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -232,7 +232,8 @@ class DeepseekAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -153,7 +153,8 @@ class FalconAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.inv_norm_factor,
|
self.inv_norm_factor,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
quant_config=quant_config)
|
||||||
elif self.use_alibi:
|
elif self.use_alibi:
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
head_start = tp_rank * self.num_heads
|
head_start = tp_rank * self.num_heads
|
||||||
@ -165,13 +166,15 @@ class FalconAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.inv_norm_factor,
|
self.inv_norm_factor,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
alibi_slopes=alibi_slopes)
|
alibi_slopes=alibi_slopes,
|
||||||
|
quant_config=quant_config)
|
||||||
else:
|
else:
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.inv_norm_factor,
|
scale=self.inv_norm_factor,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -157,7 +157,8 @@ class GemmaAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -75,7 +75,8 @@ class GPT2Attention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -88,7 +88,8 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -88,7 +88,8 @@ class GPTJAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
scaling,
|
scaling,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -89,7 +89,8 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
scaling,
|
scaling,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -117,7 +117,8 @@ class InternLM2Attention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -105,13 +105,12 @@ class JAISAttention(nn.Module):
|
|||||||
head_end = (tp_rank + 1) * self.num_heads
|
head_end = (tp_rank + 1) * self.num_heads
|
||||||
alibi_slopes = _get_alibi_slopes(total_num_heads)
|
alibi_slopes = _get_alibi_slopes(total_num_heads)
|
||||||
alibi_slopes = alibi_slopes[head_start:head_end]
|
alibi_slopes = alibi_slopes[head_start:head_end]
|
||||||
self.attn = Attention(
|
self.attn = Attention(self.num_heads,
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
)
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
default_weight_loader, kv_cache_scales_loader)
|
default_weight_loader, kv_cache_scales_loader)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
from vllm.utils import is_hip
|
from vllm.utils import is_hip, print_warning_once
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
@ -119,15 +119,6 @@ class LlamaAttention(nn.Module):
|
|||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
# This will be overwritten by model initialization if we are using it.
|
|
||||||
# N.B. currently we only support per tensor scalar scaling factors
|
|
||||||
# & only applicable to ROCm (AMD GPU).
|
|
||||||
# The scaling factor convention we are assuming is
|
|
||||||
# quantized_value * scaling_factor ~= true_value
|
|
||||||
# which is consistent with the practice of setting
|
|
||||||
# scaling_factor = tensor_amax / FPtype_max
|
|
||||||
self.kv_scale = 1.0
|
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -155,7 +146,8 @@ class LlamaAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -167,8 +159,7 @@ class LlamaAttention(nn.Module):
|
|||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
self.kv_scale)
|
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -421,6 +412,19 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
if name.endswith("kv_scale"):
|
||||||
|
remapped_kv_scale_name = name.replace(
|
||||||
|
".kv_scale", ".attn.kv_scale")
|
||||||
|
if remapped_kv_scale_name not in params_dict:
|
||||||
|
print_warning_once(
|
||||||
|
f"Found kv scale in the checkpoint (e.g. {name}), "
|
||||||
|
"but not found the expected name in the model "
|
||||||
|
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
|
||||||
|
"not loaded.")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
name = remapped_kv_scale_name
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
@ -445,7 +449,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
# scaling_factor = tensor_amax / FPtype_max
|
# scaling_factor = tensor_amax / FPtype_max
|
||||||
scaling_factor *= 2
|
scaling_factor *= 2
|
||||||
if hasattr(layer_self_attn, "kv_scale"):
|
if hasattr(layer_self_attn, "kv_scale"):
|
||||||
layer_self_attn.kv_scale = scaling_factor
|
layer_self_attn.attn._kv_scale = scaling_factor
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Self attention has no KV cache scaling "
|
raise RuntimeError("Self attention has no KV cache scaling "
|
||||||
"factor attribute!")
|
"factor attribute!")
|
||||||
|
@ -236,7 +236,8 @@ class MiniCPMAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -308,14 +308,13 @@ class MixtralAttention(nn.Module):
|
|||||||
base=int(self.rope_theta),
|
base=int(self.rope_theta),
|
||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
)
|
)
|
||||||
self.attn = Attention(
|
self.attn = Attention(self.num_heads,
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
sliding_window=self.sliding_window,
|
sliding_window=self.sliding_window,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
)
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -581,6 +580,20 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
if name.endswith("kv_scale"):
|
||||||
|
remapped_kv_scale_name = name.replace(
|
||||||
|
".kv_scale", ".attn.kv_scale")
|
||||||
|
if remapped_kv_scale_name not in params_dict:
|
||||||
|
print_warning_once(
|
||||||
|
"Found kv scale in the checkpoint "
|
||||||
|
f"(e.g. {name}), but not found the expected "
|
||||||
|
f"name in the model "
|
||||||
|
f"(e.g. {remapped_kv_scale_name}). "
|
||||||
|
"kv-scale is not loaded.")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
name = remapped_kv_scale_name
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -213,14 +213,13 @@ class MixtralAttention(nn.Module):
|
|||||||
base=int(self.rope_theta),
|
base=int(self.rope_theta),
|
||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
)
|
)
|
||||||
self.attn = Attention(
|
self.attn = Attention(self.num_heads,
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
sliding_window=self.sliding_window,
|
sliding_window=self.sliding_window,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
)
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -110,7 +110,8 @@ class MPTAttention(nn.Module):
|
|||||||
scaling,
|
scaling,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -96,7 +96,8 @@ class OlmoAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scaling,
|
scale=self.scaling,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
# Attention output projection.
|
# Attention output projection.
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
|
@ -91,7 +91,8 @@ class OPTAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
scale=self.scaling,
|
scale=self.scaling,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -121,7 +121,8 @@ class OrionAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -110,7 +110,8 @@ class PhiAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
scaling,
|
scaling,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -106,7 +106,8 @@ class QWenAttention(nn.Module):
|
|||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -141,7 +141,8 @@ class Qwen2Attention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
sliding_window=self.sliding_window,
|
sliding_window=self.sliding_window,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -241,7 +241,8 @@ class Qwen2MoeAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -127,7 +127,8 @@ class StablelmAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_key_value_heads,
|
num_kv_heads=self.num_key_value_heads,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -97,14 +97,13 @@ class Starcoder2Attention(nn.Module):
|
|||||||
base=int(self.rope_theta),
|
base=int(self.rope_theta),
|
||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
)
|
)
|
||||||
self.attn = Attention(
|
self.attn = Attention(self.num_heads,
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
sliding_window=self.sliding_window,
|
sliding_window=self.sliding_window,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
)
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -135,7 +135,8 @@ class XverseAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -31,6 +31,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
|
|||||||
"bfloat16": torch.bfloat16,
|
"bfloat16": torch.bfloat16,
|
||||||
"float": torch.float,
|
"float": torch.float,
|
||||||
"fp8": torch.uint8,
|
"fp8": torch.uint8,
|
||||||
|
"fp8_e4m3": torch.uint8,
|
||||||
|
"fp8_e5m2": torch.uint8,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -168,11 +169,21 @@ class ModelRunner:
|
|||||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||||
|
|
||||||
if self.kv_cache_dtype == "fp8" and is_hip():
|
if self.kv_cache_dtype == "fp8" and is_hip():
|
||||||
# Currently scaled KV cache is only enabled on ROCm
|
# Currently only ROCm accepts kv-cache scaling factors
|
||||||
|
# via quantization_param_path and this will be deprecated
|
||||||
|
# in the future.
|
||||||
if self.model_config.quantization_param_path is not None:
|
if self.model_config.quantization_param_path is not None:
|
||||||
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
||||||
|
warnings.warn(
|
||||||
|
"Loading kv cache scaling factor from JSON is "
|
||||||
|
"deprecated and will be removed. Please include "
|
||||||
|
"kv cache scaling factors in the model checkpoint.",
|
||||||
|
FutureWarning,
|
||||||
|
stacklevel=2)
|
||||||
self.model.load_kv_cache_scales(
|
self.model.load_kv_cache_scales(
|
||||||
self.model_config.quantization_param_path)
|
self.model_config.quantization_param_path)
|
||||||
|
logger.info("Loaded KV cache scaling factors from %s",
|
||||||
|
self.model_config.quantization_param_path)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Using FP8 KV cache and scaling factors provided but "
|
"Using FP8 KV cache and scaling factors provided but "
|
||||||
@ -183,10 +194,6 @@ class ModelRunner:
|
|||||||
"Using FP8 KV cache but no scaling factors "
|
"Using FP8 KV cache but no scaling factors "
|
||||||
"provided. Defaulting to scaling factors of 1.0. "
|
"provided. Defaulting to scaling factors of 1.0. "
|
||||||
"This may lead to less accurate results!")
|
"This may lead to less accurate results!")
|
||||||
elif self.model_config.quantization_param_path is not None:
|
|
||||||
logger.warning("KV cache scaling factors provided, "
|
|
||||||
"but the KV cache data type is not FP8. "
|
|
||||||
"KV cache scaling factors will not be used.")
|
|
||||||
|
|
||||||
def save_sharded_state(
|
def save_sharded_state(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user