Update test_ignore_eos (#4898)
This commit is contained in:
parent
044793d8df
commit
ed59a7ed23
@ -7,25 +7,26 @@ import pytest
|
|||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
MODELS = ["facebook/opt-125m"]
|
# We also test with llama because it has generation_config to specify EOS
|
||||||
|
# (past regression).
|
||||||
|
MODELS = ["facebook/opt-125m", "meta-llama/Llama-2-7b-hf"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [1024])
|
@pytest.mark.parametrize("max_tokens", [512])
|
||||||
def test_beam_search_single_input(
|
def test_ignore_eos(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
example_prompts = "1 + 1 is"
|
|
||||||
|
|
||||||
vllm_model = vllm_runner(model, dtype=dtype)
|
vllm_model = vllm_runner(model, dtype=dtype)
|
||||||
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
|
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
|
||||||
|
|
||||||
|
for prompt in example_prompts:
|
||||||
ignore_eos_output = vllm_model.model.generate(
|
ignore_eos_output = vllm_model.model.generate(
|
||||||
example_prompts, sampling_params=sampling_params)
|
prompt, sampling_params=sampling_params)
|
||||||
print(len(ignore_eos_output[0].outputs[0].token_ids))
|
output_length = len(ignore_eos_output[0].outputs[0].token_ids)
|
||||||
assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) < 10
|
assert output_length == max_tokens
|
||||||
assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) >= 0
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user