[Doc] Reorganize online pooling APIs (#11172)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-14 00:22:22 +08:00 committed by GitHub
parent 238c0d93b4
commit 0920ab9131
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 431 additions and 351 deletions

View File

@ -50,10 +50,10 @@ It returns the extracted hidden states directly, which is useful for reward mode
.. code-block:: python
llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", task="reward")
output, = llm.encode("Hello, my name is")
(output,) = llm.encode("Hello, my name is")
data = output.outputs.data
print(f"Prompt: {prompt!r} | Data: {data!r}")
print(f"Data: {data!r}")
``LLM.embed``
^^^^^^^^^^^^^
@ -64,7 +64,7 @@ It is primarily designed for embedding models.
.. code-block:: python
llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed")
output, = llm.embed("Hello, my name is")
(output,) = llm.embed("Hello, my name is")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
@ -80,7 +80,7 @@ It is primarily designed for classification models.
.. code-block:: python
llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", task="classify")
output, = llm.classify("Hello, my name is")
(output,) = llm.classify("Hello, my name is")
probs = output.outputs.probs
print(f"Class Probabilities: {probs!r} (size={len(probs)})")
@ -102,8 +102,8 @@ These types of models serve as rerankers between candidate query-document pairs
.. code-block:: python
llm = LLM(model="BAAI/bge-reranker-v2-m3", task="score")
output, = llm.score("What is the capital of France?",
"The capital of Brazil is Brasilia.")
(output,) = llm.score("What is the capital of France?",
"The capital of Brazil is Brasilia.")
score = output.outputs.score
print(f"Score: {score}")
@ -119,7 +119,7 @@ Please click on the above link for more details on how to launch the server.
Embeddings API
^^^^^^^^^^^^^^
Our Embeddings API is similar to ``LLM.encode``, accepting both text and :ref:`multi-modal inputs <multimodal_inputs>`.
Our Embeddings API is similar to ``LLM.embed``, accepting both text and :ref:`multi-modal inputs <multimodal_inputs>`.
The text-only API is compatible with `OpenAI Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`__
so that you can use OpenAI client to interact with it.

View File

@ -1,13 +1,13 @@
# OpenAI Compatible Server
vLLM provides an HTTP server that implements OpenAI's [Completions](https://platform.openai.com/docs/api-reference/completions) and [Chat](https://platform.openai.com/docs/api-reference/chat) API.
vLLM provides an HTTP server that implements OpenAI's [Completions](https://platform.openai.com/docs/api-reference/completions) and [Chat](https://platform.openai.com/docs/api-reference/chat) API, and more!
You can start the server using Python, or using [Docker](deploying_with_docker.rst):
You can start the server via the [`vllm serve`](#vllm-serve) command, or through [Docker](deploying_with_docker.rst):
```bash
vllm serve NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123
```
To call the server, you can use the official OpenAI Python client library, or any other HTTP client.
To call the server, you can use the [official OpenAI Python client](https://github.com/openai/openai-python), or any other HTTP client.
```python
from openai import OpenAI
client = OpenAI(
@ -25,265 +25,32 @@ completion = client.chat.completions.create(
print(completion.choices[0].message)
```
## API Reference
## Supported APIs
We currently support the following OpenAI APIs:
- [Completions API](https://platform.openai.com/docs/api-reference/completions)
- [Completions API](#completions-api) (`/v1/completions`)
- Only applicable to [text generation models](../models/generative_models.rst) (`--task generate`).
- *Note: `suffix` parameter is not supported.*
- [Chat Completions API](https://platform.openai.com/docs/api-reference/chat)
- [Chat Completions API](#chat-api) (`/v1/chat/completions`)
- Only applicable to [text generation models](../models/generative_models.rst) (`--task generate`) with a [chat template](#chat-template).
- [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Multimodal Inputs](../usage/multimodal_inputs.rst).
- *Note: `image_url.detail` parameter is not supported.*
- We also support `audio_url` content type for audio files.
- Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema.
- *TODO: Support `input_audio` content type as defined [here](https://github.com/openai/openai-python/blob/v1.52.2/src/openai/types/chat/chat_completion_content_part_input_audio_param.py).*
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
- [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings)
- Instead of `inputs`, you can pass in a list of `messages` (same schema as Chat Completions API),
which will be treated as a single prompt to the model according to its chat template.
- This enables multi-modal inputs to be passed to embedding models, see [this page](../usage/multimodal_inputs.rst) for details.
- *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.*
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
- Only applicable to [embedding models](../models/pooling_models.rst) (`--task embed`).
## Score API for Cross Encoder Models
In addition, we have the following custom APIs:
vLLM supports *cross encoders models* at the **/v1/score** endpoint, which is not an OpenAI API standard endpoint. You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
A ***Cross Encoder*** takes exactly two sentences / texts as input and either predicts a score or label for this sentence pair. It can for example predict the similarity of the sentence pair on a scale of 0 … 1.
### Example of usage for a pair of a string and a list of texts
In this case, the model will compare the first given text to each of the texts containing the list.
```bash
curl -X 'POST' \
'http://127.0.0.1:8000/v1/score' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"model": "BAAI/bge-reranker-v2-m3",
"text_1": "What is the capital of France?",
"text_2": [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris."
]
}'
```
Response:
```bash
{
"id": "score-request-id",
"object": "list",
"created": 693570,
"model": "BAAI/bge-reranker-v2-m3",
"data": [
{
"index": 0,
"object": "score",
"score": [
0.001094818115234375
]
},
{
"index": 1,
"object": "score",
"score": [
1
]
}
],
"usage": {}
}
```
### Example of usage for a pair of two lists of texts
In this case, the model will compare the one by one, making pairs by same index correspondent in each list.
```bash
curl -X 'POST' \
'http://127.0.0.1:8000/v1/score' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"model": "BAAI/bge-reranker-v2-m3",
"encoding_format": "float",
"text_1": [
"What is the capital of Brazil?",
"What is the capital of France?"
],
"text_2": [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris."
]
}'
```
Response:
```bash
{
"id": "score-request-id",
"object": "list",
"created": 693447,
"model": "BAAI/bge-reranker-v2-m3",
"data": [
{
"index": 0,
"object": "score",
"score": [
1
]
},
{
"index": 1,
"object": "score",
"score": [
1
]
}
],
"usage": {}
}
```
### Example of usage for a pair of two strings
In this case, the model will compare the strings of texts.
```bash
curl -X 'POST' \
'http://127.0.0.1:8000/v1/score' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"model": "BAAI/bge-reranker-v2-m3",
"encoding_format": "float",
"text_1": "What is the capital of France?",
"text_2": "The capital of France is Paris."
}'
```
Response:
```bash
{
"id": "score-request-id",
"object": "list",
"created": 693447,
"model": "BAAI/bge-reranker-v2-m3",
"data": [
{
"index": 0,
"object": "score",
"score": [
1
]
}
],
"usage": {}
}
```
## Extra Parameters
vLLM supports a set of parameters that are not part of the OpenAI API.
In order to use them, you can pass them as extra parameters in the OpenAI client.
Or directly merge them into the JSON payload if you are using HTTP call directly.
```python
completion = client.chat.completions.create(
model="NousResearch/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
],
extra_body={
"guided_choice": ["positive", "negative"]
}
)
```
### Extra HTTP Headers
Only `X-Request-Id` HTTP request header is supported for now.
```python
completion = client.chat.completions.create(
model="NousResearch/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
],
extra_headers={
"x-request-id": "sentiment-classification-00001",
}
)
print(completion._request_id)
completion = client.completions.create(
model="NousResearch/Meta-Llama-3-8B-Instruct",
prompt="A robot may not injure a human being",
extra_headers={
"x-request-id": "completion-test",
}
)
print(completion._request_id)
```
### Extra Parameters for Completions API
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-completion-sampling-params
:end-before: end-completion-sampling-params
```
The following extra parameters are supported:
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-completion-extra-params
:end-before: end-completion-extra-params
```
### Extra Parameters for Chat Completions API
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-chat-completion-sampling-params
:end-before: end-chat-completion-sampling-params
```
The following extra parameters are supported:
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-chat-completion-extra-params
:end-before: end-chat-completion-extra-params
```
### Extra Parameters for Embeddings API
The following [pooling parameters (click through to see documentation)](../dev/pooling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-embedding-pooling-params
:end-before: end-embedding-pooling-params
```
The following extra parameters are supported:
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-embedding-extra-params
:end-before: end-embedding-extra-params
```
- [Tokenizer API](#tokenizer-api) (`/tokenize`, `/detokenize`)
- Applicable to any model with a tokenizer.
- [Score API](#score-api) (`/score`)
- Only applicable to [cross-encoder models](../models/pooling_models.rst) (`--task score`).
(chat-template)=
## Chat Template
In order for the language model to support chat protocol, vLLM requires the model to include
@ -329,7 +96,56 @@ the detected format, which can be one of:
If the result is not what you expect, you can set the `--chat-template-content-format` CLI argument
to override which format to use.
## Command line arguments for the server
## Extra Parameters
vLLM supports a set of parameters that are not part of the OpenAI API.
In order to use them, you can pass them as extra parameters in the OpenAI client.
Or directly merge them into the JSON payload if you are using HTTP call directly.
```python
completion = client.chat.completions.create(
model="NousResearch/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
],
extra_body={
"guided_choice": ["positive", "negative"]
}
)
```
## Extra HTTP Headers
Only `X-Request-Id` HTTP request header is supported for now.
```python
completion = client.chat.completions.create(
model="NousResearch/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
],
extra_headers={
"x-request-id": "sentiment-classification-00001",
}
)
print(completion._request_id)
completion = client.completions.create(
model="NousResearch/Meta-Llama-3-8B-Instruct",
prompt="A robot may not injure a human being",
extra_headers={
"x-request-id": "completion-test",
}
)
print(completion._request_id)
```
## CLI Reference
(vllm-serve)=
### `vllm serve`
The `vllm serve` command is used to launch the OpenAI-compatible server.
```{argparse}
:module: vllm.entrypoints.openai.cli_args
@ -337,12 +153,10 @@ to override which format to use.
:prog: vllm serve
```
#### Configuration file
### Config file
The `serve` module can also accept arguments from a config file in
`yaml` format. The arguments in the yaml must be specified using the
long form of the argument outlined [here](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server):
You can load CLI arguments via a [YAML](https://yaml.org/) config file.
The argument names must be the long form of those outlined [above](#vllm-serve).
For example:
@ -354,10 +168,268 @@ port: 6379
uvicorn-log-level: "info"
```
To use the above config file:
```bash
$ vllm serve SOME_MODEL --config config.yaml
```
---
**NOTE**
In case an argument is supplied simultaneously using command line and the config file, the value from the commandline will take precedence.
```{note}
In case an argument is supplied simultaneously using command line and the config file, the value from the command line will take precedence.
The order of priorities is `command line > config file values > defaults`.
```
## API Reference
(completions-api)=
### Completions API
Refer to [OpenAI's API reference](https://platform.openai.com/docs/api-reference/completions) for more details.
#### Extra parameters
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-completion-sampling-params
:end-before: end-completion-sampling-params
```
The following extra parameters are supported:
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-completion-extra-params
:end-before: end-completion-extra-params
```
(chat-api)=
### Chat Completions API
Refer to [OpenAI's API reference](https://platform.openai.com/docs/api-reference/chat) for more details.
#### Extra parameters
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-chat-completion-sampling-params
:end-before: end-chat-completion-sampling-params
```
The following extra parameters are supported:
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-chat-completion-extra-params
:end-before: end-chat-completion-extra-params
```
(embeddings-api)=
### Embeddings API
Refer to [OpenAI's API reference](https://platform.openai.com/docs/api-reference/embeddings) for more details.
If the model has a [chat template](#chat-template), you can replace `inputs` with a list of `messages` (same schema as [Chat Completions API](#chat-api))
which will be treated as a single prompt to the model.
```{tip}
This enables multi-modal inputs to be passed to embedding models, see [this page](../usage/multimodal_inputs.rst) for details.
```
#### Extra parameters
The following [pooling parameters (click through to see documentation)](../dev/pooling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-embedding-pooling-params
:end-before: end-embedding-pooling-params
```
The following extra parameters are supported by default:
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-embedding-extra-params
:end-before: end-embedding-extra-params
```
For chat-like input (i.e. if `messages` is passed), these extra parameters are supported instead:
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-chat-embedding-extra-params
:end-before: end-chat-embedding-extra-params
```
(tokenizer-api)=
### Tokenizer API
The Tokenizer API is a simple wrapper over [HuggingFace-style tokenizers](https://huggingface.co/docs/transformers/en/main_classes/tokenizer).
It consists of two endpoints:
- `/tokenize` corresponds to calling `tokenizer.encode()`.
- `/detokenize` corresponds to calling `tokenizer.decode()`.
(score-api)=
### Score API
The Score API applies a cross-encoder model to predict scores for sentence pairs.
Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1.
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
#### Single inference
You can pass a string to both `text_1` and `text_2`, forming a single sentence pair.
Request:
```bash
curl -X 'POST' \
'http://127.0.0.1:8000/score' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"model": "BAAI/bge-reranker-v2-m3",
"encoding_format": "float",
"text_1": "What is the capital of France?",
"text_2": "The capital of France is Paris."
}'
```
Response:
```bash
{
"id": "score-request-id",
"object": "list",
"created": 693447,
"model": "BAAI/bge-reranker-v2-m3",
"data": [
{
"index": 0,
"object": "score",
"score": 1
}
],
"usage": {}
}
```
#### Batch inference
You can pass a string to `text_1` and a list to `text_2`, forming multiple sentence pairs
where each pair is built from `text_1` and a string in `text_2`.
The total number of pairs is `len(text_2)`.
Request:
```bash
curl -X 'POST' \
'http://127.0.0.1:8000/score' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"model": "BAAI/bge-reranker-v2-m3",
"text_1": "What is the capital of France?",
"text_2": [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris."
]
}'
```
Response:
```bash
{
"id": "score-request-id",
"object": "list",
"created": 693570,
"model": "BAAI/bge-reranker-v2-m3",
"data": [
{
"index": 0,
"object": "score",
"score": 0.001094818115234375
},
{
"index": 1,
"object": "score",
"score": 1
}
],
"usage": {}
}
```
You can pass a list to both `text_1` and `text_2`, forming multiple sentence pairs
where each pair is built from a string in `text_1` and the corresponding string in `text_2` (similar to `zip()`).
The total number of pairs is `len(text_2)`.
Request:
```bash
curl -X 'POST' \
'http://127.0.0.1:8000/score' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"model": "BAAI/bge-reranker-v2-m3",
"encoding_format": "float",
"text_1": [
"What is the capital of Brazil?",
"What is the capital of France?"
],
"text_2": [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris."
]
}'
```
Response:
```bash
{
"id": "score-request-id",
"object": "list",
"created": 693447,
"model": "BAAI/bge-reranker-v2-m3",
"data": [
{
"index": 0,
"object": "score",
"score": 1
},
{
"index": 1,
"object": "score",
"score": 1
}
],
"usage": {}
}
```
#### Extra parameters
The following [pooling parameters (click through to see documentation)](../dev/pooling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-score-pooling-params
:end-before: end-score-pooling-params
```
The following extra parameters are supported:
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-score-extra-params
:end-before: end-score-extra-params
```

View File

@ -345,12 +345,12 @@ Here is an end-to-end example using VLM2Vec. To serve the model:
.. code-block:: bash
vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \
vllm serve TIGER-Lab/VLM2Vec-Full --task embed \
--trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja
.. important::
Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding``
Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embed``
to run this model in embedding mode instead of text generation mode.
The custom chat template is completely different from the original one for this model,
@ -386,12 +386,12 @@ Below is another example, this time using the ``MrLight/dse-qwen2-2b-mrl-v1`` mo
.. code-block:: bash
vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embedding \
vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embed \
--trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja
.. important::
Like with VLM2Vec, we have to explicitly pass ``--task embedding``.
Like with VLM2Vec, we have to explicitly pass ``--task embed``.
Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings, which is handled
by `this custom chat template <https://github.com/vllm-project/vllm/blob/main/examples/template_dse_qwen2_vl.jinja>`__.

View File

@ -1,45 +1,48 @@
# Offline Inference with the OpenAI Batch file format
**NOTE:** This is a guide to performing batch inference using the OpenAI batch file format, **NOT** the complete Batch (REST) API.
```{important}
This is a guide to performing batch inference using the OpenAI batch file format, **not** the complete Batch (REST) API.
```
## File Format
## File Format
The OpenAI batch file format consists of a series of json objects on new lines.
The OpenAI batch file format consists of a series of json objects on new lines.
[See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/openai_example_batch.jsonl)
[See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/openai_example_batch.jsonl)
Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details.
Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details.
```{note}
We currently only support `/v1/chat/completions` and `/v1/embeddings` endpoints (completions coming soon).
```
**NOTE:** We currently only support `/v1/chat/completions` and `/v1/embeddings` endpoints (completions coming soon).
## Pre-requisites
* Ensure you are using `vllm >= 0.4.3`. You can check by running `python -c "import vllm; print(vllm.__version__)"`.
## Pre-requisites
* The examples in this document use `meta-llama/Meta-Llama-3-8B-Instruct`.
- Create a [user access token](https://huggingface.co/docs/hub/en/security-tokens)
- Install the token on your machine (Run `huggingface-cli login`).
- Get access to the gated model by [visiting the model card](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and agreeing to the terms and conditions.
## Example 1: Running with a local file
### Step 1: Create your batch file
To follow along with this example, you can download the example batch, or create your own batch file in your working directory.
```
wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl
```
Once you've created your batch file it should look like this
```
$ cat openai_example_batch.jsonl
## Example 1: Running with a local file
### Step 1: Create your batch file
To follow along with this example, you can download the example batch, or create your own batch file in your working directory.
```
wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl
```
Once you've created your batch file it should look like this
```
$ cat openai_example_batch.jsonl
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
```
### Step 2: Run the batch
```
### Step 2: Run the batch
The batch running tool is designed to be used from the command line.
@ -85,18 +88,18 @@ To integrate with cloud blob storage, we recommend using presigned urls.
### Step 1: Upload your input script
To follow along with this example, you can download the example batch, or create your own batch file in your working directory.
```
wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl
```
Once you've created your batch file it should look like this
```
$ cat openai_example_batch.jsonl
```
wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl
```
Once you've created your batch file it should look like this
```
$ cat openai_example_batch.jsonl
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
```
```
Now upload your batch file to your S3 bucket.
@ -104,7 +107,6 @@ Now upload your batch file to your S3 bucket.
aws s3 cp openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl
```
### Step 2: Generate your presigned urls
Presigned urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names.
@ -179,21 +181,19 @@ aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl -
### Step 1: Create your batch file
Add embedding requests to your batch file. The following is an example:
Add embedding requests to your batch file. The following is an example:
```
{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}
```
{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}}
```
You can even mix chat completion and embedding requests in the batch file, as long as the model you are using supports both chat completion and embeddings (note that all requests must use the same model).
You can even mix chat completion and embedding requests in the batch file, as long as the model you are using supports both chat completion and embeddings (note that all requests must use the same model).
### Step 2: Run the batch
### Step 2: Run the batch
You can run the batch using the same command as in earlier examples.
### Step 3: Check your results
You can check your results by running `cat results.jsonl`
@ -201,5 +201,5 @@ You can check your results by running `cat results.jsonl`
```
$ cat results.jsonl
{"id":"vllm-db0f71f7dec244e6bce530e0b4ef908b","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-3580bf4d4ae54d52b67eee266a6eab20","body":{"id":"embd-33ac2efa7996430184461f2e38529746","object":"list","created":444647,"model":"intfloat/e5-mistral-7b-instruct","data":[{"index":0,"object":"embedding","embedding":[0.016204833984375,0.0092010498046875,0.0018358230590820312,-0.0028228759765625,0.001422882080078125,-0.0031147003173828125,...]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0}}},"error":null}
...```
...
```

View File

@ -99,7 +99,7 @@ def dse_qwen2_vl(inp: dict):
if __name__ == '__main__':
parser = argparse.ArgumentParser(
"Script to call a specified VLM through the API. Make sure to serve "
"the model with --task embedding before running this.")
"the model with --task embed before running this.")
parser.add_argument("model",
type=str,
choices=["vlm2vec", "dse_qwen2_vl"],

View File

@ -1,14 +1,15 @@
"""Examples Python client Score for Cross Encoder Models
"""
Example online usage of Score API.
Run `vllm serve <model> --task score` to start up the server in vLLM.
"""
import argparse
import json
import pprint
import requests
def post_http_request(prompt: json, api_url: str) -> requests.Response:
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response
@ -20,20 +21,29 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3")
args = parser.parse_args()
api_url = f"http://{args.host}:{args.port}/v1/score"
api_url = f"http://{args.host}:{args.port}/score"
model_name = args.model
text_1 = "What is the capital of Brazil?"
text_2 = "The capital of Brazil is Brasilia."
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url)
print("Prompt when text_1 and text_2 are both strings:")
pprint.pprint(prompt)
print("Score Response:")
pprint.pprint(score_response.json())
text_1 = "What is the capital of France?"
text_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url)
print("Prompt for text_1 is string and text_2 is a list:")
print("Prompt when text_1 is string and text_2 is a list:")
pprint.pprint(prompt)
print("Score Response:")
pprint.pprint(score_response.data)
pprint.pprint(score_response.json())
text_1 = [
"What is the capital of Brazil?", "What is the capital of France?"
@ -43,16 +53,7 @@ if __name__ == "__main__":
]
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url)
print("Prompt for text_1 and text_2 are lists:")
print("Prompt when text_1 and text_2 are both lists:")
pprint.pprint(prompt)
print("Score Response:")
pprint.pprint(score_response.data)
text_1 = "What is the capital of Brazil?"
text_2 = "The capital of Brazil is Brasilia."
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url)
print("Prompt for text_1 and text_2 are strings:")
pprint.pprint(prompt)
print("Score Response:")
pprint.pprint(score_response.data)
pprint.pprint(score_response.json())

View File

@ -27,7 +27,7 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
score_response = requests.post(server.url_for("v1/score"),
score_response = requests.post(server.url_for("score"),
json={
"model": model_name,
"text_1": text_1,
@ -55,7 +55,7 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
score_response = requests.post(server.url_for("v1/score"),
score_response = requests.post(server.url_for("score"),
json={
"model": model_name,
"text_1": text_1,
@ -78,7 +78,7 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
text_1 = "What is the capital of France?"
text_2 = "The capital of France is Paris."
score_response = requests.post(server.url_for("v1/score"),
score_response = requests.post(server.url_for("score"),
json={
"model": model_name,
"text_1": text_1,

View File

@ -406,7 +406,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator)
@router.post("/v1/score")
@router.post("/score")
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
if handler is None:
@ -423,6 +423,15 @@ async def create_score(request: ScoreRequest, raw_request: Request):
assert_never(generator)
@router.post("/v1/score")
async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning(
"To indicate that Score API is not part of standard OpenAI API, we "
"have moved it to `/score`. Please update your client accordingly.")
return await create_score(request, raw_request)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "

View File

@ -812,10 +812,11 @@ class ScoreRequest(OpenAIBaseModel):
text_2: Union[List[str], str]
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: begin-chat-embedding-pooling-params
# doc: begin-score-pooling-params
additional_data: Optional[Any] = None
# doc: end-chat-embedding-pooling-params
# doc: end-score-pooling-params
# doc: begin-score-extra-params
priority: int = Field(
default=0,
description=(
@ -823,6 +824,8 @@ class ScoreRequest(OpenAIBaseModel):
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-score-extra-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)

View File

@ -1,12 +1,11 @@
import time
import warnings
from dataclasses import dataclass
from typing import Dict, Generic, List, Optional
from typing import Sequence as GenericSequence
from typing import Union
import torch
from typing_extensions import TypeVar
from typing_extensions import TypeVar, deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict
@ -73,13 +72,11 @@ class PoolingOutput:
(self.data == other.data).all()))
@property
@deprecated("`LLM.encode()` now stores raw outputs in the `data` "
"attribute. To return embeddings, use `LLM.embed()`. "
"To return class probabilities, use `LLM.classify()` "
"and access the `probs` attribute. ")
def embedding(self) -> list[float]:
msg = ("`LLM.encode()` now returns raw outputs. "
"To return embeddings, use `LLM.embed()`. "
"To return class probabilities, use `LLM.classify()` "
"and access the `probs` attribute. ")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return self.data.tolist()
@ -491,11 +488,9 @@ class ScoringOutput:
return f"ScoringOutput(score={self.score})"
@property
@deprecated("`LLM.score()` now returns scalar scores. "
"Please access it via the `score` attribute. ")
def embedding(self) -> list[float]:
msg = ("`LLM.score()` now returns scalar scores. "
"Please access it via the `score` attribute. ")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return [self.score]