2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-11-25 21:09:43 -08:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
from vllm import envs
|
|
|
|
from vllm.config import VllmConfig
|
|
|
|
from vllm.engine.arg_utils import EngineArgs
|
|
|
|
from vllm.usage.usage_lib import UsageContext
|
2024-11-27 23:59:28 -08:00
|
|
|
from vllm.utils import FlexibleArgumentParser
|
2024-11-25 21:09:43 -08:00
|
|
|
|
|
|
|
if not envs.VLLM_USE_V1:
|
|
|
|
pytest.skip(
|
|
|
|
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
|
|
|
|
allow_module_level=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-11-27 23:59:28 -08:00
|
|
|
def test_prefix_caching_from_cli():
|
|
|
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
|
|
|
args = parser.parse_args([])
|
2025-03-15 01:02:20 -04:00
|
|
|
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
|
|
|
assert (vllm_config.cache_config.enable_prefix_caching
|
2024-11-27 23:59:28 -08:00
|
|
|
), "V1 turns on prefix caching by default."
|
|
|
|
|
|
|
|
# Turn it off possible with flag.
|
|
|
|
args = parser.parse_args(["--no-enable-prefix-caching"])
|
2025-03-15 01:02:20 -04:00
|
|
|
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
|
|
|
assert not vllm_config.cache_config.enable_prefix_caching
|
2024-11-27 23:59:28 -08:00
|
|
|
|
|
|
|
# Turn it on with flag.
|
|
|
|
args = parser.parse_args(["--enable-prefix-caching"])
|
2025-03-15 01:02:20 -04:00
|
|
|
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
|
|
|
assert vllm_config.cache_config.enable_prefix_caching
|
2024-11-27 23:59:28 -08:00
|
|
|
|
|
|
|
|
2024-11-25 21:09:43 -08:00
|
|
|
def test_defaults_with_usage_context():
|
|
|
|
engine_args = EngineArgs(model="facebook/opt-125m")
|
|
|
|
vllm_config: VllmConfig = engine_args.create_engine_config(
|
|
|
|
UsageContext.LLM_CLASS)
|
|
|
|
|
2025-03-15 01:02:20 -04:00
|
|
|
from vllm.platforms import current_platform
|
|
|
|
device_name = current_platform.get_device_name().lower()
|
|
|
|
if "h100" in device_name or "h200" in device_name:
|
|
|
|
# For H100 and H200, we use larger default values.
|
|
|
|
default_llm_tokens = 16384
|
|
|
|
default_server_tokens = 8192
|
|
|
|
else:
|
|
|
|
default_llm_tokens = 8192
|
|
|
|
default_server_tokens = 2048
|
|
|
|
|
2024-11-25 21:09:43 -08:00
|
|
|
assert vllm_config.scheduler_config.max_num_seqs == 1024
|
2025-03-15 01:02:20 -04:00
|
|
|
assert vllm_config.scheduler_config.max_num_batched_tokens == default_llm_tokens # noqa: E501
|
2024-11-25 21:09:43 -08:00
|
|
|
|
|
|
|
engine_args = EngineArgs(model="facebook/opt-125m")
|
|
|
|
vllm_config = engine_args.create_engine_config(
|
|
|
|
UsageContext.OPENAI_API_SERVER)
|
|
|
|
assert vllm_config.scheduler_config.max_num_seqs == 1024
|
2025-03-15 01:02:20 -04:00
|
|
|
assert vllm_config.scheduler_config.max_num_batched_tokens == default_server_tokens # noqa: E501
|