[1/n][CI] Load models in CI from S3 instead of HF (#13205)

Signed-off-by: <>
Co-authored-by: EC2 Default User <ec2-user@ip-172-31-20-117.us-west-2.compute.internal>
This commit is contained in:
Kevin H. Luu 2025-02-18 23:34:59 -08:00 committed by GitHub
parent fd84857f64
commit d5d214ac7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 225 additions and 76 deletions

View File

@ -37,3 +37,5 @@ genai_perf==0.0.8
tritonclient==2.51.0
numpy < 2.0.0
runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0

View File

@ -171,6 +171,8 @@ huggingface-hub==0.26.2
# tokenizers
# transformers
# vocos
humanize==4.11.0
# via runai-model-streamer
idna==3.10
# via
# anyio
@ -290,6 +292,7 @@ numpy==1.26.4
# patsy
# peft
# rouge-score
# runai-model-streamer
# sacrebleu
# scikit-learn
# scipy
@ -514,6 +517,10 @@ rpds-py==0.20.1
# referencing
rsa==4.7.2
# via awscli
runai-model-streamer==0.11.0
# via -r requirements-test.in
runai-model-streamer-s3==0.11.0
# via -r requirements-test.in
s3transfer==0.10.3
# via
# awscli
@ -594,6 +601,7 @@ torch==2.5.1
# encodec
# lm-eval
# peft
# runai-model-streamer
# sentence-transformers
# tensorizer
# timm

View File

@ -9,6 +9,7 @@ import weakref
import pytest
from vllm import LLM
from vllm.config import LoadFormat
from vllm.platforms import current_platform
from ..conftest import VllmRunner
@ -33,7 +34,7 @@ def v1(run_with_both_engines):
def test_vllm_gc_ed():
"""Verify vllm instance is GC'ed when it is deleted"""
llm = LLM("facebook/opt-125m")
llm = LLM("distilbert/distilgpt2", load_format=LoadFormat.RUNAI_STREAMER)
weak_llm = weakref.ref(llm)
del llm
# If there's any circular reference to vllm, this fails
@ -94,14 +95,14 @@ def test_models(
@pytest.mark.parametrize(
"model, distributed_executor_backend, attention_backend, "
"test_suite", [
("facebook/opt-125m", "ray", "", "L4"),
("facebook/opt-125m", "mp", "", "L4"),
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4"),
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4"),
("facebook/opt-125m", "ray", "", "A100"),
("facebook/opt-125m", "mp", "", "A100"),
("facebook/opt-125m", "mp", "FLASHINFER", "A100"),
("meta-llama/Llama-3.2-1B-Instruct", "ray", "FLASHINFER", "A100"),
("distilbert/distilgpt2", "ray", "", "L4"),
("distilbert/distilgpt2", "mp", "", "L4"),
("meta-llama/Llama-2-7b-hf", "ray", "", "L4"),
("meta-llama/Llama-2-7b-hf", "mp", "", "L4"),
("distilbert/distilgpt2", "ray", "", "A100"),
("distilbert/distilgpt2", "mp", "", "A100"),
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
])
def test_models_distributed(
hf_runner,

View File

@ -4,9 +4,11 @@ import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.config import LoadFormat
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.utils import GiB_bytes
from ..conftest import MODEL_WEIGHTS_S3_BUCKET
from ..utils import fork_new_process_for_each_test
@ -118,13 +120,18 @@ def test_cumem_with_cudagraph():
@pytest.mark.parametrize(
"model",
[
"meta-llama/Llama-3.2-1B-Instruct", # sleep mode with safetensors
"facebook/opt-125m" # sleep mode with pytorch checkpoint
# sleep mode with safetensors
f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B",
# sleep mode with pytorch checkpoint
"facebook/opt-125m"
])
def test_end_to_end(model):
free, total = torch.cuda.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
llm = LLM(model, enable_sleep_mode=True)
load_format = LoadFormat.AUTO
if "Llama" in model:
load_format = LoadFormat.RUNAI_STREAMER
llm = LLM(model, load_format=load_format, enable_sleep_mode=True)
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
output = llm.generate(prompt, sampling_params)

View File

@ -17,7 +17,7 @@ from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
from ..models.utils import check_outputs_equal
MODELS = [
"facebook/opt-125m",
"distilbert/distilgpt2",
]

View File

@ -24,7 +24,7 @@ from tests.models.utils import (TokensTextLogprobs,
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import TaskOption, TokenizerPoolConfig
from vllm.config import LoadFormat, TaskOption, TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
@ -46,6 +46,21 @@ _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
_M = TypeVar("_M")
MODELS_ON_S3 = [
"distilbert/distilgpt2",
"meta-llama/Llama-2-7b-hf",
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Llama-3.2-1B",
"meta-llama/Llama-3.2-1B-Instruct",
"openai-community/gpt2",
"ArthurZ/Ilama-3.2-1B",
"llava-hf/llava-1.5-7b-hf",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
]
MODEL_WEIGHTS_S3_BUCKET = "s3://vllm-ci-model-weights"
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
PromptImageInput = _PromptMultiModalInput[Image.Image]
@ -677,8 +692,15 @@ class VllmRunner:
enable_chunked_prefill: bool = False,
swap_space: int = 4,
enforce_eager: Optional[bool] = False,
load_format: Optional[LoadFormat] = None,
**kwargs,
) -> None:
if model_name in MODELS_ON_S3 and not load_format:
model_name = (f"s3://vllm-ci-model-weights/"
f"{model_name.split('/')[-1]}")
load_format = LoadFormat.RUNAI_STREAMER
if not load_format:
load_format = LoadFormat.AUTO
self.model = LLM(
model=model_name,
task=task,
@ -693,6 +715,7 @@ class VllmRunner:
max_model_len=max_model_len,
block_size=block_size,
enable_chunked_prefill=enable_chunked_prefill,
load_format=load_format,
**kwargs,
)

View File

@ -2,12 +2,15 @@
import pytest
from vllm.config import LoadFormat
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
from ..conftest import MODEL_WEIGHTS_S3_BUCKET
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"])
@pytest.mark.parametrize("block_size", [16])
def test_computed_prefix_blocks(model: str, block_size: int):
# This test checks if we are able to run the engine to completion
@ -24,6 +27,7 @@ def test_computed_prefix_blocks(model: str, block_size: int):
"decoration.")
engine_args = EngineArgs(model=model,
load_format=LoadFormat.RUNAI_STREAMER,
block_size=block_size,
enable_prefix_caching=True)

View File

@ -2,11 +2,14 @@
import pytest
from vllm.config import LoadFormat
from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams
from ..conftest import MODEL_WEIGHTS_S3_BUCKET
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"])
def test_computed_prefix_blocks(model: str):
# This test checks if the engine generates completions both with and
# without optional detokenization, that detokenization includes text
@ -17,7 +20,7 @@ def test_computed_prefix_blocks(model: str):
"paper clips? Is there an easy to follow video tutorial available "
"online for free?")
llm = LLM(model=model)
llm = LLM(model=model, load_format=LoadFormat.RUNAI_STREAMER)
sampling_params = SamplingParams(max_tokens=10,
temperature=0.0,
detokenize=False)

View File

@ -6,12 +6,17 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import pytest
from vllm.config import LoadFormat
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.uniproc_executor import UniProcExecutor
from vllm.sampling_params import SamplingParams
from ..conftest import MODEL_WEIGHTS_S3_BUCKET
RUNAI_STREAMER_LOAD_FORMAT = LoadFormat.RUNAI_STREAMER
class Mock:
...
@ -33,10 +38,11 @@ class CustomUniExecutor(UniProcExecutor):
CustomUniExecutorAsync = CustomUniExecutor
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"])
def test_custom_executor_type_checking(model):
with pytest.raises(ValueError):
engine_args = EngineArgs(model=model,
load_format=RUNAI_STREAMER_LOAD_FORMAT,
distributed_executor_backend=Mock)
LLMEngine.from_engine_args(engine_args)
with pytest.raises(ValueError):
@ -45,7 +51,7 @@ def test_custom_executor_type_checking(model):
AsyncLLMEngine.from_engine_args(engine_args)
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"])
def test_custom_executor(model, tmp_path):
cwd = os.path.abspath(".")
os.chdir(tmp_path)
@ -54,6 +60,7 @@ def test_custom_executor(model, tmp_path):
engine_args = EngineArgs(
model=model,
load_format=RUNAI_STREAMER_LOAD_FORMAT,
distributed_executor_backend=CustomUniExecutor,
enforce_eager=True, # reduce test time
)
@ -68,7 +75,7 @@ def test_custom_executor(model, tmp_path):
os.chdir(cwd)
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"])
def test_custom_executor_async(model, tmp_path):
cwd = os.path.abspath(".")
os.chdir(tmp_path)
@ -77,6 +84,7 @@ def test_custom_executor_async(model, tmp_path):
engine_args = AsyncEngineArgs(
model=model,
load_format=RUNAI_STREAMER_LOAD_FORMAT,
distributed_executor_backend=CustomUniExecutorAsync,
enforce_eager=True, # reduce test time
)
@ -95,7 +103,7 @@ def test_custom_executor_async(model, tmp_path):
os.chdir(cwd)
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"])
def test_respect_ray(model):
# even for TP=1 and PP=1,
# if users specify ray, we should use ray.
@ -104,6 +112,7 @@ def test_respect_ray(model):
engine_args = EngineArgs(
model=model,
distributed_executor_backend="ray",
load_format=RUNAI_STREAMER_LOAD_FORMAT,
enforce_eager=True, # reduce test time
)
engine = LLMEngine.from_engine_args(engine_args)

View File

@ -2,16 +2,21 @@
import pytest
from vllm.config import LoadFormat
from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams
from ..conftest import MODEL_WEIGHTS_S3_BUCKET
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("model", [f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2"])
def test_skip_tokenizer_initialization(model: str):
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(model=model, skip_tokenizer_init=True)
llm = LLM(model=model,
skip_tokenizer_init=True,
load_format=LoadFormat.RUNAI_STREAMER)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
with pytest.raises(ValueError, match="cannot pass text prompts when"):

View File

@ -12,7 +12,7 @@ import transformers
from vllm import SamplingParams
MODEL = "facebook/opt-350m"
MODEL = "distilbert/distilgpt2"
STOP_STR = "."
SEED = 42
MAX_TOKENS = 1024

View File

@ -5,12 +5,17 @@ from typing import List
import pytest
from vllm import LLM
from vllm.config import LoadFormat
from ...conftest import MODEL_WEIGHTS_S3_BUCKET
from ..openai.test_vision import TEST_IMAGE_URLS
RUNAI_STREAMER_LOAD_FORMAT = LoadFormat.RUNAI_STREAMER
def test_chat():
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B-Instruct",
load_format=RUNAI_STREAMER_LOAD_FORMAT)
prompt1 = "Explain the concept of entropy."
messages = [
@ -28,7 +33,8 @@ def test_chat():
def test_multi_chat():
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B-Instruct",
load_format=RUNAI_STREAMER_LOAD_FORMAT)
prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."
@ -65,7 +71,8 @@ def test_multi_chat():
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
model=f"{MODEL_WEIGHTS_S3_BUCKET}/Phi-3.5-vision-instruct",
load_format=RUNAI_STREAMER_LOAD_FORMAT,
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,

View File

@ -28,7 +28,7 @@ def test_collective_rpc(tp_size, backend):
def echo_rank(self):
return self.rank
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
llm = LLM(model="s3://vllm-ci-model-weights/Llama-3.2-1B-Instruct",
enforce_eager=True,
load_format="dummy",
tensor_parallel_size=tp_size,

View File

@ -6,9 +6,10 @@ from typing import List
import pytest
from vllm import LLM, PoolingParams, PoolingRequestOutput
from vllm.config import LoadFormat
from vllm.distributed import cleanup_dist_env_and_memory
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
MODEL_NAME = "s3://vllm-ci-model-weights/e5-mistral-7b-instruct"
PROMPTS = [
"Hello, my name is",
@ -32,6 +33,7 @@ def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
load_format=LoadFormat.RUNAI_STREAMER,
max_num_batched_tokens=32768,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,

View File

@ -6,9 +6,10 @@ from typing import List
import pytest
from vllm import LLM, RequestOutput, SamplingParams
from vllm.config import LoadFormat
from vllm.distributed import cleanup_dist_env_and_memory
MODEL_NAME = "facebook/opt-125m"
MODEL_NAME = "s3://vllm-ci-model-weights/distilgpt2"
PROMPTS = [
"Hello, my name is",
@ -30,6 +31,7 @@ def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
load_format=LoadFormat.RUNAI_STREAMER,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,

View File

@ -7,10 +7,11 @@ import pytest
from huggingface_hub import snapshot_download
from vllm import LLM
from vllm.config import LoadFormat
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME = "s3://vllm-ci-model-weights/zephyr-7b-beta"
PROMPTS = [
"Hello, my name is",
@ -27,6 +28,7 @@ def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
load_format=LoadFormat.RUNAI_STREAMER,
tensor_parallel_size=1,
max_model_len=8192,
enable_lora=True,

View File

@ -7,12 +7,13 @@ import weakref
import jsonschema
import pytest
from vllm.config import LoadFormat
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
MODEL_NAME = "s3://vllm-ci-model-weights/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
@ -20,7 +21,9 @@ GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME, max_model_len=1024)
llm = LLM(model=MODEL_NAME,
load_format=LoadFormat.RUNAI_STREAMER,
max_model_len=1024)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)

View File

@ -6,10 +6,11 @@ from contextlib import nullcontext
from vllm_test_utils import BlameResult, blame
from vllm import LLM, SamplingParams
from vllm.config import LoadFormat
from vllm.distributed import cleanup_dist_env_and_memory
def run_normal():
def run_normal_opt125m():
prompts = [
"Hello, my name is",
"The president of the United States is",
@ -33,9 +34,35 @@ def run_normal():
cleanup_dist_env_and_memory()
def run_normal():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM without guided decoding as a baseline.
llm = LLM(model="s3://vllm-ci-model-weights/distilgpt2",
load_format=LoadFormat.RUNAI_STREAMER,
enforce_eager=True,
gpu_memory_utilization=0.3)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Destroy the LLM object and free up the GPU memory.
del llm
cleanup_dist_env_and_memory()
def run_lmfe(sample_regex):
# Create an LLM with guided decoding enabled.
llm = LLM(model="facebook/opt-125m",
llm = LLM(model="s3://vllm-ci-model-weights/distilgpt2",
load_format=LoadFormat.RUNAI_STREAMER,
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.3)

View File

@ -3,6 +3,7 @@
import pytest
from vllm import LLM
from vllm.config import LoadFormat
@pytest.fixture(autouse=True)
@ -14,13 +15,17 @@ def v1(run_with_both_engines):
def test_empty_prompt():
llm = LLM(model="gpt2", enforce_eager=True)
llm = LLM(model="s3://vllm-ci-model-weights/gpt2",
load_format=LoadFormat.RUNAI_STREAMER,
enforce_eager=True)
with pytest.raises(ValueError, match='Prompt cannot be empty'):
llm.generate([""])
@pytest.mark.skip_v1
def test_out_of_vocab_token():
llm = LLM(model="gpt2", enforce_eager=True)
llm = LLM(model="s3://vllm-ci-model-weights/gpt2",
load_format=LoadFormat.RUNAI_STREAMER,
enforce_eager=True)
with pytest.raises(ValueError, match='out of vocabulary'):
llm.generate({"prompt_token_ids": [999999]})

View File

@ -8,16 +8,21 @@ import ray
from prometheus_client import REGISTRY
from vllm import EngineArgs, LLMEngine
from vllm.config import LoadFormat
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import RayPrometheusStatLogger
from vllm.sampling_params import SamplingParams
from ..conftest import MODEL_WEIGHTS_S3_BUCKET
MODELS = [
"facebook/opt-125m",
"distilbert/distilgpt2",
]
RUNAI_STREAMER_LOAD_FORMAT = LoadFormat.RUNAI_STREAMER
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@ -141,8 +146,9 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str,
metrics_tag_content = stat_logger.labels["model_name"]
if served_model_name is None or served_model_name == []:
assert metrics_tag_content == model, (
f"Metrics tag model_name is wrong! expect: {model!r}\n"
actual_model_name = f"{MODEL_WEIGHTS_S3_BUCKET}/{model.split('/')[-1]}"
assert metrics_tag_content == actual_model_name, (
f"Metrics tag model_name is wrong! expect: {actual_model_name!r}\n"
f"actual: {metrics_tag_content!r}")
else:
assert metrics_tag_content == served_model_name[0], (
@ -170,7 +176,8 @@ async def test_async_engine_log_metrics_regression(
"""
engine_args = AsyncEngineArgs(model=model,
dtype=dtype,
disable_log_stats=disable_log_stats)
disable_log_stats=disable_log_stats,
load_format=RUNAI_STREAMER_LOAD_FORMAT)
async_engine = AsyncLLMEngine.from_engine_args(engine_args)
for i, prompt in enumerate(example_prompts):
results = async_engine.generate(
@ -199,7 +206,8 @@ def test_engine_log_metrics_regression(
) -> None:
engine_args = EngineArgs(model=model,
dtype=dtype,
disable_log_stats=disable_log_stats)
disable_log_stats=disable_log_stats,
load_format=RUNAI_STREAMER_LOAD_FORMAT)
engine = LLMEngine.from_engine_args(engine_args)
for i, prompt in enumerate(example_prompts):
engine.add_request(
@ -283,7 +291,8 @@ def test_metric_spec_decode_interval(
gpu_memory_utilization=0.4,
speculative_model=model,
num_speculative_tokens=k,
enforce_eager=True)
enforce_eager=True,
load_format=RUNAI_STREAMER_LOAD_FORMAT)
engine = LLMEngine.from_engine_args(engine_args)

View File

@ -173,7 +173,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
trust_remote_code=True),
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct"),
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct",
extras={"2.5": "Qwen/Qwen2.5-7B-Instruct"}), # noqa: E501
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b",
is_available_online=False),

View File

@ -7,6 +7,7 @@ from transformers import PretrainedConfig
from vllm import LLM
from ..conftest import MODELS_ON_S3
from .registry import HF_EXAMPLE_MODELS
@ -42,8 +43,11 @@ def test_can_initialize(model_arch):
with patch.object(LLM.get_engine_class(), "_initialize_kv_caches",
_initialize_kv_caches):
model_name = model_info.default
if model_name in MODELS_ON_S3:
model_name = f"s3://vllm-ci-model-weights/{model_name.split('/')[-1]}"
LLM(
model_info.default,
model_name,
tokenizer=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode,
speculative_model=model_info.speculative_model,

View File

@ -10,8 +10,8 @@ import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
from vllm.engine.arg_utils import AsyncEngineArgs
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
MODEL = "s3://vllm-ci-model-weights/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, load_format="runai_streamer")
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"
EXPECTED_TOKENS = 250

View File

@ -21,8 +21,10 @@ from vllm.lora.request import LoRARequest
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True)
MODEL = "s3://vllm-ci-model-weights/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL,
load_format="runai_streamer",
enforce_eager=True)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"

View File

@ -10,12 +10,14 @@ import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
from vllm.engine.arg_utils import AsyncEngineArgs
MODEL = "google/gemma-1.1-2b-it"
MODEL = "s3://vllm-ci-model-weights/gemma-1.1-2b-it"
NUM_EXPECTED_TOKENS = 10
NUM_REQUESTS = 10000
# Scenarios to test for num generated token.
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True)
ENGINE_ARGS = AsyncEngineArgs(model=MODEL,
load_format="runai_streamer",
disable_log_requests=True)
@pytest.fixture(scope="function")

View File

@ -553,7 +553,8 @@ def test_find_mm_placeholders(
assert result == expected
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
"model_id", ["s3://vllm-ci-model-weights/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
("limit", "num_supported", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
@ -592,7 +593,8 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
profiler.get_dummy_data(model_config.max_model_len)
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
"model_id", ["s3://vllm-ci-model-weights/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
("num_images", "limit", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
@ -661,7 +663,7 @@ class _ProcessorProxy:
return dict(exists=exists)
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-7B-Instruct"]) # Dummy
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # Dummy
# yapf: disable
@pytest.mark.parametrize(
("call_kwargs", "expected_kwargs"),

View File

@ -10,7 +10,7 @@ from vllm import SamplingParams
# We also test with llama because it has generation_config to specify EOS
# (past regression).
MODELS = ["facebook/opt-125m", "meta-llama/Llama-3.2-1B-Instruct"]
MODELS = ["distilbert/distilgpt2", "meta-llama/Llama-3.2-1B"]
@pytest.mark.parametrize("model", MODELS)

View File

@ -5,7 +5,7 @@ import torch
from vllm import SamplingParams
MODELS = ["facebook/opt-125m"]
MODELS = ["distilbert/distilgpt2"]
@pytest.mark.parametrize("model", MODELS)

View File

@ -9,7 +9,7 @@ from vllm import SamplingParams
from ..conftest import VllmRunner
MODELS = ["facebook/opt-125m"]
MODELS = ["distilbert/distilgpt2"]
@pytest.mark.parametrize("model", MODELS)

View File

@ -76,7 +76,7 @@ class TestOneTokenBadWord:
class TestTwoTokenBadWord:
# Another model (with a different tokenizer behaviour)
MODEL = "openai-community/gpt2"
MODEL = "distilbert/distilgpt2"
PROMPT = "How old are you? I am 10"
TARGET_TOKEN1 = "years"

View File

@ -4,7 +4,7 @@ import pytest
from vllm import SamplingParams
MODELS = ["facebook/opt-125m"]
MODELS = ["distilbert/distilgpt2"]
@pytest.mark.parametrize("model", MODELS)

View File

@ -8,14 +8,19 @@ from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
from .conftest import MODEL_WEIGHTS_S3_BUCKET
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[
("facebook/opt-125m", "generate", "generate"),
("intfloat/e5-mistral-7b-instruct", "pooling", "embed"),
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"),
(f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2", "generate", "generate"),
(f"{MODEL_WEIGHTS_S3_BUCKET}/e5-mistral-7b-instruct", "pooling",
"embed"),
(f"{MODEL_WEIGHTS_S3_BUCKET}/Qwen2.5-1.5B-apeach", "pooling",
"classify"),
(f"{MODEL_WEIGHTS_S3_BUCKET}/ms-marco-MiniLM-L-6-v2", "pooling",
"score"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
("openai/whisper-small", "transcription", "transcription"),
],

View File

@ -10,6 +10,9 @@ import gc
import torch
from vllm import LLM, SamplingParams
from vllm.config import LoadFormat
from .conftest import MODEL_WEIGHTS_S3_BUCKET
def test_duplicated_ignored_sequence_group():
@ -18,7 +21,8 @@ def test_duplicated_ignored_sequence_group():
sampling_params = SamplingParams(temperature=0.01,
top_p=0.1,
max_tokens=256)
llm = LLM(model="facebook/opt-125m",
llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2",
load_format=LoadFormat.RUNAI_STREAMER,
max_num_batched_tokens=4096,
tensor_parallel_size=1)
prompts = ["This is a short prompt", "This is a very long prompt " * 1000]
@ -31,7 +35,8 @@ def test_max_tokens_none():
sampling_params = SamplingParams(temperature=0.01,
top_p=0.1,
max_tokens=None)
llm = LLM(model="facebook/opt-125m",
llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2",
load_format=LoadFormat.RUNAI_STREAMER,
max_num_batched_tokens=4096,
tensor_parallel_size=1)
prompts = ["Just say hello!"]
@ -41,7 +46,9 @@ def test_max_tokens_none():
def test_gc():
llm = LLM("facebook/opt-125m", enforce_eager=True)
llm = LLM(model=f"{MODEL_WEIGHTS_S3_BUCKET}/distilgpt2",
load_format=LoadFormat.RUNAI_STREAMER,
enforce_eager=True)
del llm
gc.collect()

View File

@ -10,7 +10,7 @@ from vllm.worker.worker import Worker
def test_swap() -> None:
# Configure the engine.
engine_args = EngineArgs(model="facebook/opt-125m",
engine_args = EngineArgs(model="s3://vllm-ci-model-weights/distilgpt2",
dtype="half",
load_format="dummy")
engine_config = engine_args.create_engine_config()

View File

@ -409,7 +409,8 @@ class ModelConfig:
if is_s3(model) or is_s3(tokenizer):
if is_s3(model):
s3_model = S3Model()
s3_model.pull_files(model, allow_pattern=["*config.json"])
s3_model.pull_files(
model, allow_pattern=["*.model", "*.py", "*.json"])
self.model_weights = self.model
self.model = s3_model.dir

View File

@ -1327,6 +1327,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_s3_path = is_s3(model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
@ -1340,7 +1341,6 @@ class RunaiModelStreamerLoader(BaseModelLoader):
revision,
ignore_patterns=self.load_config.ignore_patterns,
))
if is_s3_path:
hf_weights_files = s3_glob(path=hf_folder,
allow_pattern=[safetensors_pattern])

View File

@ -27,6 +27,8 @@ from vllm.model_executor.layers.quantization import (QuantizationConfig,
from vllm.platforms import current_platform
from vllm.utils import PlaceholderModule
logger = init_logger(__name__)
try:
from runai_model_streamer import SafetensorsStreamer
except (ImportError, OSError):
@ -37,8 +39,6 @@ except (ImportError, OSError):
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
"SafetensorsStreamer")
logger = init_logger(__name__)
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the

View File

@ -144,7 +144,6 @@ def file_exists(
revision: Optional[str] = None,
token: Union[str, bool, None] = None,
) -> bool:
file_list = list_repo_files(repo_id,
repo_type=repo_type,
revision=revision,
@ -498,7 +497,7 @@ def get_sentence_transformer_tokenizer_config(model: str,
if encoder_dict:
break
if not encoder_dict:
if not encoder_dict and not model.startswith("/"):
try:
# If model is on HuggingfaceHub, get the repo files
repo_files = list_repo_files(model,

View File

@ -46,6 +46,8 @@ def glob(s3=None,
"""
if s3 is None:
s3 = boto3.client("s3")
if not path.endswith("/"):
path = path + "/"
bucket_name, _, paths = list_files(s3,
path=path,
allow_pattern=allow_pattern)
@ -109,6 +111,7 @@ class S3Model:
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler))
self.dir = tempfile.mkdtemp()
def __del__(self):
@ -140,6 +143,9 @@ class S3Model:
ignore_pattern: A list of patterns of which files not to pull.
"""
if not s3_model_path.endswith("/"):
s3_model_path = s3_model_path + "/"
bucket_name, base_dir, files = list_files(self.s3, s3_model_path,
allow_pattern,
ignore_pattern)
@ -147,8 +153,9 @@ class S3Model:
return
for file in files:
destination_file = os.path.join(self.dir,
file.removeprefix(base_dir))
destination_file = os.path.join(
self.dir,
file.removeprefix(base_dir).lstrip("/"))
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
self.s3.download_file(bucket_name, file, destination_file)