[Misc] Split up pooling tasks (#10820)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
40766ca1b8
commit
8f10d5e393
@ -94,6 +94,8 @@ Documentation
|
|||||||
:caption: Models
|
:caption: Models
|
||||||
|
|
||||||
models/supported_models
|
models/supported_models
|
||||||
|
models/generative_models
|
||||||
|
models/pooling_models
|
||||||
models/adding_model
|
models/adding_model
|
||||||
models/enabling_multimodal_inputs
|
models/enabling_multimodal_inputs
|
||||||
|
|
||||||
|
146
docs/source/models/generative_models.rst
Normal file
146
docs/source/models/generative_models.rst
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
.. _generative_models:
|
||||||
|
|
||||||
|
Generative Models
|
||||||
|
=================
|
||||||
|
|
||||||
|
vLLM provides first-class support for generative models, which covers most of LLMs.
|
||||||
|
|
||||||
|
In vLLM, generative models implement the :class:`~vllm.model_executor.models.VllmModelForTextGeneration` interface.
|
||||||
|
Based on the final hidden states of the input, these models output log probabilities of the tokens to generate,
|
||||||
|
which are then passed through :class:`~vllm.model_executor.layers.Sampler` to obtain the final text.
|
||||||
|
|
||||||
|
Offline Inference
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
The :class:`~vllm.LLM` class provides various methods for offline inference.
|
||||||
|
See :ref:`Engine Arguments <engine_args>` for a list of options when initializing the model.
|
||||||
|
|
||||||
|
For generative models, the only supported :code:`task` option is :code:`"generate"`.
|
||||||
|
Usually, this is automatically inferred so you don't have to specify it.
|
||||||
|
|
||||||
|
``LLM.generate``
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The :class:`~vllm.LLM.generate` method is available to all generative models in vLLM.
|
||||||
|
It is similar to `its counterpart in HF Transformers <https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate>`__,
|
||||||
|
except that tokenization and detokenization are also performed automatically.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = LLM(model="facebook/opt-125m")
|
||||||
|
outputs = llm.generate("Hello, my name is")
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
You can optionally control the language generation by passing :class:`~vllm.SamplingParams`.
|
||||||
|
For example, you can use greedy sampling by setting :code:`temperature=0`:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = LLM(model="facebook/opt-125m")
|
||||||
|
params = SamplingParams(temperature=0)
|
||||||
|
outputs = llm.generate("Hello, my name is", params)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
A code example can be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_.
|
||||||
|
|
||||||
|
``LLM.beam_search``
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The :class:`~vllm.LLM.beam_search` method implements `beam search <https://huggingface.co/docs/transformers/en/generation_strategies#beam-search-decoding>`__ on top of :class:`~vllm.LLM.generate`.
|
||||||
|
For example, to search using 5 beams and output at most 50 tokens:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = LLM(model="facebook/opt-125m")
|
||||||
|
params = BeamSearchParams(beam_width=5, max_tokens=50)
|
||||||
|
outputs = llm.generate("Hello, my name is", params)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
``LLM.chat``
|
||||||
|
^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The :class:`~vllm.LLM.chat` method implements chat functionality on top of :class:`~vllm.LLM.generate`.
|
||||||
|
In particular, it accepts input similar to `OpenAI Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`__
|
||||||
|
and automatically applies the model's `chat template <https://huggingface.co/docs/transformers/en/chat_templating>`__ to format the prompt.
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
|
||||||
|
In general, only instruction-tuned models have a chat template.
|
||||||
|
Base models may perform poorly as they are not trained to respond to the chat conversation.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! How can I assist you today?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Write an essay about the importance of higher education.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
outputs = llm.chat(conversation)
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
A code example can be found in `examples/offline_inference_chat.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_chat.py>`_.
|
||||||
|
|
||||||
|
If the model doesn't have a chat template or you want to specify another one,
|
||||||
|
you can explicitly pass a chat template:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm.entrypoints.chat_utils import load_chat_template
|
||||||
|
|
||||||
|
# You can find a list of existing chat templates under `examples/`
|
||||||
|
custom_template = load_chat_template(chat_template="<path_to_template>")
|
||||||
|
print("Loaded chat template:", custom_template)
|
||||||
|
|
||||||
|
outputs = llm.chat(conversation, chat_template=custom_template)
|
||||||
|
|
||||||
|
Online Inference
|
||||||
|
----------------
|
||||||
|
|
||||||
|
Our `OpenAI Compatible Server <../serving/openai_compatible_server>`__ can be used for online inference.
|
||||||
|
Please click on the above link for more details on how to launch the server.
|
||||||
|
|
||||||
|
Completions API
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Our Completions API is similar to ``LLM.generate`` but only accepts text.
|
||||||
|
It is compatible with `OpenAI Completions API <https://platform.openai.com/docs/api-reference/completions>`__
|
||||||
|
so that you can use OpenAI client to interact with it.
|
||||||
|
A code example can be found in `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
|
||||||
|
|
||||||
|
Chat API
|
||||||
|
^^^^^^^^
|
||||||
|
|
||||||
|
Our Chat API is similar to ``LLM.chat``, accepting both text and :ref:`multi-modal inputs <multimodal_inputs>`.
|
||||||
|
It is compatible with `OpenAI Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`__
|
||||||
|
so that you can use OpenAI client to interact with it.
|
||||||
|
A code example can be found in `examples/openai_chat_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_completion_client.py>`_.
|
99
docs/source/models/pooling_models.rst
Normal file
99
docs/source/models/pooling_models.rst
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
.. _pooling_models:
|
||||||
|
|
||||||
|
Pooling Models
|
||||||
|
==============
|
||||||
|
|
||||||
|
vLLM also supports pooling models, including embedding, reranking and reward models.
|
||||||
|
|
||||||
|
In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface.
|
||||||
|
These models use a :class:`~vllm.model_executor.layers.Pooler` to aggregate the final hidden states of the input
|
||||||
|
before returning them.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
We currently support pooling models primarily as a matter of convenience.
|
||||||
|
As shown in the :ref:`Compatibility Matrix <compatibility_matrix>`, most vLLM features are not applicable to
|
||||||
|
pooling models as they only work on the generation or decode stage, so performance may not improve as much.
|
||||||
|
|
||||||
|
Offline Inference
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
The :class:`~vllm.LLM` class provides various methods for offline inference.
|
||||||
|
See :ref:`Engine Arguments <engine_args>` for a list of options when initializing the model.
|
||||||
|
|
||||||
|
For pooling models, we support the following :code:`task` options:
|
||||||
|
|
||||||
|
- Embedding (:code:`"embed"` / :code:`"embedding"`)
|
||||||
|
- Classification (:code:`"classify"`)
|
||||||
|
- Sentence Pair Scoring (:code:`"score"`)
|
||||||
|
- Reward Modeling (:code:`"reward"`)
|
||||||
|
|
||||||
|
The selected task determines the default :class:`~vllm.model_executor.layers.Pooler` that is used:
|
||||||
|
|
||||||
|
- Embedding: Extract only the hidden states corresponding to the last token, and apply normalization.
|
||||||
|
- Classification: Extract only the hidden states corresponding to the last token, and apply softmax.
|
||||||
|
- Sentence Pair Scoring: Extract only the hidden states corresponding to the last token, and apply softmax.
|
||||||
|
- Reward Modeling: Extract all of the hidden states and return them directly.
|
||||||
|
|
||||||
|
When loading `Sentence Transformers <https://huggingface.co/sentence-transformers>`__ models,
|
||||||
|
we attempt to override the default pooler based on its Sentence Transformers configuration file (:code:`modules.json`).
|
||||||
|
|
||||||
|
You can customize the model's pooling method via the :code:`override_pooler_config` option,
|
||||||
|
which takes priority over both the model's and Sentence Transformers's defaults.
|
||||||
|
|
||||||
|
``LLM.encode``
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM.
|
||||||
|
It returns the aggregated hidden states directly.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed")
|
||||||
|
outputs = llm.encode("Hello, my name is")
|
||||||
|
|
||||||
|
outputs = model.encode(prompts)
|
||||||
|
for output in outputs:
|
||||||
|
embeddings = output.outputs.embedding
|
||||||
|
print(f"Prompt: {prompt!r}, Embeddings (size={len(embeddings)}: {embeddings!r}")
|
||||||
|
|
||||||
|
A code example can be found in `examples/offline_inference_embedding.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_embedding.py>`_.
|
||||||
|
|
||||||
|
``LLM.score``
|
||||||
|
^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The :class:`~vllm.LLM.score` method outputs similarity scores between sentence pairs.
|
||||||
|
It is primarily designed for `cross-encoder models <https://www.sbert.net/examples/applications/cross-encoder/README.html>`__.
|
||||||
|
These types of models serve as rerankers between candidate query-document pairs in RAG systems.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
|
||||||
|
To handle RAG at a higher level, you should use integration frameworks such as `LangChain <https://github.com/langchain-ai/langchain>`_.
|
||||||
|
|
||||||
|
You can use `these tests <https://github.com/vllm-project/vllm/blob/main/tests/models/embedding/language/test_scoring.py>`_ as reference.
|
||||||
|
|
||||||
|
Online Inference
|
||||||
|
----------------
|
||||||
|
|
||||||
|
Our `OpenAI Compatible Server <../serving/openai_compatible_server>`__ can be used for online inference.
|
||||||
|
Please click on the above link for more details on how to launch the server.
|
||||||
|
|
||||||
|
Embeddings API
|
||||||
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Our Embeddings API is similar to ``LLM.encode``, accepting both text and :ref:`multi-modal inputs <multimodal_inputs>`.
|
||||||
|
|
||||||
|
The text-only API is compatible with `OpenAI Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`__
|
||||||
|
so that you can use OpenAI client to interact with it.
|
||||||
|
A code example can be found in `examples/openai_embedding_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_embedding_client.py>`_.
|
||||||
|
|
||||||
|
The multi-modal API is an extension of the `OpenAI Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`__
|
||||||
|
that incorporates `OpenAI Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`__,
|
||||||
|
so it is not part of the OpenAI standard. Please see :ref:`this page <multimodal_inputs>` for more details on how to use it.
|
||||||
|
|
||||||
|
Score API
|
||||||
|
^^^^^^^^^
|
||||||
|
|
||||||
|
Our Score API is similar to ``LLM.score``.
|
||||||
|
Please see `this page <../serving/openai_compatible_server.html#score-api-for-cross-encoder-models>`__ for more details on how to use it.
|
@ -3,11 +3,21 @@
|
|||||||
Supported Models
|
Supported Models
|
||||||
================
|
================
|
||||||
|
|
||||||
vLLM supports a variety of generative and embedding models from `HuggingFace (HF) Transformers <https://huggingface.co/models>`_.
|
vLLM supports generative and pooling models across various tasks.
|
||||||
This page lists the model architectures that are currently supported by vLLM.
|
If a model supports more than one task, you can set the task via the :code:`--task` argument.
|
||||||
|
|
||||||
|
For each task, we list the model architectures that have been implemented in vLLM.
|
||||||
Alongside each architecture, we include some popular models that use it.
|
Alongside each architecture, we include some popular models that use it.
|
||||||
|
|
||||||
For other models, you can check the :code:`config.json` file inside the model repository.
|
Loading a Model
|
||||||
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
HuggingFace Hub
|
||||||
|
+++++++++++++++
|
||||||
|
|
||||||
|
By default, vLLM loads models from `HuggingFace (HF) Hub <https://huggingface.co/models>`_.
|
||||||
|
|
||||||
|
To determine whether a given model is supported, you can check the :code:`config.json` file inside the HF repository.
|
||||||
If the :code:`"architectures"` field contains a model architecture listed below, then it should be supported in theory.
|
If the :code:`"architectures"` field contains a model architecture listed below, then it should be supported in theory.
|
||||||
|
|
||||||
.. tip::
|
.. tip::
|
||||||
@ -17,38 +27,57 @@ If the :code:`"architectures"` field contains a model architecture listed below,
|
|||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
|
|
||||||
llm = LLM(model=...) # Name or path of your model
|
# For generative models (task=generate) only
|
||||||
|
llm = LLM(model=..., task="generate") # Name or path of your model
|
||||||
output = llm.generate("Hello, my name is")
|
output = llm.generate("Hello, my name is")
|
||||||
print(output)
|
print(output)
|
||||||
|
|
||||||
If vLLM successfully generates text, it indicates that your model is supported.
|
# For pooling models (task={embed,classify,reward}) only
|
||||||
|
llm = LLM(model=..., task="embed") # Name or path of your model
|
||||||
|
output = llm.encode("Hello, my name is")
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
If vLLM successfully returns text (for generative models) or hidden states (for pooling models), it indicates that your model is supported.
|
||||||
|
|
||||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
|
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
|
||||||
for instructions on how to implement your model in vLLM.
|
for instructions on how to implement your model in vLLM.
|
||||||
Alternatively, you can `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ to request vLLM support.
|
Alternatively, you can `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ to request vLLM support.
|
||||||
|
|
||||||
.. note::
|
ModelScope
|
||||||
To use models from `ModelScope <https://www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:
|
++++++++++
|
||||||
|
|
||||||
.. code-block:: shell
|
To use models from `ModelScope <https://www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:
|
||||||
|
|
||||||
$ export VLLM_USE_MODELSCOPE=True
|
.. code-block:: shell
|
||||||
|
|
||||||
And use with :code:`trust_remote_code=True`.
|
$ export VLLM_USE_MODELSCOPE=True
|
||||||
|
|
||||||
.. code-block:: python
|
And use with :code:`trust_remote_code=True`.
|
||||||
|
|
||||||
from vllm import LLM
|
.. code-block:: python
|
||||||
|
|
||||||
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
|
from vllm import LLM
|
||||||
output = llm.generate("Hello, my name is")
|
|
||||||
print(output)
|
|
||||||
|
|
||||||
Text-only Language Models
|
llm = LLM(model=..., revision=..., task=..., trust_remote_code=True)
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
Text Generation
|
# For generative models (task=generate) only
|
||||||
---------------
|
output = llm.generate("Hello, my name is")
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
# For pooling models (task={embed,classify,reward}) only
|
||||||
|
output = llm.encode("Hello, my name is")
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
List of Text-only Language Models
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Generative Models
|
||||||
|
+++++++++++++++++
|
||||||
|
|
||||||
|
See :ref:`this page <generative_models>` for more information on how to use generative models.
|
||||||
|
|
||||||
|
Text Generation (``--task generate``)
|
||||||
|
-------------------------------------
|
||||||
|
|
||||||
.. list-table::
|
.. list-table::
|
||||||
:widths: 25 25 50 5 5
|
:widths: 25 25 50 5 5
|
||||||
@ -328,8 +357,24 @@ Text Generation
|
|||||||
.. note::
|
.. note::
|
||||||
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
||||||
|
|
||||||
Text Embedding
|
Pooling Models
|
||||||
--------------
|
++++++++++++++
|
||||||
|
|
||||||
|
See :ref:`this page <pooling_models>` for more information on how to use pooling models.
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
Since some model architectures support both generative and pooling tasks,
|
||||||
|
you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode.
|
||||||
|
|
||||||
|
Text Embedding (``--task embed``)
|
||||||
|
---------------------------------
|
||||||
|
|
||||||
|
Any text generation model can be converted into an embedding model by passing :code:`--task embed`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
To get the best results, you should use pooling models that are specifically trained as such.
|
||||||
|
|
||||||
|
The following table lists those that are tested in vLLM.
|
||||||
|
|
||||||
.. list-table::
|
.. list-table::
|
||||||
:widths: 25 25 50 5 5
|
:widths: 25 25 50 5 5
|
||||||
@ -371,13 +416,6 @@ Text Embedding
|
|||||||
-
|
-
|
||||||
-
|
-
|
||||||
|
|
||||||
.. important::
|
|
||||||
Some model architectures support both generation and embedding tasks.
|
|
||||||
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.
|
|
||||||
|
|
||||||
.. tip::
|
|
||||||
You can override the model's pooling method by passing :code:`--override-pooler-config`.
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
|
:code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
|
||||||
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.
|
You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`.
|
||||||
@ -389,8 +427,8 @@ Text Embedding
|
|||||||
On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention
|
On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention
|
||||||
despite being described otherwise on its model card.
|
despite being described otherwise on its model card.
|
||||||
|
|
||||||
Reward Modeling
|
Reward Modeling (``--task reward``)
|
||||||
---------------
|
-----------------------------------
|
||||||
|
|
||||||
.. list-table::
|
.. list-table::
|
||||||
:widths: 25 25 50 5 5
|
:widths: 25 25 50 5 5
|
||||||
@ -416,11 +454,8 @@ Reward Modeling
|
|||||||
For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
||||||
e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
||||||
|
|
||||||
.. note::
|
Classification (``--task classify``)
|
||||||
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
|
------------------------------------
|
||||||
|
|
||||||
Classification
|
|
||||||
---------------
|
|
||||||
|
|
||||||
.. list-table::
|
.. list-table::
|
||||||
:widths: 25 25 50 5 5
|
:widths: 25 25 50 5 5
|
||||||
@ -437,11 +472,8 @@ Classification
|
|||||||
- ✅︎
|
- ✅︎
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
|
||||||
.. note::
|
Sentence Pair Scoring (``--task score``)
|
||||||
As an interim measure, these models are supported in both offline and online inference via Embeddings API.
|
----------------------------------------
|
||||||
|
|
||||||
Sentence Pair Scoring
|
|
||||||
---------------------
|
|
||||||
|
|
||||||
.. list-table::
|
.. list-table::
|
||||||
:widths: 25 25 50 5 5
|
:widths: 25 25 50 5 5
|
||||||
@ -468,13 +500,10 @@ Sentence Pair Scoring
|
|||||||
-
|
-
|
||||||
-
|
-
|
||||||
|
|
||||||
.. note::
|
|
||||||
These models are supported in both offline and online inference via Score API.
|
|
||||||
|
|
||||||
.. _supported_mm_models:
|
.. _supported_mm_models:
|
||||||
|
|
||||||
Multimodal Language Models
|
List of Multimodal Language Models
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
The following modalities are supported depending on the model:
|
The following modalities are supported depending on the model:
|
||||||
|
|
||||||
@ -491,8 +520,15 @@ On the other hand, modalities separated by :code:`/` are mutually exclusive.
|
|||||||
|
|
||||||
- e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs.
|
- e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs.
|
||||||
|
|
||||||
Text Generation
|
See :ref:`this page <multimodal_inputs>` on how to pass multi-modal inputs to the model.
|
||||||
---------------
|
|
||||||
|
Generative Models
|
||||||
|
+++++++++++++++++
|
||||||
|
|
||||||
|
See :ref:`this page <generative_models>` for more information on how to use generative models.
|
||||||
|
|
||||||
|
Text Generation (``--task generate``)
|
||||||
|
-------------------------------------
|
||||||
|
|
||||||
.. list-table::
|
.. list-table::
|
||||||
:widths: 25 25 15 20 5 5 5
|
:widths: 25 25 15 20 5 5 5
|
||||||
@ -696,8 +732,24 @@ Text Generation
|
|||||||
The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
||||||
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
||||||
|
|
||||||
Multimodal Embedding
|
Pooling Models
|
||||||
--------------------
|
++++++++++++++
|
||||||
|
|
||||||
|
See :ref:`this page <pooling_models>` for more information on how to use pooling models.
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
Since some model architectures support both generative and pooling tasks,
|
||||||
|
you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode.
|
||||||
|
|
||||||
|
Text Embedding (``--task embed``)
|
||||||
|
---------------------------------
|
||||||
|
|
||||||
|
Any text generation model can be converted into an embedding model by passing :code:`--task embed`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
To get the best results, you should use pooling models that are specifically trained as such.
|
||||||
|
|
||||||
|
The following table lists those that are tested in vLLM.
|
||||||
|
|
||||||
.. list-table::
|
.. list-table::
|
||||||
:widths: 25 25 15 25 5 5
|
:widths: 25 25 15 25 5 5
|
||||||
@ -728,12 +780,7 @@ Multimodal Embedding
|
|||||||
-
|
-
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
|
||||||
.. important::
|
----
|
||||||
Some model architectures support both generation and embedding tasks.
|
|
||||||
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.
|
|
||||||
|
|
||||||
.. tip::
|
|
||||||
You can override the model's pooling method by passing :code:`--override-pooler-config`.
|
|
||||||
|
|
||||||
Model Support Policy
|
Model Support Policy
|
||||||
=====================
|
=====================
|
||||||
|
@ -39,13 +39,13 @@ Feature x Feature
|
|||||||
- :abbr:`prmpt adptr (Prompt Adapter)`
|
- :abbr:`prmpt adptr (Prompt Adapter)`
|
||||||
- :ref:`SD <spec_decode>`
|
- :ref:`SD <spec_decode>`
|
||||||
- CUDA graph
|
- CUDA graph
|
||||||
- :abbr:`emd (Embedding Models)`
|
- :abbr:`pooling (Pooling Models)`
|
||||||
- :abbr:`enc-dec (Encoder-Decoder Models)`
|
- :abbr:`enc-dec (Encoder-Decoder Models)`
|
||||||
- :abbr:`logP (Logprobs)`
|
- :abbr:`logP (Logprobs)`
|
||||||
- :abbr:`prmpt logP (Prompt Logprobs)`
|
- :abbr:`prmpt logP (Prompt Logprobs)`
|
||||||
- :abbr:`async output (Async Output Processing)`
|
- :abbr:`async output (Async Output Processing)`
|
||||||
- multi-step
|
- multi-step
|
||||||
- :abbr:`mm (Multimodal)`
|
- :abbr:`mm (Multimodal Inputs)`
|
||||||
- best-of
|
- best-of
|
||||||
- beam-search
|
- beam-search
|
||||||
- :abbr:`guided dec (Guided Decoding)`
|
- :abbr:`guided dec (Guided Decoding)`
|
||||||
@ -151,7 +151,7 @@ Feature x Feature
|
|||||||
-
|
-
|
||||||
-
|
-
|
||||||
-
|
-
|
||||||
* - :abbr:`emd (Embedding Models)`
|
* - :abbr:`pooling (Pooling Models)`
|
||||||
- ✗
|
- ✗
|
||||||
- ✗
|
- ✗
|
||||||
- ✗
|
- ✗
|
||||||
@ -253,7 +253,7 @@ Feature x Feature
|
|||||||
-
|
-
|
||||||
-
|
-
|
||||||
-
|
-
|
||||||
* - :abbr:`mm (Multimodal)`
|
* - :abbr:`mm (Multimodal Inputs)`
|
||||||
- ✅
|
- ✅
|
||||||
- `✗ <https://github.com/vllm-project/vllm/pull/8348>`__
|
- `✗ <https://github.com/vllm-project/vllm/pull/8348>`__
|
||||||
- `✗ <https://github.com/vllm-project/vllm/pull/7199>`__
|
- `✗ <https://github.com/vllm-project/vllm/pull/7199>`__
|
||||||
@ -386,7 +386,7 @@ Feature x Hardware
|
|||||||
- ✅
|
- ✅
|
||||||
- ✗
|
- ✗
|
||||||
- ✅
|
- ✅
|
||||||
* - :abbr:`emd (Embedding Models)`
|
* - :abbr:`pooling (Pooling Models)`
|
||||||
- ✅
|
- ✅
|
||||||
- ✅
|
- ✅
|
||||||
- ✅
|
- ✅
|
||||||
@ -402,7 +402,7 @@ Feature x Hardware
|
|||||||
- ✅
|
- ✅
|
||||||
- ✅
|
- ✅
|
||||||
- ✗
|
- ✗
|
||||||
* - :abbr:`mm (Multimodal)`
|
* - :abbr:`mm (Multimodal Inputs)`
|
||||||
- ✅
|
- ✅
|
||||||
- ✅
|
- ✅
|
||||||
- ✅
|
- ✅
|
||||||
|
@ -9,7 +9,12 @@ prompts = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
|
model = LLM(
|
||||||
|
model="intfloat/e5-mistral-7b-instruct",
|
||||||
|
task="embed", # You should pass task="embed" for embedding models
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Generate embedding. The output is a list of PoolingRequestOutputs.
|
# Generate embedding. The output is a list of PoolingRequestOutputs.
|
||||||
outputs = model.encode(prompts)
|
outputs = model.encode(prompts)
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
|
@ -59,7 +59,7 @@ def run_e5_v(query: Query):
|
|||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="royokong/e5-v",
|
model="royokong/e5-v",
|
||||||
task="embedding",
|
task="embed",
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ def run_vlm2vec(query: Query):
|
|||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="TIGER-Lab/VLM2Vec-Full",
|
model="TIGER-Lab/VLM2Vec-Full",
|
||||||
task="embedding",
|
task="embed",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
mm_processor_kwargs={"num_crops": 4},
|
mm_processor_kwargs={"num_crops": 4},
|
||||||
)
|
)
|
||||||
|
@ -55,7 +55,7 @@ test_settings = [
|
|||||||
# embedding model
|
# embedding model
|
||||||
TestSetting(
|
TestSetting(
|
||||||
model="BAAI/bge-multilingual-gemma2",
|
model="BAAI/bge-multilingual-gemma2",
|
||||||
model_args=["--task", "embedding"],
|
model_args=["--task", "embed"],
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
attn_backend="FLASHINFER",
|
attn_backend="FLASHINFER",
|
||||||
@ -65,7 +65,7 @@ test_settings = [
|
|||||||
# encoder-based embedding model (BERT)
|
# encoder-based embedding model (BERT)
|
||||||
TestSetting(
|
TestSetting(
|
||||||
model="BAAI/bge-base-en-v1.5",
|
model="BAAI/bge-base-en-v1.5",
|
||||||
model_args=["--task", "embedding"],
|
model_args=["--task", "embed"],
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
attn_backend="XFORMERS",
|
attn_backend="XFORMERS",
|
||||||
|
@ -37,7 +37,7 @@ def test_scheduler_schedule_simple_encoder_decoder():
|
|||||||
num_seq_group = 4
|
num_seq_group = 4
|
||||||
max_model_len = 16
|
max_model_len = 16
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
task="generate",
|
"generate",
|
||||||
max_num_batched_tokens=64,
|
max_num_batched_tokens=64,
|
||||||
max_num_seqs=num_seq_group,
|
max_num_seqs=num_seq_group,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
|
@ -27,7 +27,7 @@ TEST_IMAGE_URLS = [
|
|||||||
def server():
|
def server():
|
||||||
args = [
|
args = [
|
||||||
"--task",
|
"--task",
|
||||||
"embedding",
|
"embed",
|
||||||
"--dtype",
|
"--dtype",
|
||||||
"bfloat16",
|
"bfloat16",
|
||||||
"--max-model-len",
|
"--max-model-len",
|
||||||
|
@ -54,7 +54,7 @@ def test_models(
|
|||||||
hf_outputs = hf_model.encode(example_prompts)
|
hf_outputs = hf_model.encode(example_prompts)
|
||||||
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
task="embedding",
|
task="embed",
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
**vllm_extra_kwargs) as vllm_model:
|
**vllm_extra_kwargs) as vllm_model:
|
||||||
|
@ -35,9 +35,7 @@ def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str):
|
|||||||
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
||||||
hf_outputs = hf_model.predict([text_pair]).tolist()
|
hf_outputs = hf_model.predict([text_pair]).tolist()
|
||||||
|
|
||||||
with vllm_runner(model_name,
|
with vllm_runner(model_name, task="score", dtype=dtype,
|
||||||
task="embedding",
|
|
||||||
dtype=dtype,
|
|
||||||
max_model_len=None) as vllm_model:
|
max_model_len=None) as vllm_model:
|
||||||
vllm_outputs = vllm_model.score(text_pair[0], text_pair[1])
|
vllm_outputs = vllm_model.score(text_pair[0], text_pair[1])
|
||||||
|
|
||||||
@ -58,9 +56,7 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
|
|||||||
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
||||||
hf_outputs = hf_model.predict(text_pairs).tolist()
|
hf_outputs = hf_model.predict(text_pairs).tolist()
|
||||||
|
|
||||||
with vllm_runner(model_name,
|
with vllm_runner(model_name, task="score", dtype=dtype,
|
||||||
task="embedding",
|
|
||||||
dtype=dtype,
|
|
||||||
max_model_len=None) as vllm_model:
|
max_model_len=None) as vllm_model:
|
||||||
vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2)
|
vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2)
|
||||||
|
|
||||||
@ -82,9 +78,7 @@ def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str):
|
|||||||
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
||||||
hf_outputs = hf_model.predict(text_pairs).tolist()
|
hf_outputs = hf_model.predict(text_pairs).tolist()
|
||||||
|
|
||||||
with vllm_runner(model_name,
|
with vllm_runner(model_name, task="score", dtype=dtype,
|
||||||
task="embedding",
|
|
||||||
dtype=dtype,
|
|
||||||
max_model_len=None) as vllm_model:
|
max_model_len=None) as vllm_model:
|
||||||
vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2)
|
vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2)
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ def _run_test(
|
|||||||
# if we run HF first, the cuda initialization will be done and it
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
# will hurt multiprocessing backend with fork method (the default method).
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
task="embedding",
|
task="embed",
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
max_model_len=8192) as vllm_model:
|
max_model_len=8192) as vllm_model:
|
||||||
|
@ -47,7 +47,7 @@ def _run_test(
|
|||||||
# if we run HF first, the cuda initialization will be done and it
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
# will hurt multiprocessing backend with fork method (the default method).
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
task="embedding",
|
task="embed",
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
enforce_eager=True) as vllm_model:
|
enforce_eager=True) as vllm_model:
|
||||||
|
@ -39,7 +39,7 @@ def _run_test(
|
|||||||
# vLLM needs a fresh new process without cuda initialization.
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
# if we run HF first, the cuda initialization will be done and it
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
# will hurt multiprocessing backend with fork method (the default method).
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
with vllm_runner(model, task="embedding", dtype=dtype,
|
with vllm_runner(model, task="embed", dtype=dtype,
|
||||||
enforce_eager=True) as vllm_model:
|
enforce_eager=True) as vllm_model:
|
||||||
vllm_outputs = vllm_model.encode(input_texts, images=input_images)
|
vllm_outputs = vllm_model.encode(input_texts, images=input_images)
|
||||||
|
|
||||||
|
@ -7,11 +7,17 @@ from vllm.model_executor.layers.pooler import PoolingType
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("model_id", "expected_task"), [
|
@pytest.mark.parametrize(
|
||||||
("facebook/opt-125m", "generate"),
|
("model_id", "expected_runner_type", "expected_task"),
|
||||||
("intfloat/e5-mistral-7b-instruct", "embedding"),
|
[
|
||||||
])
|
("facebook/opt-125m", "generate", "generate"),
|
||||||
def test_auto_task(model_id, expected_task):
|
("intfloat/e5-mistral-7b-instruct", "pooling", "embed"),
|
||||||
|
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
|
||||||
|
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"),
|
||||||
|
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_auto_task(model_id, expected_runner_type, expected_task):
|
||||||
config = ModelConfig(
|
config = ModelConfig(
|
||||||
model_id,
|
model_id,
|
||||||
task="auto",
|
task="auto",
|
||||||
@ -22,6 +28,7 @@ def test_auto_task(model_id, expected_task):
|
|||||||
dtype="float16",
|
dtype="float16",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert config.runner_type == expected_runner_type
|
||||||
assert config.task == expected_task
|
assert config.task == expected_task
|
||||||
|
|
||||||
|
|
||||||
|
137
vllm/config.py
137
vllm/config.py
@ -45,13 +45,27 @@ else:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||||
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
||||||
|
|
||||||
TaskOption = Literal["auto", "generate", "embedding"]
|
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||||
|
"score", "reward"]
|
||||||
|
|
||||||
# "draft" is only used internally for speculative decoding
|
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
|
||||||
_Task = Literal["generate", "embedding", "draft"]
|
"draft"]
|
||||||
|
|
||||||
|
RunnerType = Literal["generate", "pooling", "draft"]
|
||||||
|
|
||||||
|
_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
|
||||||
|
"generate": ["generate"],
|
||||||
|
"pooling": ["embed", "classify", "score", "reward"],
|
||||||
|
"draft": ["draft"],
|
||||||
|
}
|
||||||
|
|
||||||
|
_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
|
||||||
|
task: runner
|
||||||
|
for runner, tasks in _RUNNER_TASKS.items() for task in tasks
|
||||||
|
}
|
||||||
|
|
||||||
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
|
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
|
||||||
PretrainedConfig]]
|
PretrainedConfig]]
|
||||||
@ -144,7 +158,7 @@ class ModelConfig:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
task: Union[TaskOption, _Task],
|
task: Union[TaskOption, Literal["draft"]],
|
||||||
tokenizer: str,
|
tokenizer: str,
|
||||||
tokenizer_mode: str,
|
tokenizer_mode: str,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
@ -295,6 +309,7 @@ class ModelConfig:
|
|||||||
supported_tasks, task = self._resolve_task(task, self.hf_config)
|
supported_tasks, task = self._resolve_task(task, self.hf_config)
|
||||||
self.supported_tasks = supported_tasks
|
self.supported_tasks = supported_tasks
|
||||||
self.task: Final = task
|
self.task: Final = task
|
||||||
|
|
||||||
self.pooler_config = self._init_pooler_config(override_pooler_config)
|
self.pooler_config = self._init_pooler_config(override_pooler_config)
|
||||||
|
|
||||||
self._verify_quantization()
|
self._verify_quantization()
|
||||||
@ -323,7 +338,7 @@ class ModelConfig:
|
|||||||
override_pooler_config: Optional["PoolerConfig"],
|
override_pooler_config: Optional["PoolerConfig"],
|
||||||
) -> Optional["PoolerConfig"]:
|
) -> Optional["PoolerConfig"]:
|
||||||
|
|
||||||
if self.task == "embedding":
|
if self.runner_type == "pooling":
|
||||||
user_config = override_pooler_config or PoolerConfig()
|
user_config = override_pooler_config or PoolerConfig()
|
||||||
|
|
||||||
base_config = get_pooling_config(self.model, self.revision)
|
base_config = get_pooling_config(self.model, self.revision)
|
||||||
@ -357,60 +372,90 @@ class ModelConfig:
|
|||||||
"either 'auto', 'slow' or 'mistral'.")
|
"either 'auto', 'slow' or 'mistral'.")
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
|
def _get_preferred_task(
|
||||||
|
self,
|
||||||
|
architectures: List[str],
|
||||||
|
supported_tasks: Set[_ResolvedTask],
|
||||||
|
) -> Optional[_ResolvedTask]:
|
||||||
|
model_id = self.model
|
||||||
|
if get_pooling_config(model_id, self.revision):
|
||||||
|
return "embed"
|
||||||
|
if ModelRegistry.is_cross_encoder_model(architectures):
|
||||||
|
return "score"
|
||||||
|
|
||||||
|
suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [
|
||||||
|
# Other models follow this pattern
|
||||||
|
("ForCausalLM", "generate"),
|
||||||
|
("ForConditionalGeneration", "generate"),
|
||||||
|
("ForSequenceClassification", "classify"),
|
||||||
|
("ChatModel", "generate"),
|
||||||
|
("LMHeadModel", "generate"),
|
||||||
|
("EmbeddingModel", "embed"),
|
||||||
|
("RewardModel", "reward"),
|
||||||
|
]
|
||||||
|
_, arch = ModelRegistry.inspect_model_cls(architectures)
|
||||||
|
|
||||||
|
for suffix, pref_task in suffix_to_preferred_task:
|
||||||
|
if arch.endswith(suffix) and pref_task in supported_tasks:
|
||||||
|
return pref_task
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def _resolve_task(
|
def _resolve_task(
|
||||||
self,
|
self,
|
||||||
task_option: Union[TaskOption, _Task],
|
task_option: Union[TaskOption, Literal["draft"]],
|
||||||
hf_config: PretrainedConfig,
|
hf_config: PretrainedConfig,
|
||||||
) -> Tuple[Set[_Task], _Task]:
|
) -> Tuple[Set[_ResolvedTask], _ResolvedTask]:
|
||||||
if task_option == "draft":
|
if task_option == "draft":
|
||||||
return {"draft"}, "draft"
|
return {"draft"}, "draft"
|
||||||
|
|
||||||
architectures = getattr(hf_config, "architectures", [])
|
architectures = getattr(hf_config, "architectures", [])
|
||||||
|
|
||||||
task_support: Dict[_Task, bool] = {
|
runner_support: Dict[RunnerType, bool] = {
|
||||||
# NOTE: Listed from highest to lowest priority,
|
# NOTE: Listed from highest to lowest priority,
|
||||||
# in case the model supports multiple of them
|
# in case the model supports multiple of them
|
||||||
"generate": ModelRegistry.is_text_generation_model(architectures),
|
"generate": ModelRegistry.is_text_generation_model(architectures),
|
||||||
"embedding": ModelRegistry.is_pooling_model(architectures),
|
"pooling": ModelRegistry.is_pooling_model(architectures),
|
||||||
}
|
}
|
||||||
supported_tasks_lst: List[_Task] = [
|
supported_runner_types_lst: List[RunnerType] = [
|
||||||
task for task, is_supported in task_support.items() if is_supported
|
runner_type
|
||||||
|
for runner_type, is_supported in runner_support.items()
|
||||||
|
if is_supported
|
||||||
|
]
|
||||||
|
|
||||||
|
supported_tasks_lst: List[_ResolvedTask] = [
|
||||||
|
task for runner_type in supported_runner_types_lst
|
||||||
|
for task in _RUNNER_TASKS[runner_type]
|
||||||
]
|
]
|
||||||
supported_tasks = set(supported_tasks_lst)
|
supported_tasks = set(supported_tasks_lst)
|
||||||
|
|
||||||
if task_option == "auto":
|
if task_option == "auto":
|
||||||
selected_task = next(iter(supported_tasks_lst))
|
selected_task = next(iter(supported_tasks_lst))
|
||||||
|
|
||||||
if len(supported_tasks) > 1:
|
if len(supported_tasks_lst) > 1:
|
||||||
suffix_to_preferred_task: List[Tuple[str, _Task]] = [
|
preferred_task = self._get_preferred_task(
|
||||||
# Hardcode the models that are exceptions
|
architectures, supported_tasks)
|
||||||
("AquilaModel", "generate"),
|
if preferred_task is not None:
|
||||||
("ChatGLMModel", "generate"),
|
selected_task = preferred_task
|
||||||
# Other models follow this pattern
|
|
||||||
("ForCausalLM", "generate"),
|
|
||||||
("ForConditionalGeneration", "generate"),
|
|
||||||
("ChatModel", "generate"),
|
|
||||||
("LMHeadModel", "generate"),
|
|
||||||
("EmbeddingModel", "embedding"),
|
|
||||||
("RewardModel", "embedding"),
|
|
||||||
("ForSequenceClassification", "embedding"),
|
|
||||||
]
|
|
||||||
info, arch = ModelRegistry.inspect_model_cls(architectures)
|
|
||||||
|
|
||||||
for suffix, pref_task in suffix_to_preferred_task:
|
|
||||||
if arch.endswith(suffix) and pref_task in supported_tasks:
|
|
||||||
selected_task = pref_task
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
if (arch.endswith("Model")
|
|
||||||
and info.architecture.endswith("ForCausalLM")
|
|
||||||
and "embedding" in supported_tasks):
|
|
||||||
selected_task = "embedding"
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"This model supports multiple tasks: %s. "
|
"This model supports multiple tasks: %s. "
|
||||||
"Defaulting to '%s'.", supported_tasks, selected_task)
|
"Defaulting to '%s'.", supported_tasks, selected_task)
|
||||||
else:
|
else:
|
||||||
|
# Aliases
|
||||||
|
if task_option == "embedding":
|
||||||
|
preferred_task = self._get_preferred_task(
|
||||||
|
architectures, supported_tasks)
|
||||||
|
if preferred_task != "embed":
|
||||||
|
msg = ("The 'embedding' task will be restricted to "
|
||||||
|
"embedding models in a future release. Please "
|
||||||
|
"pass `--task classify`, `--task score`, or "
|
||||||
|
"`--task reward` explicitly for other pooling "
|
||||||
|
"models.")
|
||||||
|
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||||
|
|
||||||
|
task_option = preferred_task or "embed"
|
||||||
|
|
||||||
if task_option not in supported_tasks:
|
if task_option not in supported_tasks:
|
||||||
msg = (
|
msg = (
|
||||||
f"This model does not support the '{task_option}' task. "
|
f"This model does not support the '{task_option}' task. "
|
||||||
@ -533,7 +578,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
# Async postprocessor is not necessary with embedding mode
|
# Async postprocessor is not necessary with embedding mode
|
||||||
# since there is no token generation
|
# since there is no token generation
|
||||||
if self.task == "embedding":
|
if self.runner_type == "pooling":
|
||||||
self.use_async_output_proc = False
|
self.use_async_output_proc = False
|
||||||
|
|
||||||
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
@ -750,6 +795,14 @@ class ModelConfig:
|
|||||||
architectures = getattr(self.hf_config, "architectures", [])
|
architectures = getattr(self.hf_config, "architectures", [])
|
||||||
return ModelRegistry.is_cross_encoder_model(architectures)
|
return ModelRegistry.is_cross_encoder_model(architectures)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_runner_types(self) -> Set[RunnerType]:
|
||||||
|
return {_TASK_RUNNER[task] for task in self.supported_tasks}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def runner_type(self) -> RunnerType:
|
||||||
|
return _TASK_RUNNER[self.task]
|
||||||
|
|
||||||
|
|
||||||
class CacheConfig:
|
class CacheConfig:
|
||||||
"""Configuration for the KV cache.
|
"""Configuration for the KV cache.
|
||||||
@ -1096,7 +1149,7 @@ class ParallelConfig:
|
|||||||
class SchedulerConfig:
|
class SchedulerConfig:
|
||||||
"""Scheduler configuration."""
|
"""Scheduler configuration."""
|
||||||
|
|
||||||
task: str = "generate" # The task to use the model for.
|
runner_type: str = "generate" # The runner type to launch for the model.
|
||||||
|
|
||||||
# Maximum number of tokens to be processed in a single iteration.
|
# Maximum number of tokens to be processed in a single iteration.
|
||||||
max_num_batched_tokens: int = field(default=None) # type: ignore
|
max_num_batched_tokens: int = field(default=None) # type: ignore
|
||||||
@ -1164,11 +1217,11 @@ class SchedulerConfig:
|
|||||||
# for higher throughput.
|
# for higher throughput.
|
||||||
self.max_num_batched_tokens = max(self.max_model_len, 2048)
|
self.max_num_batched_tokens = max(self.max_model_len, 2048)
|
||||||
|
|
||||||
if self.task == "embedding":
|
if self.runner_type == "pooling":
|
||||||
# For embedding, choose specific value for higher throughput
|
# Choose specific value for higher throughput
|
||||||
self.max_num_batched_tokens = max(
|
self.max_num_batched_tokens = max(
|
||||||
self.max_num_batched_tokens,
|
self.max_num_batched_tokens,
|
||||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
|
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||||
)
|
)
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
# The value needs to be at least the number of multimodal tokens
|
# The value needs to be at least the number of multimodal tokens
|
||||||
|
@ -337,7 +337,7 @@ class Scheduler:
|
|||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
|
||||||
version = "selfattn"
|
version = "selfattn"
|
||||||
if (self.scheduler_config.task == "embedding"
|
if (self.scheduler_config.runner_type == "pooling"
|
||||||
or self.cache_config.is_attention_free):
|
or self.cache_config.is_attention_free):
|
||||||
version = "placeholder"
|
version = "placeholder"
|
||||||
|
|
||||||
|
@ -1066,7 +1066,7 @@ class EngineArgs:
|
|||||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||||
and not self.enable_lora
|
and not self.enable_lora
|
||||||
and not self.enable_prompt_adapter
|
and not self.enable_prompt_adapter
|
||||||
and model_config.task != "embedding"):
|
and model_config.runner_type != "pooling"):
|
||||||
self.enable_chunked_prefill = True
|
self.enable_chunked_prefill = True
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Chunked prefill is enabled by default for models with "
|
"Chunked prefill is enabled by default for models with "
|
||||||
@ -1083,7 +1083,8 @@ class EngineArgs:
|
|||||||
"errors during the initial memory profiling phase, or result "
|
"errors during the initial memory profiling phase, or result "
|
||||||
"in low performance due to small KV cache space. Consider "
|
"in low performance due to small KV cache space. Consider "
|
||||||
"setting --max-model-len to a smaller value.", max_model_len)
|
"setting --max-model-len to a smaller value.", max_model_len)
|
||||||
elif self.enable_chunked_prefill and model_config.task == "embedding":
|
elif (self.enable_chunked_prefill
|
||||||
|
and model_config.runner_type == "pooling"):
|
||||||
msg = "Chunked prefill is not supported for embedding models"
|
msg = "Chunked prefill is not supported for embedding models"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
@ -1144,7 +1145,7 @@ class EngineArgs:
|
|||||||
" please file an issue with detailed information.")
|
" please file an issue with detailed information.")
|
||||||
|
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
task=model_config.task,
|
runner_type=model_config.runner_type,
|
||||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
max_num_seqs=self.max_num_seqs,
|
max_num_seqs=self.max_num_seqs,
|
||||||
max_model_len=model_config.max_model_len,
|
max_model_len=model_config.max_model_len,
|
||||||
|
@ -288,7 +288,7 @@ class LLMEngine:
|
|||||||
|
|
||||||
self.model_executor = executor_class(vllm_config=vllm_config, )
|
self.model_executor = executor_class(vllm_config=vllm_config, )
|
||||||
|
|
||||||
if self.model_config.task != "embedding":
|
if self.model_config.runner_type != "pooling":
|
||||||
self._initialize_kv_caches()
|
self._initialize_kv_caches()
|
||||||
|
|
||||||
# If usage stat is enabled, collect relevant info.
|
# If usage stat is enabled, collect relevant info.
|
||||||
@ -1123,7 +1123,7 @@ class LLMEngine:
|
|||||||
seq_group.metrics.model_execute_time = (
|
seq_group.metrics.model_execute_time = (
|
||||||
o.model_execute_time)
|
o.model_execute_time)
|
||||||
|
|
||||||
if self.model_config.task == "embedding":
|
if self.model_config.runner_type == "pooling":
|
||||||
self._process_sequence_group_outputs(seq_group, output)
|
self._process_sequence_group_outputs(seq_group, output)
|
||||||
else:
|
else:
|
||||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||||
|
@ -381,19 +381,20 @@ class LLM:
|
|||||||
considered legacy and may be deprecated in the future. You should
|
considered legacy and may be deprecated in the future. You should
|
||||||
instead pass them via the ``inputs`` parameter.
|
instead pass them via the ``inputs`` parameter.
|
||||||
"""
|
"""
|
||||||
task = self.llm_engine.model_config.task
|
runner_type = self.llm_engine.model_config.runner_type
|
||||||
if task != "generate":
|
if runner_type != "generate":
|
||||||
messages = [
|
messages = [
|
||||||
"LLM.generate() is only supported for (conditional) generation "
|
"LLM.generate() is only supported for (conditional) generation "
|
||||||
"models (XForCausalLM, XForConditionalGeneration).",
|
"models (XForCausalLM, XForConditionalGeneration).",
|
||||||
]
|
]
|
||||||
|
|
||||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
supported_runner_types = self.llm_engine.model_config \
|
||||||
if "generate" in supported_tasks:
|
.supported_runner_types
|
||||||
|
if "generate" in supported_runner_types:
|
||||||
messages.append(
|
messages.append(
|
||||||
"Your model supports the 'generate' task, but is "
|
"Your model supports the 'generate' runner, but is "
|
||||||
f"currently initialized for the '{task}' task. Please "
|
f"currently initialized for the '{runner_type}' runner. "
|
||||||
"initialize the model using `--task generate`.")
|
"Please initialize vLLM using `--task generate`.")
|
||||||
|
|
||||||
raise ValueError(" ".join(messages))
|
raise ValueError(" ".join(messages))
|
||||||
|
|
||||||
@ -793,16 +794,18 @@ class LLM:
|
|||||||
considered legacy and may be deprecated in the future. You should
|
considered legacy and may be deprecated in the future. You should
|
||||||
instead pass them via the ``inputs`` parameter.
|
instead pass them via the ``inputs`` parameter.
|
||||||
"""
|
"""
|
||||||
task = self.llm_engine.model_config.task
|
runner_type = self.llm_engine.model_config.runner_type
|
||||||
if task != "embedding":
|
if runner_type != "pooling":
|
||||||
messages = ["LLM.encode() is only supported for embedding models."]
|
messages = ["LLM.encode() is only supported for pooling models."]
|
||||||
|
|
||||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
supported_runner_types = self.llm_engine.model_config \
|
||||||
if "embedding" in supported_tasks:
|
.supported_runner_types
|
||||||
|
if "pooling" in supported_runner_types:
|
||||||
messages.append(
|
messages.append(
|
||||||
"Your model supports the 'embedding' task, but is "
|
"Your model supports the 'pooling' runner, but is "
|
||||||
f"currently initialized for the '{task}' task. Please "
|
f"currently initialized for the '{runner_type}' runner. "
|
||||||
"initialize the model using `--task embedding`.")
|
"Please initialize vLLM using `--task embed`, "
|
||||||
|
"`--task classify`, `--task score` etc.")
|
||||||
|
|
||||||
raise ValueError(" ".join(messages))
|
raise ValueError(" ".join(messages))
|
||||||
|
|
||||||
@ -864,21 +867,23 @@ class LLM:
|
|||||||
A list of ``PoolingRequestOutput`` objects containing the
|
A list of ``PoolingRequestOutput`` objects containing the
|
||||||
generated scores in the same order as the input prompts.
|
generated scores in the same order as the input prompts.
|
||||||
"""
|
"""
|
||||||
task = self.llm_engine.model_config.task
|
runner_type = self.llm_engine.model_config.runner_type
|
||||||
if task != "embedding":
|
if runner_type != "pooling":
|
||||||
messages = ["LLM.score() is only supported for embedding models."]
|
messages = ["LLM.score() is only supported for pooling models."]
|
||||||
|
|
||||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
supported_runner_types = self.llm_engine.model_config \
|
||||||
if "embedding" in supported_tasks:
|
.supported_runner_types
|
||||||
|
if "pooling" in supported_runner_types:
|
||||||
messages.append(
|
messages.append(
|
||||||
"Your model supports the 'embedding' task, but is "
|
"Your model supports the 'pooling' runner, but is "
|
||||||
f"currently initialized for the '{task}' task. Please "
|
f"currently initialized for the '{runner_type}' runner. "
|
||||||
"initialize the model using `--task embedding`.")
|
"Please initialize vLLM using `--task embed`, "
|
||||||
|
"`--task classify`, `--task score` etc.")
|
||||||
|
|
||||||
raise ValueError(" ".join(messages))
|
raise ValueError(" ".join(messages))
|
||||||
|
|
||||||
if not self.llm_engine.model_config.is_cross_encoder:
|
if not self.llm_engine.model_config.is_cross_encoder:
|
||||||
raise ValueError("Your model does not support the cross encoding")
|
raise ValueError("Your model does not support cross encoding")
|
||||||
|
|
||||||
tokenizer = self.llm_engine.get_tokenizer()
|
tokenizer = self.llm_engine.get_tokenizer()
|
||||||
|
|
||||||
|
@ -573,7 +573,7 @@ def init_app_state(
|
|||||||
enable_auto_tools=args.enable_auto_tool_choice,
|
enable_auto_tools=args.enable_auto_tool_choice,
|
||||||
tool_parser=args.tool_call_parser,
|
tool_parser=args.tool_call_parser,
|
||||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||||
) if model_config.task == "generate" else None
|
) if model_config.runner_type == "generate" else None
|
||||||
state.openai_serving_completion = OpenAIServingCompletion(
|
state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
@ -582,7 +582,7 @@ def init_app_state(
|
|||||||
prompt_adapters=args.prompt_adapters,
|
prompt_adapters=args.prompt_adapters,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
) if model_config.task == "generate" else None
|
) if model_config.runner_type == "generate" else None
|
||||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
@ -590,13 +590,13 @@ def init_app_state(
|
|||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
) if model_config.task == "embedding" else None
|
) if model_config.runner_type == "pooling" else None
|
||||||
state.openai_serving_scores = OpenAIServingScores(
|
state.openai_serving_scores = OpenAIServingScores(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
base_model_paths,
|
base_model_paths,
|
||||||
request_logger=request_logger
|
request_logger=request_logger
|
||||||
) if (model_config.task == "embedding" \
|
) if (model_config.runner_type == "pooling" \
|
||||||
and model_config.is_cross_encoder) else None
|
and model_config.is_cross_encoder) else None
|
||||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
engine_client,
|
engine_client,
|
||||||
|
@ -224,7 +224,7 @@ async def main(args):
|
|||||||
chat_template=None,
|
chat_template=None,
|
||||||
chat_template_content_format="auto",
|
chat_template_content_format="auto",
|
||||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||||
) if model_config.task == "generate" else None
|
) if model_config.runner_type == "generate" else None
|
||||||
openai_serving_embedding = OpenAIServingEmbedding(
|
openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
engine,
|
engine,
|
||||||
model_config,
|
model_config,
|
||||||
@ -232,7 +232,7 @@ async def main(args):
|
|||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
chat_template_content_format="auto",
|
chat_template_content_format="auto",
|
||||||
) if model_config.task == "embedding" else None
|
) if model_config.runner_type == "pooling" else None
|
||||||
|
|
||||||
tracker = BatchProgressTracker()
|
tracker = BatchProgressTracker()
|
||||||
logger.info("Reading batch from %s...", args.input_file)
|
logger.info("Reading batch from %s...", args.input_file)
|
||||||
|
@ -35,7 +35,7 @@ def get_model_architecture(
|
|||||||
architectures = ["QuantMixtralForCausalLM"]
|
architectures = ["QuantMixtralForCausalLM"]
|
||||||
|
|
||||||
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
|
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
|
||||||
if model_config.task == "embedding":
|
if model_config.runner_type == "pooling":
|
||||||
model_cls = as_embedding_model(model_cls)
|
model_cls = as_embedding_model(model_cls)
|
||||||
|
|
||||||
return model_cls, arch
|
return model_cls, arch
|
||||||
|
@ -42,7 +42,7 @@ class EngineCore:
|
|||||||
executor_class: Type[Executor],
|
executor_class: Type[Executor],
|
||||||
usage_context: UsageContext,
|
usage_context: UsageContext,
|
||||||
):
|
):
|
||||||
assert vllm_config.model_config.task != "embedding"
|
assert vllm_config.model_config.runner_type != "pooling"
|
||||||
|
|
||||||
logger.info("Initializing an LLM engine (v%s) with config: %s",
|
logger.info("Initializing an LLM engine (v%s) with config: %s",
|
||||||
VLLM_VERSION, vllm_config)
|
VLLM_VERSION, vllm_config)
|
||||||
|
@ -163,7 +163,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
not in ["medusa", "mlp_speculator", "eagle"]) \
|
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||||
else {"return_hidden_states": True}
|
else {"return_hidden_states": True}
|
||||||
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
|
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
|
||||||
if self.model_config.task == "embedding":
|
if self.model_config.runner_type == "pooling":
|
||||||
ModelRunnerClass = CPUPoolingModelRunner
|
ModelRunnerClass = CPUPoolingModelRunner
|
||||||
elif self.model_config.is_encoder_decoder:
|
elif self.model_config.is_encoder_decoder:
|
||||||
ModelRunnerClass = CPUEncoderDecoderModelRunner
|
ModelRunnerClass = CPUEncoderDecoderModelRunner
|
||||||
|
@ -75,7 +75,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
else {"return_hidden_states": True}
|
else {"return_hidden_states": True}
|
||||||
|
|
||||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||||
if model_config.task == "embedding":
|
if model_config.runner_type == "pooling":
|
||||||
ModelRunnerClass = PoolingModelRunner
|
ModelRunnerClass = PoolingModelRunner
|
||||||
elif self.model_config.is_encoder_decoder:
|
elif self.model_config.is_encoder_decoder:
|
||||||
ModelRunnerClass = EncoderDecoderModelRunner
|
ModelRunnerClass = EncoderDecoderModelRunner
|
||||||
|
Loading…
x
Reference in New Issue
Block a user