[Misc] Split up pooling tasks (#10820)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-11 17:28:00 +08:00 committed by GitHub
parent 40766ca1b8
commit 8f10d5e393
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 527 additions and 168 deletions

View File

@ -94,6 +94,8 @@ Documentation
:caption: Models
models/supported_models
models/generative_models
models/pooling_models
models/adding_model
models/enabling_multimodal_inputs

View File

@ -0,0 +1,146 @@
.. _generative_models:
Generative Models
=================
vLLM provides first-class support for generative models, which covers most of LLMs.
In vLLM, generative models implement the :class:`~vllm.model_executor.models.VllmModelForTextGeneration` interface.
Based on the final hidden states of the input, these models output log probabilities of the tokens to generate,
which are then passed through :class:`~vllm.model_executor.layers.Sampler` to obtain the final text.
Offline Inference
-----------------
The :class:`~vllm.LLM` class provides various methods for offline inference.
See :ref:`Engine Arguments <engine_args>` for a list of options when initializing the model.
For generative models, the only supported :code:`task` option is :code:`"generate"`.
Usually, this is automatically inferred so you don't have to specify it.
``LLM.generate``
^^^^^^^^^^^^^^^^
The :class:`~vllm.LLM.generate` method is available to all generative models in vLLM.
It is similar to `its counterpart in HF Transformers <https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate>`__,
except that tokenization and detokenization are also performed automatically.
.. code-block:: python
llm = LLM(model="facebook/opt-125m")
outputs = llm.generate("Hello, my name is")
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
You can optionally control the language generation by passing :class:`~vllm.SamplingParams`.
For example, you can use greedy sampling by setting :code:`temperature=0`:
.. code-block:: python
llm = LLM(model="facebook/opt-125m")
params = SamplingParams(temperature=0)
outputs = llm.generate("Hello, my name is", params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
A code example can be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_.
``LLM.beam_search``
^^^^^^^^^^^^^^^^^^^
The :class:`~vllm.LLM.beam_search` method implements `beam search <https://huggingface.co/docs/transformers/en/generation_strategies#beam-search-decoding>`__ on top of :class:`~vllm.LLM.generate`.
For example, to search using 5 beams and output at most 50 tokens:
.. code-block:: python
llm = LLM(model="facebook/opt-125m")
params = BeamSearchParams(beam_width=5, max_tokens=50)
outputs = llm.generate("Hello, my name is", params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
``LLM.chat``
^^^^^^^^^^^^
The :class:`~vllm.LLM.chat` method implements chat functionality on top of :class:`~vllm.LLM.generate`.
In particular, it accepts input similar to `OpenAI Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`__
and automatically applies the model's `chat template <https://huggingface.co/docs/transformers/en/chat_templating>`__ to format the prompt.
.. important::
In general, only instruction-tuned models have a chat template.
Base models may perform poorly as they are not trained to respond to the chat conversation.
.. code-block:: python
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
conversation = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "Hello"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
outputs = llm.chat(conversation)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
A code example can be found in `examples/offline_inference_chat.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_chat.py>`_.
If the model doesn't have a chat template or you want to specify another one,
you can explicitly pass a chat template:
.. code-block:: python
from vllm.entrypoints.chat_utils import load_chat_template
# You can find a list of existing chat templates under `examples/`
custom_template = load_chat_template(chat_template="<path_to_template>")
print("Loaded chat template:", custom_template)
outputs = llm.chat(conversation, chat_template=custom_template)
Online Inference
----------------
Our `OpenAI Compatible Server <../serving/openai_compatible_server>`__ can be used for online inference.
Please click on the above link for more details on how to launch the server.
Completions API
^^^^^^^^^^^^^^^
Our Completions API is similar to ``LLM.generate`` but only accepts text.
It is compatible with `OpenAI Completions API <https://platform.openai.com/docs/api-reference/completions>`__
so that you can use OpenAI client to interact with it.
A code example can be found in `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
Chat API
^^^^^^^^
Our Chat API is similar to ``LLM.chat``, accepting both text and :ref:`multi-modal inputs <multimodal_inputs>`.
It is compatible with `OpenAI Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`__
so that you can use OpenAI client to interact with it.
A code example can be found in `examples/openai_chat_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_completion_client.py>`_.

View File

@ -0,0 +1,99 @@
.. _pooling_models:
Pooling Models
==============
vLLM also supports pooling models, including embedding, reranking and reward models.
In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface.
These models use a :class:`~vllm.model_executor.layers.Pooler` to aggregate the final hidden states of the input
before returning them.
.. note::
We currently support pooling models primarily as a matter of convenience.
As shown in the :ref:`Compatibility Matrix <compatibility_matrix>`, most vLLM features are not applicable to
pooling models as they only work on the generation or decode stage, so performance may not improve as much.
Offline Inference
-----------------
The :class:`~vllm.LLM` class provides various methods for offline inference.
See :ref:`Engine Arguments <engine_args>` for a list of options when initializing the model.
For pooling models, we support the following :code:`task` options:
- Embedding (:code:`"embed"` / :code:`"embedding"`)
- Classification (:code:`"classify"`)
- Sentence Pair Scoring (:code:`"score"`)
- Reward Modeling (:code:`"reward"`)
The selected task determines the default :class:`~vllm.model_executor.layers.Pooler` that is used:
- Embedding: Extract only the hidden states corresponding to the last token, and apply normalization.
- Classification: Extract only the hidden states corresponding to the last token, and apply softmax.
- Sentence Pair Scoring: Extract only the hidden states corresponding to the last token, and apply softmax.
- Reward Modeling: Extract all of the hidden states and return them directly.
When loading `Sentence Transformers <https://huggingface.co/sentence-transformers>`__ models,
we attempt to override the default pooler based on its Sentence Transformers configuration file (:code:`modules.json`).
You can customize the model's pooling method via the :code:`override_pooler_config` option,
which takes priority over both the model's and Sentence Transformers's defaults.
``LLM.encode``
^^^^^^^^^^^^^^
The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM.
It returns the aggregated hidden states directly.
.. code-block:: python
llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed")
outputs = llm.encode("Hello, my name is")
outputs = model.encode(prompts)
for output in outputs:
embeddings = output.outputs.embedding
print(f"Prompt: {prompt!r}, Embeddings (size={len(embeddings)}: {embeddings!r}")
A code example can be found in `examples/offline_inference_embedding.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_embedding.py>`_.
``LLM.score``
^^^^^^^^^^^^^
The :class:`~vllm.LLM.score` method outputs similarity scores between sentence pairs.
It is primarily designed for `cross-encoder models <https://www.sbert.net/examples/applications/cross-encoder/README.html>`__.
These types of models serve as rerankers between candidate query-document pairs in RAG systems.
.. note::
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
To handle RAG at a higher level, you should use integration frameworks such as `LangChain <https://github.com/langchain-ai/langchain>`_.
You can use `these tests <https://github.com/vllm-project/vllm/blob/main/tests/models/embedding/language/test_scoring.py>`_ as reference.
Online Inference
----------------
Our `OpenAI Compatible Server <../serving/openai_compatible_server>`__ can be used for online inference.
Please click on the above link for more details on how to launch the server.
Embeddings API
^^^^^^^^^^^^^^
Our Embeddings API is similar to ``LLM.encode``, accepting both text and :ref:`multi-modal inputs <multimodal_inputs>`.
The text-only API is compatible with `OpenAI Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`__
so that you can use OpenAI client to interact with it.
A code example can be found in `examples/openai_embedding_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_embedding_client.py>`_.
The multi-modal API is an extension of the `OpenAI Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`__
that incorporates `OpenAI Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`__,
so it is not part of the OpenAI standard. Please see :ref:`this page <multimodal_inputs>` for more details on how to use it.
Score API
^^^^^^^^^
Our Score API is similar to ``LLM.score``.
Please see `this page <../serving/openai_compatible_server.html#score-api-for-cross-encoder-models>`__ for more details on how to use it.

View File

@ -3,11 +3,21 @@
Supported Models
================
vLLM supports a variety of generative and embedding models from `HuggingFace (HF) Transformers <https://huggingface.co/models>`_.
This page lists the model architectures that are currently supported by vLLM.
vLLM supports generative and pooling models across various tasks.
If a model supports more than one task, you can set the task via the :code:`--task` argument.
For each task, we list the model architectures that have been implemented in vLLM.
Alongside each architecture, we include some popular models that use it.
For other models, you can check the :code:`config.json` file inside the model repository.
Loading a Model
^^^^^^^^^^^^^^^
HuggingFace Hub
+++++++++++++++
By default, vLLM loads models from `HuggingFace (HF) Hub <https://huggingface.co/models>`_.
To determine whether a given model is supported, you can check the :code:`config.json` file inside the HF repository.
If the :code:`"architectures"` field contains a model architecture listed below, then it should be supported in theory.
.. tip::
@ -17,38 +27,57 @@ If the :code:`"architectures"` field contains a model architecture listed below,
from vllm import LLM
llm = LLM(model=...) # Name or path of your model
# For generative models (task=generate) only
llm = LLM(model=..., task="generate") # Name or path of your model
output = llm.generate("Hello, my name is")
print(output)
If vLLM successfully generates text, it indicates that your model is supported.
# For pooling models (task={embed,classify,reward}) only
llm = LLM(model=..., task="embed") # Name or path of your model
output = llm.encode("Hello, my name is")
print(output)
If vLLM successfully returns text (for generative models) or hidden states (for pooling models), it indicates that your model is supported.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
for instructions on how to implement your model in vLLM.
Alternatively, you can `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ to request vLLM support.
.. note::
To use models from `ModelScope <https://www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:
ModelScope
++++++++++
.. code-block:: shell
To use models from `ModelScope <https://www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:
.. code-block:: shell
$ export VLLM_USE_MODELSCOPE=True
And use with :code:`trust_remote_code=True`.
And use with :code:`trust_remote_code=True`.
.. code-block:: python
.. code-block:: python
from vllm import LLM
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
llm = LLM(model=..., revision=..., task=..., trust_remote_code=True)
# For generative models (task=generate) only
output = llm.generate("Hello, my name is")
print(output)
Text-only Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^
# For pooling models (task={embed,classify,reward}) only
output = llm.encode("Hello, my name is")
print(output)
Text Generation
---------------
List of Text-only Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Generative Models
+++++++++++++++++
See :ref:`this page <generative_models>` for more information on how to use generative models.
Text Generation (``--task generate``)
-------------------------------------
.. list-table::
:widths: 25 25 50 5 5
@ -328,8 +357,24 @@ Text Generation
.. note::
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
Text Embedding
--------------
Pooling Models
++++++++++++++
See :ref:`this page <pooling_models>` for more information on how to use pooling models.
.. important::
Since some model architectures support both generative and pooling tasks,
you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode.
Text Embedding (``--task embed``)
---------------------------------
Any text generation model can be converted into an embedding model by passing :code:`--task embed`.
.. note::
To get the best results, you should use pooling models that are specifically trained as such.
The following table lists those that are tested in vLLM.
.. list-table::
:widths: 25 25 50 5 5
@ -371,13 +416,6 @@ Text Embedding
-
-
.. important::
Some model architectures support both generation and embedding tasks.
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.
.. note::
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.
@ -389,8 +427,8 @@ Text Embedding
On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention
despite being described otherwise on its model card.
Reward Modeling
---------------
Reward Modeling (``--task reward``)
-----------------------------------
.. list-table::
:widths: 25 25 50 5 5
@ -416,11 +454,8 @@ Reward Modeling
For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
.. note::
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
Classification
---------------
Classification (``--task classify``)
------------------------------------
.. list-table::
:widths: 25 25 50 5 5
@ -437,11 +472,8 @@ Classification
- ✅︎
- ✅︎
.. note::
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
Sentence Pair Scoring
---------------------
Sentence Pair Scoring (``--task score``)
----------------------------------------
.. list-table::
:widths: 25 25 50 5 5
@ -468,13 +500,10 @@ Sentence Pair Scoring
-
-
.. note::
These models are supported in both offline and online inference via Score API.
.. _supported_mm_models:
Multimodal Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^
List of Multimodal Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The following modalities are supported depending on the model:
@ -491,8 +520,15 @@ On the other hand, modalities separated by :code:`/` are mutually exclusive.
- e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs.
Text Generation
---------------
See :ref:`this page <multimodal_inputs>` on how to pass multi-modal inputs to the model.
Generative Models
+++++++++++++++++
See :ref:`this page <generative_models>` for more information on how to use generative models.
Text Generation (``--task generate``)
-------------------------------------
.. list-table::
:widths: 25 25 15 20 5 5 5
@ -696,8 +732,24 @@ Text Generation
The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
Multimodal Embedding
--------------------
Pooling Models
++++++++++++++
See :ref:`this page <pooling_models>` for more information on how to use pooling models.
.. important::
Since some model architectures support both generative and pooling tasks,
you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode.
Text Embedding (``--task embed``)
---------------------------------
Any text generation model can be converted into an embedding model by passing :code:`--task embed`.
.. note::
To get the best results, you should use pooling models that are specifically trained as such.
The following table lists those that are tested in vLLM.
.. list-table::
:widths: 25 25 15 25 5 5
@ -728,12 +780,7 @@ Multimodal Embedding
-
- ✅︎
.. important::
Some model architectures support both generation and embedding tasks.
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.
----
Model Support Policy
=====================

View File

@ -39,13 +39,13 @@ Feature x Feature
- :abbr:`prmpt adptr (Prompt Adapter)`
- :ref:`SD <spec_decode>`
- CUDA graph
- :abbr:`emd (Embedding Models)`
- :abbr:`pooling (Pooling Models)`
- :abbr:`enc-dec (Encoder-Decoder Models)`
- :abbr:`logP (Logprobs)`
- :abbr:`prmpt logP (Prompt Logprobs)`
- :abbr:`async output (Async Output Processing)`
- multi-step
- :abbr:`mm (Multimodal)`
- :abbr:`mm (Multimodal Inputs)`
- best-of
- beam-search
- :abbr:`guided dec (Guided Decoding)`
@ -151,7 +151,7 @@ Feature x Feature
-
-
-
* - :abbr:`emd (Embedding Models)`
* - :abbr:`pooling (Pooling Models)`
- ✗
- ✗
- ✗
@ -253,7 +253,7 @@ Feature x Feature
-
-
-
* - :abbr:`mm (Multimodal)`
* - :abbr:`mm (Multimodal Inputs)`
- ✅
- `✗ <https://github.com/vllm-project/vllm/pull/8348>`__
- `✗ <https://github.com/vllm-project/vllm/pull/7199>`__
@ -386,7 +386,7 @@ Feature x Hardware
- ✅
- ✗
- ✅
* - :abbr:`emd (Embedding Models)`
* - :abbr:`pooling (Pooling Models)`
- ✅
- ✅
- ✅
@ -402,7 +402,7 @@ Feature x Hardware
- ✅
- ✅
- ✗
* - :abbr:`mm (Multimodal)`
* - :abbr:`mm (Multimodal Inputs)`
- ✅
- ✅
- ✅

View File

@ -9,7 +9,12 @@ prompts = [
]
# Create an LLM.
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
model = LLM(
model="intfloat/e5-mistral-7b-instruct",
task="embed", # You should pass task="embed" for embedding models
enforce_eager=True,
)
# Generate embedding. The output is a list of PoolingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.

View File

@ -59,7 +59,7 @@ def run_e5_v(query: Query):
llm = LLM(
model="royokong/e5-v",
task="embedding",
task="embed",
max_model_len=4096,
)
@ -88,7 +88,7 @@ def run_vlm2vec(query: Query):
llm = LLM(
model="TIGER-Lab/VLM2Vec-Full",
task="embedding",
task="embed",
trust_remote_code=True,
mm_processor_kwargs={"num_crops": 4},
)

View File

@ -55,7 +55,7 @@ test_settings = [
# embedding model
TestSetting(
model="BAAI/bge-multilingual-gemma2",
model_args=["--task", "embedding"],
model_args=["--task", "embed"],
pp_size=1,
tp_size=1,
attn_backend="FLASHINFER",
@ -65,7 +65,7 @@ test_settings = [
# encoder-based embedding model (BERT)
TestSetting(
model="BAAI/bge-base-en-v1.5",
model_args=["--task", "embedding"],
model_args=["--task", "embed"],
pp_size=1,
tp_size=1,
attn_backend="XFORMERS",

View File

@ -37,7 +37,7 @@ def test_scheduler_schedule_simple_encoder_decoder():
num_seq_group = 4
max_model_len = 16
scheduler_config = SchedulerConfig(
task="generate",
"generate",
max_num_batched_tokens=64,
max_num_seqs=num_seq_group,
max_model_len=max_model_len,

View File

@ -27,7 +27,7 @@ TEST_IMAGE_URLS = [
def server():
args = [
"--task",
"embedding",
"embed",
"--dtype",
"bfloat16",
"--max-model-len",

View File

@ -54,7 +54,7 @@ def test_models(
hf_outputs = hf_model.encode(example_prompts)
with vllm_runner(model,
task="embedding",
task="embed",
dtype=dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:

View File

@ -35,9 +35,7 @@ def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str):
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
hf_outputs = hf_model.predict([text_pair]).tolist()
with vllm_runner(model_name,
task="embedding",
dtype=dtype,
with vllm_runner(model_name, task="score", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(text_pair[0], text_pair[1])
@ -58,9 +56,7 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
hf_outputs = hf_model.predict(text_pairs).tolist()
with vllm_runner(model_name,
task="embedding",
dtype=dtype,
with vllm_runner(model_name, task="score", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2)
@ -82,9 +78,7 @@ def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str):
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
hf_outputs = hf_model.predict(text_pairs).tolist()
with vllm_runner(model_name,
task="embedding",
dtype=dtype,
with vllm_runner(model_name, task="score", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2)

View File

@ -93,7 +93,7 @@ def _run_test(
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
task="embedding",
task="embed",
dtype=dtype,
enforce_eager=True,
max_model_len=8192) as vllm_model:

View File

@ -47,7 +47,7 @@ def _run_test(
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
task="embedding",
task="embed",
dtype=dtype,
max_model_len=4096,
enforce_eager=True) as vllm_model:

View File

@ -39,7 +39,7 @@ def _run_test(
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model, task="embedding", dtype=dtype,
with vllm_runner(model, task="embed", dtype=dtype,
enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.encode(input_texts, images=input_images)

View File

@ -7,11 +7,17 @@ from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
@pytest.mark.parametrize(("model_id", "expected_task"), [
("facebook/opt-125m", "generate"),
("intfloat/e5-mistral-7b-instruct", "embedding"),
])
def test_auto_task(model_id, expected_task):
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[
("facebook/opt-125m", "generate", "generate"),
("intfloat/e5-mistral-7b-instruct", "pooling", "embed"),
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
],
)
def test_auto_task(model_id, expected_runner_type, expected_task):
config = ModelConfig(
model_id,
task="auto",
@ -22,6 +28,7 @@ def test_auto_task(model_id, expected_task):
dtype="float16",
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task

View File

@ -45,13 +45,27 @@ else:
logger = init_logger(__name__)
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
TaskOption = Literal["auto", "generate", "embedding"]
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward"]
# "draft" is only used internally for speculative decoding
_Task = Literal["generate", "embedding", "draft"]
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
"draft"]
RunnerType = Literal["generate", "pooling", "draft"]
_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
"generate": ["generate"],
"pooling": ["embed", "classify", "score", "reward"],
"draft": ["draft"],
}
_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
task: runner
for runner, tasks in _RUNNER_TASKS.items() for task in tasks
}
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]]
@ -144,7 +158,7 @@ class ModelConfig:
def __init__(
self,
model: str,
task: Union[TaskOption, _Task],
task: Union[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
@ -295,6 +309,7 @@ class ModelConfig:
supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks
self.task: Final = task
self.pooler_config = self._init_pooler_config(override_pooler_config)
self._verify_quantization()
@ -323,7 +338,7 @@ class ModelConfig:
override_pooler_config: Optional["PoolerConfig"],
) -> Optional["PoolerConfig"]:
if self.task == "embedding":
if self.runner_type == "pooling":
user_config = override_pooler_config or PoolerConfig()
base_config = get_pooling_config(self.model, self.revision)
@ -357,60 +372,90 @@ class ModelConfig:
"either 'auto', 'slow' or 'mistral'.")
self.tokenizer_mode = tokenizer_mode
def _get_preferred_task(
self,
architectures: List[str],
supported_tasks: Set[_ResolvedTask],
) -> Optional[_ResolvedTask]:
model_id = self.model
if get_pooling_config(model_id, self.revision):
return "embed"
if ModelRegistry.is_cross_encoder_model(architectures):
return "score"
suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [
# Other models follow this pattern
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ForSequenceClassification", "classify"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("EmbeddingModel", "embed"),
("RewardModel", "reward"),
]
_, arch = ModelRegistry.inspect_model_cls(architectures)
for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix) and pref_task in supported_tasks:
return pref_task
return None
def _resolve_task(
self,
task_option: Union[TaskOption, _Task],
task_option: Union[TaskOption, Literal["draft"]],
hf_config: PretrainedConfig,
) -> Tuple[Set[_Task], _Task]:
) -> Tuple[Set[_ResolvedTask], _ResolvedTask]:
if task_option == "draft":
return {"draft"}, "draft"
architectures = getattr(hf_config, "architectures", [])
task_support: Dict[_Task, bool] = {
runner_support: Dict[RunnerType, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_pooling_model(architectures),
"pooling": ModelRegistry.is_pooling_model(architectures),
}
supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported
supported_runner_types_lst: List[RunnerType] = [
runner_type
for runner_type, is_supported in runner_support.items()
if is_supported
]
supported_tasks_lst: List[_ResolvedTask] = [
task for runner_type in supported_runner_types_lst
for task in _RUNNER_TASKS[runner_type]
]
supported_tasks = set(supported_tasks_lst)
if task_option == "auto":
selected_task = next(iter(supported_tasks_lst))
if len(supported_tasks) > 1:
suffix_to_preferred_task: List[Tuple[str, _Task]] = [
# Hardcode the models that are exceptions
("AquilaModel", "generate"),
("ChatGLMModel", "generate"),
# Other models follow this pattern
("ForCausalLM", "generate"),
("ForConditionalGeneration", "generate"),
("ChatModel", "generate"),
("LMHeadModel", "generate"),
("EmbeddingModel", "embedding"),
("RewardModel", "embedding"),
("ForSequenceClassification", "embedding"),
]
info, arch = ModelRegistry.inspect_model_cls(architectures)
for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix) and pref_task in supported_tasks:
selected_task = pref_task
break
else:
if (arch.endswith("Model")
and info.architecture.endswith("ForCausalLM")
and "embedding" in supported_tasks):
selected_task = "embedding"
if len(supported_tasks_lst) > 1:
preferred_task = self._get_preferred_task(
architectures, supported_tasks)
if preferred_task is not None:
selected_task = preferred_task
logger.info(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)
else:
# Aliases
if task_option == "embedding":
preferred_task = self._get_preferred_task(
architectures, supported_tasks)
if preferred_task != "embed":
msg = ("The 'embedding' task will be restricted to "
"embedding models in a future release. Please "
"pass `--task classify`, `--task score`, or "
"`--task reward` explicitly for other pooling "
"models.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
task_option = preferred_task or "embed"
if task_option not in supported_tasks:
msg = (
f"This model does not support the '{task_option}' task. "
@ -533,7 +578,7 @@ class ModelConfig:
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if self.task == "embedding":
if self.runner_type == "pooling":
self.use_async_output_proc = False
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
@ -750,6 +795,14 @@ class ModelConfig:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_cross_encoder_model(architectures)
@property
def supported_runner_types(self) -> Set[RunnerType]:
return {_TASK_RUNNER[task] for task in self.supported_tasks}
@property
def runner_type(self) -> RunnerType:
return _TASK_RUNNER[self.task]
class CacheConfig:
"""Configuration for the KV cache.
@ -1096,7 +1149,7 @@ class ParallelConfig:
class SchedulerConfig:
"""Scheduler configuration."""
task: str = "generate" # The task to use the model for.
runner_type: str = "generate" # The runner type to launch for the model.
# Maximum number of tokens to be processed in a single iteration.
max_num_batched_tokens: int = field(default=None) # type: ignore
@ -1164,11 +1217,11 @@ class SchedulerConfig:
# for higher throughput.
self.max_num_batched_tokens = max(self.max_model_len, 2048)
if self.task == "embedding":
# For embedding, choose specific value for higher throughput
if self.runner_type == "pooling":
# Choose specific value for higher throughput
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if self.is_multimodal_model:
# The value needs to be at least the number of multimodal tokens

View File

@ -337,7 +337,7 @@ class Scheduler:
self.lora_config = lora_config
version = "selfattn"
if (self.scheduler_config.task == "embedding"
if (self.scheduler_config.runner_type == "pooling"
or self.cache_config.is_attention_free):
version = "placeholder"

View File

@ -1066,7 +1066,7 @@ class EngineArgs:
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
and model_config.task != "embedding"):
and model_config.runner_type != "pooling"):
self.enable_chunked_prefill = True
logger.warning(
"Chunked prefill is enabled by default for models with "
@ -1083,7 +1083,8 @@ class EngineArgs:
"errors during the initial memory profiling phase, or result "
"in low performance due to small KV cache space. Consider "
"setting --max-model-len to a smaller value.", max_model_len)
elif self.enable_chunked_prefill and model_config.task == "embedding":
elif (self.enable_chunked_prefill
and model_config.runner_type == "pooling"):
msg = "Chunked prefill is not supported for embedding models"
raise ValueError(msg)
@ -1144,7 +1145,7 @@ class EngineArgs:
" please file an issue with detailed information.")
scheduler_config = SchedulerConfig(
task=model_config.task,
runner_type=model_config.runner_type,
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len,

View File

@ -288,7 +288,7 @@ class LLMEngine:
self.model_executor = executor_class(vllm_config=vllm_config, )
if self.model_config.task != "embedding":
if self.model_config.runner_type != "pooling":
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
@ -1123,7 +1123,7 @@ class LLMEngine:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
if self.model_config.task == "embedding":
if self.model_config.runner_type == "pooling":
self._process_sequence_group_outputs(seq_group, output)
else:
self.output_processor.process_prompt_logprob(seq_group, output)

View File

@ -381,19 +381,20 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
task = self.llm_engine.model_config.task
if task != "generate":
runner_type = self.llm_engine.model_config.runner_type
if runner_type != "generate":
messages = [
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration).",
]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "generate" in supported_tasks:
supported_runner_types = self.llm_engine.model_config \
.supported_runner_types
if "generate" in supported_runner_types:
messages.append(
"Your model supports the 'generate' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task generate`.")
"Your model supports the 'generate' runner, but is "
f"currently initialized for the '{runner_type}' runner. "
"Please initialize vLLM using `--task generate`.")
raise ValueError(" ".join(messages))
@ -793,16 +794,18 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
task = self.llm_engine.model_config.task
if task != "embedding":
messages = ["LLM.encode() is only supported for embedding models."]
runner_type = self.llm_engine.model_config.runner_type
if runner_type != "pooling":
messages = ["LLM.encode() is only supported for pooling models."]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
supported_runner_types = self.llm_engine.model_config \
.supported_runner_types
if "pooling" in supported_runner_types:
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")
"Your model supports the 'pooling' runner, but is "
f"currently initialized for the '{runner_type}' runner. "
"Please initialize vLLM using `--task embed`, "
"`--task classify`, `--task score` etc.")
raise ValueError(" ".join(messages))
@ -864,21 +867,23 @@ class LLM:
A list of ``PoolingRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
"""
task = self.llm_engine.model_config.task
if task != "embedding":
messages = ["LLM.score() is only supported for embedding models."]
runner_type = self.llm_engine.model_config.runner_type
if runner_type != "pooling":
messages = ["LLM.score() is only supported for pooling models."]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
supported_runner_types = self.llm_engine.model_config \
.supported_runner_types
if "pooling" in supported_runner_types:
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")
"Your model supports the 'pooling' runner, but is "
f"currently initialized for the '{runner_type}' runner. "
"Please initialize vLLM using `--task embed`, "
"`--task classify`, `--task score` etc.")
raise ValueError(" ".join(messages))
if not self.llm_engine.model_config.is_cross_encoder:
raise ValueError("Your model does not support the cross encoding")
raise ValueError("Your model does not support cross encoding")
tokenizer = self.llm_engine.get_tokenizer()

View File

@ -573,7 +573,7 @@ def init_app_state(
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None
) if model_config.runner_type == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
@ -582,7 +582,7 @@ def init_app_state(
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) if model_config.task == "generate" else None
) if model_config.runner_type == "generate" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
@ -590,13 +590,13 @@ def init_app_state(
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embedding" else None
) if model_config.runner_type == "pooling" else None
state.openai_serving_scores = OpenAIServingScores(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger
) if (model_config.task == "embedding" \
) if (model_config.runner_type == "pooling" \
and model_config.is_cross_encoder) else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,

View File

@ -224,7 +224,7 @@ async def main(args):
chat_template=None,
chat_template_content_format="auto",
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None
) if model_config.runner_type == "generate" else None
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
@ -232,7 +232,7 @@ async def main(args):
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
) if model_config.task == "embedding" else None
) if model_config.runner_type == "pooling" else None
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)

View File

@ -35,7 +35,7 @@ def get_model_architecture(
architectures = ["QuantMixtralForCausalLM"]
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embedding":
if model_config.runner_type == "pooling":
model_cls = as_embedding_model(model_cls)
return model_cls, arch

View File

@ -42,7 +42,7 @@ class EngineCore:
executor_class: Type[Executor],
usage_context: UsageContext,
):
assert vllm_config.model_config.task != "embedding"
assert vllm_config.model_config.runner_type != "pooling"
logger.info("Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)

View File

@ -163,7 +163,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
not in ["medusa", "mlp_speculator", "eagle"]) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
if self.model_config.task == "embedding":
if self.model_config.runner_type == "pooling":
ModelRunnerClass = CPUPoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = CPUEncoderDecoderModelRunner

View File

@ -75,7 +75,7 @@ class Worker(LocalOrDistributedWorkerBase):
else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_config.task == "embedding":
if model_config.runner_type == "pooling":
ModelRunnerClass = PoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner