[Frontend] Support embeddings in the run_batch API (#7132)

Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Pooya Davoodi 2024-08-09 09:48:21 -07:00 committed by GitHub
parent 74af2bbd90
commit 249b88228d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 138 additions and 24 deletions

View File

@ -10,7 +10,7 @@
Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details.
**NOTE:** We currently only support to `/v1/chat/completions` endpoint (embeddings and completions coming soon).
**NOTE:** We currently only support `/v1/chat/completions` and `/v1/embeddings` endpoints (completions coming soon).
## Pre-requisites
@ -21,7 +21,7 @@
- Get access to the gated model by [visiting the model card](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and agreeing to the terms and conditions.
## Example: Running with a local file
## Example 1: Running with a local file
### Step 1: Create your batch file
@ -54,7 +54,7 @@ python -m vllm.entrypoints.openai.run_batch -i openai_example_batch.jsonl -o res
You should now have your results at `results.jsonl`. You can check your results by running `cat results.jsonl`
```
$ cat ../results.jsonl
$ cat results.jsonl
{"id":"vllm-383d1c59835645aeb2e07d004d62a826","custom_id":"request-1","response":{"id":"cmpl-61c020e54b964d5a98fa7527bfcdd378","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Hello! It's great to meet you! I'm here to help with any questions or tasks you may have. What's on your mind today?"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":25,"total_tokens":56,"completion_tokens":31}},"error":null}
{"id":"vllm-42e3d09b14b04568afa3f1797751a267","custom_id":"request-2","response":{"id":"cmpl-f44d049f6b3a42d4b2d7850bb1e31bcc","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"*silence*"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":27,"total_tokens":32,"completion_tokens":5}},"error":null}
```
@ -107,7 +107,7 @@ aws s3 cp openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl
### Step 2: Generate your presigned urls
Presigned put urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names.
Presigned urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names.
(The script is adapted from https://github.com/awsdocs/aws-doc-sdk-examples/blob/main/python/example_code/s3/s3_basics/presigned_url.py)
@ -170,3 +170,36 @@ Your results are now on S3. You can view them in your terminal by running
```
aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl -
```
## Example 4: Using embeddings endpoint
### Additional prerequisites
* Ensure you are using `vllm >= 0.5.5`.
### Step 1: Create your batch file
Add embedding requests to your batch file. The following is an example:
```
{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}}
```
You can even mix chat completion and embedding requests in the batch file, as long as the model you are using supports both chat completion and embeddings (note that all requests must use the same model).
### Step 2: Run the batch
You can run the batch using the same command as in earlier examples.
### Step 3: Check your results
You can check your results by running `cat results.jsonl`
```
$ cat results.jsonl
{"id":"vllm-db0f71f7dec244e6bce530e0b4ef908b","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-3580bf4d4ae54d52b67eee266a6eab20","body":{"id":"embd-33ac2efa7996430184461f2e38529746","object":"list","created":444647,"model":"intfloat/e5-mistral-7b-instruct","data":[{"index":0,"object":"embedding","embedding":[0.016204833984375,0.0092010498046875,0.0018358230590820312,-0.0028228759765625,0.001422882080078125,-0.0031147003173828125,...]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0}}},"error":null}
...```
```

View File

@ -7,13 +7,39 @@ from vllm.entrypoints.openai.protocol import BatchRequestOutput
# ruff: noqa: E501
INPUT_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NonExistModel", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""
INVALID_INPUT_BATCH = """{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""
INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}}
def test_e2e():
{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "Hello world!"}}
{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}"""
def test_empty_file():
with tempfile.NamedTemporaryFile(
"w") as input_file, tempfile.NamedTemporaryFile(
"r") as output_file:
input_file.write("")
input_file.flush()
proc = subprocess.Popen([
sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i",
input_file.name, "-o", output_file.name, "--model",
"intfloat/e5-mistral-7b-instruct"
], )
proc.communicate()
proc.wait()
assert proc.returncode == 0, f"{proc=}"
contents = output_file.read()
assert contents.strip() == ""
def test_completions():
with tempfile.NamedTemporaryFile(
"w") as input_file, tempfile.NamedTemporaryFile(
"r") as output_file:
@ -35,7 +61,7 @@ def test_e2e():
BatchRequestOutput.model_validate_json(line)
def test_e2e_invalid_input():
def test_completions_invalid_input():
"""
Ensure that we fail when the input doesn't conform to the openai api.
"""
@ -52,3 +78,25 @@ def test_e2e_invalid_input():
proc.communicate()
proc.wait()
assert proc.returncode != 0, f"{proc=}"
def test_embeddings():
with tempfile.NamedTemporaryFile(
"w") as input_file, tempfile.NamedTemporaryFile(
"r") as output_file:
input_file.write(INPUT_EMBEDDING_BATCH)
input_file.flush()
proc = subprocess.Popen([
sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i",
input_file.name, "-o", output_file.name, "--model",
"intfloat/e5-mistral-7b-instruct"
], )
proc.communicate()
proc.wait()
assert proc.returncode == 0, f"{proc=}"
contents = output_file.read()
for line in contents.strip().split("\n"):
# Ensure that the output format conforms to the openai api.
# Validation should throw if the schema is wrong.
BatchRequestOutput.model_validate_json(line)

View File

@ -672,7 +672,7 @@ class BatchRequestInput(OpenAIBaseModel):
url: str
# The parameters of the request.
body: ChatCompletionRequest
body: Union[ChatCompletionRequest, EmbeddingRequest]
class BatchResponseData(OpenAIBaseModel):
@ -683,7 +683,7 @@ class BatchResponseData(OpenAIBaseModel):
request_id: str
# The body of the response.
body: Optional[ChatCompletionResponse] = None
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
class BatchRequestOutput(OpenAIBaseModel):

View File

@ -1,18 +1,21 @@
import asyncio
from io import StringIO
from typing import Awaitable, List
from typing import Awaitable, Callable, List
import aiohttp
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
ChatCompletionResponse,
ErrorResponse)
EmbeddingResponse, ErrorResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
@ -82,27 +85,26 @@ async def write_file(path_or_url: str, data: str) -> None:
f.write(data)
async def run_request(chat_serving: OpenAIServingChat,
async def run_request(serving_engine_func: Callable,
request: BatchRequestInput) -> BatchRequestOutput:
chat_request = request.body
chat_response = await chat_serving.create_chat_completion(chat_request)
response = await serving_engine_func(request.body)
if isinstance(chat_response, ChatCompletionResponse):
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
body=chat_response, request_id=f"vllm-batch-{random_uuid()}"),
body=response, request_id=f"vllm-batch-{random_uuid()}"),
error=None,
)
elif isinstance(chat_response, ErrorResponse):
elif isinstance(response, ErrorResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=chat_response.code,
status_code=response.code,
request_id=f"vllm-batch-{random_uuid()}"),
error=chat_response,
error=response,
)
else:
raise ValueError("Request must not be sent in stream mode")
@ -128,6 +130,7 @@ async def main(args):
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
# Create the openai serving objects.
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
@ -138,12 +141,35 @@ async def main(args):
request_logger=request_logger,
chat_template=None,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
served_model_names,
request_logger=request_logger,
)
# Submit all requests in the file to the engine "concurrently".
response_futures: List[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
# Skip empty lines.
request_json = request_json.strip()
if not request_json:
continue
request = BatchRequestInput.model_validate_json(request_json)
response_futures.append(run_request(openai_serving_chat, request))
# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
response_futures.append(
run_request(openai_serving_chat.create_chat_completion,
request))
elif request.url == "/v1/embeddings":
response_futures.append(
run_request(openai_serving_embedding.create_embedding,
request))
else:
raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
"supported in the batch endpoint.")
responses = await asyncio.gather(*response_futures)

View File

@ -1,7 +1,8 @@
import asyncio
import base64
import time
from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple, cast
from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple,
Union, cast)
import numpy as np
from fastapi import Request
@ -11,7 +12,8 @@ from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData, UsageInfo)
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
@ -71,8 +73,11 @@ class OpenAIServingEmbedding(OpenAIServing):
request_logger=request_logger)
self._check_embedding_mode(model_config.embedding_mode)
async def create_embedding(self, request: EmbeddingRequest,
raw_request: Request):
async def create_embedding(
self,
request: EmbeddingRequest,
raw_request: Optional[Request] = None
) -> Union[ErrorResponse, EmbeddingResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
@ -140,7 +145,9 @@ class OpenAIServingEmbedding(OpenAIServing):
result_generator: AsyncIterator[Tuple[
int, EmbeddingRequestOutput]] = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
*generators,
is_cancelled=raw_request.is_disconnected
if raw_request else None)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]