vllm/tests/samplers/test_ignore_eos.py

33 lines
928 B
Python
Raw Normal View History

2024-05-01 21:45:42 +09:00
"""Make sure ignore_eos works.
Run `pytest tests/samplers/test_ignore_eos.py`.
"""
import pytest
from vllm import SamplingParams
2024-06-01 21:21:53 -05:00
# We also test with llama because it has generation_config to specify EOS
# (past regression).
MODELS = ["facebook/opt-125m", "meta-llama/Llama-2-7b-hf"]
2024-05-01 21:45:42 +09:00
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
2024-06-01 21:21:53 -05:00
@pytest.mark.parametrize("max_tokens", [512])
def test_ignore_eos(
2024-05-01 21:45:42 +09:00
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
vllm_model = vllm_runner(model, dtype=dtype)
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
2024-06-01 21:21:53 -05:00
for prompt in example_prompts:
ignore_eos_output = vllm_model.model.generate(
prompt, sampling_params=sampling_params)
output_length = len(ignore_eos_output[0].outputs[0].token_ids)
assert output_length == max_tokens