[Frontend] Chat-based Embeddings API (#9759)

This commit is contained in:
Cyrus Leung 2024-11-01 16:13:35 +08:00 committed by GitHub
parent d3aa2a8b2f
commit 06386a64dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 846 additions and 408 deletions

View File

@ -13,5 +13,7 @@ torch
py-cpuinfo py-cpuinfo
transformers transformers
mistral_common >= 1.3.4 mistral_common >= 1.3.4
aiohttp
starlette
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args

View File

@ -96,7 +96,6 @@ def setup(app):
# Mock out external dependencies here, otherwise the autodoc pages may be blank. # Mock out external dependencies here, otherwise the autodoc pages may be blank.
autodoc_mock_imports = [ autodoc_mock_imports = [
"aiohttp",
"compressed_tensors", "compressed_tensors",
"cpuinfo", "cpuinfo",
"cv2", "cv2",
@ -143,6 +142,7 @@ intersphinx_mapping = {
"python": ("https://docs.python.org/3", None), "python": ("https://docs.python.org/3", None),
"typing_extensions": "typing_extensions":
("https://typing-extensions.readthedocs.io/en/latest", None), ("https://typing-extensions.readthedocs.io/en/latest", None),
"aiohttp": ("https://docs.aiohttp.org/en/stable", None),
"pillow": ("https://pillow.readthedocs.io/en/stable", None), "pillow": ("https://pillow.readthedocs.io/en/stable", None),
"numpy": ("https://numpy.org/doc/stable", None), "numpy": ("https://numpy.org/doc/stable", None),
"torch": ("https://pytorch.org/docs/stable", None), "torch": ("https://pytorch.org/docs/stable", None),

View File

@ -0,0 +1,5 @@
Pooling Parameters
==================
.. autoclass:: vllm.PoolingParams
:members:

View File

@ -138,10 +138,10 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
A more detailed client example can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`__. A more detailed client example can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`__.
OpenAI Chat API with vLLM OpenAI Chat Completions API with vLLM
~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
vLLM is designed to also support the OpenAI Chat API. The chat interface is a more dynamic, interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations. vLLM is designed to also support the OpenAI Chat Completions API. The chat interface is a more dynamic, interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations.
You can use the `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_ endpoint to interact with the model: You can use the `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_ endpoint to interact with the model:
@ -157,7 +157,7 @@ You can use the `create chat completion <https://platform.openai.com/docs/api-re
$ ] $ ]
$ }' $ }'
Alternatively, you can use the `openai` python package: Alternatively, you can use the ``openai`` python package:
.. code-block:: python .. code-block:: python

View File

@ -134,6 +134,7 @@ Documentation
:caption: Developer Documentation :caption: Developer Documentation
dev/sampling_params dev/sampling_params
dev/pooling_params
dev/offline_inference/offline_index dev/offline_inference/offline_index
dev/engine/engine_index dev/engine/engine_index
dev/kernel/paged_attention dev/kernel/paged_attention

View File

@ -185,7 +185,7 @@ Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruc
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
.. important:: .. important::
Since OpenAI Vision API is based on `Chat Completions <https://platform.openai.com/docs/api-reference/chat>`_ API, Since OpenAI Vision API is based on `Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`_,
a chat template is **required** to launch the API server. a chat template is **required** to launch the API server.
Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it. Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it.
@ -243,6 +243,10 @@ To consume the server, you can use the OpenAI client like in the example below:
A full code example can be found in `examples/openai_api_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_api_client_for_multimodal.py>`_. A full code example can be found in `examples/openai_api_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_api_client_for_multimodal.py>`_.
.. tip::
There is no need to place image placeholders in the text content of the API request - they are already represented by the image content.
In fact, you can place image placeholders in the middle of the text by interleaving text and image content.
.. note:: .. note::
By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable: By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable:
@ -251,5 +255,49 @@ A full code example can be found in `examples/openai_api_client_for_multimodal.p
$ export VLLM_IMAGE_FETCH_TIMEOUT=<timeout> $ export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
.. note:: Chat Embeddings API
There is no need to format the prompt in the API request since it will be handled by the server. ^^^^^^^^^^^^^^^^^^^
vLLM's Chat Embeddings API is a superset of OpenAI's `Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`_,
where a list of ``messages`` can be passed instead of batched ``inputs``. This enables multi-modal inputs to be passed to embedding models.
.. tip::
The schema of ``messages`` is exactly the same as in Chat Completions API.
In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model.
.. code-block:: bash
vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \
--trust-remote-code --max-model-len 4096
.. important::
Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding``
to run this model in embedding mode instead of text generation mode.
Since this schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library:
.. code-block:: python
import requests
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
response = requests.post(
"http://localhost:8000/v1/embeddings",
json={
"model": "TIGER-Lab/VLM2Vec-Full",
"messages": [{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "Represent the given image."},
],
}],
"encoding_format": "float",
},
)
response.raise_for_status()
response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"])

View File

@ -26,13 +26,26 @@ print(completion.choices[0].message)
``` ```
## API Reference ## API Reference
Please see the [OpenAI API Reference](https://platform.openai.com/docs/api-reference) for more information on the API. We support all parameters except:
- Chat: `tools`, and `tool_choice`.
- Completions: `suffix`.
vLLM also provides experimental support for OpenAI Vision API compatible inference. See more details in [Using VLMs](../models/vlm.rst). We currently support the following OpenAI APIs:
- [Completions API](https://platform.openai.com/docs/api-reference/completions)
- *Note: `suffix` parameter is not supported.*
- [Chat Completions API](https://platform.openai.com/docs/api-reference/chat)
- [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Using VLMs](../models/vlm.rst).
- *Note: `image_url.detail` parameter is not supported.*
- We also support `audio_url` content type for audio files.
- Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema.
- *TODO: Support `input_audio` content type as defined [here](https://github.com/openai/openai-python/blob/v1.52.2/src/openai/types/chat/chat_completion_content_part_input_audio_param.py).*
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
- [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings)
- Instead of `inputs`, you can pass in a list of `messages` (same schema as Chat Completions API),
which will be treated as a single prompt to the model according to its chat template.
- This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst).
- *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.*
## Extra Parameters ## Extra Parameters
vLLM supports a set of parameters that are not part of the OpenAI API. vLLM supports a set of parameters that are not part of the OpenAI API.
In order to use them, you can pass them as extra parameters in the OpenAI client. In order to use them, you can pass them as extra parameters in the OpenAI client.
Or directly merge them into the JSON payload if you are using HTTP call directly. Or directly merge them into the JSON payload if you are using HTTP call directly.
@ -49,7 +62,26 @@ completion = client.chat.completions.create(
) )
``` ```
### Extra Parameters for Chat API ### Extra Parameters for Completions API
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-completion-sampling-params
:end-before: end-completion-sampling-params
```
The following extra parameters are supported:
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-completion-extra-params
:end-before: end-completion-extra-params
```
### Extra Parameters for Chat Completions API
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
@ -66,21 +98,22 @@ The following extra parameters are supported:
:end-before: end-chat-completion-extra-params :end-before: end-chat-completion-extra-params
``` ```
### Extra Parameters for Completions API ### Extra Parameters for Embeddings API
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
The following [pooling parameters (click through to see documentation)](../dev/pooling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python :language: python
:start-after: begin-completion-sampling-params :start-after: begin-embedding-pooling-params
:end-before: end-completion-sampling-params :end-before: end-embedding-pooling-params
``` ```
The following extra parameters are supported: The following extra parameters are supported:
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python :language: python
:start-after: begin-completion-extra-params :start-after: begin-embedding-extra-params
:end-before: end-completion-extra-params :end-before: end-embedding-extra-params
``` ```
## Chat Template ## Chat Template

View File

@ -1,7 +1,6 @@
from http import HTTPStatus from http import HTTPStatus
from typing import List from typing import List
import openai
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import requests import requests
@ -83,10 +82,8 @@ async def client(server):
indirect=True, indirect=True,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_show_version(client: openai.AsyncOpenAI): async def test_show_version(server: RemoteOpenAIServer):
base_url = str(client.base_url)[:-3].strip("/") response = requests.get(server.url_for("version"))
response = requests.get(base_url + "/version")
response.raise_for_status() response.raise_for_status()
assert response.json() == {"version": VLLM_VERSION} assert response.json() == {"version": VLLM_VERSION}
@ -102,9 +99,7 @@ async def test_show_version(client: openai.AsyncOpenAI):
indirect=True, indirect=True,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_health(client: openai.AsyncOpenAI): async def test_check_health(server: RemoteOpenAIServer):
base_url = str(client.base_url)[:-3].strip("/") response = requests.get(server.url_for("health"))
response = requests.get(base_url + "/health")
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK

View File

@ -4,14 +4,18 @@ import numpy as np
import openai import openai
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import requests
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def embedding_server(): def server():
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
@ -19,31 +23,29 @@ def embedding_server():
"--enforce-eager", "--enforce-eager",
"--max-model-len", "--max-model-len",
"8192", "8192",
"--chat-template",
DUMMY_CHAT_TEMPLATE,
] ]
with RemoteOpenAIServer(EMBEDDING_MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def embedding_client(embedding_server): async def client(server):
async with embedding_server.get_async_client() as async_client: async with server.get_async_client() as async_client:
yield async_client yield async_client
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize("model_name", [MODEL_NAME])
"model_name", async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
model_name: str):
input_texts = [ input_texts = [
"The chef prepared a delicious meal.", "The chef prepared a delicious meal.",
] ]
# test single embedding # test single embedding
embeddings = await embedding_client.embeddings.create( embeddings = await client.embeddings.create(
model=model_name, model=model_name,
input=input_texts, input=input_texts,
encoding_format="float", encoding_format="float",
@ -57,7 +59,7 @@ async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
# test using token IDs # test using token IDs
input_tokens = [1, 1, 1, 1, 1] input_tokens = [1, 1, 1, 1, 1]
embeddings = await embedding_client.embeddings.create( embeddings = await client.embeddings.create(
model=model_name, model=model_name,
input=input_tokens, input=input_tokens,
encoding_format="float", encoding_format="float",
@ -71,18 +73,14 @@ async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize("model_name", [MODEL_NAME])
"model_name", async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
[EMBEDDING_MODEL_NAME],
)
async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
model_name: str):
# test List[str] # test List[str]
input_texts = [ input_texts = [
"The cat sat on the mat.", "A feline was resting on a rug.", "The cat sat on the mat.", "A feline was resting on a rug.",
"Stars twinkle brightly in the night sky." "Stars twinkle brightly in the night sky."
] ]
embeddings = await embedding_client.embeddings.create( embeddings = await client.embeddings.create(
model=model_name, model=model_name,
input=input_texts, input=input_texts,
encoding_format="float", encoding_format="float",
@ -90,11 +88,14 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
assert embeddings.id is not None assert embeddings.id is not None
assert len(embeddings.data) == 3 assert len(embeddings.data) == 3
assert len(embeddings.data[0].embedding) == 4096 assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 32
assert embeddings.usage.total_tokens == 32
# test List[List[int]] # test List[List[int]]
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
[25, 32, 64, 77]] [25, 32, 64, 77]]
embeddings = await embedding_client.embeddings.create( embeddings = await client.embeddings.create(
model=model_name, model=model_name,
input=input_tokens, input=input_tokens,
encoding_format="float", encoding_format="float",
@ -108,22 +109,70 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize("model_name", [MODEL_NAME])
"model_name", async def test_conversation_embedding(server: RemoteOpenAIServer,
[EMBEDDING_MODEL_NAME], client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "user",
"content": "The cat sat on the mat.",
}, {
"role": "assistant",
"content": "A feline was resting on a rug.",
}, {
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
}]
chat_response = requests.post(server.url_for("v1/embeddings"),
json={
"model": model_name,
"messages": messages,
"encoding_format": "float",
})
chat_response.raise_for_status()
chat_embeddings = chat_response.json()
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
prompt = tokenizer.apply_chat_template(
messages,
chat_template=DUMMY_CHAT_TEMPLATE,
add_generation_prompt=True,
continue_final_message=False,
tokenize=False,
) )
async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI, completion_response = await client.embeddings.create(
model=model_name,
input=prompt,
encoding_format="float",
# To be consistent with chat
extra_body={"add_special_tokens": False},
)
completion_embeddings = completion_response.model_dump(mode="json")
assert chat_embeddings.pop("id") is not None
assert completion_embeddings.pop("id") is not None
assert chat_embeddings.pop("created") <= completion_embeddings.pop(
"created")
assert chat_embeddings == completion_embeddings
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
model_name: str): model_name: str):
input_texts = [ input_texts = [
"Hello my name is", "Hello my name is",
"The best thing about vLLM is that it supports many different models" "The best thing about vLLM is that it supports many different models"
] ]
responses_float = await embedding_client.embeddings.create( responses_float = await client.embeddings.create(input=input_texts,
input=input_texts, model=model_name, encoding_format="float") model=model_name,
encoding_format="float")
responses_base64 = await embedding_client.embeddings.create( responses_base64 = await client.embeddings.create(input=input_texts,
input=input_texts, model=model_name, encoding_format="base64") model=model_name,
encoding_format="base64")
decoded_responses_base64_data = [] decoded_responses_base64_data = []
for data in responses_base64.data: for data in responses_base64.data:
@ -137,8 +186,8 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
1] 1]
# Default response is float32 decoded from base64 by OpenAI Client # Default response is float32 decoded from base64 by OpenAI Client
responses_default = await embedding_client.embeddings.create( responses_default = await client.embeddings.create(input=input_texts,
input=input_texts, model=model_name) model=model_name)
assert responses_float.data[0].embedding == responses_default.data[ assert responses_float.data[0].embedding == responses_default.data[
0].embedding 0].embedding
@ -147,18 +196,15 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize("model_name", [MODEL_NAME])
"model_name", async def test_single_embedding_truncation(client: openai.AsyncOpenAI,
[EMBEDDING_MODEL_NAME], model_name: str):
)
async def test_single_embedding_truncation(
embedding_client: openai.AsyncOpenAI, model_name: str):
input_texts = [ input_texts = [
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
] ]
# test single embedding # test single embedding
embeddings = await embedding_client.embeddings.create( embeddings = await client.embeddings.create(
model=model_name, model=model_name,
input=input_texts, input=input_texts,
extra_body={"truncate_prompt_tokens": 10}) extra_body={"truncate_prompt_tokens": 10})
@ -173,7 +219,7 @@ async def test_single_embedding_truncation(
1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728, 1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
9901, 340, 2229, 385, 340, 315, 28741, 28804, 2 9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
] ]
embeddings = await embedding_client.embeddings.create( embeddings = await client.embeddings.create(
model=model_name, model=model_name,
input=input_tokens, input=input_tokens,
extra_body={"truncate_prompt_tokens": 10}) extra_body={"truncate_prompt_tokens": 10})
@ -187,18 +233,15 @@ async def test_single_embedding_truncation(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize("model_name", [MODEL_NAME])
"model_name", async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
[EMBEDDING_MODEL_NAME], model_name: str):
)
async def test_single_embedding_truncation_invalid(
embedding_client: openai.AsyncOpenAI, model_name: str):
input_texts = [ input_texts = [
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
] ]
with pytest.raises(openai.BadRequestError): with pytest.raises(openai.BadRequestError):
embeddings = await embedding_client.embeddings.create( embeddings = await client.embeddings.create(
model=model_name, model=model_name,
input=input_texts, input=input_texts,
extra_body={"truncate_prompt_tokens": 8193}) extra_body={"truncate_prompt_tokens": 8193})

View File

@ -79,9 +79,8 @@ EXPECTED_VALUES = {
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_metrics_counts(client: openai.AsyncOpenAI): async def test_metrics_counts(server: RemoteOpenAIServer,
base_url = str(client.base_url)[:-3].strip("/") client: openai.AsyncClient):
for _ in range(_NUM_REQUESTS): for _ in range(_NUM_REQUESTS):
# sending a request triggers the metrics to be logged. # sending a request triggers the metrics to be logged.
await client.completions.create( await client.completions.create(
@ -89,7 +88,7 @@ async def test_metrics_counts(client: openai.AsyncOpenAI):
prompt=_TOKENIZED_PROMPT, prompt=_TOKENIZED_PROMPT,
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST) max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST)
response = requests.get(base_url + "/metrics") response = requests.get(server.url_for("metrics"))
print(response.text) print(response.text)
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
@ -170,16 +169,15 @@ EXPECTED_METRICS = [
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_metrics_exist(client: openai.AsyncOpenAI): async def test_metrics_exist(server: RemoteOpenAIServer,
base_url = str(client.base_url)[:-3].strip("/") client: openai.AsyncClient):
# sending a request triggers the metrics to be logged. # sending a request triggers the metrics to be logged.
await client.completions.create(model=MODEL_NAME, await client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is", prompt="Hello, my name is",
max_tokens=5, max_tokens=5,
temperature=0.0) temperature=0.0)
response = requests.get(base_url + "/metrics") response = requests.get(server.url_for("metrics"))
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
for metric in EXPECTED_METRICS: for metric in EXPECTED_METRICS:

View File

@ -1,4 +1,3 @@
import openai # use the official client for correctness check
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import requests import requests
@ -55,9 +54,11 @@ async def client(server):
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"], indirect=["tokenizer_name"],
) )
async def test_tokenize_completions(client: openai.AsyncOpenAI, async def test_tokenize_completions(
model_name: str, tokenizer_name: str): server: RemoteOpenAIServer,
base_url = str(client.base_url)[:-3].strip("/") model_name: str,
tokenizer_name: str,
):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast") tokenizer_mode="fast")
@ -65,7 +66,7 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,
prompt = "vllm1 This is a test prompt." prompt = "vllm1 This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special) tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize", response = requests.post(server.url_for("tokenize"),
json={ json={
"add_special_tokens": add_special, "add_special_tokens": add_special,
"model": model_name, "model": model_name,
@ -86,9 +87,11 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"], indirect=["tokenizer_name"],
) )
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str, async def test_tokenize_chat(
tokenizer_name: str): server: RemoteOpenAIServer,
base_url = str(client.base_url)[:-3].strip("/") model_name: str,
tokenizer_name: str,
):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast") tokenizer_mode="fast")
@ -121,7 +124,7 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
tokens = tokenizer.encode(prompt, tokens = tokenizer.encode(prompt,
add_special_tokens=add_special) add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize", response = requests.post(server.url_for("tokenize"),
json={ json={
"add_generation_prompt": "add_generation_prompt":
add_generation, add_generation,
@ -146,17 +149,18 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"], indirect=["tokenizer_name"],
) )
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str, async def test_detokenize(
tokenizer_name: str): server: RemoteOpenAIServer,
base_url = str(client.base_url)[:-3].strip("/") model_name: str,
tokenizer_name: str,
):
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast") tokenizer_mode="fast")
prompt = "This is a test prompt. vllm1" prompt = "This is a test prompt. vllm1"
tokens = tokenizer.encode(prompt, add_special_tokens=False) tokens = tokenizer.encode(prompt, add_special_tokens=False)
print(f"CALLING {base_url} FOR {model_name}") response = requests.post(server.url_for("detokenize"),
response = requests.post(base_url + "/detokenize",
json={ json={
"model": model_name, "model": model_name,
"tokens": tokens "tokens": tokens

View File

@ -0,0 +1,94 @@
from typing import Dict
import pytest
import pytest_asyncio
import requests
from vllm.multimodal.utils import encode_image_base64, fetch_image
from ...utils import RemoteOpenAIServer
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
MAXIMUM_IMAGES = 2
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
]
@pytest.fixture(scope="module")
def server():
args = [
"--task",
"embedding",
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"5",
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.fixture(scope="session")
def base64_encoded_image() -> Dict[str, str]:
return {
image_url: encode_image_base64(fetch_image(image_url))
for image_url in TEST_IMAGE_URLS
}
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
image_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "Represent the given image."
},
],
}]
response = requests.post(server.url_for("v1/embeddings"),
json={
"model": model_name,
"messages": messages,
"encoding_format": "float"
})
response.raise_for_status()
embeddings = response.json()
assert embeddings["id"] is not None
assert len(embeddings["data"]) == 1
assert len(embeddings["data"][0]["embedding"]) == 3072
assert embeddings["usage"]["completion_tokens"] == 0
assert embeddings["usage"]["prompt_tokens"] == 771
assert embeddings["usage"]["total_tokens"] == 771

View File

@ -11,7 +11,7 @@ from argparse import Namespace
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from http import HTTPStatus from http import HTTPStatus
from typing import AsyncIterator, Set from typing import AsyncIterator, Optional, Set
import uvloop import uvloop
from fastapi import APIRouter, FastAPI, Request from fastapi import APIRouter, FastAPI, Request
@ -51,7 +51,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization) OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParserManager
@ -248,22 +248,27 @@ def mount_metrics(app: FastAPI):
app.routes.append(metrics_route) app.routes.append(metrics_route)
def chat(request: Request) -> OpenAIServingChat: def base(request: Request) -> OpenAIServing:
# Reuse the existing instance
return tokenization(request)
def chat(request: Request) -> Optional[OpenAIServingChat]:
return request.app.state.openai_serving_chat return request.app.state.openai_serving_chat
def completion(request: Request) -> OpenAIServingCompletion: def completion(request: Request) -> Optional[OpenAIServingCompletion]:
return request.app.state.openai_serving_completion return request.app.state.openai_serving_completion
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding
def tokenization(request: Request) -> OpenAIServingTokenization: def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization return request.app.state.openai_serving_tokenization
def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding
def engine_client(request: Request) -> EngineClient: def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client return request.app.state.engine_client
@ -277,7 +282,9 @@ async def health(raw_request: Request) -> Response:
@router.post("/tokenize") @router.post("/tokenize")
async def tokenize(request: TokenizeRequest, raw_request: Request): async def tokenize(request: TokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_tokenize(request) handler = tokenization(raw_request)
generator = await handler.create_tokenize(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
@ -289,7 +296,9 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
@router.post("/detokenize") @router.post("/detokenize")
async def detokenize(request: DetokenizeRequest, raw_request: Request): async def detokenize(request: DetokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_detokenize(request) handler = tokenization(raw_request)
generator = await handler.create_detokenize(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
@ -301,7 +310,9 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
@router.get("/v1/models") @router.get("/v1/models")
async def show_available_models(raw_request: Request): async def show_available_models(raw_request: Request):
models = await completion(raw_request).show_available_models() handler = base(raw_request)
models = await handler.show_available_models()
return JSONResponse(content=models.model_dump()) return JSONResponse(content=models.model_dump())
@ -314,9 +325,12 @@ async def show_version():
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest, async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request): raw_request: Request):
handler = chat(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Chat Completions API")
generator = await chat(raw_request).create_chat_completion( generator = await handler.create_chat_completion(request, raw_request)
request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
@ -330,8 +344,12 @@ async def create_chat_completion(request: ChatCompletionRequest,
@router.post("/v1/completions") @router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await completion(raw_request).create_completion( handler = completion(raw_request)
request, raw_request) if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
@ -343,8 +361,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@router.post("/v1/embeddings") @router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await embedding(raw_request).create_embedding( handler = embedding(raw_request)
request, raw_request) if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
generator = await handler.create_embedding(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
@ -382,12 +404,10 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@router.post("/v1/load_lora_adapter") @router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest, async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request): raw_request: Request):
response = await chat(raw_request).load_lora_adapter(request) for route in [chat, completion, embedding]:
if isinstance(response, ErrorResponse): handler = route(raw_request)
return JSONResponse(content=response.model_dump(), if handler is not None:
status_code=response.code) response = await handler.load_lora_adapter(request)
response = await completion(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.code)
@ -397,12 +417,10 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@router.post("/v1/unload_lora_adapter") @router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest, async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request): raw_request: Request):
response = await chat(raw_request).unload_lora_adapter(request) for route in [chat, completion, embedding]:
if isinstance(response, ErrorResponse): handler = route(raw_request)
return JSONResponse(content=response.model_dump(), if handler is not None:
status_code=response.code) response = await handler.unload_lora_adapter(request)
response = await completion(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.code)
@ -501,7 +519,8 @@ def init_app_state(
chat_template=args.chat_template, chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice, enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser) tool_parser=args.tool_call_parser,
) if model_config.task == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion( state.openai_serving_completion = OpenAIServingCompletion(
engine_client, engine_client,
model_config, model_config,
@ -510,13 +529,14 @@ def init_app_state(
prompt_adapters=args.prompt_adapters, prompt_adapters=args.prompt_adapters,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) ) if model_config.task == "generate" else None
state.openai_serving_embedding = OpenAIServingEmbedding( state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client, engine_client,
model_config, model_config,
base_model_paths, base_model_paths,
request_logger=request_logger, request_logger=request_logger,
) chat_template=args.chat_template,
) if model_config.task == "embedding" else None
state.openai_serving_tokenization = OpenAIServingTokenization( state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client, engine_client,
model_config, model_config,

View File

@ -708,7 +708,7 @@ class CompletionRequest(OpenAIBaseModel):
return data return data
class EmbeddingRequest(OpenAIBaseModel): class EmbeddingCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings # https://platform.openai.com/docs/api-reference/embeddings
model: str model: str
@ -720,10 +720,15 @@ class EmbeddingRequest(OpenAIBaseModel):
# doc: begin-embedding-pooling-params # doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None additional_data: Optional[Any] = None
# doc: end-embedding-pooling-params # doc: end-embedding-pooling-params
# doc: begin-embedding-extra-params # doc: begin-embedding-extra-params
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
priority: int = Field( priority: int = Field(
default=0, default=0,
description=( description=(
@ -737,6 +742,82 @@ class EmbeddingRequest(OpenAIBaseModel):
return PoolingParams(additional_data=self.additional_data) return PoolingParams(additional_data=self.additional_data)
class EmbeddingChatRequest(OpenAIBaseModel):
model: str
messages: List[ChatCompletionMessageParam]
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: begin-chat-embedding-pooling-params
additional_data: Optional[Any] = None
# doc: end-chat-embedding-pooling-params
# doc: begin-chat-embedding-extra-params
add_generation_prompt: bool = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
continue_final_message: bool = Field(
default=False,
description=
("If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
"This allows you to \"prefill\" part of the model's response for it. "
"Cannot be used at the same time as `add_generation_prompt`."),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."),
)
chat_template: Optional[str] = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-chat-embedding-extra-params
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
class CompletionLogProbs(OpenAIBaseModel): class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list)
@ -799,7 +880,7 @@ class EmbeddingResponseData(OpenAIBaseModel):
class EmbeddingResponse(OpenAIBaseModel): class EmbeddingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
object: str = "list" object: str = "list"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str

View File

@ -217,13 +217,14 @@ async def main(args):
prompt_adapters=None, prompt_adapters=None,
request_logger=request_logger, request_logger=request_logger,
chat_template=None, chat_template=None,
) ) if model_config.task == "generate" else None
openai_serving_embedding = OpenAIServingEmbedding( openai_serving_embedding = OpenAIServingEmbedding(
engine, engine,
model_config, model_config,
base_model_paths, base_model_paths,
request_logger=request_logger, request_logger=request_logger,
) chat_template=None,
) if model_config.task == "embedding" else None
tracker = BatchProgressTracker() tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file) logger.info("Reading batch from %s...", args.input_file)
@ -240,14 +241,31 @@ async def main(args):
# Determine the type of request and run it. # Determine the type of request and run it.
if request.url == "/v1/chat/completions": if request.url == "/v1/chat/completions":
handler_fn = (None if openai_serving_chat is None else
openai_serving_chat.create_chat_completion)
if handler_fn is None:
response_futures.append( response_futures.append(
run_request(openai_serving_chat.create_chat_completion, make_async_error_request_output(
request, tracker)) request,
error_msg=
"The model does not support Chat Completions API",
))
continue
response_futures.append(run_request(handler_fn, request, tracker))
tracker.submitted() tracker.submitted()
elif request.url == "/v1/embeddings": elif request.url == "/v1/embeddings":
handler_fn = (None if openai_serving_embedding is None else
openai_serving_embedding.create_embedding)
if handler_fn is None:
response_futures.append( response_futures.append(
run_request(openai_serving_embedding.create_embedding, request, make_async_error_request_output(
tracker)) request,
error_msg="The model does not support Embeddings API",
))
continue
response_futures.append(run_request(handler_fn, request, tracker))
tracker.submitted() tracker.submitted()
else: else:
response_futures.append( response_futures.append(

View File

@ -10,11 +10,7 @@ from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage, from vllm.entrypoints.chat_utils import ConversationMessage, load_chat_template
apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProb, ChatCompletionLogProbs,
@ -27,16 +23,12 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import (BaseModelPath, from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath, LoRAModulePath,
OpenAIServing, OpenAIServing,
PromptAdapterPath, PromptAdapterPath)
TextTokensPrompt)
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import iterate_with_cancellation from vllm.utils import iterate_with_cancellation
@ -94,12 +86,12 @@ class OpenAIServingChat(OpenAIServing):
raw_request: Optional[Request] = None, raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
ErrorResponse]: ErrorResponse]:
"""Completion API similar to OpenAI's API. """
Chat Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI for the API specification. This API mimics the OpenAI
Chat Completion API. Chat Completion API.
""" """
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
@ -118,49 +110,8 @@ class OpenAIServingChat(OpenAIServing):
prompt_adapter_request, prompt_adapter_request,
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
model_config = self.model_config
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
tool_parser = self.tool_parser
conversation, mm_data_future = parse_chat_messages_futures(
request.messages, model_config, tokenizer)
tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
]
prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
except Exception as e:
logger.exception("Error in applying chat template from request")
return self.create_error_response(str(e))
try:
mm_data = await mm_data_future
except Exception as e:
logger.exception("Error in loading multi-modal data")
return self.create_error_response(str(e))
# validation for OpenAI tools # validation for OpenAI tools
# tool_choice = "required" is not supported # tool_choice = "required" is not supported
@ -168,45 +119,55 @@ class OpenAIServingChat(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"tool_choice = \"required\" is not supported!") "tool_choice = \"required\" is not supported!")
if not is_mistral_tokenizer and request.tool_choice == "auto" and not ( if (request.tool_choice == "auto" and
self.enable_auto_tools and self.tool_parser is not None): not (self.enable_auto_tools and tool_parser is not None)
and not isinstance(tokenizer, MistralTokenizer)):
# for hf tokenizers, "auto" tools requires # for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser # --enable-auto-tool-choice and --tool-call-parser
return self.create_error_response( return self.create_error_response(
"\"auto\" tool choice requires " "\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set") "--enable-auto-tool-choice and --tool-call-parser to be set"
)
request_id = f"chat-{request.request_id}" tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
]
(
conversation,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=tool_dicts,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
request_id = f"chatcmpl-{request.request_id}"
request_metadata = RequestResponseMetadata(request_id=request_id) request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request: if raw_request:
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try: try:
if self.enable_auto_tools and self.tool_parser: for i, engine_prompt in enumerate(engine_prompts):
request = self.tool_parser(tokenizer).adjust_request(
request=request)
if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
assert isinstance(prompt, list) and isinstance(
prompt[0], int
), "Prompt has to be either a string or a list of token ids"
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
assert prompt_inputs is not None
sampling_params: Union[SamplingParams, BeamSearchParams] sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len( default_max_tokens = self.max_model_len - len(
prompt_inputs["prompt_token_ids"]) engine_prompt["prompt_token_ids"])
if request.use_beam_search: if request.use_beam_search:
sampling_params = request.to_beam_search_params( sampling_params = request.to_beam_search_params(
default_max_tokens) default_max_tokens)
@ -215,35 +176,24 @@ class OpenAIServingChat(OpenAIServing):
default_max_tokens) default_max_tokens)
self._log_inputs(request_id, self._log_inputs(request_id,
prompt_inputs, request_prompts[i],
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
engine_inputs = TokensPrompt( trace_headers = (None if raw_request is None else await
prompt_token_ids=prompt_inputs["prompt_token_ids"]) self._get_trace_headers(raw_request.headers))
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = (await
self.engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
if (not is_tracing_enabled and raw_request
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
result_generator = self.engine_client.beam_search( generator = self.engine_client.beam_search(
prompt=engine_inputs, prompt=engine_prompt,
model_config=self.model_config, model_config=self.model_config,
request_id=request_id, request_id=request_id,
params=sampling_params, params=sampling_params,
) )
else: else:
result_generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_inputs, engine_prompt,
sampling_params, sampling_params,
request_id, request_id,
lora_request=lora_request, lora_request=lora_request,
@ -251,10 +201,15 @@ class OpenAIServingChat(OpenAIServing):
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=request.priority, priority=request.priority,
) )
generators.append(generator)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
assert len(generators) == 1
result_generator, = generators
if raw_request: if raw_request:
result_generator = iterate_with_cancellation( result_generator = iterate_with_cancellation(
result_generator, raw_request.is_disconnected) result_generator, raw_request.is_disconnected)
@ -626,6 +581,9 @@ class OpenAIServingChat(OpenAIServing):
final_res = res final_res = res
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
assert final_res is not None assert final_res is not None

View File

@ -1,7 +1,6 @@
import asyncio import asyncio
import time import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Tuple, Union, cast from typing import Tuple, Union, cast
@ -30,18 +29,11 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
class OpenAIServingCompletion(OpenAIServing): class OpenAIServingCompletion(OpenAIServing):
@ -101,8 +93,6 @@ class OpenAIServingCompletion(OpenAIServing):
if raw_request: if raw_request:
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try: try:
( (
lora_request, lora_request,
@ -111,19 +101,24 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
prompts = list( request_prompts, engine_prompts = self._preprocess_completion(
self._tokenize_prompt_input_or_inputs(
request, request,
tokenizer, tokenizer,
request.prompt, request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens, truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
)) )
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
for i, prompt_inputs in enumerate(prompts): # Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams] sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len( default_max_tokens = self.max_model_len - len(
prompt_inputs["prompt_token_ids"]) engine_prompt["prompt_token_ids"])
if request.use_beam_search: if request.use_beam_search:
sampling_params = request.to_beam_search_params( sampling_params = request.to_beam_search_params(
default_max_tokens) default_max_tokens)
@ -134,36 +129,24 @@ class OpenAIServingCompletion(OpenAIServing):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
prompt_inputs, request_prompts[i],
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = (await trace_headers = (await
self.engine_client.is_tracing_enabled()) self._get_trace_headers(raw_request.headers))
trace_headers = None
if is_tracing_enabled:
trace_headers = extract_trace_headers(raw_request.headers)
if not is_tracing_enabled and contains_trace_headers(
raw_request.headers):
log_tracing_disabled_warning()
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search( generator = self.engine_client.beam_search(
prompt={ prompt=engine_prompt,
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
model_config=self.model_config, model_config=self.model_config,
request_id=request_id, request_id=request_id,
params=sampling_params, params=sampling_params,
) )
else: else:
generator = self.engine_client.generate( generator = self.engine_client.generate(
{ engine_prompt,
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
sampling_params, sampling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
@ -180,6 +163,8 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators( result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected) *generators, is_cancelled=raw_request.is_disconnected)
num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use # results. In addition, we do not stream the results when use
# beam search. # beam search.
@ -195,16 +180,22 @@ class OpenAIServingCompletion(OpenAIServing):
request_id, request_id,
created_time, created_time,
model_name, model_name,
num_prompts=len(prompts), num_prompts=num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
request_metadata=request_metadata) request_metadata=request_metadata)
# Non-streaming response # Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) final_res_batch: List[Optional[RequestOutput]] = [None] * num_prompts
try: try:
async for i, res in result_generator: async for i, res in result_generator:
final_res_batch[i] = res final_res_batch[i] = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
try:
for i, final_res in enumerate(final_res_batch): for i, final_res in enumerate(final_res_batch):
assert final_res is not None assert final_res is not None
@ -212,7 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
# We did not pass it into vLLM engine to avoid being redundant # We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs # with the inputs token IDs
if final_res.prompt is None: if final_res.prompt is None:
final_res.prompt = prompts[i]["prompt"] final_res.prompt = request_prompts[i]["prompt"]
final_res_batch_checked = cast(List[RequestOutput], final_res_batch_checked = cast(List[RequestOutput],
final_res_batch) final_res_batch)
@ -226,8 +217,6 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer, tokenizer,
request_metadata, request_metadata,
) )
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))

View File

@ -9,8 +9,10 @@ from typing_extensions import assert_never
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest, from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
EmbeddingRequest,
EmbeddingResponse, EmbeddingResponse,
EmbeddingResponseData, EmbeddingResponseData,
ErrorResponse, UsageInfo) ErrorResponse, UsageInfo)
@ -21,8 +23,6 @@ from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
TypeTokenIDs = List[int]
def _get_embedding( def _get_embedding(
output: EmbeddingOutput, output: EmbeddingOutput,
@ -76,6 +76,7 @@ class OpenAIServingEmbedding(OpenAIServing):
base_model_paths: List[BaseModelPath], base_model_paths: List[BaseModelPath],
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
chat_template: Optional[str],
): ):
super().__init__(engine_client=engine_client, super().__init__(engine_client=engine_client,
model_config=model_config, model_config=model_config,
@ -83,21 +84,20 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_modules=None, lora_modules=None,
prompt_adapters=None, prompt_adapters=None,
request_logger=request_logger) request_logger=request_logger)
self._enabled = self._check_embedding_mode(
model_config.task == "embedding") self.chat_template = load_chat_template(chat_template)
async def create_embedding( async def create_embedding(
self, self,
request: EmbeddingRequest, request: EmbeddingRequest,
raw_request: Optional[Request] = None, raw_request: Optional[Request] = None,
) -> Union[EmbeddingResponse, ErrorResponse]: ) -> Union[EmbeddingResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API. """
Embedding API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API. for the API specification. This API mimics the OpenAI Embedding API.
""" """
if not self._enabled:
return self.create_error_response("Embedding API disabled")
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@ -122,8 +122,6 @@ class OpenAIServingEmbedding(OpenAIServing):
"greater than max_model_len." "greater than max_model_len."
" Please, select a smaller truncation size.") " Please, select a smaller truncation size.")
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
try: try:
( (
lora_request, lora_request,
@ -132,32 +130,60 @@ class OpenAIServingEmbedding(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for embedding models")
if isinstance(request, EmbeddingChatRequest):
(
_,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
prompts = list( for i, engine_prompt in enumerate(engine_prompts):
self._tokenize_prompt_input_or_inputs(request, tokenizer,
request.input,
truncate_prompt_tokens))
for i, prompt_inputs in enumerate(prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
prompt_inputs, request_prompts[i],
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None: trace_headers = (None if raw_request is None else await
raise NotImplementedError( self._get_trace_headers(raw_request.headers))
"Prompt adapter is not supported "
"for embedding models")
generator = self.engine_client.encode( generator = self.engine_client.encode(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, engine_prompt,
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority, priority=request.priority,
) )
@ -171,13 +197,18 @@ class OpenAIServingEmbedding(OpenAIServing):
is_cancelled=raw_request.is_disconnected if raw_request else None, is_cancelled=raw_request.is_disconnected if raw_request else None,
) )
num_prompts = len(engine_prompts)
# Non-streaming response # Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]] final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts) final_res_batch = [None] * num_prompts
try: try:
async for i, res in result_generator: async for i, res in result_generator:
final_res_batch[i] = res final_res_batch[i] = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
try:
for final_res in final_res_batch: for final_res in final_res_batch:
assert final_res is not None assert final_res is not None
@ -187,18 +218,8 @@ class OpenAIServingEmbedding(OpenAIServing):
response = request_output_to_embedding_response( response = request_output_to_embedding_response(
final_res_batch_checked, request_id, created_time, model_name, final_res_batch_checked, request_id, created_time, model_name,
encoding_format) encoding_format)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return response return response
def _check_embedding_mode(self, embedding_mode: bool) -> bool:
if not embedding_mode:
logger.warning(
"embedding_mode is False. Embedding API will not work.")
else:
logger.info("Activating the server engine with embedding enabled.")
return embedding_mode

View File

@ -2,28 +2,38 @@ import json
import pathlib import pathlib
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Optional, Sequence, Tuple, TypedDict, Union)
from pydantic import Field from pydantic import Field
from starlette.datastructures import Headers
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, CompletionRequest,
DetokenizeRequest, DetokenizeRequest,
EmbeddingRequest, ErrorResponse, EmbeddingChatRequest,
EmbeddingCompletionRequest,
ErrorResponse,
LoadLoraAdapterRequest, LoadLoraAdapterRequest,
ModelCard, ModelList, ModelCard, ModelList,
ModelPermission, ModelPermission,
TokenizeChatRequest, TokenizeChatRequest,
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeRequest,
UnloadLoraAdapterRequest) UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.tool_parsers import ToolParser
# yapf: enable # yapf: enable
from vllm.inputs import TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -31,8 +41,10 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tracing import (contains_trace_headers, extract_trace_headers,
from vllm.utils import AtomicCounter log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of
logger = init_logger(__name__) logger = init_logger(__name__)
@ -56,8 +68,14 @@ class LoRAModulePath:
base_model_name: Optional[str] = None base_model_name: Optional[str] = None
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest] EmbeddingCompletionRequest,
TokenizeCompletionRequest]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest]
class TextTokensPrompt(TypedDict): class TextTokensPrompt(TypedDict):
@ -65,6 +83,9 @@ class TextTokensPrompt(TypedDict):
prompt_token_ids: List[int] prompt_token_ids: List[int]
RequestPrompt = Union[List[int], str, TextTokensPrompt]
class OpenAIServing: class OpenAIServing:
def __init__( def __init__(
@ -246,7 +267,8 @@ class OpenAIServing:
token_num = len(input_ids) token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens # Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request, EmbeddingRequest): if isinstance(request,
(EmbeddingChatRequest, EmbeddingCompletionRequest)):
if token_num > self.max_model_len: if token_num > self.max_model_len:
raise ValueError( raise ValueError(
f"This model's maximum context length is " f"This model's maximum context length is "
@ -373,10 +395,115 @@ class OpenAIServing:
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
) )
def _preprocess_completion(
self,
request: CompletionLikeRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]:
request_prompts = [
request_prompt
for request_prompt in self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
]
engine_prompts = [
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
for request_prompt in request_prompts
]
return request_prompts, engine_prompts
async def _preprocess_chat(
self,
request: ChatLikeRequest,
tokenizer: AnyTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str] = None,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tool_dicts: Optional[List[Dict[str, Any]]] = None,
documents: Optional[List[Dict[str, str]]] = None,
chat_template_kwargs: Optional[Dict[str, Any]] = None,
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = False,
) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
List[TokensPrompt]]:
conversation, mm_data_future = parse_chat_messages_futures(
messages,
self.model_config,
tokenizer,
)
request_prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
request_prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tool_dicts,
documents=documents,
**(chat_template_kwargs or {}),
)
else:
request_prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tool_dicts,
documents=documents,
**(chat_template_kwargs or {}),
)
mm_data = await mm_data_future
if tool_parser is not None:
if not isinstance(request, ChatCompletionRequest):
msg = "Tool usage is only supported for Chat Completions API"
raise NotImplementedError(msg)
request = tool_parser(tokenizer).adjust_request(request=request)
if isinstance(request_prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
request_prompt,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
# For MistralTokenizer
assert is_list_of(request_prompt, int), (
"Prompt has to be either a string or a list of token ids")
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
return conversation, [request_prompt], [engine_prompt]
def _log_inputs( def _log_inputs(
self, self,
request_id: str, request_id: str,
inputs: Union[str, List[int], TextTokensPrompt], inputs: RequestPrompt,
params: Optional[Union[SamplingParams, PoolingParams, params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]], BeamSearchParams]],
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
@ -404,6 +531,20 @@ class OpenAIServing:
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
async def _get_trace_headers(
self,
headers: Headers,
) -> Optional[Mapping[str, str]]:
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
if is_tracing_enabled:
return extract_trace_headers(headers)
if contains_trace_headers(headers):
log_tracing_disabled_warning()
return None
@staticmethod @staticmethod
def _get_decoded_token(logprob: Logprob, def _get_decoded_token(logprob: Logprob,
token_id: int, token_id: int,

View File

@ -2,10 +2,7 @@ from typing import List, Optional, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (apply_hf_chat_template, from vllm.entrypoints.chat_utils import load_chat_template
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -20,7 +17,6 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath, LoRAModulePath,
OpenAIServing) OpenAIServing)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@ -62,6 +58,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokn-{random_uuid()}" request_id = f"tokn-{random_uuid()}"
try:
( (
lora_request, lora_request,
prompt_adapter_request, prompt_adapter_request,
@ -69,52 +66,43 @@ class OpenAIServingTokenization(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
prompt: Union[str, List[int]]
if isinstance(request, TokenizeChatRequest): if isinstance(request, TokenizeChatRequest):
model_config = self.model_config (
_,
conversation, mm_data_future = parse_chat_messages_futures( request_prompts,
request.messages, model_config, tokenizer) engine_prompts,
) = await self._preprocess_chat(
mm_data = await mm_data_future request,
if mm_data:
logger.warning(
"Multi-modal inputs are ignored during tokenization")
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer, tokenizer,
messages=request.messages, request.messages,
chat_template=self.chat_template, chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message, continue_final_message=request.continue_final_message,
add_special_tokens=request.add_special_tokens,
) )
else: else:
prompt = apply_hf_chat_template( request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer, tokenizer,
conversation=conversation, request.prompt,
chat_template=self.chat_template, add_special_tokens=request.add_special_tokens,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
) )
else: except ValueError as e:
prompt = request.prompt logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
input_ids: List[int] = []
for i, engine_prompt in enumerate(engine_prompts):
self._log_inputs(request_id, self._log_inputs(request_id,
prompt, request_prompts[i],
params=None, params=None,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
# Silently ignore prompt adapter since it does not affect tokenization # Silently ignore prompt adapter since it does not affect
# tokenization (Unlike in Embeddings API where an error is raised)
prompt_input = self._tokenize_prompt_input( input_ids.extend(engine_prompt["prompt_token_ids"])
request,
tokenizer,
prompt,
add_special_tokens=request.add_special_tokens,
)
input_ids = prompt_input["prompt_token_ids"]
return TokenizeResponse(tokens=input_ids, return TokenizeResponse(tokens=input_ids,
count=len(input_ids), count=len(input_ids),
@ -143,9 +131,8 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None: # Silently ignore prompt adapter since it does not affect tokenization
raise NotImplementedError("Prompt adapter is not supported " # (Unlike in Embeddings API where an error is raised)
"for tokenization")
prompt_input = self._tokenize_prompt_input( prompt_input = self._tokenize_prompt_input(
request, request,

View File

@ -7,7 +7,7 @@ class PoolingParams(
msgspec.Struct, msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg] array_like=True): # type: ignore[call-arg]
"""Pooling parameters for pooling. """Pooling parameters for embeddings API.
Attributes: Attributes:
additional_data: Any additional data needed for pooling. additional_data: Any additional data needed for pooling.
@ -16,7 +16,7 @@ class PoolingParams(
def clone(self) -> "PoolingParams": def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance.""" """Returns a deep copy of the PoolingParams instance."""
return PoolingParams(additional_data=self.additional_data, ) return PoolingParams(additional_data=self.additional_data)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"PoolingParams(" return (f"PoolingParams("