
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
70 lines
2.0 KiB
Python
70 lines
2.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm import SamplingParams
|
|
|
|
MODELS = ["distilbert/distilgpt2"]
|
|
|
|
|
|
@pytest.fixture(scope="function", autouse=True)
|
|
def use_v0_only(monkeypatch):
|
|
"""
|
|
This file tests V0 internals, so set VLLM_USE_V1=0.
|
|
"""
|
|
monkeypatch.setenv('VLLM_USE_V1', '0')
|
|
|
|
|
|
@pytest.mark.parametrize("model", MODELS)
|
|
@pytest.mark.parametrize("dtype", ["half"])
|
|
def test_logits_processor_force_generate(
|
|
vllm_runner,
|
|
example_prompts,
|
|
model: str,
|
|
dtype: str,
|
|
) -> None:
|
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
|
tokenizer = vllm_model.model.get_tokenizer()
|
|
repeat_times = 2
|
|
enforced_answers = " vLLM"
|
|
vllm_token_ids = tokenizer.encode(enforced_answers,
|
|
add_special_tokens=False)
|
|
max_tokens = len(vllm_token_ids) * repeat_times
|
|
|
|
def pick_vllm(token_ids, logits):
|
|
token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
|
|
logits[token_id] = torch.finfo(logits.dtype).max
|
|
return logits
|
|
|
|
params_with_logprobs = SamplingParams(
|
|
logits_processors=[pick_vllm],
|
|
prompt_logprobs=3,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
# test logits_processors when prompt_logprobs is not None
|
|
vllm_model.model._add_request(
|
|
example_prompts[0],
|
|
params=params_with_logprobs,
|
|
)
|
|
|
|
# test prompt_logprobs is not None
|
|
vllm_model.model._add_request(
|
|
example_prompts[1],
|
|
params=SamplingParams(
|
|
prompt_logprobs=3,
|
|
max_tokens=max_tokens,
|
|
),
|
|
)
|
|
|
|
# test grouped requests
|
|
vllm_model.model._add_request(
|
|
example_prompts[2],
|
|
params=SamplingParams(max_tokens=max_tokens),
|
|
)
|
|
|
|
outputs = vllm_model.model._run_engine(use_tqdm=False)
|
|
|
|
assert outputs[0].outputs[0].text == enforced_answers * repeat_times
|