[Frontend] Chat-based Embeddings API (#9759)
This commit is contained in:
parent
d3aa2a8b2f
commit
06386a64dd
@ -13,5 +13,7 @@ torch
|
||||
py-cpuinfo
|
||||
transformers
|
||||
mistral_common >= 1.3.4
|
||||
aiohttp
|
||||
starlette
|
||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
@ -96,7 +96,6 @@ def setup(app):
|
||||
|
||||
# Mock out external dependencies here, otherwise the autodoc pages may be blank.
|
||||
autodoc_mock_imports = [
|
||||
"aiohttp",
|
||||
"compressed_tensors",
|
||||
"cpuinfo",
|
||||
"cv2",
|
||||
@ -143,6 +142,7 @@ intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"typing_extensions":
|
||||
("https://typing-extensions.readthedocs.io/en/latest", None),
|
||||
"aiohttp": ("https://docs.aiohttp.org/en/stable", None),
|
||||
"pillow": ("https://pillow.readthedocs.io/en/stable", None),
|
||||
"numpy": ("https://numpy.org/doc/stable", None),
|
||||
"torch": ("https://pytorch.org/docs/stable", None),
|
||||
|
5
docs/source/dev/pooling_params.rst
Normal file
5
docs/source/dev/pooling_params.rst
Normal file
@ -0,0 +1,5 @@
|
||||
Pooling Parameters
|
||||
==================
|
||||
|
||||
.. autoclass:: vllm.PoolingParams
|
||||
:members:
|
@ -138,10 +138,10 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
|
||||
|
||||
A more detailed client example can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`__.
|
||||
|
||||
OpenAI Chat API with vLLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
OpenAI Chat Completions API with vLLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
vLLM is designed to also support the OpenAI Chat API. The chat interface is a more dynamic, interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations.
|
||||
vLLM is designed to also support the OpenAI Chat Completions API. The chat interface is a more dynamic, interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations.
|
||||
|
||||
You can use the `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_ endpoint to interact with the model:
|
||||
|
||||
@ -157,7 +157,7 @@ You can use the `create chat completion <https://platform.openai.com/docs/api-re
|
||||
$ ]
|
||||
$ }'
|
||||
|
||||
Alternatively, you can use the `openai` python package:
|
||||
Alternatively, you can use the ``openai`` python package:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -134,6 +134,7 @@ Documentation
|
||||
:caption: Developer Documentation
|
||||
|
||||
dev/sampling_params
|
||||
dev/pooling_params
|
||||
dev/offline_inference/offline_index
|
||||
dev/engine/engine_index
|
||||
dev/kernel/paged_attention
|
||||
|
@ -185,7 +185,7 @@ Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruc
|
||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
||||
|
||||
.. important::
|
||||
Since OpenAI Vision API is based on `Chat Completions <https://platform.openai.com/docs/api-reference/chat>`_ API,
|
||||
Since OpenAI Vision API is based on `Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`_,
|
||||
a chat template is **required** to launch the API server.
|
||||
|
||||
Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it.
|
||||
@ -243,6 +243,10 @@ To consume the server, you can use the OpenAI client like in the example below:
|
||||
|
||||
A full code example can be found in `examples/openai_api_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_api_client_for_multimodal.py>`_.
|
||||
|
||||
.. tip::
|
||||
There is no need to place image placeholders in the text content of the API request - they are already represented by the image content.
|
||||
In fact, you can place image placeholders in the middle of the text by interleaving text and image content.
|
||||
|
||||
.. note::
|
||||
|
||||
By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable:
|
||||
@ -251,5 +255,49 @@ A full code example can be found in `examples/openai_api_client_for_multimodal.p
|
||||
|
||||
$ export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
|
||||
|
||||
.. note::
|
||||
There is no need to format the prompt in the API request since it will be handled by the server.
|
||||
Chat Embeddings API
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
vLLM's Chat Embeddings API is a superset of OpenAI's `Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`_,
|
||||
where a list of ``messages`` can be passed instead of batched ``inputs``. This enables multi-modal inputs to be passed to embedding models.
|
||||
|
||||
.. tip::
|
||||
The schema of ``messages`` is exactly the same as in Chat Completions API.
|
||||
|
||||
In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \
|
||||
--trust-remote-code --max-model-len 4096
|
||||
|
||||
.. important::
|
||||
|
||||
Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding``
|
||||
to run this model in embedding mode instead of text generation mode.
|
||||
|
||||
Since this schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import requests
|
||||
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:8000/v1/embeddings",
|
||||
json={
|
||||
"model": "TIGER-Lab/VLM2Vec-Full",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": "Represent the given image."},
|
||||
],
|
||||
}],
|
||||
"encoding_format": "float",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
print("Embedding output:", response_json["data"][0]["embedding"])
|
||||
|
@ -26,13 +26,26 @@ print(completion.choices[0].message)
|
||||
```
|
||||
|
||||
## API Reference
|
||||
Please see the [OpenAI API Reference](https://platform.openai.com/docs/api-reference) for more information on the API. We support all parameters except:
|
||||
- Chat: `tools`, and `tool_choice`.
|
||||
- Completions: `suffix`.
|
||||
|
||||
vLLM also provides experimental support for OpenAI Vision API compatible inference. See more details in [Using VLMs](../models/vlm.rst).
|
||||
We currently support the following OpenAI APIs:
|
||||
|
||||
- [Completions API](https://platform.openai.com/docs/api-reference/completions)
|
||||
- *Note: `suffix` parameter is not supported.*
|
||||
- [Chat Completions API](https://platform.openai.com/docs/api-reference/chat)
|
||||
- [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Using VLMs](../models/vlm.rst).
|
||||
- *Note: `image_url.detail` parameter is not supported.*
|
||||
- We also support `audio_url` content type for audio files.
|
||||
- Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema.
|
||||
- *TODO: Support `input_audio` content type as defined [here](https://github.com/openai/openai-python/blob/v1.52.2/src/openai/types/chat/chat_completion_content_part_input_audio_param.py).*
|
||||
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
|
||||
- [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings)
|
||||
- Instead of `inputs`, you can pass in a list of `messages` (same schema as Chat Completions API),
|
||||
which will be treated as a single prompt to the model according to its chat template.
|
||||
- This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst).
|
||||
- *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.*
|
||||
|
||||
## Extra Parameters
|
||||
|
||||
vLLM supports a set of parameters that are not part of the OpenAI API.
|
||||
In order to use them, you can pass them as extra parameters in the OpenAI client.
|
||||
Or directly merge them into the JSON payload if you are using HTTP call directly.
|
||||
@ -49,7 +62,26 @@ completion = client.chat.completions.create(
|
||||
)
|
||||
```
|
||||
|
||||
### Extra Parameters for Chat API
|
||||
### Extra Parameters for Completions API
|
||||
|
||||
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
|
||||
|
||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
:language: python
|
||||
:start-after: begin-completion-sampling-params
|
||||
:end-before: end-completion-sampling-params
|
||||
```
|
||||
|
||||
The following extra parameters are supported:
|
||||
|
||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
:language: python
|
||||
:start-after: begin-completion-extra-params
|
||||
:end-before: end-completion-extra-params
|
||||
```
|
||||
|
||||
### Extra Parameters for Chat Completions API
|
||||
|
||||
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
|
||||
|
||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
@ -66,21 +98,22 @@ The following extra parameters are supported:
|
||||
:end-before: end-chat-completion-extra-params
|
||||
```
|
||||
|
||||
### Extra Parameters for Completions API
|
||||
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
|
||||
### Extra Parameters for Embeddings API
|
||||
|
||||
The following [pooling parameters (click through to see documentation)](../dev/pooling_params.rst) are supported.
|
||||
|
||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
:language: python
|
||||
:start-after: begin-completion-sampling-params
|
||||
:end-before: end-completion-sampling-params
|
||||
:start-after: begin-embedding-pooling-params
|
||||
:end-before: end-embedding-pooling-params
|
||||
```
|
||||
|
||||
The following extra parameters are supported:
|
||||
|
||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
:language: python
|
||||
:start-after: begin-completion-extra-params
|
||||
:end-before: end-completion-extra-params
|
||||
:start-after: begin-embedding-extra-params
|
||||
:end-before: end-embedding-extra-params
|
||||
```
|
||||
|
||||
## Chat Template
|
||||
|
@ -1,7 +1,6 @@
|
||||
from http import HTTPStatus
|
||||
from typing import List
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
@ -83,10 +82,8 @@ async def client(server):
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_version(client: openai.AsyncOpenAI):
|
||||
base_url = str(client.base_url)[:-3].strip("/")
|
||||
|
||||
response = requests.get(base_url + "/version")
|
||||
async def test_show_version(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("version"))
|
||||
response.raise_for_status()
|
||||
|
||||
assert response.json() == {"version": VLLM_VERSION}
|
||||
@ -102,9 +99,7 @@ async def test_show_version(client: openai.AsyncOpenAI):
|
||||
indirect=True,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_health(client: openai.AsyncOpenAI):
|
||||
base_url = str(client.base_url)[:-3].strip("/")
|
||||
|
||||
response = requests.get(base_url + "/health")
|
||||
async def test_check_health(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("health"))
|
||||
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
@ -4,14 +4,18 @@ import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def embedding_server():
|
||||
def server():
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
@ -19,31 +23,29 @@ def embedding_server():
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--chat-template",
|
||||
DUMMY_CHAT_TEMPLATE,
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(EMBEDDING_MODEL_NAME, args) as remote_server:
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def embedding_client(embedding_server):
|
||||
async with embedding_server.get_async_client() as async_client:
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
)
|
||||
async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
|
||||
# test single embedding
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
encoding_format="float",
|
||||
@ -57,7 +59,7 @@ async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
|
||||
# test using token IDs
|
||||
input_tokens = [1, 1, 1, 1, 1]
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_tokens,
|
||||
encoding_format="float",
|
||||
@ -71,18 +73,14 @@ async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
)
|
||||
async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test List[str]
|
||||
input_texts = [
|
||||
"The cat sat on the mat.", "A feline was resting on a rug.",
|
||||
"Stars twinkle brightly in the night sky."
|
||||
]
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
encoding_format="float",
|
||||
@ -90,11 +88,14 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 3
|
||||
assert len(embeddings.data[0].embedding) == 4096
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 32
|
||||
assert embeddings.usage.total_tokens == 32
|
||||
|
||||
# test List[List[int]]
|
||||
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||
[25, 32, 64, 77]]
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_tokens,
|
||||
encoding_format="float",
|
||||
@ -108,22 +109,70 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
)
|
||||
async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_conversation_embedding(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "The cat sat on the mat.",
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "A feline was resting on a rug.",
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
}]
|
||||
|
||||
chat_response = requests.post(server.url_for("v1/embeddings"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
})
|
||||
chat_response.raise_for_status()
|
||||
chat_embeddings = chat_response.json()
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
chat_template=DUMMY_CHAT_TEMPLATE,
|
||||
add_generation_prompt=True,
|
||||
continue_final_message=False,
|
||||
tokenize=False,
|
||||
)
|
||||
completion_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=prompt,
|
||||
encoding_format="float",
|
||||
# To be consistent with chat
|
||||
extra_body={"add_special_tokens": False},
|
||||
)
|
||||
completion_embeddings = completion_response.model_dump(mode="json")
|
||||
|
||||
assert chat_embeddings.pop("id") is not None
|
||||
assert completion_embeddings.pop("id") is not None
|
||||
assert chat_embeddings.pop("created") <= completion_embeddings.pop(
|
||||
"created")
|
||||
assert chat_embeddings == completion_embeddings
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
input_texts = [
|
||||
"Hello my name is",
|
||||
"The best thing about vLLM is that it supports many different models"
|
||||
]
|
||||
|
||||
responses_float = await embedding_client.embeddings.create(
|
||||
input=input_texts, model=model_name, encoding_format="float")
|
||||
responses_float = await client.embeddings.create(input=input_texts,
|
||||
model=model_name,
|
||||
encoding_format="float")
|
||||
|
||||
responses_base64 = await embedding_client.embeddings.create(
|
||||
input=input_texts, model=model_name, encoding_format="base64")
|
||||
responses_base64 = await client.embeddings.create(input=input_texts,
|
||||
model=model_name,
|
||||
encoding_format="base64")
|
||||
|
||||
decoded_responses_base64_data = []
|
||||
for data in responses_base64.data:
|
||||
@ -137,8 +186,8 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
1]
|
||||
|
||||
# Default response is float32 decoded from base64 by OpenAI Client
|
||||
responses_default = await embedding_client.embeddings.create(
|
||||
input=input_texts, model=model_name)
|
||||
responses_default = await client.embeddings.create(input=input_texts,
|
||||
model=model_name)
|
||||
|
||||
assert responses_float.data[0].embedding == responses_default.data[
|
||||
0].embedding
|
||||
@ -147,18 +196,15 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
)
|
||||
async def test_single_embedding_truncation(
|
||||
embedding_client: openai.AsyncOpenAI, model_name: str):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_embedding_truncation(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
input_texts = [
|
||||
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
||||
]
|
||||
|
||||
# test single embedding
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
extra_body={"truncate_prompt_tokens": 10})
|
||||
@ -173,7 +219,7 @@ async def test_single_embedding_truncation(
|
||||
1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
|
||||
9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
|
||||
]
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_tokens,
|
||||
extra_body={"truncate_prompt_tokens": 10})
|
||||
@ -187,18 +233,15 @@ async def test_single_embedding_truncation(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
)
|
||||
async def test_single_embedding_truncation_invalid(
|
||||
embedding_client: openai.AsyncOpenAI, model_name: str):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
input_texts = [
|
||||
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
||||
]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
extra_body={"truncate_prompt_tokens": 8193})
|
||||
|
@ -79,9 +79,8 @@ EXPECTED_VALUES = {
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_counts(client: openai.AsyncOpenAI):
|
||||
base_url = str(client.base_url)[:-3].strip("/")
|
||||
|
||||
async def test_metrics_counts(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncClient):
|
||||
for _ in range(_NUM_REQUESTS):
|
||||
# sending a request triggers the metrics to be logged.
|
||||
await client.completions.create(
|
||||
@ -89,7 +88,7 @@ async def test_metrics_counts(client: openai.AsyncOpenAI):
|
||||
prompt=_TOKENIZED_PROMPT,
|
||||
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST)
|
||||
|
||||
response = requests.get(base_url + "/metrics")
|
||||
response = requests.get(server.url_for("metrics"))
|
||||
print(response.text)
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
@ -170,16 +169,15 @@ EXPECTED_METRICS = [
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_exist(client: openai.AsyncOpenAI):
|
||||
base_url = str(client.base_url)[:-3].strip("/")
|
||||
|
||||
async def test_metrics_exist(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncClient):
|
||||
# sending a request triggers the metrics to be logged.
|
||||
await client.completions.create(model=MODEL_NAME,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
response = requests.get(base_url + "/metrics")
|
||||
response = requests.get(server.url_for("metrics"))
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
for metric in EXPECTED_METRICS:
|
||||
|
@ -1,4 +1,3 @@
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
@ -55,9 +54,11 @@ async def client(server):
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
||||
model_name: str, tokenizer_name: str):
|
||||
base_url = str(client.base_url)[:-3].strip("/")
|
||||
async def test_tokenize_completions(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
tokenizer_name: str,
|
||||
):
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode="fast")
|
||||
|
||||
@ -65,7 +66,7 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
||||
prompt = "vllm1 This is a test prompt."
|
||||
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
||||
|
||||
response = requests.post(base_url + "/tokenize",
|
||||
response = requests.post(server.url_for("tokenize"),
|
||||
json={
|
||||
"add_special_tokens": add_special,
|
||||
"model": model_name,
|
||||
@ -86,9 +87,11 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
|
||||
tokenizer_name: str):
|
||||
base_url = str(client.base_url)[:-3].strip("/")
|
||||
async def test_tokenize_chat(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
tokenizer_name: str,
|
||||
):
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode="fast")
|
||||
|
||||
@ -121,7 +124,7 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
|
||||
tokens = tokenizer.encode(prompt,
|
||||
add_special_tokens=add_special)
|
||||
|
||||
response = requests.post(base_url + "/tokenize",
|
||||
response = requests.post(server.url_for("tokenize"),
|
||||
json={
|
||||
"add_generation_prompt":
|
||||
add_generation,
|
||||
@ -146,17 +149,18 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str,
|
||||
tokenizer_name: str):
|
||||
base_url = str(client.base_url)[:-3].strip("/")
|
||||
async def test_detokenize(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
tokenizer_name: str,
|
||||
):
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode="fast")
|
||||
|
||||
prompt = "This is a test prompt. vllm1"
|
||||
tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
|
||||
print(f"CALLING {base_url} FOR {model_name}")
|
||||
response = requests.post(base_url + "/detokenize",
|
||||
response = requests.post(server.url_for("detokenize"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"tokens": tokens
|
||||
|
94
tests/entrypoints/openai/test_vision_embedding.py
Normal file
94
tests/entrypoints/openai/test_vision_embedding.py
Normal file
@ -0,0 +1,94 @@
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
|
||||
MAXIMUM_IMAGES = 2
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
TEST_IMAGE_URLS = [
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--task",
|
||||
"embedding",
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"5",
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--limit-mm-per-prompt",
|
||||
f"image={MAXIMUM_IMAGES}",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base64_encoded_image() -> Dict[str, str]:
|
||||
return {
|
||||
image_url: encode_image_base64(fetch_image(image_url))
|
||||
for image_url in TEST_IMAGE_URLS
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
|
||||
image_url: str):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Represent the given image."
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
response = requests.post(server.url_for("v1/embeddings"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"encoding_format": "float"
|
||||
})
|
||||
response.raise_for_status()
|
||||
|
||||
embeddings = response.json()
|
||||
assert embeddings["id"] is not None
|
||||
assert len(embeddings["data"]) == 1
|
||||
assert len(embeddings["data"][0]["embedding"]) == 3072
|
||||
assert embeddings["usage"]["completion_tokens"] == 0
|
||||
assert embeddings["usage"]["prompt_tokens"] == 771
|
||||
assert embeddings["usage"]["total_tokens"] == 771
|
@ -11,7 +11,7 @@ from argparse import Namespace
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Set
|
||||
from typing import AsyncIterator, Optional, Set
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
@ -51,7 +51,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
@ -248,22 +248,27 @@ def mount_metrics(app: FastAPI):
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def chat(request: Request) -> OpenAIServingChat:
|
||||
def base(request: Request) -> OpenAIServing:
|
||||
# Reuse the existing instance
|
||||
return tokenization(request)
|
||||
|
||||
|
||||
def chat(request: Request) -> Optional[OpenAIServingChat]:
|
||||
return request.app.state.openai_serving_chat
|
||||
|
||||
|
||||
def completion(request: Request) -> OpenAIServingCompletion:
|
||||
def completion(request: Request) -> Optional[OpenAIServingCompletion]:
|
||||
return request.app.state.openai_serving_completion
|
||||
|
||||
|
||||
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def embedding(request: Request) -> OpenAIServingEmbedding:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
@ -277,7 +282,9 @@ async def health(raw_request: Request) -> Response:
|
||||
|
||||
@router.post("/tokenize")
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_tokenize(request)
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
generator = await handler.create_tokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
@ -289,7 +296,9 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
|
||||
@router.post("/detokenize")
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_detokenize(request)
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
generator = await handler.create_detokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
@ -301,7 +310,9 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models(raw_request: Request):
|
||||
models = await completion(raw_request).show_available_models()
|
||||
handler = base(raw_request)
|
||||
|
||||
models = await handler.show_available_models()
|
||||
return JSONResponse(content=models.model_dump())
|
||||
|
||||
|
||||
@ -314,9 +325,12 @@ async def show_version():
|
||||
@router.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
handler = chat(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Chat Completions API")
|
||||
|
||||
generator = await chat(raw_request).create_chat_completion(
|
||||
request, raw_request)
|
||||
generator = await handler.create_chat_completion(request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
@ -330,8 +344,12 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
generator = await completion(raw_request).create_completion(
|
||||
request, raw_request)
|
||||
handler = completion(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Completions API")
|
||||
|
||||
generator = await handler.create_completion(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
@ -343,8 +361,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
generator = await embedding(raw_request).create_embedding(
|
||||
request, raw_request)
|
||||
handler = embedding(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Embeddings API")
|
||||
|
||||
generator = await handler.create_embedding(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
@ -382,12 +404,10 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
@router.post("/v1/load_lora_adapter")
|
||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await completion(raw_request).load_lora_adapter(request)
|
||||
for route in [chat, completion, embedding]:
|
||||
handler = route(raw_request)
|
||||
if handler is not None:
|
||||
response = await handler.load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
@ -397,12 +417,10 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
@router.post("/v1/unload_lora_adapter")
|
||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await completion(raw_request).unload_lora_adapter(request)
|
||||
for route in [chat, completion, embedding]:
|
||||
handler = route(raw_request)
|
||||
if handler is not None:
|
||||
response = await handler.unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
@ -501,7 +519,8 @@ def init_app_state(
|
||||
chat_template=args.chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser)
|
||||
tool_parser=args.tool_call_parser,
|
||||
) if model_config.task == "generate" else None
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
@ -510,13 +529,14 @@ def init_app_state(
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
) if model_config.task == "generate" else None
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
chat_template=args.chat_template,
|
||||
) if model_config.task == "embedding" else None
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
|
@ -708,7 +708,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingRequest(OpenAIBaseModel):
|
||||
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
model: str
|
||||
@ -720,10 +720,15 @@ class EmbeddingRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: begin-embedding-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
|
||||
# doc: end-embedding-pooling-params
|
||||
|
||||
# doc: begin-embedding-extra-params
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
@ -737,6 +742,82 @@ class EmbeddingRequest(OpenAIBaseModel):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
|
||||
# doc: begin-chat-embedding-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
# doc: end-chat-embedding-pooling-params
|
||||
|
||||
# doc: begin-chat-embedding-extra-params
|
||||
add_generation_prompt: bool = Field(
|
||||
default=True,
|
||||
description=
|
||||
("If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."),
|
||||
)
|
||||
continue_final_message: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
("If this is set, the chat will be formatted so that the final "
|
||||
"message in the chat is open-ended, without any EOS tokens. The "
|
||||
"model will continue this message rather than starting a new one. "
|
||||
"This allows you to \"prefill\" part of the model's response for it. "
|
||||
"Cannot be used at the same time as `add_generation_prompt`."),
|
||||
)
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."),
|
||||
)
|
||||
chat_template: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."),
|
||||
)
|
||||
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
# doc: end-chat-embedding-extra-params
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_generation_prompt(cls, data):
|
||||
if data.get("continue_final_message") and data.get(
|
||||
"add_generation_prompt"):
|
||||
raise ValueError("Cannot set both `continue_final_message` and "
|
||||
"`add_generation_prompt` to True.")
|
||||
return data
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
||||
|
||||
|
||||
class CompletionLogProbs(OpenAIBaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
@ -799,7 +880,7 @@ class EmbeddingResponseData(OpenAIBaseModel):
|
||||
|
||||
|
||||
class EmbeddingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
|
@ -217,13 +217,14 @@ async def main(args):
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
)
|
||||
) if model_config.task == "generate" else None
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
chat_template=None,
|
||||
) if model_config.task == "embedding" else None
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
@ -240,14 +241,31 @@ async def main(args):
|
||||
|
||||
# Determine the type of request and run it.
|
||||
if request.url == "/v1/chat/completions":
|
||||
handler_fn = (None if openai_serving_chat is None else
|
||||
openai_serving_chat.create_chat_completion)
|
||||
if handler_fn is None:
|
||||
response_futures.append(
|
||||
run_request(openai_serving_chat.create_chat_completion,
|
||||
request, tracker))
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=
|
||||
"The model does not support Chat Completions API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/embeddings":
|
||||
handler_fn = (None if openai_serving_embedding is None else
|
||||
openai_serving_embedding.create_embedding)
|
||||
if handler_fn is None:
|
||||
response_futures.append(
|
||||
run_request(openai_serving_embedding.create_embedding, request,
|
||||
tracker))
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Embeddings API",
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
else:
|
||||
response_futures.append(
|
||||
|
@ -10,11 +10,7 @@ from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
load_chat_template,
|
||||
parse_chat_messages_futures)
|
||||
from vllm.entrypoints.chat_utils import ConversationMessage, load_chat_template
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||
@ -27,16 +23,12 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath,
|
||||
TextTokensPrompt)
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import iterate_with_cancellation
|
||||
|
||||
@ -94,12 +86,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
|
||||
ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
"""
|
||||
Chat Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
for the API specification. This API mimics the OpenAI
|
||||
ChatCompletion API.
|
||||
|
||||
Chat Completion API.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
@ -118,49 +110,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
conversation, mm_data_future = parse_chat_messages_futures(
|
||||
request.messages, model_config, tokenizer)
|
||||
|
||||
tool_dicts = None if request.tools is None else [
|
||||
tool.model_dump() for tool in request.tools
|
||||
]
|
||||
|
||||
prompt: Union[str, List[int]]
|
||||
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
||||
if is_mistral_tokenizer:
|
||||
prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
else:
|
||||
prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error in applying chat template from request")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
try:
|
||||
mm_data = await mm_data_future
|
||||
except Exception as e:
|
||||
logger.exception("Error in loading multi-modal data")
|
||||
return self.create_error_response(str(e))
|
||||
tool_parser = self.tool_parser
|
||||
|
||||
# validation for OpenAI tools
|
||||
# tool_choice = "required" is not supported
|
||||
@ -168,45 +119,55 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"tool_choice = \"required\" is not supported!")
|
||||
|
||||
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
|
||||
self.enable_auto_tools and self.tool_parser is not None):
|
||||
if (request.tool_choice == "auto" and
|
||||
not (self.enable_auto_tools and tool_parser is not None)
|
||||
and not isinstance(tokenizer, MistralTokenizer)):
|
||||
# for hf tokenizers, "auto" tools requires
|
||||
# --enable-auto-tool-choice and --tool-call-parser
|
||||
return self.create_error_response(
|
||||
"\"auto\" tool choice requires "
|
||||
"--enable-auto-tool-choice and --tool-call-parser to be set")
|
||||
"--enable-auto-tool-choice and --tool-call-parser to be set"
|
||||
)
|
||||
|
||||
request_id = f"chat-{request.request_id}"
|
||||
tool_dicts = None if request.tools is None else [
|
||||
tool.model_dump() for tool in request.tools
|
||||
]
|
||||
|
||||
(
|
||||
conversation,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tool_dicts=tool_dicts,
|
||||
documents=request.documents,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
tool_parser=tool_parser,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
request_id = f"chatcmpl-{request.request_id}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
if self.enable_auto_tools and self.tool_parser:
|
||||
request = self.tool_parser(tokenizer).adjust_request(
|
||||
request=request)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
assert isinstance(prompt, list) and isinstance(
|
||||
prompt[0], int
|
||||
), "Prompt has to be either a string or a list of token ids"
|
||||
prompt_inputs = TextTokensPrompt(
|
||||
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
|
||||
|
||||
assert prompt_inputs is not None
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
prompt_inputs["prompt_token_ids"])
|
||||
engine_prompt["prompt_token_ids"])
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
@ -215,35 +176,24 @@ class OpenAIServingChat(OpenAIServing):
|
||||
default_max_tokens)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
prompt_inputs,
|
||||
request_prompts[i],
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
engine_inputs = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||
if mm_data is not None:
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
is_tracing_enabled = (await
|
||||
self.engine_client.is_tracing_enabled())
|
||||
trace_headers = None
|
||||
if is_tracing_enabled and raw_request:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
if (not is_tracing_enabled and raw_request
|
||||
and contains_trace_headers(raw_request.headers)):
|
||||
log_tracing_disabled_warning()
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
result_generator = self.engine_client.beam_search(
|
||||
prompt=engine_inputs,
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt=engine_prompt,
|
||||
model_config=self.model_config,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
)
|
||||
else:
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_inputs,
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
@ -251,10 +201,15 @@ class OpenAIServingChat(OpenAIServing):
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
assert len(generators) == 1
|
||||
result_generator, = generators
|
||||
|
||||
if raw_request:
|
||||
result_generator = iterate_with_cancellation(
|
||||
result_generator, raw_request.is_disconnected)
|
||||
@ -626,6 +581,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
final_res = res
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
assert final_res is not None
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
||||
Optional)
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple, Union, cast
|
||||
|
||||
@ -30,18 +29,11 @@ from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TypeTokenIDs = List[int]
|
||||
TypeTopLogProbs = List[Optional[Dict[int, float]]]
|
||||
TypeCreateLogProbsFn = Callable[
|
||||
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
@ -101,8 +93,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
@ -111,19 +101,24 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
prompts = list(
|
||||
self._tokenize_prompt_input_or_inputs(
|
||||
request_prompts, engine_prompts = self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
request.prompt,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
))
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, prompt_inputs in enumerate(prompts):
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
prompt_inputs["prompt_token_ids"])
|
||||
engine_prompt["prompt_token_ids"])
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
@ -134,36 +129,24 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
prompt_inputs,
|
||||
request_prompts[i],
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
is_tracing_enabled = (await
|
||||
self.engine_client.is_tracing_enabled())
|
||||
trace_headers = None
|
||||
if is_tracing_enabled:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
if not is_tracing_enabled and contains_trace_headers(
|
||||
raw_request.headers):
|
||||
log_tracing_disabled_warning()
|
||||
trace_headers = (await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt={
|
||||
"prompt_token_ids":
|
||||
prompt_inputs["prompt_token_ids"]
|
||||
},
|
||||
prompt=engine_prompt,
|
||||
model_config=self.model_config,
|
||||
request_id=request_id,
|
||||
params=sampling_params,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
{
|
||||
"prompt_token_ids":
|
||||
prompt_inputs["prompt_token_ids"]
|
||||
},
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
@ -180,6 +163,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
result_generator = merge_async_iterators(
|
||||
*generators, is_cancelled=raw_request.is_disconnected)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use
|
||||
# beam search.
|
||||
@ -195,16 +180,22 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=len(prompts),
|
||||
num_prompts=num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
request_metadata=request_metadata)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
||||
final_res_batch: List[Optional[RequestOutput]] = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
try:
|
||||
for i, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
|
||||
@ -212,7 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
final_res.prompt = prompts[i]["prompt"]
|
||||
final_res.prompt = request_prompts[i]["prompt"]
|
||||
|
||||
final_res_batch_checked = cast(List[RequestOutput],
|
||||
final_res_batch)
|
||||
@ -226,8 +217,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
@ -9,8 +9,10 @@ from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse, UsageInfo)
|
||||
@ -21,8 +23,6 @@ from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TypeTokenIDs = List[int]
|
||||
|
||||
|
||||
def _get_embedding(
|
||||
output: EmbeddingOutput,
|
||||
@ -76,6 +76,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
@ -83,21 +84,20 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
self._enabled = self._check_embedding_mode(
|
||||
model_config.task == "embedding")
|
||||
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
"""
|
||||
Embedding API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return self.create_error_response("Embedding API disabled")
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@ -122,8 +122,6 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
@ -132,32 +130,60 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
|
||||
if isinstance(request, EmbeddingChatRequest):
|
||||
(
|
||||
_,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
request_prompts, engine_prompts = self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
request.input,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
prompts = list(
|
||||
self._tokenize_prompt_input_or_inputs(request, tokenizer,
|
||||
request.input,
|
||||
truncate_prompt_tokens))
|
||||
|
||||
for i, prompt_inputs in enumerate(prompts):
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
prompt_inputs,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError(
|
||||
"Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
@ -171,13 +197,18 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
is_cancelled=raw_request.is_disconnected if raw_request else None,
|
||||
)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
final_res_batch = [None] * len(prompts)
|
||||
final_res_batch = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
|
||||
try:
|
||||
for final_res in final_res_batch:
|
||||
assert final_res is not None
|
||||
|
||||
@ -187,18 +218,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
response = request_output_to_embedding_response(
|
||||
final_res_batch_checked, request_id, created_time, model_name,
|
||||
encoding_format)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
def _check_embedding_mode(self, embedding_mode: bool) -> bool:
|
||||
if not embedding_mode:
|
||||
logger.warning(
|
||||
"embedding_mode is False. Embedding API will not work.")
|
||||
else:
|
||||
logger.info("Activating the server engine with embedding enabled.")
|
||||
return embedding_mode
|
||||
|
@ -2,28 +2,38 @@ import json
|
||||
import pathlib
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
|
||||
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
|
||||
Optional, Sequence, Tuple, TypedDict, Union)
|
||||
|
||||
from pydantic import Field
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages_futures)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingRequest, ErrorResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeRequest,
|
||||
UnloadLoraAdapterRequest)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
# yapf: enable
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -31,8 +41,10 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import AtomicCounter
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import AtomicCounter, is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -56,8 +68,14 @@ class LoRAModulePath:
|
||||
base_model_name: Optional[str] = None
|
||||
|
||||
|
||||
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
||||
EmbeddingRequest, TokenizeRequest]
|
||||
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
TokenizeCompletionRequest]
|
||||
|
||||
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||
TokenizeChatRequest]
|
||||
|
||||
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest]
|
||||
|
||||
|
||||
class TextTokensPrompt(TypedDict):
|
||||
@ -65,6 +83,9 @@ class TextTokensPrompt(TypedDict):
|
||||
prompt_token_ids: List[int]
|
||||
|
||||
|
||||
RequestPrompt = Union[List[int], str, TextTokensPrompt]
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
|
||||
def __init__(
|
||||
@ -246,7 +267,8 @@ class OpenAIServing:
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest doesn't have max_tokens
|
||||
if isinstance(request, EmbeddingRequest):
|
||||
if isinstance(request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest)):
|
||||
if token_num > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
@ -373,10 +395,115 @@ class OpenAIServing:
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def _preprocess_completion(
|
||||
self,
|
||||
request: CompletionLikeRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]:
|
||||
request_prompts = [
|
||||
request_prompt
|
||||
for request_prompt in self._tokenize_prompt_input_or_inputs(
|
||||
request,
|
||||
tokenizer,
|
||||
input_or_inputs,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
engine_prompts = [
|
||||
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
|
||||
for request_prompt in request_prompts
|
||||
]
|
||||
|
||||
return request_prompts, engine_prompts
|
||||
|
||||
async def _preprocess_chat(
|
||||
self,
|
||||
request: ChatLikeRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str] = None,
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tool_dicts: Optional[List[Dict[str, Any]]] = None,
|
||||
documents: Optional[List[Dict[str, str]]] = None,
|
||||
chat_template_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = False,
|
||||
) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
|
||||
List[TokensPrompt]]:
|
||||
conversation, mm_data_future = parse_chat_messages_futures(
|
||||
messages,
|
||||
self.model_config,
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
request_prompt: Union[str, List[int]]
|
||||
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
||||
if is_mistral_tokenizer:
|
||||
request_prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=messages,
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=documents,
|
||||
**(chat_template_kwargs or {}),
|
||||
)
|
||||
else:
|
||||
request_prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=documents,
|
||||
**(chat_template_kwargs or {}),
|
||||
)
|
||||
|
||||
mm_data = await mm_data_future
|
||||
|
||||
if tool_parser is not None:
|
||||
if not isinstance(request, ChatCompletionRequest):
|
||||
msg = "Tool usage is only supported for Chat Completions API"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
request = tool_parser(tokenizer).adjust_request(request=request)
|
||||
|
||||
if isinstance(request_prompt, str):
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
request_prompt,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
# For MistralTokenizer
|
||||
assert is_list_of(request_prompt, int), (
|
||||
"Prompt has to be either a string or a list of token ids")
|
||||
prompt_inputs = TextTokensPrompt(
|
||||
prompt=tokenizer.decode(request_prompt),
|
||||
prompt_token_ids=request_prompt)
|
||||
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||
if mm_data is not None:
|
||||
engine_prompt["multi_modal_data"] = mm_data
|
||||
|
||||
return conversation, [request_prompt], [engine_prompt]
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: Union[str, List[int], TextTokensPrompt],
|
||||
inputs: RequestPrompt,
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
@ -404,6 +531,20 @@ class OpenAIServing:
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
async def _get_trace_headers(
|
||||
self,
|
||||
headers: Headers,
|
||||
) -> Optional[Mapping[str, str]]:
|
||||
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
|
||||
|
||||
if is_tracing_enabled:
|
||||
return extract_trace_headers(headers)
|
||||
|
||||
if contains_trace_headers(headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_decoded_token(logprob: Logprob,
|
||||
token_id: int,
|
||||
|
@ -2,10 +2,7 @@ from typing import List, Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
load_chat_template,
|
||||
parse_chat_messages_futures)
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -20,7 +17,6 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -62,6 +58,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
request_id = f"tokn-{random_uuid()}"
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
@ -69,52 +66,43 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
prompt: Union[str, List[int]]
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
model_config = self.model_config
|
||||
|
||||
conversation, mm_data_future = parse_chat_messages_futures(
|
||||
request.messages, model_config, tokenizer)
|
||||
|
||||
mm_data = await mm_data_future
|
||||
if mm_data:
|
||||
logger.warning(
|
||||
"Multi-modal inputs are ignored during tokenization")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
prompt = apply_mistral_chat_template(
|
||||
(
|
||||
_,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
messages=request.messages,
|
||||
request.messages,
|
||||
chat_template=self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
prompt = apply_hf_chat_template(
|
||||
request_prompts, engine_prompts = self._preprocess_completion(
|
||||
request,
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
request.prompt,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
prompt = request.prompt
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
input_ids: List[int] = []
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
self._log_inputs(request_id,
|
||||
prompt,
|
||||
request_prompts[i],
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
# Silently ignore prompt adapter since it does not affect tokenization
|
||||
# Silently ignore prompt adapter since it does not affect
|
||||
# tokenization (Unlike in Embeddings API where an error is raised)
|
||||
|
||||
prompt_input = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
input_ids = prompt_input["prompt_token_ids"]
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"])
|
||||
|
||||
return TokenizeResponse(tokens=input_ids,
|
||||
count=len(input_ids),
|
||||
@ -143,9 +131,8 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for tokenization")
|
||||
# Silently ignore prompt adapter since it does not affect tokenization
|
||||
# (Unlike in Embeddings API where an error is raised)
|
||||
|
||||
prompt_input = self._tokenize_prompt_input(
|
||||
request,
|
||||
|
@ -7,7 +7,7 @@ class PoolingParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""Pooling parameters for pooling.
|
||||
"""Pooling parameters for embeddings API.
|
||||
|
||||
Attributes:
|
||||
additional_data: Any additional data needed for pooling.
|
||||
@ -16,7 +16,7 @@ class PoolingParams(
|
||||
|
||||
def clone(self) -> "PoolingParams":
|
||||
"""Returns a deep copy of the PoolingParams instance."""
|
||||
return PoolingParams(additional_data=self.additional_data, )
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"PoolingParams("
|
||||
|
Loading…
x
Reference in New Issue
Block a user