
Signed-off-by: <> Co-authored-by: EC2 Default User <ec2-user@ip-172-31-20-117.us-west-2.compute.internal>
62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm import SamplingParams
|
|
|
|
MODELS = ["distilbert/distilgpt2"]
|
|
|
|
|
|
@pytest.mark.parametrize("model", MODELS)
|
|
@pytest.mark.parametrize("dtype", ["half"])
|
|
def test_logits_processor_force_generate(
|
|
vllm_runner,
|
|
example_prompts,
|
|
model: str,
|
|
dtype: str,
|
|
) -> None:
|
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
|
tokenizer = vllm_model.model.get_tokenizer()
|
|
repeat_times = 2
|
|
enforced_answers = " vLLM"
|
|
vllm_token_ids = tokenizer.encode(enforced_answers,
|
|
add_special_tokens=False)
|
|
max_tokens = len(vllm_token_ids) * repeat_times
|
|
|
|
def pick_vllm(token_ids, logits):
|
|
token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
|
|
logits[token_id] = torch.finfo(logits.dtype).max
|
|
return logits
|
|
|
|
params_with_logprobs = SamplingParams(
|
|
logits_processors=[pick_vllm],
|
|
prompt_logprobs=3,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
# test logits_processors when prompt_logprobs is not None
|
|
vllm_model.model._add_request(
|
|
example_prompts[0],
|
|
params=params_with_logprobs,
|
|
)
|
|
|
|
# test prompt_logprobs is not None
|
|
vllm_model.model._add_request(
|
|
example_prompts[1],
|
|
params=SamplingParams(
|
|
prompt_logprobs=3,
|
|
max_tokens=max_tokens,
|
|
),
|
|
)
|
|
|
|
# test grouped requests
|
|
vllm_model.model._add_request(
|
|
example_prompts[2],
|
|
params=SamplingParams(max_tokens=max_tokens),
|
|
)
|
|
|
|
outputs = vllm_model.model._run_engine(use_tqdm=False)
|
|
|
|
assert outputs[0].outputs[0].text == enforced_answers * repeat_times
|