Fix assertion failure in Qwen 1.5 with prefix caching enabled (#3373)

Co-authored-by: Cade Daniel <edacih@gmail.com>
This commit is contained in:
陈序 2024-03-15 04:56:57 +08:00 committed by GitHub
parent dfc77408bd
commit 54be8a0be2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 55 additions and 2 deletions

43
tests/test_config.py Normal file
View File

@ -0,0 +1,43 @@
from vllm.config import ModelConfig
def test_get_sliding_window():
TEST_SLIDING_WINDOW = 4096
# Test that the sliding window is correctly computed.
# For Qwen1.5/Qwen2, get_sliding_window() should be None
# when use_sliding_window is False.
qwen2_model_config = ModelConfig(
"Qwen/Qwen1.5-7B",
"Qwen/Qwen1.5-7B",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
)
qwen2_model_config.hf_config.use_sliding_window = False
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert qwen2_model_config.get_sliding_window() is None
qwen2_model_config.hf_config.use_sliding_window = True
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
mistral_model_config = ModelConfig(
"mistralai/Mistral-7B-v0.1",
"mistralai/Mistral-7B-v0.1",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
)
mistral_model_config.hf_config.sliding_window = None
assert mistral_model_config.get_sliding_window() is None
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW

View File

@ -103,6 +103,7 @@ class ModelConfig:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
if not os.path.exists(model):
model_path = snapshot_download(model_id=model,
cache_dir=download_dir,
@ -139,7 +140,7 @@ class ModelConfig:
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format \'{load_format}\' is not supported in ROCm. "
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load format are "
f"{rocm_supported_load_format}")
@ -232,6 +233,15 @@ class ModelConfig:
f"({pipeline_parallel_size}).")
def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
# and if it's False, return None.
if (hasattr(self.hf_config, "use_sliding_window")
and not self.hf_config.use_sliding_window):
return None
return getattr(self.hf_config, "sliding_window", None)
def get_vocab_size(self) -> int:
@ -624,7 +634,7 @@ def _get_and_verify_dtype(
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
]
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
f"Supported dtypes are {rocm_supported_dtypes}")
# Verify the dtype.