2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-08-13 16:24:17 -07:00
|
|
|
import pytest
|
2024-04-06 17:11:41 -07:00
|
|
|
|
2024-12-13 18:40:07 +08:00
|
|
|
from vllm import LLM, SamplingParams
|
2024-10-04 10:38:25 -07:00
|
|
|
from vllm.assets.image import ImageAsset
|
2024-04-06 17:11:41 -07:00
|
|
|
|
2025-03-17 19:33:35 +08:00
|
|
|
from ..utils import create_new_process_for_each_test
|
2024-04-06 17:11:41 -07:00
|
|
|
|
|
|
|
|
2025-03-17 19:33:35 +08:00
|
|
|
@create_new_process_for_each_test()
|
2025-03-17 11:35:57 +08:00
|
|
|
def test_plugin(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
dummy_opt_path: str,
|
|
|
|
):
|
2025-03-15 01:02:20 -04:00
|
|
|
# V1 shuts down rather than raising an error here.
|
2025-03-17 11:35:57 +08:00
|
|
|
with monkeypatch.context() as m:
|
|
|
|
m.setenv("VLLM_USE_V1", "0")
|
|
|
|
m.setenv("VLLM_PLUGINS", "")
|
2024-04-06 17:11:41 -07:00
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
with pytest.raises(Exception) as excinfo:
|
|
|
|
LLM(model=dummy_opt_path, load_format="dummy")
|
|
|
|
error_msg = "has no vLLM implementation and the Transformers implementation is not compatible with vLLM" # noqa: E501
|
|
|
|
assert (error_msg in str(excinfo.value))
|
2024-04-06 17:11:41 -07:00
|
|
|
|
|
|
|
|
2025-03-17 19:33:35 +08:00
|
|
|
@create_new_process_for_each_test()
|
2025-03-17 11:35:57 +08:00
|
|
|
def test_oot_registration_text_generation(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
dummy_opt_path: str,
|
|
|
|
):
|
|
|
|
with monkeypatch.context() as m:
|
|
|
|
m.setenv("VLLM_PLUGINS", "register_dummy_model")
|
|
|
|
prompts = ["Hello, my name is", "The text does not matter"]
|
|
|
|
sampling_params = SamplingParams(temperature=0)
|
|
|
|
llm = LLM(model=dummy_opt_path, load_format="dummy")
|
|
|
|
first_token = llm.get_tokenizer().decode(0)
|
|
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
|
|
|
|
for output in outputs:
|
|
|
|
generated_text = output.outputs[0].text
|
|
|
|
# make sure only the first token is generated
|
|
|
|
rest = generated_text.replace(first_token, "")
|
|
|
|
assert rest == ""
|
2024-10-04 10:38:25 -07:00
|
|
|
|
|
|
|
|
2025-03-17 19:33:35 +08:00
|
|
|
@create_new_process_for_each_test()
|
2025-03-17 11:35:57 +08:00
|
|
|
def test_oot_registration_embedding(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
dummy_gemma2_embedding_path: str,
|
|
|
|
):
|
|
|
|
with monkeypatch.context() as m:
|
|
|
|
m.setenv("VLLM_PLUGINS", "register_dummy_model")
|
|
|
|
prompts = ["Hello, my name is", "The text does not matter"]
|
|
|
|
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
|
|
|
|
outputs = llm.embed(prompts)
|
2024-10-07 14:10:35 +08:00
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
for output in outputs:
|
|
|
|
assert all(v == 0 for v in output.outputs.embedding)
|
2024-10-07 14:10:35 +08:00
|
|
|
|
|
|
|
|
2024-10-04 10:38:25 -07:00
|
|
|
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
|
|
|
|
|
|
|
|
|
2025-03-17 19:33:35 +08:00
|
|
|
@create_new_process_for_each_test()
|
2025-03-17 11:35:57 +08:00
|
|
|
def test_oot_registration_multimodal(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
dummy_llava_path: str,
|
|
|
|
):
|
|
|
|
with monkeypatch.context() as m:
|
|
|
|
m.setenv("VLLM_PLUGINS", "register_dummy_model")
|
|
|
|
prompts = [{
|
|
|
|
"prompt": "What's in the image?<image>",
|
|
|
|
"multi_modal_data": {
|
|
|
|
"image": image
|
|
|
|
},
|
|
|
|
}, {
|
|
|
|
"prompt": "Describe the image<image>",
|
|
|
|
"multi_modal_data": {
|
|
|
|
"image": image
|
|
|
|
},
|
|
|
|
}]
|
|
|
|
|
|
|
|
sampling_params = SamplingParams(temperature=0)
|
|
|
|
llm = LLM(model=dummy_llava_path,
|
|
|
|
load_format="dummy",
|
|
|
|
max_num_seqs=1,
|
|
|
|
trust_remote_code=True,
|
|
|
|
gpu_memory_utilization=0.98,
|
|
|
|
max_model_len=4096,
|
|
|
|
enforce_eager=True,
|
|
|
|
limit_mm_per_prompt={"image": 1})
|
2025-04-12 16:52:39 +08:00
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
first_token = llm.get_tokenizer().decode(0)
|
|
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
|
|
|
|
for output in outputs:
|
|
|
|
generated_text = output.outputs[0].text
|
|
|
|
# make sure only the first token is generated
|
|
|
|
rest = generated_text.replace(first_token, "")
|
|
|
|
assert rest == ""
|