[Frontend] Support for chat completions input in the tokenize endpoint (#5923)

This commit is contained in:
sasha0552 2024-07-16 12:18:09 +00:00 committed by GitHub
parent d97011512e
commit 7a3d2a5b95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 386 additions and 244 deletions

View File

@ -4,8 +4,8 @@ from dataclasses import dataclass
import pytest import pytest
from vllm.entrypoints.openai.chat_utils import load_chat_template
from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
@ -64,8 +64,7 @@ def test_load_chat_template():
# Testing chatml template # Testing chatml template
tokenizer = MockTokenizer() tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer) mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat, load_chat_template(mock_serving_chat, chat_template=chatml_jinja_path)
chat_template=chatml_jinja_path)
template_content = tokenizer.chat_template template_content = tokenizer.chat_template
@ -84,8 +83,7 @@ def test_no_load_chat_template_filelike():
mock_serving_chat = MockServingChat(tokenizer) mock_serving_chat = MockServingChat(tokenizer)
with pytest.raises(ValueError, match="looks like a file path"): with pytest.raises(ValueError, match="looks like a file path"):
OpenAIServingChat._load_chat_template(mock_serving_chat, load_chat_template(mock_serving_chat, chat_template=template)
chat_template=template)
def test_no_load_chat_template_literallike(): def test_no_load_chat_template_literallike():
@ -94,8 +92,7 @@ def test_no_load_chat_template_literallike():
tokenizer = MockTokenizer() tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer) mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat, load_chat_template(mock_serving_chat, chat_template=template)
chat_template=template)
template_content = tokenizer.chat_template template_content = tokenizer.chat_template
assert template_content == template assert template_content == template
@ -109,8 +106,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Initialize the tokenizer # Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model) tokenizer = get_tokenizer(tokenizer_name=model)
mock_serving_chat = MockServingChat(tokenizer) mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat, load_chat_template(mock_serving_chat, chat_template=template)
chat_template=template)
# Create a mock request object using keyword arguments # Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest( mock_request = ChatCompletionRequest(

View File

@ -6,7 +6,6 @@ from typing import List
import jsonschema import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
import requests
# downloading lora to test lora requests # downloading lora to test lora requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from openai import BadRequestError from openai import BadRequestError
@ -636,51 +635,3 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
prompt="Give an example string that fits this regex", prompt="Give an example string that fits this regex",
extra_body=dict(guided_regex=sample_regex, extra_body=dict(guided_regex=sample_regex,
guided_json=sample_json_schema)) guided_json=sample_json_schema))
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
for add_special in [False, True]:
prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize",
json={
"add_special_tokens": add_special,
"model": model_name,
"prompt": prompt
})
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3]
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=False)
response = requests.post(base_url + "detokenize",
json={
"model": model_name,
"tokens": tokens
})
response.raise_for_status()
assert response.json() == {"prompt": prompt}

View File

@ -0,0 +1,128 @@
import openai # use the official client for correctness check
import pytest
import requests
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module")
def server():
with RemoteOpenAIServer([
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
"--max-num-seqs",
"128",
]) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_tokenize_completions(client: openai.AsyncOpenAI,
model_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
for add_special in [False, True]:
prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize",
json={
"add_special_tokens": add_special,
"model": model_name,
"prompt": prompt
})
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
for add_generation in [False, True]:
for add_special in [False, True]:
conversation = [{
"role": "user",
"content": "Hi there!"
}, {
"role": "assistant",
"content": "Nice to meet you!"
}, {
"role": "user",
"content": "Can I ask a question?"
}]
prompt = tokenizer.apply_chat_template(
add_generation_prompt=add_generation,
conversation=conversation,
tokenize=False)
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize",
json={
"add_generation_prompt":
add_generation,
"add_special_tokens": add_special,
"messages": conversation,
"model": model_name
})
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=False)
response = requests.post(base_url + "/detokenize",
json={
"model": model_name,
"tokens": tokens
})
response.raise_for_status()
assert response.json() == {"prompt": prompt}

View File

