32 lines
1.0 KiB
Python
32 lines
1.0 KiB
Python
![]() |
from vllm import SamplingParams
|
||
|
from vllm.config import LoadConfig, LoadFormat
|
||
|
from vllm.model_executor.model_loader.loader import (RunaiModelStreamerLoader,
|
||
|
get_model_loader)
|
||
|
|
||
|
test_model = "openai-community/gpt2"
|
||
|
|
||
|
prompts = [
|
||
|
"Hello, my name is",
|
||
|
"The president of the United States is",
|
||
|
"The capital of France is",
|
||
|
"The future of AI is",
|
||
|
]
|
||
|
# Create a sampling params object.
|
||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
|
||
|
|
||
|
|
||
|
def get_runai_model_loader():
|
||
|
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER)
|
||
|
return get_model_loader(load_config)
|
||
|
|
||
|
|
||
|
def test_get_model_loader_with_runai_flag():
|
||
|
model_loader = get_runai_model_loader()
|
||
|
assert isinstance(model_loader, RunaiModelStreamerLoader)
|
||
|
|
||
|
|
||
|
def test_runai_model_loader_download_files(vllm_runner):
|
||
|
with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm:
|
||
|
deserialized_outputs = llm.generate(prompts, sampling_params)
|
||
|
assert deserialized_outputs
|