[Frontend] Support embeddings in the run_batch API (#7132)
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
parent
74af2bbd90
commit
249b88228d
@ -10,7 +10,7 @@
|
||||
|
||||
Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details.
|
||||
|
||||
**NOTE:** We currently only support to `/v1/chat/completions` endpoint (embeddings and completions coming soon).
|
||||
**NOTE:** We currently only support `/v1/chat/completions` and `/v1/embeddings` endpoints (completions coming soon).
|
||||
|
||||
## Pre-requisites
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
- Get access to the gated model by [visiting the model card](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and agreeing to the terms and conditions.
|
||||
|
||||
|
||||
## Example: Running with a local file
|
||||
## Example 1: Running with a local file
|
||||
|
||||
### Step 1: Create your batch file
|
||||
|
||||
@ -54,7 +54,7 @@ python -m vllm.entrypoints.openai.run_batch -i openai_example_batch.jsonl -o res
|
||||
You should now have your results at `results.jsonl`. You can check your results by running `cat results.jsonl`
|
||||
|
||||
```
|
||||
$ cat ../results.jsonl
|
||||
$ cat results.jsonl
|
||||
{"id":"vllm-383d1c59835645aeb2e07d004d62a826","custom_id":"request-1","response":{"id":"cmpl-61c020e54b964d5a98fa7527bfcdd378","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Hello! It's great to meet you! I'm here to help with any questions or tasks you may have. What's on your mind today?"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":25,"total_tokens":56,"completion_tokens":31}},"error":null}
|
||||
{"id":"vllm-42e3d09b14b04568afa3f1797751a267","custom_id":"request-2","response":{"id":"cmpl-f44d049f6b3a42d4b2d7850bb1e31bcc","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"*silence*"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":27,"total_tokens":32,"completion_tokens":5}},"error":null}
|
||||
```
|
||||
@ -107,7 +107,7 @@ aws s3 cp openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl
|
||||
|
||||
### Step 2: Generate your presigned urls
|
||||
|
||||
Presigned put urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names.
|
||||
Presigned urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names.
|
||||
|
||||
(The script is adapted from https://github.com/awsdocs/aws-doc-sdk-examples/blob/main/python/example_code/s3/s3_basics/presigned_url.py)
|
||||
|
||||
@ -170,3 +170,36 @@ Your results are now on S3. You can view them in your terminal by running
|
||||
```
|
||||
aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl -
|
||||
```
|
||||
|
||||
## Example 4: Using embeddings endpoint
|
||||
|
||||
### Additional prerequisites
|
||||
|
||||
* Ensure you are using `vllm >= 0.5.5`.
|
||||
|
||||
### Step 1: Create your batch file
|
||||
|
||||
Add embedding requests to your batch file. The following is an example:
|
||||
|
||||
```
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}}
|
||||
```
|
||||
|
||||
You can even mix chat completion and embedding requests in the batch file, as long as the model you are using supports both chat completion and embeddings (note that all requests must use the same model).
|
||||
|
||||
|
||||
### Step 2: Run the batch
|
||||
|
||||
You can run the batch using the same command as in earlier examples.
|
||||
|
||||
|
||||
### Step 3: Check your results
|
||||
|
||||
You can check your results by running `cat results.jsonl`
|
||||
|
||||
```
|
||||
$ cat results.jsonl
|
||||
{"id":"vllm-db0f71f7dec244e6bce530e0b4ef908b","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-3580bf4d4ae54d52b67eee266a6eab20","body":{"id":"embd-33ac2efa7996430184461f2e38529746","object":"list","created":444647,"model":"intfloat/e5-mistral-7b-instruct","data":[{"index":0,"object":"embedding","embedding":[0.016204833984375,0.0092010498046875,0.0018358230590820312,-0.0028228759765625,0.001422882080078125,-0.0031147003173828125,...]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0}}},"error":null}
|
||||
...```
|
||||
```
|
||||
|
@ -7,13 +7,39 @@ from vllm.entrypoints.openai.protocol import BatchRequestOutput
|
||||
# ruff: noqa: E501
|
||||
INPUT_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
|
||||
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NonExistModel", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""
|
||||
|
||||
INVALID_INPUT_BATCH = """{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""
|
||||
|
||||
INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}}
|
||||
|
||||
def test_e2e():
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "Hello world!"}}
|
||||
{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}"""
|
||||
|
||||
|
||||
def test_empty_file():
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
input_file.write("")
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen([
|
||||
sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i",
|
||||
input_file.name, "-o", output_file.name, "--model",
|
||||
"intfloat/e5-mistral-7b-instruct"
|
||||
], )
|
||||
proc.communicate()
|
||||
proc.wait()
|
||||
assert proc.returncode == 0, f"{proc=}"
|
||||
|
||||
contents = output_file.read()
|
||||
assert contents.strip() == ""
|
||||
|
||||
|
||||
def test_completions():
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
@ -35,7 +61,7 @@ def test_e2e():
|
||||
BatchRequestOutput.model_validate_json(line)
|
||||
|
||||
|
||||
def test_e2e_invalid_input():
|
||||
def test_completions_invalid_input():
|
||||
"""
|
||||
Ensure that we fail when the input doesn't conform to the openai api.
|
||||
"""
|
||||
@ -52,3 +78,25 @@ def test_e2e_invalid_input():
|
||||
proc.communicate()
|
||||
proc.wait()
|
||||
assert proc.returncode != 0, f"{proc=}"
|
||||
|
||||
|
||||
def test_embeddings():
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
input_file.write(INPUT_EMBEDDING_BATCH)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen([
|
||||
sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i",
|
||||
input_file.name, "-o", output_file.name, "--model",
|
||||
"intfloat/e5-mistral-7b-instruct"
|
||||
], )
|
||||
proc.communicate()
|
||||
proc.wait()
|
||||
assert proc.returncode == 0, f"{proc=}"
|
||||
|
||||
contents = output_file.read()
|
||||
for line in contents.strip().split("\n"):
|
||||
# Ensure that the output format conforms to the openai api.
|
||||
# Validation should throw if the schema is wrong.
|
||||
BatchRequestOutput.model_validate_json(line)
|
||||
|
@ -672,7 +672,7 @@ class BatchRequestInput(OpenAIBaseModel):
|
||||
url: str
|
||||
|
||||
# The parameters of the request.
|
||||
body: ChatCompletionRequest
|
||||
body: Union[ChatCompletionRequest, EmbeddingRequest]
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
@ -683,7 +683,7 @@ class BatchResponseData(OpenAIBaseModel):
|
||||
request_id: str
|
||||
|
||||
# The body of the response.
|
||||
body: Optional[ChatCompletionResponse] = None
|
||||
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
|
||||
|
||||
|
||||
class BatchRequestOutput(OpenAIBaseModel):
|
||||
|
@ -1,18 +1,21 @@
|
||||
import asyncio
|
||||
from io import StringIO
|
||||
from typing import Awaitable, List
|
||||
from typing import Awaitable, Callable, List
|
||||
|
||||
import aiohttp
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
BatchRequestOutput,
|
||||
BatchResponseData,
|
||||
ChatCompletionResponse,
|
||||
ErrorResponse)
|
||||
EmbeddingResponse, ErrorResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
@ -82,27 +85,26 @@ async def write_file(path_or_url: str, data: str) -> None:
|
||||
f.write(data)
|
||||
|
||||
|
||||
async def run_request(chat_serving: OpenAIServingChat,
|
||||
async def run_request(serving_engine_func: Callable,
|
||||
request: BatchRequestInput) -> BatchRequestOutput:
|
||||
chat_request = request.body
|
||||
chat_response = await chat_serving.create_chat_completion(chat_request)
|
||||
response = await serving_engine_func(request.body)
|
||||
|
||||
if isinstance(chat_response, ChatCompletionResponse):
|
||||
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
body=chat_response, request_id=f"vllm-batch-{random_uuid()}"),
|
||||
body=response, request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=None,
|
||||
)
|
||||
elif isinstance(chat_response, ErrorResponse):
|
||||
elif isinstance(response, ErrorResponse):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=chat_response.code,
|
||||
status_code=response.code,
|
||||
request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=chat_response,
|
||||
error=response,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Request must not be sent in stream mode")
|
||||
@ -128,6 +130,7 @@ async def main(args):
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
@ -138,12 +141,35 @@ async def main(args):
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
model_config,
|
||||
served_model_names,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
|
||||
# Submit all requests in the file to the engine "concurrently".
|
||||
response_futures: List[Awaitable[BatchRequestOutput]] = []
|
||||
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
||||
# Skip empty lines.
|
||||
request_json = request_json.strip()
|
||||
if not request_json:
|
||||
continue
|
||||
|
||||
request = BatchRequestInput.model_validate_json(request_json)
|
||||
response_futures.append(run_request(openai_serving_chat, request))
|
||||
|
||||
# Determine the type of request and run it.
|
||||
if request.url == "/v1/chat/completions":
|
||||
response_futures.append(
|
||||
run_request(openai_serving_chat.create_chat_completion,
|
||||
request))
|
||||
elif request.url == "/v1/embeddings":
|
||||
response_futures.append(
|
||||
run_request(openai_serving_embedding.create_embedding,
|
||||
request))
|
||||
else:
|
||||
raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
|
||||
"supported in the batch endpoint.")
|
||||
|
||||
responses = await asyncio.gather(*response_futures)
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple, cast
|
||||
from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple,
|
||||
Union, cast)
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
@ -11,7 +12,8 @@ from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData, UsageInfo)
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingRequestOutput
|
||||
@ -71,8 +73,11 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
request_logger=request_logger)
|
||||
self._check_embedding_mode(model_config.embedding_mode)
|
||||
|
||||
async def create_embedding(self, request: EmbeddingRequest,
|
||||
raw_request: Request):
|
||||
async def create_embedding(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[ErrorResponse, EmbeddingResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
@ -140,7 +145,9 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
result_generator: AsyncIterator[Tuple[
|
||||
int, EmbeddingRequestOutput]] = merge_async_iterators(
|
||||
*generators, is_cancelled=raw_request.is_disconnected)
|
||||
*generators,
|
||||
is_cancelled=raw_request.is_disconnected
|
||||
if raw_request else None)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user