
[CI/Test] improve robustness of test by replacing del with context manager (vllm_runner) (#5357)
55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
import pytest
|
|
|
|
from vllm import SamplingParams
|
|
|
|
MODELS = ["facebook/opt-125m"]
|
|
|
|
|
|
@pytest.mark.parametrize("model", MODELS)
|
|
@pytest.mark.parametrize("dtype", ["half"])
|
|
def test_ranks(
|
|
vllm_runner,
|
|
model,
|
|
dtype,
|
|
example_prompts,
|
|
):
|
|
max_tokens = 5
|
|
num_top_logprobs = 5
|
|
num_prompt_logprobs = 5
|
|
|
|
with vllm_runner(model, dtype=dtype,
|
|
max_logprobs=num_top_logprobs) as vllm_model:
|
|
|
|
## Test greedy logprobs ranks
|
|
vllm_sampling_params = SamplingParams(
|
|
temperature=0.0,
|
|
top_p=1.0,
|
|
max_tokens=max_tokens,
|
|
logprobs=num_top_logprobs,
|
|
prompt_logprobs=num_prompt_logprobs)
|
|
vllm_results = vllm_model.generate_w_logprobs(example_prompts,
|
|
vllm_sampling_params)
|
|
|
|
## Test non-greedy logprobs ranks
|
|
sampling_params = SamplingParams(temperature=1.0,
|
|
top_p=1.0,
|
|
max_tokens=max_tokens,
|
|
logprobs=num_top_logprobs,
|
|
prompt_logprobs=num_prompt_logprobs)
|
|
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
|
|
|
|
for result in vllm_results:
|
|
assert result[2] is not None
|
|
assert len(result[2]) == len(result[0])
|
|
# check whether all chosen tokens have ranks = 1
|
|
for token, logprobs in zip(result[0], result[2]):
|
|
assert token in logprobs
|
|
assert logprobs[token].rank == 1
|
|
|
|
for result in res:
|
|
assert result[2] is not None
|
|
assert len(result[2]) == len(result[0])
|
|
# check whether all chosen tokens have ranks
|
|
for token, logprobs in zip(result[0], result[2]):
|
|
assert logprobs[token].rank >= 1
|