@ -33,6 +33,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -46,6 +48,7 @@ engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization
logger = init_logger('vllm.entrypoints.openai.api_server') logger = init_logger('vllm.entrypoints.openai.api_server')
@ -86,7 +89,7 @@ async def health() -> Response:
@router.post("/tokenize") @router.post("/tokenize")
async def tokenize(request: TokenizeRequest): async def tokenize(request: TokenizeRequest):
generator = await openai_serving_completion.create_tokenize(request) generator = await openai_serving_tokenization.create_tokenize(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
@ -97,7 +100,7 @@ async def tokenize(request: TokenizeRequest):
@router.post("/detokenize") @router.post("/detokenize")
async def detokenize(request: DetokenizeRequest): async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_completion.create_detokenize(request) generator = await openai_serving_tokenization.create_detokenize(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
@ -241,6 +244,7 @@ def run_server(args, llm_engine=None):
global openai_serving_chat global openai_serving_chat
global openai_serving_completion global openai_serving_completion
global openai_serving_embedding global openai_serving_embedding
global openai_serving_tokenization
openai_serving_chat = OpenAIServingChat(engine, model_config, openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names, served_model_names,
@ -252,6 +256,8 @@ def run_server(args, llm_engine=None):
args.prompt_adapters) args.prompt_adapters)
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names) served_model_names)
openai_serving_tokenization = OpenAIServingTokenization(
engine, model_config, served_model_names, args.chat_template)
app.root_path = args.root_path app.root_path = args.root_path
logger.info("Available routes are:") logger.info("Available routes are:")

View File

@ -0,0 +1,156 @@
import codecs
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Awaitable, Iterable, List, Optional, TypedDict, cast, final
from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam)
from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam,
ChatCompletionMessageParam)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
logger = init_logger(__name__)
@final # So that it should be compatible with Dict[str, str]
class ConversationMessage(TypedDict):
role: str
content: str
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
default_factory=list)
def load_chat_template(engine: OpenAIServing, chat_template: Optional[str]):
tokenizer = engine.tokenizer
if chat_template is not None:
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(chat_template,
"unicode_escape")
logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning("No chat template provided. Chat API will not work.")
@lru_cache(maxsize=None)
def _image_token_str(engine: OpenAIServing) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = engine.model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return engine.tokenizer.decode(
engine.model_config.hf_config.image_token_index)
else:
raise TypeError("Unknown model type: {model_type}")
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,
text_prompt: str) -> str:
"""Combine image and text prompts for vision language model"""
# NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future.
return f"{image_token_str}\n{text_prompt}"
def _parse_chat_message_content_parts(
engine: OpenAIServing,
role: str,
parts: Iterable[ChatCompletionContentPartParam],
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for part in parts:
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text)
elif part_type == "image_url":
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple 'image_url' input is currently not supported.")
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if mm_futures:
image_token_str = _image_token_str(engine)
if image_token_str is not None:
if image_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = _get_full_image_text_prompt(
engine,
image_token_str=image_token_str,
text_prompt=text_prompt,
)
messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
def parse_chat_message_content(
engine: OpenAIServing,
message: ChatCompletionMessageParam,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
if content is None:
return ChatMessageParseResult(messages=[], mm_futures=[])
if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])
return _parse_chat_message_content_parts(engine, role, content)

View File

@ -738,15 +738,17 @@ class BatchRequestOutput(OpenAIBaseModel):
class TokenizeRequest(OpenAIBaseModel): class TokenizeRequest(OpenAIBaseModel):
add_generation_prompt: bool = Field(default=True)
add_special_tokens: bool = Field(default=False)
prompt: Optional[str] = Field(default=None)
messages: Optional[List[ChatCompletionMessageParam]] = Field(default=None)
model: str model: str
prompt: str
add_special_tokens: bool = Field(default=True)
class TokenizeResponse(OpenAIBaseModel): class TokenizeResponse(OpenAIBaseModel):
tokens: List[int]
count: int count: int
max_model_len: int max_model_len: int
tokens: List[int]
class DetokenizeRequest(OpenAIBaseModel): class DetokenizeRequest(OpenAIBaseModel):

View File

