Update test_ignore_eos (#4898)

This commit is contained in:
Simon Mo 2024-06-01 21:21:53 -05:00 committed by GitHub
parent 044793d8df
commit ed59a7ed23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,25 +7,26 @@ import pytest
from vllm import SamplingParams from vllm import SamplingParams
MODELS = ["facebook/opt-125m"] # We also test with llama because it has generation_config to specify EOS
# (past regression).
MODELS = ["facebook/opt-125m", "meta-llama/Llama-2-7b-hf"]
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [1024]) @pytest.mark.parametrize("max_tokens", [512])
def test_beam_search_single_input( def test_ignore_eos(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model: str, model: str,
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
) -> None: ) -> None:
example_prompts = "1 + 1 is"
vllm_model = vllm_runner(model, dtype=dtype) vllm_model = vllm_runner(model, dtype=dtype)
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
for prompt in example_prompts:
ignore_eos_output = vllm_model.model.generate( ignore_eos_output = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params) prompt, sampling_params=sampling_params)
print(len(ignore_eos_output[0].outputs[0].token_ids)) output_length = len(ignore_eos_output[0].outputs[0].token_ids)
assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) < 10 assert output_length == max_tokens
assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) >= 0