Fix assertion failure in Qwen 1.5 with prefix caching enabled (#3373)
Co-authored-by: Cade Daniel <edacih@gmail.com>
This commit is contained in:
parent
dfc77408bd
commit
54be8a0be2
43
tests/test_config.py
Normal file
43
tests/test_config.py
Normal file
@ -0,0 +1,43 @@
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
|
||||
def test_get_sliding_window():
|
||||
TEST_SLIDING_WINDOW = 4096
|
||||
# Test that the sliding window is correctly computed.
|
||||
# For Qwen1.5/Qwen2, get_sliding_window() should be None
|
||||
# when use_sliding_window is False.
|
||||
qwen2_model_config = ModelConfig(
|
||||
"Qwen/Qwen1.5-7B",
|
||||
"Qwen/Qwen1.5-7B",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
)
|
||||
|
||||
qwen2_model_config.hf_config.use_sliding_window = False
|
||||
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
||||
assert qwen2_model_config.get_sliding_window() is None
|
||||
|
||||
qwen2_model_config.hf_config.use_sliding_window = True
|
||||
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||
|
||||
mistral_model_config = ModelConfig(
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
)
|
||||
mistral_model_config.hf_config.sliding_window = None
|
||||
assert mistral_model_config.get_sliding_window() is None
|
||||
|
||||
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
||||
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
@ -103,6 +103,7 @@ class ModelConfig:
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
|
||||
|
||||
if not os.path.exists(model):
|
||||
model_path = snapshot_download(model_id=model,
|
||||
cache_dir=download_dir,
|
||||
@ -139,7 +140,7 @@ class ModelConfig:
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
]
|
||||
raise ValueError(
|
||||
f"load format \'{load_format}\' is not supported in ROCm. "
|
||||
f"load format '{load_format}' is not supported in ROCm. "
|
||||
f"Supported load format are "
|
||||
f"{rocm_supported_load_format}")
|
||||
|
||||
@ -232,6 +233,15 @@ class ModelConfig:
|
||||
f"({pipeline_parallel_size}).")
|
||||
|
||||
def get_sliding_window(self) -> Optional[int]:
|
||||
"""Get the sliding window size, or None if disabled.
|
||||
"""
|
||||
|
||||
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
|
||||
# addition to sliding window size. We check if that field is present
|
||||
# and if it's False, return None.
|
||||
if (hasattr(self.hf_config, "use_sliding_window")
|
||||
and not self.hf_config.use_sliding_window):
|
||||
return None
|
||||
return getattr(self.hf_config, "sliding_window", None)
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
@ -624,7 +634,7 @@ def _get_and_verify_dtype(
|
||||
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
|
||||
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
|
||||
]
|
||||
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
|
||||
raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
|
||||
f"Supported dtypes are {rocm_supported_dtypes}")
|
||||
|
||||
# Verify the dtype.
|
||||
|
Loading…
x
Reference in New Issue
Block a user