@ -1,22 +1,19 @@
import codecs
import time import time
from dataclasses import dataclass, field from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List,
from functools import cached_property Optional)
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
List, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import TypedDict, Union, cast, final from typing import Union
from fastapi import Request from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam)
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionContentPartParam, ChatCompletionLogProb, ChatCompletionLogProb, ChatCompletionLogProbs,
ChatCompletionLogProbs, ChatCompletionLogProbsContent, ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
@ -28,7 +25,6 @@ from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
@ -38,19 +34,6 @@ from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@final # So that it should be compatible with Dict[str, str]
class ConversationMessage(TypedDict):
role: str
content: str
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
default_factory=list)
class OpenAIServingChat(OpenAIServing): class OpenAIServingChat(OpenAIServing):
def __init__(self, def __init__(self,
@ -66,131 +49,7 @@ class OpenAIServingChat(OpenAIServing):
lora_modules=lora_modules) lora_modules=lora_modules)
self.response_role = response_role self.response_role = response_role
self._load_chat_template(chat_template) load_chat_template(self, chat_template)
def _load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer
if chat_template is not None:
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")
logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning(
"No chat template provided. Chat API will not work.")
@cached_property
def image_token_str(self) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = self.model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv",
"paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return self.tokenizer.decode(
self.model_config.hf_config.image_token_index)
else:
raise TypeError("Unknown model type: {model_type}")
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
def _get_full_image_text_prompt(self, image_token_str: str,
text_prompt: str) -> str:
"""Combine image and text prompts for vision language model"""
# NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future.
return f"{image_token_str}\n{text_prompt}"
def _parse_chat_message_content_parts(
self,
role: str,
parts: Iterable[ChatCompletionContentPartParam],
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for part in parts:
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text)
elif part_type == "image_url":
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple 'image_url' input is currently not supported."
)
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if mm_futures:
image_token_str = self.image_token_str
if image_token_str is not None:
if image_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = self._get_full_image_text_prompt(
image_token_str=image_token_str,
text_prompt=text_prompt,
)
messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
def _parse_chat_message_content(
self,
message: ChatCompletionMessageParam,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
if content is None:
return ChatMessageParseResult(messages=[], mm_futures=[])
if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])
return self._parse_chat_message_content_parts(role, content)
async def create_chat_completion( async def create_chat_completion(
self, self,
@ -216,7 +75,7 @@ class OpenAIServingChat(OpenAIServing):
mm_futures: List[Awaitable[MultiModalDataDict]] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in request.messages: for msg in request.messages:
chat_parsed_result = self._parse_chat_message_content(msg) chat_parsed_result = parse_chat_message_content(self, msg)
conversation.extend(chat_parsed_result.messages) conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures) mm_futures.extend(chat_parsed_result.mm_futures)

View File

@ -16,10 +16,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
DetokenizeRequest, UsageInfo)
DetokenizeResponse,
TokenizeRequest,
TokenizeResponse, UsageInfo)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing, OpenAIServing,
@ -457,29 +454,3 @@ class OpenAIServingCompletion(OpenAIServing):
tokens=out_tokens, tokens=out_tokens,
top_logprobs=out_top_logprobs, top_logprobs=out_top_logprobs,
) )
async def create_tokenize(self,
request: TokenizeRequest) -> TokenizeResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
(input_ids, input_text) = self._validate_prompt_and_tokenize(
request,
prompt=request.prompt,
add_special_tokens=request.add_special_tokens)
return TokenizeResponse(tokens=input_ids,
count=len(input_ids),
max_model_len=self.max_model_len)
async def create_detokenize(
self, request: DetokenizeRequest) -> DetokenizeResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
(input_ids, input_text) = self._validate_prompt_and_tokenize(
request, prompt_ids=request.tokens)
return DetokenizeResponse(prompt=input_text)

View File

@ -0,0 +1,73 @@
from typing import List, Optional
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse,
TokenizeRequest,
TokenizeResponse)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
class OpenAIServingTokenization(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
chat_template: Optional[str] = None):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=None)
load_chat_template(self, chat_template)
async def create_tokenize(self,
request: TokenizeRequest) -> TokenizeResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
if not (request.prompt or request.messages):
return self.create_error_response(
"Either `prompt` or `messages` should be provided.")
if (request.prompt and request.messages):
return self.create_error_response(
"Only one of `prompt` or `messages` should be provided.")
if request.messages:
conversation: List[ConversationMessage] = []
for message in request.messages:
conversation.extend(
parse_chat_message_content(self, message).messages)
request.prompt = self.tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt,
conversation=conversation,
tokenize=False)
(input_ids, input_text) = self._validate_prompt_and_tokenize(
request,
prompt=request.prompt,
add_special_tokens=request.add_special_tokens)
return TokenizeResponse(tokens=input_ids,
count=len(input_ids),
max_model_len=self.max_model_len)
async def create_detokenize(
self, request: DetokenizeRequest) -> DetokenizeResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
(input_ids, input_text) = self._validate_prompt_and_tokenize(
request, prompt_ids=request.tokens)
return DetokenizeResponse(prompt=input_text)