[Frontend] Chat-based Embeddings API (#9759)
This commit is contained in:
parent
d3aa2a8b2f
commit
06386a64dd
@ -13,5 +13,7 @@ torch
|
|||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
transformers
|
transformers
|
||||||
mistral_common >= 1.3.4
|
mistral_common >= 1.3.4
|
||||||
|
aiohttp
|
||||||
|
starlette
|
||||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||||
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
@ -96,7 +96,6 @@ def setup(app):
|
|||||||
|
|
||||||
# Mock out external dependencies here, otherwise the autodoc pages may be blank.
|
# Mock out external dependencies here, otherwise the autodoc pages may be blank.
|
||||||
autodoc_mock_imports = [
|
autodoc_mock_imports = [
|
||||||
"aiohttp",
|
|
||||||
"compressed_tensors",
|
"compressed_tensors",
|
||||||
"cpuinfo",
|
"cpuinfo",
|
||||||
"cv2",
|
"cv2",
|
||||||
@ -143,6 +142,7 @@ intersphinx_mapping = {
|
|||||||
"python": ("https://docs.python.org/3", None),
|
"python": ("https://docs.python.org/3", None),
|
||||||
"typing_extensions":
|
"typing_extensions":
|
||||||
("https://typing-extensions.readthedocs.io/en/latest", None),
|
("https://typing-extensions.readthedocs.io/en/latest", None),
|
||||||
|
"aiohttp": ("https://docs.aiohttp.org/en/stable", None),
|
||||||
"pillow": ("https://pillow.readthedocs.io/en/stable", None),
|
"pillow": ("https://pillow.readthedocs.io/en/stable", None),
|
||||||
"numpy": ("https://numpy.org/doc/stable", None),
|
"numpy": ("https://numpy.org/doc/stable", None),
|
||||||
"torch": ("https://pytorch.org/docs/stable", None),
|
"torch": ("https://pytorch.org/docs/stable", None),
|
||||||
|
5
docs/source/dev/pooling_params.rst
Normal file
5
docs/source/dev/pooling_params.rst
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
Pooling Parameters
|
||||||
|
==================
|
||||||
|
|
||||||
|
.. autoclass:: vllm.PoolingParams
|
||||||
|
:members:
|
@ -138,10 +138,10 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
|
|||||||
|
|
||||||
A more detailed client example can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`__.
|
A more detailed client example can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`__.
|
||||||
|
|
||||||
OpenAI Chat API with vLLM
|
OpenAI Chat Completions API with vLLM
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
vLLM is designed to also support the OpenAI Chat API. The chat interface is a more dynamic, interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations.
|
vLLM is designed to also support the OpenAI Chat Completions API. The chat interface is a more dynamic, interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations.
|
||||||
|
|
||||||
You can use the `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_ endpoint to interact with the model:
|
You can use the `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_ endpoint to interact with the model:
|
||||||
|
|
||||||
@ -157,7 +157,7 @@ You can use the `create chat completion <https://platform.openai.com/docs/api-re
|
|||||||
$ ]
|
$ ]
|
||||||
$ }'
|
$ }'
|
||||||
|
|
||||||
Alternatively, you can use the `openai` python package:
|
Alternatively, you can use the ``openai`` python package:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -134,6 +134,7 @@ Documentation
|
|||||||
:caption: Developer Documentation
|
:caption: Developer Documentation
|
||||||
|
|
||||||
dev/sampling_params
|
dev/sampling_params
|
||||||
|
dev/pooling_params
|
||||||
dev/offline_inference/offline_index
|
dev/offline_inference/offline_index
|
||||||
dev/engine/engine_index
|
dev/engine/engine_index
|
||||||
dev/kernel/paged_attention
|
dev/kernel/paged_attention
|
||||||
|
@ -185,7 +185,7 @@ Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruc
|
|||||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
||||||
|
|
||||||
.. important::
|
.. important::
|
||||||
Since OpenAI Vision API is based on `Chat Completions <https://platform.openai.com/docs/api-reference/chat>`_ API,
|
Since OpenAI Vision API is based on `Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`_,
|
||||||
a chat template is **required** to launch the API server.
|
a chat template is **required** to launch the API server.
|
||||||
|
|
||||||
Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it.
|
Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it.
|
||||||
@ -243,6 +243,10 @@ To consume the server, you can use the OpenAI client like in the example below:
|
|||||||
|
|
||||||
A full code example can be found in `examples/openai_api_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_api_client_for_multimodal.py>`_.
|
A full code example can be found in `examples/openai_api_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_api_client_for_multimodal.py>`_.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
There is no need to place image placeholders in the text content of the API request - they are already represented by the image content.
|
||||||
|
In fact, you can place image placeholders in the middle of the text by interleaving text and image content.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable:
|
By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable:
|
||||||
@ -251,5 +255,49 @@ A full code example can be found in `examples/openai_api_client_for_multimodal.p
|
|||||||
|
|
||||||
$ export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
|
$ export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
|
||||||
|
|
||||||
.. note::
|
Chat Embeddings API
|
||||||
There is no need to format the prompt in the API request since it will be handled by the server.
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
vLLM's Chat Embeddings API is a superset of OpenAI's `Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`_,
|
||||||
|
where a list of ``messages`` can be passed instead of batched ``inputs``. This enables multi-modal inputs to be passed to embedding models.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
The schema of ``messages`` is exactly the same as in Chat Completions API.
|
||||||
|
|
||||||
|
In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \
|
||||||
|
--trust-remote-code --max-model-len 4096
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
|
||||||
|
Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding``
|
||||||
|
to run this model in embedding mode instead of text generation mode.
|
||||||
|
|
||||||
|
Since this schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:8000/v1/embeddings",
|
||||||
|
json={
|
||||||
|
"model": "TIGER-Lab/VLM2Vec-Full",
|
||||||
|
"messages": [{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}},
|
||||||
|
{"type": "text", "text": "Represent the given image."},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
"encoding_format": "float",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
response_json = response.json()
|
||||||
|
print("Embedding output:", response_json["data"][0]["embedding"])
|
||||||
|
@ -26,13 +26,26 @@ print(completion.choices[0].message)
|
|||||||
```
|
```
|
||||||
|
|
||||||
## API Reference
|
## API Reference
|
||||||
Please see the [OpenAI API Reference](https://platform.openai.com/docs/api-reference) for more information on the API. We support all parameters except:
|
|
||||||
- Chat: `tools`, and `tool_choice`.
|
|
||||||
- Completions: `suffix`.
|
|
||||||
|
|
||||||
vLLM also provides experimental support for OpenAI Vision API compatible inference. See more details in [Using VLMs](../models/vlm.rst).
|
We currently support the following OpenAI APIs:
|
||||||
|
|
||||||
|
- [Completions API](https://platform.openai.com/docs/api-reference/completions)
|
||||||
|
- *Note: `suffix` parameter is not supported.*
|
||||||
|
- [Chat Completions API](https://platform.openai.com/docs/api-reference/chat)
|
||||||
|
- [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Using VLMs](../models/vlm.rst).
|
||||||
|
- *Note: `image_url.detail` parameter is not supported.*
|
||||||
|
- We also support `audio_url` content type for audio files.
|
||||||
|
- Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema.
|
||||||
|
- *TODO: Support `input_audio` content type as defined [here](https://github.com/openai/openai-python/blob/v1.52.2/src/openai/types/chat/chat_completion_content_part_input_audio_param.py).*
|
||||||
|
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
|
||||||
|
- [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings)
|
||||||
|
- Instead of `inputs`, you can pass in a list of `messages` (same schema as Chat Completions API),
|
||||||
|
which will be treated as a single prompt to the model according to its chat template.
|
||||||
|
- This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst).
|
||||||
|
- *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.*
|
||||||
|
|
||||||
## Extra Parameters
|
## Extra Parameters
|
||||||
|
|
||||||
vLLM supports a set of parameters that are not part of the OpenAI API.
|
vLLM supports a set of parameters that are not part of the OpenAI API.
|
||||||
In order to use them, you can pass them as extra parameters in the OpenAI client.
|
In order to use them, you can pass them as extra parameters in the OpenAI client.
|
||||||
Or directly merge them into the JSON payload if you are using HTTP call directly.
|
Or directly merge them into the JSON payload if you are using HTTP call directly.
|
||||||
@ -49,7 +62,26 @@ completion = client.chat.completions.create(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Extra Parameters for Chat API
|
### Extra Parameters for Completions API
|
||||||
|
|
||||||
|
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
|
||||||
|
|
||||||
|
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||||
|
:language: python
|
||||||
|
:start-after: begin-completion-sampling-params
|
||||||
|
:end-before: end-completion-sampling-params
|
||||||
|
```
|
||||||
|
|
||||||
|
The following extra parameters are supported:
|
||||||
|
|
||||||
|
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||||
|
:language: python
|
||||||
|
:start-after: begin-completion-extra-params
|
||||||
|
:end-before: end-completion-extra-params
|
||||||
|
```
|
||||||
|
|
||||||
|
### Extra Parameters for Chat Completions API
|
||||||
|
|
||||||
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
|
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
|
||||||
|
|
||||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||||
@ -66,21 +98,22 @@ The following extra parameters are supported:
|
|||||||
:end-before: end-chat-completion-extra-params
|
:end-before: end-chat-completion-extra-params
|
||||||
```
|
```
|
||||||
|
|
||||||
### Extra Parameters for Completions API
|
### Extra Parameters for Embeddings API
|
||||||
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
|
|
||||||
|
The following [pooling parameters (click through to see documentation)](../dev/pooling_params.rst) are supported.
|
||||||
|
|
||||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||||
:language: python
|
:language: python
|
||||||
:start-after: begin-completion-sampling-params
|
:start-after: begin-embedding-pooling-params
|
||||||
:end-before: end-completion-sampling-params
|
:end-before: end-embedding-pooling-params
|
||||||
```
|
```
|
||||||
|
|
||||||
The following extra parameters are supported:
|
The following extra parameters are supported:
|
||||||
|
|
||||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||||
:language: python
|
:language: python
|
||||||
:start-after: begin-completion-extra-params
|
:start-after: begin-embedding-extra-params
|
||||||
:end-before: end-completion-extra-params
|
:end-before: end-embedding-extra-params
|
||||||
```
|
```
|
||||||
|
|
||||||
## Chat Template
|
## Chat Template
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import openai
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
import requests
|
import requests
|
||||||
@ -83,10 +82,8 @@ async def client(server):
|
|||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_show_version(client: openai.AsyncOpenAI):
|
async def test_show_version(server: RemoteOpenAIServer):
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
response = requests.get(server.url_for("version"))
|
||||||
|
|
||||||
response = requests.get(base_url + "/version")
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
assert response.json() == {"version": VLLM_VERSION}
|
assert response.json() == {"version": VLLM_VERSION}
|
||||||
@ -102,9 +99,7 @@ async def test_show_version(client: openai.AsyncOpenAI):
|
|||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_health(client: openai.AsyncOpenAI):
|
async def test_check_health(server: RemoteOpenAIServer):
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
response = requests.get(server.url_for("health"))
|
||||||
|
|
||||||
response = requests.get(base_url + "/health")
|
|
||||||
|
|
||||||
assert response.status_code == HTTPStatus.OK
|
assert response.status_code == HTTPStatus.OK
|
||||||
|
@ -4,14 +4,18 @@ import numpy as np
|
|||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||||
|
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def embedding_server():
|
def server():
|
||||||
args = [
|
args = [
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
"--dtype",
|
"--dtype",
|
||||||
@ -19,31 +23,29 @@ def embedding_server():
|
|||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
"--max-model-len",
|
"--max-model-len",
|
||||||
"8192",
|
"8192",
|
||||||
|
"--chat-template",
|
||||||
|
DUMMY_CHAT_TEMPLATE,
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(EMBEDDING_MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def embedding_client(embedding_server):
|
async def client(server):
|
||||||
async with embedding_server.get_async_client() as async_client:
|
async with server.get_async_client() as async_client:
|
||||||
yield async_client
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
"model_name",
|
async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
|
||||||
[EMBEDDING_MODEL_NAME],
|
|
||||||
)
|
|
||||||
async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
|
|
||||||
model_name: str):
|
|
||||||
input_texts = [
|
input_texts = [
|
||||||
"The chef prepared a delicious meal.",
|
"The chef prepared a delicious meal.",
|
||||||
]
|
]
|
||||||
|
|
||||||
# test single embedding
|
# test single embedding
|
||||||
embeddings = await embedding_client.embeddings.create(
|
embeddings = await client.embeddings.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
input=input_texts,
|
input=input_texts,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
@ -57,7 +59,7 @@ async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
# test using token IDs
|
# test using token IDs
|
||||||
input_tokens = [1, 1, 1, 1, 1]
|
input_tokens = [1, 1, 1, 1, 1]
|
||||||
embeddings = await embedding_client.embeddings.create(
|
embeddings = await client.embeddings.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
input=input_tokens,
|
input=input_tokens,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
@ -71,18 +73,14 @@ async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
"model_name",
|
async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
|
||||||
[EMBEDDING_MODEL_NAME],
|
|
||||||
)
|
|
||||||
async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
|
|
||||||
model_name: str):
|
|
||||||
# test List[str]
|
# test List[str]
|
||||||
input_texts = [
|
input_texts = [
|
||||||
"The cat sat on the mat.", "A feline was resting on a rug.",
|
"The cat sat on the mat.", "A feline was resting on a rug.",
|
||||||
"Stars twinkle brightly in the night sky."
|
"Stars twinkle brightly in the night sky."
|
||||||
]
|
]
|
||||||
embeddings = await embedding_client.embeddings.create(
|
embeddings = await client.embeddings.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
input=input_texts,
|
input=input_texts,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
@ -90,11 +88,14 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
|
|||||||
assert embeddings.id is not None
|
assert embeddings.id is not None
|
||||||
assert len(embeddings.data) == 3
|
assert len(embeddings.data) == 3
|
||||||
assert len(embeddings.data[0].embedding) == 4096
|
assert len(embeddings.data[0].embedding) == 4096
|
||||||
|
assert embeddings.usage.completion_tokens == 0
|
||||||
|
assert embeddings.usage.prompt_tokens == 32
|
||||||
|
assert embeddings.usage.total_tokens == 32
|
||||||
|
|
||||||
# test List[List[int]]
|
# test List[List[int]]
|
||||||
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||||
[25, 32, 64, 77]]
|
[25, 32, 64, 77]]
|
||||||
embeddings = await embedding_client.embeddings.create(
|
embeddings = await client.embeddings.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
input=input_tokens,
|
input=input_tokens,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
@ -108,22 +109,70 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
"model_name",
|
async def test_conversation_embedding(server: RemoteOpenAIServer,
|
||||||
[EMBEDDING_MODEL_NAME],
|
client: openai.AsyncOpenAI,
|
||||||
)
|
model_name: str):
|
||||||
async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
|
messages = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "The cat sat on the mat.",
|
||||||
|
}, {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "A feline was resting on a rug.",
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Stars twinkle brightly in the night sky.",
|
||||||
|
}]
|
||||||
|
|
||||||
|
chat_response = requests.post(server.url_for("v1/embeddings"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"messages": messages,
|
||||||
|
"encoding_format": "float",
|
||||||
|
})
|
||||||
|
chat_response.raise_for_status()
|
||||||
|
chat_embeddings = chat_response.json()
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
chat_template=DUMMY_CHAT_TEMPLATE,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
continue_final_message=False,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
completion_response = await client.embeddings.create(
|
||||||
|
model=model_name,
|
||||||
|
input=prompt,
|
||||||
|
encoding_format="float",
|
||||||
|
# To be consistent with chat
|
||||||
|
extra_body={"add_special_tokens": False},
|
||||||
|
)
|
||||||
|
completion_embeddings = completion_response.model_dump(mode="json")
|
||||||
|
|
||||||
|
assert chat_embeddings.pop("id") is not None
|
||||||
|
assert completion_embeddings.pop("id") is not None
|
||||||
|
assert chat_embeddings.pop("created") <= completion_embeddings.pop(
|
||||||
|
"created")
|
||||||
|
assert chat_embeddings == completion_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
|
||||||
model_name: str):
|
model_name: str):
|
||||||
input_texts = [
|
input_texts = [
|
||||||
"Hello my name is",
|
"Hello my name is",
|
||||||
"The best thing about vLLM is that it supports many different models"
|
"The best thing about vLLM is that it supports many different models"
|
||||||
]
|
]
|
||||||
|
|
||||||
responses_float = await embedding_client.embeddings.create(
|
responses_float = await client.embeddings.create(input=input_texts,
|
||||||
input=input_texts, model=model_name, encoding_format="float")
|
model=model_name,
|
||||||
|
encoding_format="float")
|
||||||
|
|
||||||
responses_base64 = await embedding_client.embeddings.create(
|
responses_base64 = await client.embeddings.create(input=input_texts,
|
||||||
input=input_texts, model=model_name, encoding_format="base64")
|
model=model_name,
|
||||||
|
encoding_format="base64")
|
||||||
|
|
||||||
decoded_responses_base64_data = []
|
decoded_responses_base64_data = []
|
||||||
for data in responses_base64.data:
|
for data in responses_base64.data:
|
||||||
@ -137,8 +186,8 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
|
|||||||
1]
|
1]
|
||||||
|
|
||||||
# Default response is float32 decoded from base64 by OpenAI Client
|
# Default response is float32 decoded from base64 by OpenAI Client
|
||||||
responses_default = await embedding_client.embeddings.create(
|
responses_default = await client.embeddings.create(input=input_texts,
|
||||||
input=input_texts, model=model_name)
|
model=model_name)
|
||||||
|
|
||||||
assert responses_float.data[0].embedding == responses_default.data[
|
assert responses_float.data[0].embedding == responses_default.data[
|
||||||
0].embedding
|
0].embedding
|
||||||
@ -147,18 +196,15 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
"model_name",
|
async def test_single_embedding_truncation(client: openai.AsyncOpenAI,
|
||||||
[EMBEDDING_MODEL_NAME],
|
model_name: str):
|
||||||
)
|
|
||||||
async def test_single_embedding_truncation(
|
|
||||||
embedding_client: openai.AsyncOpenAI, model_name: str):
|
|
||||||
input_texts = [
|
input_texts = [
|
||||||
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
||||||
]
|
]
|
||||||
|
|
||||||
# test single embedding
|
# test single embedding
|
||||||
embeddings = await embedding_client.embeddings.create(
|
embeddings = await client.embeddings.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
input=input_texts,
|
input=input_texts,
|
||||||
extra_body={"truncate_prompt_tokens": 10})
|
extra_body={"truncate_prompt_tokens": 10})
|
||||||
@ -173,7 +219,7 @@ async def test_single_embedding_truncation(
|
|||||||
1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
|
1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
|
||||||
9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
|
9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
|
||||||
]
|
]
|
||||||
embeddings = await embedding_client.embeddings.create(
|
embeddings = await client.embeddings.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
input=input_tokens,
|
input=input_tokens,
|
||||||
extra_body={"truncate_prompt_tokens": 10})
|
extra_body={"truncate_prompt_tokens": 10})
|
||||||
@ -187,18 +233,15 @@ async def test_single_embedding_truncation(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
"model_name",
|
async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
|
||||||
[EMBEDDING_MODEL_NAME],
|
model_name: str):
|
||||||
)
|
|
||||||
async def test_single_embedding_truncation_invalid(
|
|
||||||
embedding_client: openai.AsyncOpenAI, model_name: str):
|
|
||||||
input_texts = [
|
input_texts = [
|
||||||
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
||||||
]
|
]
|
||||||
|
|
||||||
with pytest.raises(openai.BadRequestError):
|
with pytest.raises(openai.BadRequestError):
|
||||||
embeddings = await embedding_client.embeddings.create(
|
embeddings = await client.embeddings.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
input=input_texts,
|
input=input_texts,
|
||||||
extra_body={"truncate_prompt_tokens": 8193})
|
extra_body={"truncate_prompt_tokens": 8193})
|
||||||
|
@ -79,9 +79,8 @@ EXPECTED_VALUES = {
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_metrics_counts(client: openai.AsyncOpenAI):
|
async def test_metrics_counts(server: RemoteOpenAIServer,
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
client: openai.AsyncClient):
|
||||||
|
|
||||||
for _ in range(_NUM_REQUESTS):
|
for _ in range(_NUM_REQUESTS):
|
||||||
# sending a request triggers the metrics to be logged.
|
# sending a request triggers the metrics to be logged.
|
||||||
await client.completions.create(
|
await client.completions.create(
|
||||||
@ -89,7 +88,7 @@ async def test_metrics_counts(client: openai.AsyncOpenAI):
|
|||||||
prompt=_TOKENIZED_PROMPT,
|
prompt=_TOKENIZED_PROMPT,
|
||||||
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST)
|
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST)
|
||||||
|
|
||||||
response = requests.get(base_url + "/metrics")
|
response = requests.get(server.url_for("metrics"))
|
||||||
print(response.text)
|
print(response.text)
|
||||||
assert response.status_code == HTTPStatus.OK
|
assert response.status_code == HTTPStatus.OK
|
||||||
|
|
||||||
@ -170,16 +169,15 @@ EXPECTED_METRICS = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_metrics_exist(client: openai.AsyncOpenAI):
|
async def test_metrics_exist(server: RemoteOpenAIServer,
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
client: openai.AsyncClient):
|
||||||
|
|
||||||
# sending a request triggers the metrics to be logged.
|
# sending a request triggers the metrics to be logged.
|
||||||
await client.completions.create(model=MODEL_NAME,
|
await client.completions.create(model=MODEL_NAME,
|
||||||
prompt="Hello, my name is",
|
prompt="Hello, my name is",
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
response = requests.get(base_url + "/metrics")
|
response = requests.get(server.url_for("metrics"))
|
||||||
assert response.status_code == HTTPStatus.OK
|
assert response.status_code == HTTPStatus.OK
|
||||||
|
|
||||||
for metric in EXPECTED_METRICS:
|
for metric in EXPECTED_METRICS:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import openai # use the official client for correctness check
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
import requests
|
import requests
|
||||||
@ -55,9 +54,11 @@ async def client(server):
|
|||||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||||
indirect=["tokenizer_name"],
|
indirect=["tokenizer_name"],
|
||||||
)
|
)
|
||||||
async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
async def test_tokenize_completions(
|
||||||
model_name: str, tokenizer_name: str):
|
server: RemoteOpenAIServer,
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
model_name: str,
|
||||||
|
tokenizer_name: str,
|
||||||
|
):
|
||||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||||
tokenizer_mode="fast")
|
tokenizer_mode="fast")
|
||||||
|
|
||||||
@ -65,7 +66,7 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
|||||||
prompt = "vllm1 This is a test prompt."
|
prompt = "vllm1 This is a test prompt."
|
||||||
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
||||||
|
|
||||||
response = requests.post(base_url + "/tokenize",
|
response = requests.post(server.url_for("tokenize"),
|
||||||
json={
|
json={
|
||||||
"add_special_tokens": add_special,
|
"add_special_tokens": add_special,
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
@ -86,9 +87,11 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
|||||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||||
indirect=["tokenizer_name"],
|
indirect=["tokenizer_name"],
|
||||||
)
|
)
|
||||||
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
|
async def test_tokenize_chat(
|
||||||
tokenizer_name: str):
|
server: RemoteOpenAIServer,
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
model_name: str,
|
||||||
|
tokenizer_name: str,
|
||||||
|
):
|
||||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||||
tokenizer_mode="fast")
|
tokenizer_mode="fast")
|
||||||
|
|
||||||
@ -121,7 +124,7 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
|
|||||||
tokens = tokenizer.encode(prompt,
|
tokens = tokenizer.encode(prompt,
|
||||||
add_special_tokens=add_special)
|
add_special_tokens=add_special)
|
||||||
|
|
||||||
response = requests.post(base_url + "/tokenize",
|
response = requests.post(server.url_for("tokenize"),
|
||||||
json={
|
json={
|
||||||
"add_generation_prompt":
|
"add_generation_prompt":
|
||||||
add_generation,
|
add_generation,
|
||||||
@ -146,17 +149,18 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
|
|||||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||||
indirect=["tokenizer_name"],
|
indirect=["tokenizer_name"],
|
||||||
)
|
)
|
||||||
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str,
|
async def test_detokenize(
|
||||||
tokenizer_name: str):
|
server: RemoteOpenAIServer,
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
model_name: str,
|
||||||
|
tokenizer_name: str,
|
||||||
|
):
|
||||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||||
tokenizer_mode="fast")
|
tokenizer_mode="fast")
|
||||||
|
|
||||||
prompt = "This is a test prompt. vllm1"
|
prompt = "This is a test prompt. vllm1"
|
||||||
tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
|
|
||||||
print(f"CALLING {base_url} FOR {model_name}")
|
response = requests.post(server.url_for("detokenize"),
|
||||||
response = requests.post(base_url + "/detokenize",
|
|
||||||
json={
|
json={
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"tokens": tokens
|
"tokens": tokens
|
||||||
|
94
tests/entrypoints/openai/test_vision_embedding.py
Normal file
94
tests/entrypoints/openai/test_vision_embedding.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
|
||||||
|
MAXIMUM_IMAGES = 2
|
||||||
|
|
||||||
|
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||||
|
TEST_IMAGE_URLS = [
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server():
|
||||||
|
args = [
|
||||||
|
"--task",
|
||||||
|
"embedding",
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"5",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--limit-mm-per-prompt",
|
||||||
|
f"image={MAXIMUM_IMAGES}",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(server):
|
||||||
|
async with server.get_async_client() as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def base64_encoded_image() -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
image_url: encode_image_base64(fetch_image(image_url))
|
||||||
|
for image_url in TEST_IMAGE_URLS
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||||
|
async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
|
||||||
|
image_url: str):
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Represent the given image."
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
response = requests.post(server.url_for("v1/embeddings"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"messages": messages,
|
||||||
|
"encoding_format": "float"
|
||||||
|
})
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
embeddings = response.json()
|
||||||
|
assert embeddings["id"] is not None
|
||||||
|
assert len(embeddings["data"]) == 1
|
||||||
|
assert len(embeddings["data"][0]["embedding"]) == 3072
|
||||||
|
assert embeddings["usage"]["completion_tokens"] == 0
|
||||||
|
assert embeddings["usage"]["prompt_tokens"] == 771
|
||||||
|
assert embeddings["usage"]["total_tokens"] == 771
|
@ -11,7 +11,7 @@ from argparse import Namespace
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import AsyncIterator, Set
|
from typing import AsyncIterator, Optional, Set
|
||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import APIRouter, FastAPI, Request
|
from fastapi import APIRouter, FastAPI, Request
|
||||||
@ -51,7 +51,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_tokenization import (
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
OpenAIServingTokenization)
|
OpenAIServingTokenization)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
@ -248,22 +248,27 @@ def mount_metrics(app: FastAPI):
|
|||||||
app.routes.append(metrics_route)
|
app.routes.append(metrics_route)
|
||||||
|
|
||||||
|
|
||||||
def chat(request: Request) -> OpenAIServingChat:
|
def base(request: Request) -> OpenAIServing:
|
||||||
|
# Reuse the existing instance
|
||||||
|
return tokenization(request)
|
||||||
|
|
||||||
|
|
||||||
|
def chat(request: Request) -> Optional[OpenAIServingChat]:
|
||||||
return request.app.state.openai_serving_chat
|
return request.app.state.openai_serving_chat
|
||||||
|
|
||||||
|
|
||||||
def completion(request: Request) -> OpenAIServingCompletion:
|
def completion(request: Request) -> Optional[OpenAIServingCompletion]:
|
||||||
return request.app.state.openai_serving_completion
|
return request.app.state.openai_serving_completion
|
||||||
|
|
||||||
|
|
||||||
|
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
|
||||||
|
return request.app.state.openai_serving_embedding
|
||||||
|
|
||||||
|
|
||||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||||
return request.app.state.openai_serving_tokenization
|
return request.app.state.openai_serving_tokenization
|
||||||
|
|
||||||
|
|
||||||
def embedding(request: Request) -> OpenAIServingEmbedding:
|
|
||||||
return request.app.state.openai_serving_embedding
|
|
||||||
|
|
||||||
|
|
||||||
def engine_client(request: Request) -> EngineClient:
|
def engine_client(request: Request) -> EngineClient:
|
||||||
return request.app.state.engine_client
|
return request.app.state.engine_client
|
||||||
|
|
||||||
@ -277,7 +282,9 @@ async def health(raw_request: Request) -> Response:
|
|||||||
|
|
||||||
@router.post("/tokenize")
|
@router.post("/tokenize")
|
||||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||||
generator = await tokenization(raw_request).create_tokenize(request)
|
handler = tokenization(raw_request)
|
||||||
|
|
||||||
|
generator = await handler.create_tokenize(request)
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
return JSONResponse(content=generator.model_dump(),
|
return JSONResponse(content=generator.model_dump(),
|
||||||
status_code=generator.code)
|
status_code=generator.code)
|
||||||
@ -289,7 +296,9 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
|
|||||||
|
|
||||||
@router.post("/detokenize")
|
@router.post("/detokenize")
|
||||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||||
generator = await tokenization(raw_request).create_detokenize(request)
|
handler = tokenization(raw_request)
|
||||||
|
|
||||||
|
generator = await handler.create_detokenize(request)
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
return JSONResponse(content=generator.model_dump(),
|
return JSONResponse(content=generator.model_dump(),
|
||||||
status_code=generator.code)
|
status_code=generator.code)
|
||||||
@ -301,7 +310,9 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
|||||||
|
|
||||||
@router.get("/v1/models")
|
@router.get("/v1/models")
|
||||||
async def show_available_models(raw_request: Request):
|
async def show_available_models(raw_request: Request):
|
||||||
models = await completion(raw_request).show_available_models()
|
handler = base(raw_request)
|
||||||
|
|
||||||
|
models = await handler.show_available_models()
|
||||||
return JSONResponse(content=models.model_dump())
|
return JSONResponse(content=models.model_dump())
|
||||||
|
|
||||||
|
|
||||||
@ -314,9 +325,12 @@ async def show_version():
|
|||||||
@router.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
async def create_chat_completion(request: ChatCompletionRequest,
|
async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
raw_request: Request):
|
raw_request: Request):
|
||||||
|
handler = chat(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
return base(raw_request).create_error_response(
|
||||||
|
message="The model does not support Chat Completions API")
|
||||||
|
|
||||||
generator = await chat(raw_request).create_chat_completion(
|
generator = await handler.create_chat_completion(request, raw_request)
|
||||||
request, raw_request)
|
|
||||||
|
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
return JSONResponse(content=generator.model_dump(),
|
return JSONResponse(content=generator.model_dump(),
|
||||||
@ -330,8 +344,12 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
|
|
||||||
@router.post("/v1/completions")
|
@router.post("/v1/completions")
|
||||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
generator = await completion(raw_request).create_completion(
|
handler = completion(raw_request)
|
||||||
request, raw_request)
|
if handler is None:
|
||||||
|
return base(raw_request).create_error_response(
|
||||||
|
message="The model does not support Completions API")
|
||||||
|
|
||||||
|
generator = await handler.create_completion(request, raw_request)
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
return JSONResponse(content=generator.model_dump(),
|
return JSONResponse(content=generator.model_dump(),
|
||||||
status_code=generator.code)
|
status_code=generator.code)
|
||||||
@ -343,8 +361,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
|
|
||||||
@router.post("/v1/embeddings")
|
@router.post("/v1/embeddings")
|
||||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||||
generator = await embedding(raw_request).create_embedding(
|
handler = embedding(raw_request)
|
||||||
request, raw_request)
|
if handler is None:
|
||||||
|
return base(raw_request).create_error_response(
|
||||||
|
message="The model does not support Embeddings API")
|
||||||
|
|
||||||
|
generator = await handler.create_embedding(request, raw_request)
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
return JSONResponse(content=generator.model_dump(),
|
return JSONResponse(content=generator.model_dump(),
|
||||||
status_code=generator.code)
|
status_code=generator.code)
|
||||||
@ -382,12 +404,10 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
|||||||
@router.post("/v1/load_lora_adapter")
|
@router.post("/v1/load_lora_adapter")
|
||||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||||
raw_request: Request):
|
raw_request: Request):
|
||||||
response = await chat(raw_request).load_lora_adapter(request)
|
for route in [chat, completion, embedding]:
|
||||||
if isinstance(response, ErrorResponse):
|
handler = route(raw_request)
|
||||||
return JSONResponse(content=response.model_dump(),
|
if handler is not None:
|
||||||
status_code=response.code)
|
response = await handler.load_lora_adapter(request)
|
||||||
|
|
||||||
response = await completion(raw_request).load_lora_adapter(request)
|
|
||||||
if isinstance(response, ErrorResponse):
|
if isinstance(response, ErrorResponse):
|
||||||
return JSONResponse(content=response.model_dump(),
|
return JSONResponse(content=response.model_dump(),
|
||||||
status_code=response.code)
|
status_code=response.code)
|
||||||
@ -397,12 +417,10 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
|||||||
@router.post("/v1/unload_lora_adapter")
|
@router.post("/v1/unload_lora_adapter")
|
||||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||||
raw_request: Request):
|
raw_request: Request):
|
||||||
response = await chat(raw_request).unload_lora_adapter(request)
|
for route in [chat, completion, embedding]:
|
||||||
if isinstance(response, ErrorResponse):
|
handler = route(raw_request)
|
||||||
return JSONResponse(content=response.model_dump(),
|
if handler is not None:
|
||||||
status_code=response.code)
|
response = await handler.unload_lora_adapter(request)
|
||||||
|
|
||||||
response = await completion(raw_request).unload_lora_adapter(request)
|
|
||||||
if isinstance(response, ErrorResponse):
|
if isinstance(response, ErrorResponse):
|
||||||
return JSONResponse(content=response.model_dump(),
|
return JSONResponse(content=response.model_dump(),
|
||||||
status_code=response.code)
|
status_code=response.code)
|
||||||
@ -501,7 +519,8 @@ def init_app_state(
|
|||||||
chat_template=args.chat_template,
|
chat_template=args.chat_template,
|
||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
enable_auto_tools=args.enable_auto_tool_choice,
|
enable_auto_tools=args.enable_auto_tool_choice,
|
||||||
tool_parser=args.tool_call_parser)
|
tool_parser=args.tool_call_parser,
|
||||||
|
) if model_config.task == "generate" else None
|
||||||
state.openai_serving_completion = OpenAIServingCompletion(
|
state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
@ -510,13 +529,14 @@ def init_app_state(
|
|||||||
prompt_adapters=args.prompt_adapters,
|
prompt_adapters=args.prompt_adapters,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
)
|
) if model_config.task == "generate" else None
|
||||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
base_model_paths,
|
base_model_paths,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
)
|
chat_template=args.chat_template,
|
||||||
|
) if model_config.task == "embedding" else None
|
||||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
|
@ -708,7 +708,7 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRequest(OpenAIBaseModel):
|
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
# https://platform.openai.com/docs/api-reference/embeddings
|
# https://platform.openai.com/docs/api-reference/embeddings
|
||||||
model: str
|
model: str
|
||||||
@ -720,10 +720,15 @@ class EmbeddingRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
# doc: begin-embedding-pooling-params
|
# doc: begin-embedding-pooling-params
|
||||||
additional_data: Optional[Any] = None
|
additional_data: Optional[Any] = None
|
||||||
|
|
||||||
# doc: end-embedding-pooling-params
|
# doc: end-embedding-pooling-params
|
||||||
|
|
||||||
# doc: begin-embedding-extra-params
|
# doc: begin-embedding-extra-params
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||||
|
"the prompt."),
|
||||||
|
)
|
||||||
priority: int = Field(
|
priority: int = Field(
|
||||||
default=0,
|
default=0,
|
||||||
description=(
|
description=(
|
||||||
@ -737,6 +742,82 @@ class EmbeddingRequest(OpenAIBaseModel):
|
|||||||
return PoolingParams(additional_data=self.additional_data)
|
return PoolingParams(additional_data=self.additional_data)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||||
|
model: str
|
||||||
|
messages: List[ChatCompletionMessageParam]
|
||||||
|
|
||||||
|
encoding_format: Literal["float", "base64"] = "float"
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
|
|
||||||
|
# doc: begin-chat-embedding-pooling-params
|
||||||
|
additional_data: Optional[Any] = None
|
||||||
|
# doc: end-chat-embedding-pooling-params
|
||||||
|
|
||||||
|
# doc: begin-chat-embedding-extra-params
|
||||||
|
add_generation_prompt: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=
|
||||||
|
("If true, the generation prompt will be added to the chat template. "
|
||||||
|
"This is a parameter used by chat template in tokenizer config of the "
|
||||||
|
"model."),
|
||||||
|
)
|
||||||
|
continue_final_message: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=
|
||||||
|
("If this is set, the chat will be formatted so that the final "
|
||||||
|
"message in the chat is open-ended, without any EOS tokens. The "
|
||||||
|
"model will continue this message rather than starting a new one. "
|
||||||
|
"This allows you to \"prefill\" part of the model's response for it. "
|
||||||
|
"Cannot be used at the same time as `add_generation_prompt`."),
|
||||||
|
)
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||||
|
"on top of what is added by the chat template. "
|
||||||
|
"For most models, the chat template takes care of adding the "
|
||||||
|
"special tokens so this should be set to false (as is the "
|
||||||
|
"default)."),
|
||||||
|
)
|
||||||
|
chat_template: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"A Jinja template to use for this conversion. "
|
||||||
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
|
"does not define one."),
|
||||||
|
)
|
||||||
|
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=("Additional kwargs to pass to the template renderer. "
|
||||||
|
"Will be accessible by the chat template."),
|
||||||
|
)
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."))
|
||||||
|
# doc: end-chat-embedding-extra-params
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_generation_prompt(cls, data):
|
||||||
|
if data.get("continue_final_message") and data.get(
|
||||||
|
"add_generation_prompt"):
|
||||||
|
raise ValueError("Cannot set both `continue_final_message` and "
|
||||||
|
"`add_generation_prompt` to True.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(additional_data=self.additional_data)
|
||||||
|
|
||||||
|
|
||||||
|
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
||||||
|
|
||||||
|
|
||||||
class CompletionLogProbs(OpenAIBaseModel):
|
class CompletionLogProbs(OpenAIBaseModel):
|
||||||
text_offset: List[int] = Field(default_factory=list)
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
@ -799,7 +880,7 @@ class EmbeddingResponseData(OpenAIBaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingResponse(OpenAIBaseModel):
|
class EmbeddingResponse(OpenAIBaseModel):
|
||||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||||
object: str = "list"
|
object: str = "list"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
|
@ -217,13 +217,14 @@ async def main(args):
|
|||||||
prompt_adapters=None,
|
prompt_adapters=None,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
)
|
) if model_config.task == "generate" else None
|
||||||
openai_serving_embedding = OpenAIServingEmbedding(
|
openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
engine,
|
engine,
|
||||||
model_config,
|
model_config,
|
||||||
base_model_paths,
|
base_model_paths,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
)
|
chat_template=None,
|
||||||
|
) if model_config.task == "embedding" else None
|
||||||
|
|
||||||
tracker = BatchProgressTracker()
|
tracker = BatchProgressTracker()
|
||||||
logger.info("Reading batch from %s...", args.input_file)
|
logger.info("Reading batch from %s...", args.input_file)
|
||||||
@ -240,14 +241,31 @@ async def main(args):
|
|||||||
|
|
||||||
# Determine the type of request and run it.
|
# Determine the type of request and run it.
|
||||||
if request.url == "/v1/chat/completions":
|
if request.url == "/v1/chat/completions":
|
||||||
|
handler_fn = (None if openai_serving_chat is None else
|
||||||
|
openai_serving_chat.create_chat_completion)
|
||||||
|
if handler_fn is None:
|
||||||
response_futures.append(
|
response_futures.append(
|
||||||
run_request(openai_serving_chat.create_chat_completion,
|
make_async_error_request_output(
|
||||||
request, tracker))
|
request,
|
||||||
|
error_msg=
|
||||||
|
"The model does not support Chat Completions API",
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
response_futures.append(run_request(handler_fn, request, tracker))
|
||||||
tracker.submitted()
|
tracker.submitted()
|
||||||
elif request.url == "/v1/embeddings":
|
elif request.url == "/v1/embeddings":
|
||||||
|
handler_fn = (None if openai_serving_embedding is None else
|
||||||
|
openai_serving_embedding.create_embedding)
|
||||||
|
if handler_fn is None:
|
||||||
response_futures.append(
|
response_futures.append(
|
||||||
run_request(openai_serving_embedding.create_embedding, request,
|
make_async_error_request_output(
|
||||||
tracker))
|
request,
|
||||||
|
error_msg="The model does not support Embeddings API",
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
response_futures.append(run_request(handler_fn, request, tracker))
|
||||||
tracker.submitted()
|
tracker.submitted()
|
||||||
else:
|
else:
|
||||||
response_futures.append(
|
response_futures.append(
|
||||||
|
@ -10,11 +10,7 @@ from fastapi import Request
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
from vllm.entrypoints.chat_utils import ConversationMessage, load_chat_template
|
||||||
apply_hf_chat_template,
|
|
||||||
apply_mistral_chat_template,
|
|
||||||
load_chat_template,
|
|
||||||
parse_chat_messages_futures)
|
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||||
@ -27,16 +23,12 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||||
LoRAModulePath,
|
LoRAModulePath,
|
||||||
OpenAIServing,
|
OpenAIServing,
|
||||||
PromptAdapterPath,
|
PromptAdapterPath)
|
||||||
TextTokensPrompt)
|
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
from vllm.inputs import TokensPrompt
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
|
||||||
log_tracing_disabled_warning)
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import iterate_with_cancellation
|
from vllm.utils import iterate_with_cancellation
|
||||||
|
|
||||||
@ -94,12 +86,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
raw_request: Optional[Request] = None,
|
raw_request: Optional[Request] = None,
|
||||||
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
|
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
|
||||||
ErrorResponse]:
|
ErrorResponse]:
|
||||||
"""Completion API similar to OpenAI's API.
|
"""
|
||||||
|
Chat Completion API similar to OpenAI's API.
|
||||||
|
|
||||||
See https://platform.openai.com/docs/api-reference/chat/create
|
See https://platform.openai.com/docs/api-reference/chat/create
|
||||||
for the API specification. This API mimics the OpenAI
|
for the API specification. This API mimics the OpenAI
|
||||||
ChatCompletion API.
|
Chat Completion API.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
error_check_ret = await self._check_model(request)
|
error_check_ret = await self._check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
@ -118,49 +110,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
prompt_adapter_request,
|
prompt_adapter_request,
|
||||||
) = self._maybe_get_adapters(request)
|
) = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
model_config = self.model_config
|
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
tool_parser = self.tool_parser
|
||||||
conversation, mm_data_future = parse_chat_messages_futures(
|
|
||||||
request.messages, model_config, tokenizer)
|
|
||||||
|
|
||||||
tool_dicts = None if request.tools is None else [
|
|
||||||
tool.model_dump() for tool in request.tools
|
|
||||||
]
|
|
||||||
|
|
||||||
prompt: Union[str, List[int]]
|
|
||||||
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
|
||||||
if is_mistral_tokenizer:
|
|
||||||
prompt = apply_mistral_chat_template(
|
|
||||||
tokenizer,
|
|
||||||
messages=request.messages,
|
|
||||||
chat_template=request.chat_template or self.chat_template,
|
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
|
||||||
continue_final_message=request.continue_final_message,
|
|
||||||
tools=tool_dicts,
|
|
||||||
documents=request.documents,
|
|
||||||
**(request.chat_template_kwargs or {}),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = apply_hf_chat_template(
|
|
||||||
tokenizer,
|
|
||||||
conversation=conversation,
|
|
||||||
chat_template=request.chat_template or self.chat_template,
|
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
|
||||||
continue_final_message=request.continue_final_message,
|
|
||||||
tools=tool_dicts,
|
|
||||||
documents=request.documents,
|
|
||||||
**(request.chat_template_kwargs or {}),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Error in applying chat template from request")
|
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
try:
|
|
||||||
mm_data = await mm_data_future
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Error in loading multi-modal data")
|
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
# validation for OpenAI tools
|
# validation for OpenAI tools
|
||||||
# tool_choice = "required" is not supported
|
# tool_choice = "required" is not supported
|
||||||
@ -168,45 +119,55 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"tool_choice = \"required\" is not supported!")
|
"tool_choice = \"required\" is not supported!")
|
||||||
|
|
||||||
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
|
if (request.tool_choice == "auto" and
|
||||||
self.enable_auto_tools and self.tool_parser is not None):
|
not (self.enable_auto_tools and tool_parser is not None)
|
||||||
|
and not isinstance(tokenizer, MistralTokenizer)):
|
||||||
# for hf tokenizers, "auto" tools requires
|
# for hf tokenizers, "auto" tools requires
|
||||||
# --enable-auto-tool-choice and --tool-call-parser
|
# --enable-auto-tool-choice and --tool-call-parser
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"\"auto\" tool choice requires "
|
"\"auto\" tool choice requires "
|
||||||
"--enable-auto-tool-choice and --tool-call-parser to be set")
|
"--enable-auto-tool-choice and --tool-call-parser to be set"
|
||||||
|
)
|
||||||
|
|
||||||
request_id = f"chat-{request.request_id}"
|
tool_dicts = None if request.tools is None else [
|
||||||
|
tool.model_dump() for tool in request.tools
|
||||||
|
]
|
||||||
|
|
||||||
|
(
|
||||||
|
conversation,
|
||||||
|
request_prompts,
|
||||||
|
engine_prompts,
|
||||||
|
) = await self._preprocess_chat(
|
||||||
|
request,
|
||||||
|
tokenizer,
|
||||||
|
request.messages,
|
||||||
|
chat_template=request.chat_template or self.chat_template,
|
||||||
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
continue_final_message=request.continue_final_message,
|
||||||
|
tool_dicts=tool_dicts,
|
||||||
|
documents=request.documents,
|
||||||
|
chat_template_kwargs=request.chat_template_kwargs,
|
||||||
|
tool_parser=tool_parser,
|
||||||
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||||
|
add_special_tokens=request.add_special_tokens,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
request_id = f"chatcmpl-{request.request_id}"
|
||||||
|
|
||||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||||
if raw_request:
|
if raw_request:
|
||||||
raw_request.state.request_metadata = request_metadata
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
|
# Schedule the request and get the result generator.
|
||||||
|
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||||
try:
|
try:
|
||||||
if self.enable_auto_tools and self.tool_parser:
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
request = self.tool_parser(tokenizer).adjust_request(
|
|
||||||
request=request)
|
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
|
||||||
prompt_inputs = self._tokenize_prompt_input(
|
|
||||||
request,
|
|
||||||
tokenizer,
|
|
||||||
prompt,
|
|
||||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
|
||||||
add_special_tokens=request.add_special_tokens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert isinstance(prompt, list) and isinstance(
|
|
||||||
prompt[0], int
|
|
||||||
), "Prompt has to be either a string or a list of token ids"
|
|
||||||
prompt_inputs = TextTokensPrompt(
|
|
||||||
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
|
|
||||||
|
|
||||||
assert prompt_inputs is not None
|
|
||||||
|
|
||||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
default_max_tokens = self.max_model_len - len(
|
default_max_tokens = self.max_model_len - len(
|
||||||
prompt_inputs["prompt_token_ids"])
|
engine_prompt["prompt_token_ids"])
|
||||||
if request.use_beam_search:
|
if request.use_beam_search:
|
||||||
sampling_params = request.to_beam_search_params(
|
sampling_params = request.to_beam_search_params(
|
||||||
default_max_tokens)
|
default_max_tokens)
|
||||||
@ -215,35 +176,24 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
default_max_tokens)
|
default_max_tokens)
|
||||||
|
|
||||||
self._log_inputs(request_id,
|
self._log_inputs(request_id,
|
||||||
prompt_inputs,
|
request_prompts[i],
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
engine_inputs = TokensPrompt(
|
trace_headers = (None if raw_request is None else await
|
||||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
self._get_trace_headers(raw_request.headers))
|
||||||
if mm_data is not None:
|
|
||||||
engine_inputs["multi_modal_data"] = mm_data
|
|
||||||
|
|
||||||
is_tracing_enabled = (await
|
|
||||||
self.engine_client.is_tracing_enabled())
|
|
||||||
trace_headers = None
|
|
||||||
if is_tracing_enabled and raw_request:
|
|
||||||
trace_headers = extract_trace_headers(raw_request.headers)
|
|
||||||
if (not is_tracing_enabled and raw_request
|
|
||||||
and contains_trace_headers(raw_request.headers)):
|
|
||||||
log_tracing_disabled_warning()
|
|
||||||
|
|
||||||
if isinstance(sampling_params, BeamSearchParams):
|
if isinstance(sampling_params, BeamSearchParams):
|
||||||
result_generator = self.engine_client.beam_search(
|
generator = self.engine_client.beam_search(
|
||||||
prompt=engine_inputs,
|
prompt=engine_prompt,
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result_generator = self.engine_client.generate(
|
generator = self.engine_client.generate(
|
||||||
engine_inputs,
|
engine_prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
request_id,
|
request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
@ -251,10 +201,15 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
generators.append(generator)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
assert len(generators) == 1
|
||||||
|
result_generator, = generators
|
||||||
|
|
||||||
if raw_request:
|
if raw_request:
|
||||||
result_generator = iterate_with_cancellation(
|
result_generator = iterate_with_cancellation(
|
||||||
result_generator, raw_request.is_disconnected)
|
result_generator, raw_request.is_disconnected)
|
||||||
@ -626,6 +581,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
final_res = res
|
final_res = res
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
|
||||||
Optional)
|
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Tuple, Union, cast
|
from typing import Tuple, Union, cast
|
||||||
|
|
||||||
@ -30,18 +29,11 @@ from vllm.logger import init_logger
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
|
||||||
log_tracing_disabled_warning)
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import merge_async_iterators, random_uuid
|
from vllm.utils import merge_async_iterators, random_uuid
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
TypeTokenIDs = List[int]
|
|
||||||
TypeTopLogProbs = List[Optional[Dict[int, float]]]
|
|
||||||
TypeCreateLogProbsFn = Callable[
|
|
||||||
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingCompletion(OpenAIServing):
|
class OpenAIServingCompletion(OpenAIServing):
|
||||||
|
|
||||||
@ -101,8 +93,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
if raw_request:
|
if raw_request:
|
||||||
raw_request.state.request_metadata = request_metadata
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
|
||||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
|
||||||
try:
|
try:
|
||||||
(
|
(
|
||||||
lora_request,
|
lora_request,
|
||||||
@ -111,19 +101,24 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
prompts = list(
|
request_prompts, engine_prompts = self._preprocess_completion(
|
||||||
self._tokenize_prompt_input_or_inputs(
|
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
request.prompt,
|
request.prompt,
|
||||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
))
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
for i, prompt_inputs in enumerate(prompts):
|
# Schedule the request and get the result generator.
|
||||||
|
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||||
|
try:
|
||||||
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
default_max_tokens = self.max_model_len - len(
|
default_max_tokens = self.max_model_len - len(
|
||||||
prompt_inputs["prompt_token_ids"])
|
engine_prompt["prompt_token_ids"])
|
||||||
if request.use_beam_search:
|
if request.use_beam_search:
|
||||||
sampling_params = request.to_beam_search_params(
|
sampling_params = request.to_beam_search_params(
|
||||||
default_max_tokens)
|
default_max_tokens)
|
||||||
@ -134,36 +129,24 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|
||||||
self._log_inputs(request_id_item,
|
self._log_inputs(request_id_item,
|
||||||
prompt_inputs,
|
request_prompts[i],
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
is_tracing_enabled = (await
|
trace_headers = (await
|
||||||
self.engine_client.is_tracing_enabled())
|
self._get_trace_headers(raw_request.headers))
|
||||||
trace_headers = None
|
|
||||||
if is_tracing_enabled:
|
|
||||||
trace_headers = extract_trace_headers(raw_request.headers)
|
|
||||||
if not is_tracing_enabled and contains_trace_headers(
|
|
||||||
raw_request.headers):
|
|
||||||
log_tracing_disabled_warning()
|
|
||||||
|
|
||||||
if isinstance(sampling_params, BeamSearchParams):
|
if isinstance(sampling_params, BeamSearchParams):
|
||||||
generator = self.engine_client.beam_search(
|
generator = self.engine_client.beam_search(
|
||||||
prompt={
|
prompt=engine_prompt,
|
||||||
"prompt_token_ids":
|
|
||||||
prompt_inputs["prompt_token_ids"]
|
|
||||||
},
|
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
generator = self.engine_client.generate(
|
generator = self.engine_client.generate(
|
||||||
{
|
engine_prompt,
|
||||||
"prompt_token_ids":
|
|
||||||
prompt_inputs["prompt_token_ids"]
|
|
||||||
},
|
|
||||||
sampling_params,
|
sampling_params,
|
||||||
request_id_item,
|
request_id_item,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
@ -180,6 +163,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
result_generator = merge_async_iterators(
|
result_generator = merge_async_iterators(
|
||||||
*generators, is_cancelled=raw_request.is_disconnected)
|
*generators, is_cancelled=raw_request.is_disconnected)
|
||||||
|
|
||||||
|
num_prompts = len(engine_prompts)
|
||||||
|
|
||||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||||
# results. In addition, we do not stream the results when use
|
# results. In addition, we do not stream the results when use
|
||||||
# beam search.
|
# beam search.
|
||||||
@ -195,16 +180,22 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
request_id,
|
request_id,
|
||||||
created_time,
|
created_time,
|
||||||
model_name,
|
model_name,
|
||||||
num_prompts=len(prompts),
|
num_prompts=num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
request_metadata=request_metadata)
|
request_metadata=request_metadata)
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
final_res_batch: List[Optional[RequestOutput]] = [None] * num_prompts
|
||||||
try:
|
try:
|
||||||
async for i, res in result_generator:
|
async for i, res in result_generator:
|
||||||
final_res_batch[i] = res
|
final_res_batch[i] = res
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return self.create_error_response("Client disconnected")
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
try:
|
||||||
for i, final_res in enumerate(final_res_batch):
|
for i, final_res in enumerate(final_res_batch):
|
||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
|
|
||||||
@ -212,7 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
# We did not pass it into vLLM engine to avoid being redundant
|
# We did not pass it into vLLM engine to avoid being redundant
|
||||||
# with the inputs token IDs
|
# with the inputs token IDs
|
||||||
if final_res.prompt is None:
|
if final_res.prompt is None:
|
||||||
final_res.prompt = prompts[i]["prompt"]
|
final_res.prompt = request_prompts[i]["prompt"]
|
||||||
|
|
||||||
final_res_batch_checked = cast(List[RequestOutput],
|
final_res_batch_checked = cast(List[RequestOutput],
|
||||||
final_res_batch)
|
final_res_batch)
|
||||||
@ -226,8 +217,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
request_metadata,
|
request_metadata,
|
||||||
)
|
)
|
||||||
except asyncio.CancelledError:
|
|
||||||
return self.create_error_response("Client disconnected")
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
@ -9,8 +9,10 @@ from typing_extensions import assert_never
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.chat_utils import load_chat_template
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||||
|
EmbeddingRequest,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
EmbeddingResponseData,
|
EmbeddingResponseData,
|
||||||
ErrorResponse, UsageInfo)
|
ErrorResponse, UsageInfo)
|
||||||
@ -21,8 +23,6 @@ from vllm.utils import merge_async_iterators, random_uuid
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
TypeTokenIDs = List[int]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_embedding(
|
def _get_embedding(
|
||||||
output: EmbeddingOutput,
|
output: EmbeddingOutput,
|
||||||
@ -76,6 +76,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
base_model_paths: List[BaseModelPath],
|
base_model_paths: List[BaseModelPath],
|
||||||
*,
|
*,
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
|
chat_template: Optional[str],
|
||||||
):
|
):
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@ -83,21 +84,20 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
lora_modules=None,
|
lora_modules=None,
|
||||||
prompt_adapters=None,
|
prompt_adapters=None,
|
||||||
request_logger=request_logger)
|
request_logger=request_logger)
|
||||||
self._enabled = self._check_embedding_mode(
|
|
||||||
model_config.task == "embedding")
|
self.chat_template = load_chat_template(chat_template)
|
||||||
|
|
||||||
async def create_embedding(
|
async def create_embedding(
|
||||||
self,
|
self,
|
||||||
request: EmbeddingRequest,
|
request: EmbeddingRequest,
|
||||||
raw_request: Optional[Request] = None,
|
raw_request: Optional[Request] = None,
|
||||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||||
"""Completion API similar to OpenAI's API.
|
"""
|
||||||
|
Embedding API similar to OpenAI's API.
|
||||||
|
|
||||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||||
for the API specification. This API mimics the OpenAI Embedding API.
|
for the API specification. This API mimics the OpenAI Embedding API.
|
||||||
"""
|
"""
|
||||||
if not self._enabled:
|
|
||||||
return self.create_error_response("Embedding API disabled")
|
|
||||||
error_check_ret = await self._check_model(request)
|
error_check_ret = await self._check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
@ -122,8 +122,6 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
"greater than max_model_len."
|
"greater than max_model_len."
|
||||||
" Please, select a smaller truncation size.")
|
" Please, select a smaller truncation size.")
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
|
||||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
|
||||||
try:
|
try:
|
||||||
(
|
(
|
||||||
lora_request,
|
lora_request,
|
||||||
@ -132,32 +130,60 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
|
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
|
if prompt_adapter_request is not None:
|
||||||
|
raise NotImplementedError("Prompt adapter is not supported "
|
||||||
|
"for embedding models")
|
||||||
|
|
||||||
|
if isinstance(request, EmbeddingChatRequest):
|
||||||
|
(
|
||||||
|
_,
|
||||||
|
request_prompts,
|
||||||
|
engine_prompts,
|
||||||
|
) = await self._preprocess_chat(
|
||||||
|
request,
|
||||||
|
tokenizer,
|
||||||
|
request.messages,
|
||||||
|
chat_template=request.chat_template or self.chat_template,
|
||||||
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
continue_final_message=request.continue_final_message,
|
||||||
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
|
add_special_tokens=request.add_special_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
request_prompts, engine_prompts = self._preprocess_completion(
|
||||||
|
request,
|
||||||
|
tokenizer,
|
||||||
|
request.input,
|
||||||
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
|
add_special_tokens=request.add_special_tokens,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
# Schedule the request and get the result generator.
|
||||||
|
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||||
|
try:
|
||||||
pooling_params = request.to_pooling_params()
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
prompts = list(
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
self._tokenize_prompt_input_or_inputs(request, tokenizer,
|
|
||||||
request.input,
|
|
||||||
truncate_prompt_tokens))
|
|
||||||
|
|
||||||
for i, prompt_inputs in enumerate(prompts):
|
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|
||||||
self._log_inputs(request_id_item,
|
self._log_inputs(request_id_item,
|
||||||
prompt_inputs,
|
request_prompts[i],
|
||||||
params=pooling_params,
|
params=pooling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
if prompt_adapter_request is not None:
|
trace_headers = (None if raw_request is None else await
|
||||||
raise NotImplementedError(
|
self._get_trace_headers(raw_request.headers))
|
||||||
"Prompt adapter is not supported "
|
|
||||||
"for embedding models")
|
|
||||||
|
|
||||||
generator = self.engine_client.encode(
|
generator = self.engine_client.encode(
|
||||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
engine_prompt,
|
||||||
pooling_params,
|
pooling_params,
|
||||||
request_id_item,
|
request_id_item,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -171,13 +197,18 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
is_cancelled=raw_request.is_disconnected if raw_request else None,
|
is_cancelled=raw_request.is_disconnected if raw_request else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_prompts = len(engine_prompts)
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||||
final_res_batch = [None] * len(prompts)
|
final_res_batch = [None] * num_prompts
|
||||||
try:
|
try:
|
||||||
async for i, res in result_generator:
|
async for i, res in result_generator:
|
||||||
final_res_batch[i] = res
|
final_res_batch[i] = res
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return self.create_error_response("Client disconnected")
|
||||||
|
|
||||||
|
try:
|
||||||
for final_res in final_res_batch:
|
for final_res in final_res_batch:
|
||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
|
|
||||||
@ -187,18 +218,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
response = request_output_to_embedding_response(
|
response = request_output_to_embedding_response(
|
||||||
final_res_batch_checked, request_id, created_time, model_name,
|
final_res_batch_checked, request_id, created_time, model_name,
|
||||||
encoding_format)
|
encoding_format)
|
||||||
except asyncio.CancelledError:
|
|
||||||
return self.create_error_response("Client disconnected")
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _check_embedding_mode(self, embedding_mode: bool) -> bool:
|
|
||||||
if not embedding_mode:
|
|
||||||
logger.warning(
|
|
||||||
"embedding_mode is False. Embedding API will not work.")
|
|
||||||
else:
|
|
||||||
logger.info("Activating the server engine with embedding enabled.")
|
|
||||||
return embedding_mode
|
|
||||||
|
@ -2,28 +2,38 @@ import json
|
|||||||
import pathlib
|
import pathlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
|
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
|
||||||
|
Optional, Sequence, Tuple, TypedDict, Union)
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from starlette.datastructures import Headers
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||||
|
ConversationMessage,
|
||||||
|
apply_hf_chat_template,
|
||||||
|
apply_mistral_chat_template,
|
||||||
|
parse_chat_messages_futures)
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
DetokenizeRequest,
|
DetokenizeRequest,
|
||||||
EmbeddingRequest, ErrorResponse,
|
EmbeddingChatRequest,
|
||||||
|
EmbeddingCompletionRequest,
|
||||||
|
ErrorResponse,
|
||||||
LoadLoraAdapterRequest,
|
LoadLoraAdapterRequest,
|
||||||
ModelCard, ModelList,
|
ModelCard, ModelList,
|
||||||
ModelPermission,
|
ModelPermission,
|
||||||
TokenizeChatRequest,
|
TokenizeChatRequest,
|
||||||
TokenizeCompletionRequest,
|
TokenizeCompletionRequest,
|
||||||
TokenizeRequest,
|
|
||||||
UnloadLoraAdapterRequest)
|
UnloadLoraAdapterRequest)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
from vllm.inputs import TokensPrompt
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -31,8 +41,10 @@ from vllm.pooling_params import PoolingParams
|
|||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
from vllm.utils import AtomicCounter
|
log_tracing_disabled_warning)
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
|
from vllm.utils import AtomicCounter, is_list_of
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -56,8 +68,14 @@ class LoRAModulePath:
|
|||||||
base_model_name: Optional[str] = None
|
base_model_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
||||||
EmbeddingRequest, TokenizeRequest]
|
EmbeddingCompletionRequest,
|
||||||
|
TokenizeCompletionRequest]
|
||||||
|
|
||||||
|
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||||
|
TokenizeChatRequest]
|
||||||
|
|
||||||
|
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest]
|
||||||
|
|
||||||
|
|
||||||
class TextTokensPrompt(TypedDict):
|
class TextTokensPrompt(TypedDict):
|
||||||
@ -65,6 +83,9 @@ class TextTokensPrompt(TypedDict):
|
|||||||
prompt_token_ids: List[int]
|
prompt_token_ids: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
RequestPrompt = Union[List[int], str, TextTokensPrompt]
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServing:
|
class OpenAIServing:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -246,7 +267,8 @@ class OpenAIServing:
|
|||||||
token_num = len(input_ids)
|
token_num = len(input_ids)
|
||||||
|
|
||||||
# Note: EmbeddingRequest doesn't have max_tokens
|
# Note: EmbeddingRequest doesn't have max_tokens
|
||||||
if isinstance(request, EmbeddingRequest):
|
if isinstance(request,
|
||||||
|
(EmbeddingChatRequest, EmbeddingCompletionRequest)):
|
||||||
if token_num > self.max_model_len:
|
if token_num > self.max_model_len:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"This model's maximum context length is "
|
f"This model's maximum context length is "
|
||||||
@ -373,10 +395,115 @@ class OpenAIServing:
|
|||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _preprocess_completion(
|
||||||
|
self,
|
||||||
|
request: CompletionLikeRequest,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||||
|
add_special_tokens: bool = True,
|
||||||
|
) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]:
|
||||||
|
request_prompts = [
|
||||||
|
request_prompt
|
||||||
|
for request_prompt in self._tokenize_prompt_input_or_inputs(
|
||||||
|
request,
|
||||||
|
tokenizer,
|
||||||
|
input_or_inputs,
|
||||||
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
engine_prompts = [
|
||||||
|
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
|
||||||
|
for request_prompt in request_prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
return request_prompts, engine_prompts
|
||||||
|
|
||||||
|
async def _preprocess_chat(
|
||||||
|
self,
|
||||||
|
request: ChatLikeRequest,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
add_generation_prompt: bool = True,
|
||||||
|
continue_final_message: bool = False,
|
||||||
|
tool_dicts: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
documents: Optional[List[Dict[str, str]]] = None,
|
||||||
|
chat_template_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||||
|
add_special_tokens: bool = False,
|
||||||
|
) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
|
||||||
|
List[TokensPrompt]]:
|
||||||
|
conversation, mm_data_future = parse_chat_messages_futures(
|
||||||
|
messages,
|
||||||
|
self.model_config,
|
||||||
|
tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_prompt: Union[str, List[int]]
|
||||||
|
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
||||||
|
if is_mistral_tokenizer:
|
||||||
|
request_prompt = apply_mistral_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
messages=messages,
|
||||||
|
chat_template=chat_template,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
continue_final_message=continue_final_message,
|
||||||
|
tools=tool_dicts,
|
||||||
|
documents=documents,
|
||||||
|
**(chat_template_kwargs or {}),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
request_prompt = apply_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
conversation=conversation,
|
||||||
|
chat_template=chat_template,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
continue_final_message=continue_final_message,
|
||||||
|
tools=tool_dicts,
|
||||||
|
documents=documents,
|
||||||
|
**(chat_template_kwargs or {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_data = await mm_data_future
|
||||||
|
|
||||||
|
if tool_parser is not None:
|
||||||
|
if not isinstance(request, ChatCompletionRequest):
|
||||||
|
msg = "Tool usage is only supported for Chat Completions API"
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
request = tool_parser(tokenizer).adjust_request(request=request)
|
||||||
|
|
||||||
|
if isinstance(request_prompt, str):
|
||||||
|
prompt_inputs = self._tokenize_prompt_input(
|
||||||
|
request,
|
||||||
|
tokenizer,
|
||||||
|
request_prompt,
|
||||||
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For MistralTokenizer
|
||||||
|
assert is_list_of(request_prompt, int), (
|
||||||
|
"Prompt has to be either a string or a list of token ids")
|
||||||
|
prompt_inputs = TextTokensPrompt(
|
||||||
|
prompt=tokenizer.decode(request_prompt),
|
||||||
|
prompt_token_ids=request_prompt)
|
||||||
|
|
||||||
|
engine_prompt = TokensPrompt(
|
||||||
|
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||||
|
if mm_data is not None:
|
||||||
|
engine_prompt["multi_modal_data"] = mm_data
|
||||||
|
|
||||||
|
return conversation, [request_prompt], [engine_prompt]
|
||||||
|
|
||||||
def _log_inputs(
|
def _log_inputs(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
inputs: Union[str, List[int], TextTokensPrompt],
|
inputs: RequestPrompt,
|
||||||
params: Optional[Union[SamplingParams, PoolingParams,
|
params: Optional[Union[SamplingParams, PoolingParams,
|
||||||
BeamSearchParams]],
|
BeamSearchParams]],
|
||||||
lora_request: Optional[LoRARequest],
|
lora_request: Optional[LoRARequest],
|
||||||
@ -404,6 +531,20 @@ class OpenAIServing:
|
|||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _get_trace_headers(
|
||||||
|
self,
|
||||||
|
headers: Headers,
|
||||||
|
) -> Optional[Mapping[str, str]]:
|
||||||
|
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
|
||||||
|
|
||||||
|
if is_tracing_enabled:
|
||||||
|
return extract_trace_headers(headers)
|
||||||
|
|
||||||
|
if contains_trace_headers(headers):
|
||||||
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_decoded_token(logprob: Logprob,
|
def _get_decoded_token(logprob: Logprob,
|
||||||
token_id: int,
|
token_id: int,
|
||||||
|
@ -2,10 +2,7 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
from vllm.entrypoints.chat_utils import load_chat_template
|
||||||
apply_mistral_chat_template,
|
|
||||||
load_chat_template,
|
|
||||||
parse_chat_messages_futures)
|
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -20,7 +17,6 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
|||||||
LoRAModulePath,
|
LoRAModulePath,
|
||||||
OpenAIServing)
|
OpenAIServing)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -62,6 +58,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
|
|
||||||
request_id = f"tokn-{random_uuid()}"
|
request_id = f"tokn-{random_uuid()}"
|
||||||
|
|
||||||
|
try:
|
||||||
(
|
(
|
||||||
lora_request,
|
lora_request,
|
||||||
prompt_adapter_request,
|
prompt_adapter_request,
|
||||||
@ -69,52 +66,43 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
|
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
prompt: Union[str, List[int]]
|
|
||||||
if isinstance(request, TokenizeChatRequest):
|
if isinstance(request, TokenizeChatRequest):
|
||||||
model_config = self.model_config
|
(
|
||||||
|
_,
|
||||||
conversation, mm_data_future = parse_chat_messages_futures(
|
request_prompts,
|
||||||
request.messages, model_config, tokenizer)
|
engine_prompts,
|
||||||
|
) = await self._preprocess_chat(
|
||||||
mm_data = await mm_data_future
|
request,
|
||||||
if mm_data:
|
|
||||||
logger.warning(
|
|
||||||
"Multi-modal inputs are ignored during tokenization")
|
|
||||||
|
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
|
||||||
prompt = apply_mistral_chat_template(
|
|
||||||
tokenizer,
|
tokenizer,
|
||||||
messages=request.messages,
|
request.messages,
|
||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
continue_final_message=request.continue_final_message,
|
continue_final_message=request.continue_final_message,
|
||||||
|
add_special_tokens=request.add_special_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = apply_hf_chat_template(
|
request_prompts, engine_prompts = self._preprocess_completion(
|
||||||
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
conversation=conversation,
|
request.prompt,
|
||||||
chat_template=self.chat_template,
|
add_special_tokens=request.add_special_tokens,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
|
||||||
continue_final_message=request.continue_final_message,
|
|
||||||
)
|
)
|
||||||
else:
|
except ValueError as e:
|
||||||
prompt = request.prompt
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
input_ids: List[int] = []
|
||||||
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
self._log_inputs(request_id,
|
self._log_inputs(request_id,
|
||||||
prompt,
|
request_prompts[i],
|
||||||
params=None,
|
params=None,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
# Silently ignore prompt adapter since it does not affect tokenization
|
# Silently ignore prompt adapter since it does not affect
|
||||||
|
# tokenization (Unlike in Embeddings API where an error is raised)
|
||||||
|
|
||||||
prompt_input = self._tokenize_prompt_input(
|
input_ids.extend(engine_prompt["prompt_token_ids"])
|
||||||
request,
|
|
||||||
tokenizer,
|
|
||||||
prompt,
|
|
||||||
add_special_tokens=request.add_special_tokens,
|
|
||||||
)
|
|
||||||
input_ids = prompt_input["prompt_token_ids"]
|
|
||||||
|
|
||||||
return TokenizeResponse(tokens=input_ids,
|
return TokenizeResponse(tokens=input_ids,
|
||||||
count=len(input_ids),
|
count=len(input_ids),
|
||||||
@ -143,9 +131,8 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
if prompt_adapter_request is not None:
|
# Silently ignore prompt adapter since it does not affect tokenization
|
||||||
raise NotImplementedError("Prompt adapter is not supported "
|
# (Unlike in Embeddings API where an error is raised)
|
||||||
"for tokenization")
|
|
||||||
|
|
||||||
prompt_input = self._tokenize_prompt_input(
|
prompt_input = self._tokenize_prompt_input(
|
||||||
request,
|
request,
|
||||||
|
@ -7,7 +7,7 @@ class PoolingParams(
|
|||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
array_like=True): # type: ignore[call-arg]
|
array_like=True): # type: ignore[call-arg]
|
||||||
"""Pooling parameters for pooling.
|
"""Pooling parameters for embeddings API.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
additional_data: Any additional data needed for pooling.
|
additional_data: Any additional data needed for pooling.
|
||||||
@ -16,7 +16,7 @@ class PoolingParams(
|
|||||||
|
|
||||||
def clone(self) -> "PoolingParams":
|
def clone(self) -> "PoolingParams":
|
||||||
"""Returns a deep copy of the PoolingParams instance."""
|
"""Returns a deep copy of the PoolingParams instance."""
|
||||||
return PoolingParams(additional_data=self.additional_data, )
|
return PoolingParams(additional_data=self.additional_data)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"PoolingParams("
|
return (f"PoolingParams("
|
||||||
|
Loading…
x
Reference in New Issue
Block a user