
Signed-off-by: <> Co-authored-by: EC2 Default User <ec2-user@ip-172-31-20-117.us-west-2.compute.internal>
71 lines
2.0 KiB
Python
71 lines
2.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import weakref
|
|
|
|
import pytest
|
|
# downloading lora to test lora requests
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from vllm import LLM
|
|
from vllm.config import LoadFormat
|
|
from vllm.distributed import cleanup_dist_env_and_memory
|
|
from vllm.lora.request import LoRARequest
|
|
|
|
MODEL_NAME = "s3://vllm-ci-model-weights/zephyr-7b-beta"
|
|
|
|
PROMPTS = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
|
|
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def llm():
|
|
# pytest caches the fixture so we use weakref.proxy to
|
|
# enable garbage collection
|
|
llm = LLM(model=MODEL_NAME,
|
|
load_format=LoadFormat.RUNAI_STREAMER,
|
|
tensor_parallel_size=1,
|
|
max_model_len=8192,
|
|
enable_lora=True,
|
|
max_loras=4,
|
|
max_lora_rank=64,
|
|
max_num_seqs=128,
|
|
enforce_eager=True)
|
|
|
|
with llm.deprecate_legacy_api():
|
|
yield weakref.proxy(llm)
|
|
|
|
del llm
|
|
|
|
cleanup_dist_env_and_memory()
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def zephyr_lora_files():
|
|
return snapshot_download(repo_id=LORA_NAME)
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
|
|
lora_request = [
|
|
LoRARequest(LORA_NAME + str(idx), idx + 1, zephyr_lora_files)
|
|
for idx in range(len(PROMPTS))
|
|
]
|
|
# Multiple SamplingParams should be matched with each prompt
|
|
outputs = llm.generate(PROMPTS, lora_request=lora_request)
|
|
assert len(PROMPTS) == len(outputs)
|
|
|
|
# Exception raised, if the size of params does not match the size of prompts
|
|
with pytest.raises(ValueError):
|
|
outputs = llm.generate(PROMPTS, lora_request=lora_request[:1])
|
|
|
|
# Single LoRARequest should be applied to every prompt
|
|
single_lora_request = lora_request[0]
|
|
outputs = llm.generate(PROMPTS, lora_request=single_lora_request)
|
|
assert len(PROMPTS) == len(outputs)
|