# SPDX-License-Identifier: Apache-2.0 from typing import Any, Optional import pytest from vllm import LLM, SamplingParams, envs MODEL = "meta-llama/llama-2-7b-hf" MAX_TOKENS = 200 def _test_stopping(llm: LLM, expected_output: str, expected_reason: Any, stop: Optional[list[str]] = None, stop_token_ids: Optional[list[int]] = None, include_in_output: bool = False) -> None: output = llm.generate( "A story about vLLM:\n", SamplingParams( temperature=0.0, max_tokens=MAX_TOKENS, stop=stop, stop_token_ids=stop_token_ids, include_stop_str_in_output=include_in_output, ))[0].outputs[0] assert output is not None assert output.text == expected_output assert output.stop_reason == expected_reason def _set_async_mode(llm, is_async): llm.llm_engine.scheduler[0].use_async_output_proc = is_async def _stop_basic(llm): _test_stopping(llm, stop=["."], include_in_output=False, expected_output="VLLM is a 100% volunteer organization", expected_reason=".") _test_stopping(llm, stop=["."], include_in_output=True, expected_output="VLLM is a 100% volunteer organization.", expected_reason=".") def _stop_multi_tokens(llm): _test_stopping( llm, stop=["group of peo", "short"], include_in_output=False, expected_output="VLLM is a 100% volunteer organization. We are a ", expected_reason="group of peo") _test_stopping( llm, stop=["group of peo", "short"], include_in_output=True, expected_output= "VLLM is a 100% volunteer organization. We are a group of peo", expected_reason="group of peo") def _stop_partial_token(llm): _test_stopping(llm, stop=["gani"], include_in_output=False, expected_output="VLLM is a 100% volunteer or", expected_reason="gani") _test_stopping(llm, stop=["gani"], include_in_output=True, expected_output="VLLM is a 100% volunteer organi", expected_reason="gani") def _stop_token_id(llm): # token id 13013 => " organization" _test_stopping(llm, stop_token_ids=[13013], include_in_output=False, expected_output="VLLM is a 100% volunteer", expected_reason=13013) _test_stopping(llm, stop_token_ids=[13013], include_in_output=True, expected_output="VLLM is a 100% volunteer organization", expected_reason=13013) @pytest.mark.skip_global_cleanup def test_stop_strings(): # If V0, must set enforce_eager=False since we use # async output processing below. vllm_model = LLM(MODEL, enforce_eager=envs.VLLM_USE_V1) if envs.VLLM_USE_V1: _stop_basic(vllm_model) else: _set_async_mode(vllm_model, True) _stop_basic(vllm_model) _set_async_mode(vllm_model, False) _stop_basic(vllm_model) if envs.VLLM_USE_V1: _stop_multi_tokens(vllm_model) else: _set_async_mode(vllm_model, True) _stop_multi_tokens(vllm_model) _set_async_mode(vllm_model, False) _stop_multi_tokens(vllm_model) if envs.VLLM_USE_V1: _stop_partial_token(vllm_model) else: _set_async_mode(vllm_model, True) _stop_partial_token(vllm_model) _set_async_mode(vllm_model, False) _stop_partial_token(vllm_model) if envs.VLLM_USE_V1: # FIXME: this does not respect include_in_output=False # _stop_token_id(vllm_model) pass else: _set_async_mode(vllm_model, True) _stop_token_id(vllm_model) _set_async_mode(vllm_model, False) _stop_token_id(vllm_model)