[Frontend] Support for chat completions input in the tokenize endpoint (#5923)
This commit is contained in:
parent
d97011512e
commit
7a3d2a5b95
@ -4,8 +4,8 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.chat_utils import load_chat_template
|
||||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
|
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
|
||||||
@ -64,8 +64,7 @@ def test_load_chat_template():
|
|||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
tokenizer = MockTokenizer()
|
tokenizer = MockTokenizer()
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
load_chat_template(mock_serving_chat, chat_template=chatml_jinja_path)
|
||||||
chat_template=chatml_jinja_path)
|
|
||||||
|
|
||||||
template_content = tokenizer.chat_template
|
template_content = tokenizer.chat_template
|
||||||
|
|
||||||
@ -84,8 +83,7 @@ def test_no_load_chat_template_filelike():
|
|||||||
mock_serving_chat = MockServingChat(tokenizer)
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="looks like a file path"):
|
with pytest.raises(ValueError, match="looks like a file path"):
|
||||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
load_chat_template(mock_serving_chat, chat_template=template)
|
||||||
chat_template=template)
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_load_chat_template_literallike():
|
def test_no_load_chat_template_literallike():
|
||||||
@ -94,8 +92,7 @@ def test_no_load_chat_template_literallike():
|
|||||||
tokenizer = MockTokenizer()
|
tokenizer = MockTokenizer()
|
||||||
|
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
load_chat_template(mock_serving_chat, chat_template=template)
|
||||||
chat_template=template)
|
|
||||||
template_content = tokenizer.chat_template
|
template_content = tokenizer.chat_template
|
||||||
|
|
||||||
assert template_content == template
|
assert template_content == template
|
||||||
@ -109,8 +106,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
load_chat_template(mock_serving_chat, chat_template=template)
|
||||||
chat_template=template)
|
|
||||||
|
|
||||||
# Create a mock request object using keyword arguments
|
# Create a mock request object using keyword arguments
|
||||||
mock_request = ChatCompletionRequest(
|
mock_request = ChatCompletionRequest(
|
||||||
|
@ -6,7 +6,6 @@ from typing import List
|
|||||||
import jsonschema
|
import jsonschema
|
||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
|
||||||
# downloading lora to test lora requests
|
# downloading lora to test lora requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
@ -636,51 +635,3 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
|||||||
prompt="Give an example string that fits this regex",
|
prompt="Give an example string that fits this regex",
|
||||||
extra_body=dict(guided_regex=sample_regex,
|
extra_body=dict(guided_regex=sample_regex,
|
||||||
guided_json=sample_json_schema))
|
guided_json=sample_json_schema))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model_name",
|
|
||||||
[MODEL_NAME],
|
|
||||||
)
|
|
||||||
async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
|
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
|
||||||
|
|
||||||
for add_special in [False, True]:
|
|
||||||
prompt = "This is a test prompt."
|
|
||||||
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
|
||||||
|
|
||||||
response = requests.post(base_url + "/tokenize",
|
|
||||||
json={
|
|
||||||
"add_special_tokens": add_special,
|
|
||||||
"model": model_name,
|
|
||||||
"prompt": prompt
|
|
||||||
})
|
|
||||||
response.raise_for_status()
|
|
||||||
assert response.json() == {
|
|
||||||
"tokens": tokens,
|
|
||||||
"count": len(tokens),
|
|
||||||
"max_model_len": 8192
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model_name",
|
|
||||||
[MODEL_NAME],
|
|
||||||
)
|
|
||||||
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
|
|
||||||
base_url = str(client.base_url)[:-3]
|
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
|
||||||
|
|
||||||
prompt = "This is a test prompt."
|
|
||||||
tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
|
||||||
|
|
||||||
response = requests.post(base_url + "detokenize",
|
|
||||||
json={
|
|
||||||
"model": model_name,
|
|
||||||
"tokens": tokens
|
|
||||||
})
|
|
||||||
response.raise_for_status()
|
|
||||||
assert response.json() == {"prompt": prompt}
|
|
||||||
|
128
tests/entrypoints/openai/test_tokenization.py
Normal file
128
tests/entrypoints/openai/test_tokenization.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import openai # use the official client for correctness check
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
# any model with a chat template should work here
|
||||||
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server():
|
||||||
|
with RemoteOpenAIServer([
|
||||||
|
"--model",
|
||||||
|
MODEL_NAME,
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
]) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def client(server):
|
||||||
|
return server.get_async_client()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME],
|
||||||
|
)
|
||||||
|
async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
|
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||||
|
|
||||||
|
for add_special in [False, True]:
|
||||||
|
prompt = "This is a test prompt."
|
||||||
|
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
||||||
|
|
||||||
|
response = requests.post(base_url + "/tokenize",
|
||||||
|
json={
|
||||||
|
"add_special_tokens": add_special,
|
||||||
|
"model": model_name,
|
||||||
|
"prompt": prompt
|
||||||
|
})
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
assert response.json() == {
|
||||||
|
"tokens": tokens,
|
||||||
|
"count": len(tokens),
|
||||||
|
"max_model_len": 8192
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME],
|
||||||
|
)
|
||||||
|
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
|
||||||
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
|
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||||
|
|
||||||
|
for add_generation in [False, True]:
|
||||||
|
for add_special in [False, True]:
|
||||||
|
conversation = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi there!"
|
||||||
|
}, {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Nice to meet you!"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Can I ask a question?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
add_generation_prompt=add_generation,
|
||||||
|
conversation=conversation,
|
||||||
|
tokenize=False)
|
||||||
|
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
||||||
|
|
||||||
|
response = requests.post(base_url + "/tokenize",
|
||||||
|
json={
|
||||||
|
"add_generation_prompt":
|
||||||
|
add_generation,
|
||||||
|
"add_special_tokens": add_special,
|
||||||
|
"messages": conversation,
|
||||||
|
"model": model_name
|
||||||
|
})
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
assert response.json() == {
|
||||||
|
"tokens": tokens,
|
||||||
|
"count": len(tokens),
|
||||||
|
"max_model_len": 8192
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME],
|
||||||
|
)
|
||||||
|
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
|
||||||
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
|
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||||
|
|
||||||
|
prompt = "This is a test prompt."
|
||||||
|
tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
|
|
||||||
|
response = requests.post(base_url + "/detokenize",
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"tokens": tokens
|
||||||
|
})
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
assert response.json() == {"prompt": prompt}
|
@ -33,6 +33,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
|
OpenAIServingTokenization)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
@ -46,6 +48,7 @@ engine_args: AsyncEngineArgs
|
|||||||
openai_serving_chat: OpenAIServingChat
|
openai_serving_chat: OpenAIServingChat
|
||||||
openai_serving_completion: OpenAIServingCompletion
|
openai_serving_completion: OpenAIServingCompletion
|
||||||
openai_serving_embedding: OpenAIServingEmbedding
|
openai_serving_embedding: OpenAIServingEmbedding
|
||||||
|
openai_serving_tokenization: OpenAIServingTokenization
|
||||||
|
|
||||||
logger = init_logger('vllm.entrypoints.openai.api_server')
|
logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||||
|
|
||||||
@ -86,7 +89,7 @@ async def health() -> Response:
|
|||||||
|
|
||||||
@router.post("/tokenize")
|
@router.post("/tokenize")
|
||||||
async def tokenize(request: TokenizeRequest):
|
async def tokenize(request: TokenizeRequest):
|
||||||
generator = await openai_serving_completion.create_tokenize(request)
|
generator = await openai_serving_tokenization.create_tokenize(request)
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
return JSONResponse(content=generator.model_dump(),
|
return JSONResponse(content=generator.model_dump(),
|
||||||
status_code=generator.code)
|
status_code=generator.code)
|
||||||
@ -97,7 +100,7 @@ async def tokenize(request: TokenizeRequest):
|
|||||||
|
|
||||||
@router.post("/detokenize")
|
@router.post("/detokenize")
|
||||||
async def detokenize(request: DetokenizeRequest):
|
async def detokenize(request: DetokenizeRequest):
|
||||||
generator = await openai_serving_completion.create_detokenize(request)
|
generator = await openai_serving_tokenization.create_detokenize(request)
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
return JSONResponse(content=generator.model_dump(),
|
return JSONResponse(content=generator.model_dump(),
|
||||||
status_code=generator.code)
|
status_code=generator.code)
|
||||||
@ -241,6 +244,7 @@ def run_server(args, llm_engine=None):
|
|||||||
global openai_serving_chat
|
global openai_serving_chat
|
||||||
global openai_serving_completion
|
global openai_serving_completion
|
||||||
global openai_serving_embedding
|
global openai_serving_embedding
|
||||||
|
global openai_serving_tokenization
|
||||||
|
|
||||||
openai_serving_chat = OpenAIServingChat(engine, model_config,
|
openai_serving_chat = OpenAIServingChat(engine, model_config,
|
||||||
served_model_names,
|
served_model_names,
|
||||||
@ -252,6 +256,8 @@ def run_server(args, llm_engine=None):
|
|||||||
args.prompt_adapters)
|
args.prompt_adapters)
|
||||||
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
|
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
|
||||||
served_model_names)
|
served_model_names)
|
||||||
|
openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
|
engine, model_config, served_model_names, args.chat_template)
|
||||||
app.root_path = args.root_path
|
app.root_path = args.root_path
|
||||||
|
|
||||||
logger.info("Available routes are:")
|
logger.info("Available routes are:")
|
||||||
|
156
vllm/entrypoints/openai/chat_utils.py
Normal file
156
vllm/entrypoints/openai/chat_utils.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import codecs
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Awaitable, Iterable, List, Optional, TypedDict, cast, final
|
||||||
|
|
||||||
|
from openai.types.chat import (ChatCompletionContentPartImageParam,
|
||||||
|
ChatCompletionContentPartTextParam)
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam,
|
||||||
|
ChatCompletionMessageParam)
|
||||||
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.multimodal import MultiModalDataDict
|
||||||
|
from vllm.multimodal.utils import async_get_and_parse_image
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@final # So that it should be compatible with Dict[str, str]
|
||||||
|
class ConversationMessage(TypedDict):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ChatMessageParseResult:
|
||||||
|
messages: List[ConversationMessage]
|
||||||
|
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
|
||||||
|
default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def load_chat_template(engine: OpenAIServing, chat_template: Optional[str]):
|
||||||
|
tokenizer = engine.tokenizer
|
||||||
|
|
||||||
|
if chat_template is not None:
|
||||||
|
try:
|
||||||
|
with open(chat_template, "r") as f:
|
||||||
|
tokenizer.chat_template = f.read()
|
||||||
|
except OSError as e:
|
||||||
|
JINJA_CHARS = "{}\n"
|
||||||
|
if not any(c in chat_template for c in JINJA_CHARS):
|
||||||
|
msg = (f"The supplied chat template ({chat_template}) "
|
||||||
|
f"looks like a file path, but it failed to be "
|
||||||
|
f"opened. Reason: {e}")
|
||||||
|
raise ValueError(msg) from e
|
||||||
|
|
||||||
|
# If opening a file fails, set chat template to be args to
|
||||||
|
# ensure we decode so our escape are interpreted correctly
|
||||||
|
tokenizer.chat_template = codecs.decode(chat_template,
|
||||||
|
"unicode_escape")
|
||||||
|
|
||||||
|
logger.info("Using supplied chat template:\n%s",
|
||||||
|
tokenizer.chat_template)
|
||||||
|
elif tokenizer.chat_template is not None:
|
||||||
|
logger.info("Using default chat template:\n%s",
|
||||||
|
tokenizer.chat_template)
|
||||||
|
else:
|
||||||
|
logger.warning("No chat template provided. Chat API will not work.")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def _image_token_str(engine: OpenAIServing) -> Optional[str]:
|
||||||
|
# TODO: Let user specify how to insert image tokens into prompt
|
||||||
|
# (similar to chat template)
|
||||||
|
model_type = engine.model_config.hf_config.model_type
|
||||||
|
if model_type == "phi3_v":
|
||||||
|
# Workaround since this token is not defined in the tokenizer
|
||||||
|
return "<|image_1|>"
|
||||||
|
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", "paligemma"):
|
||||||
|
# These models do not use image tokens in the prompt
|
||||||
|
return None
|
||||||
|
if model_type.startswith("llava"):
|
||||||
|
return engine.tokenizer.decode(
|
||||||
|
engine.model_config.hf_config.image_token_index)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise TypeError("Unknown model type: {model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Let user specify how to insert image tokens into prompt
|
||||||
|
# (similar to chat template)
|
||||||
|
def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,
|
||||||
|
text_prompt: str) -> str:
|
||||||
|
"""Combine image and text prompts for vision language model"""
|
||||||
|
|
||||||
|
# NOTE: For now we assume all model architectures use the same
|
||||||
|
# image + text prompt format. This may change in the future.
|
||||||
|
return f"{image_token_str}\n{text_prompt}"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_chat_message_content_parts(
|
||||||
|
engine: OpenAIServing,
|
||||||
|
role: str,
|
||||||
|
parts: Iterable[ChatCompletionContentPartParam],
|
||||||
|
) -> ChatMessageParseResult:
|
||||||
|
texts: List[str] = []
|
||||||
|
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
part_type = part["type"]
|
||||||
|
if part_type == "text":
|
||||||
|
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
||||||
|
texts.append(text)
|
||||||
|
elif part_type == "image_url":
|
||||||
|
if len(mm_futures) > 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Multiple 'image_url' input is currently not supported.")
|
||||||
|
|
||||||
|
image_url = cast(ChatCompletionContentPartImageParam,
|
||||||
|
part)["image_url"]
|
||||||
|
|
||||||
|
if image_url.get("detail", "auto") != "auto":
|
||||||
|
logger.warning(
|
||||||
|
"'image_url.detail' is currently not supported and "
|
||||||
|
"will be ignored.")
|
||||||
|
|
||||||
|
image_future = async_get_and_parse_image(image_url["url"])
|
||||||
|
mm_futures.append(image_future)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||||
|
|
||||||
|
text_prompt = "\n".join(texts)
|
||||||
|
|
||||||
|
if mm_futures:
|
||||||
|
image_token_str = _image_token_str(engine)
|
||||||
|
if image_token_str is not None:
|
||||||
|
if image_token_str in text_prompt:
|
||||||
|
logger.warning(
|
||||||
|
"Detected image token string in the text prompt. "
|
||||||
|
"Skipping prompt formatting.")
|
||||||
|
else:
|
||||||
|
text_prompt = _get_full_image_text_prompt(
|
||||||
|
engine,
|
||||||
|
image_token_str=image_token_str,
|
||||||
|
text_prompt=text_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [ConversationMessage(role=role, content=text_prompt)]
|
||||||
|
|
||||||
|
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_chat_message_content(
|
||||||
|
engine: OpenAIServing,
|
||||||
|
message: ChatCompletionMessageParam,
|
||||||
|
) -> ChatMessageParseResult:
|
||||||
|
role = message["role"]
|
||||||
|
content = message.get("content")
|
||||||
|
|
||||||
|
if content is None:
|
||||||
|
return ChatMessageParseResult(messages=[], mm_futures=[])
|
||||||
|
if isinstance(content, str):
|
||||||
|
messages = [ConversationMessage(role=role, content=content)]
|
||||||
|
return ChatMessageParseResult(messages=messages, mm_futures=[])
|
||||||
|
|
||||||
|
return _parse_chat_message_content_parts(engine, role, content)
|
@ -738,15 +738,17 @@ class BatchRequestOutput(OpenAIBaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TokenizeRequest(OpenAIBaseModel):
|
class TokenizeRequest(OpenAIBaseModel):
|
||||||
|
add_generation_prompt: bool = Field(default=True)
|
||||||
|
add_special_tokens: bool = Field(default=False)
|
||||||
|
prompt: Optional[str] = Field(default=None)
|
||||||
|
messages: Optional[List[ChatCompletionMessageParam]] = Field(default=None)
|
||||||
model: str
|
model: str
|
||||||
prompt: str
|
|
||||||
add_special_tokens: bool = Field(default=True)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenizeResponse(OpenAIBaseModel):
|
class TokenizeResponse(OpenAIBaseModel):
|
||||||
tokens: List[int]
|
|
||||||
count: int
|
count: int
|
||||||
max_model_len: int
|
max_model_len: int
|
||||||
|
tokens: List[int]
|
||||||
|
|
||||||
|
|
||||||
class DetokenizeRequest(OpenAIBaseModel):
|
class DetokenizeRequest(OpenAIBaseModel):
|
||||||
|
@ -1,22 +1,19 @@
|
|||||||
import codecs
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List,
|
||||||
from functools import cached_property
|
Optional)
|
||||||
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
|
|
||||||
List, Optional)
|
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import TypedDict, Union, cast, final
|
from typing import Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from openai.types.chat import (ChatCompletionContentPartImageParam,
|
|
||||||
ChatCompletionContentPartTextParam)
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.entrypoints.openai.chat_utils import (ConversationMessage,
|
||||||
|
load_chat_template,
|
||||||
|
parse_chat_message_content)
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionContentPartParam, ChatCompletionLogProb,
|
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||||
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
|
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
|
||||||
ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
|
|
||||||
ChatCompletionRequest, ChatCompletionResponse,
|
ChatCompletionRequest, ChatCompletionResponse,
|
||||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||||
@ -28,7 +25,6 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.guided_decoding import (
|
from vllm.model_executor.guided_decoding import (
|
||||||
get_guided_decoding_logits_processor)
|
get_guided_decoding_logits_processor)
|
||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.multimodal.utils import async_get_and_parse_image
|
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
@ -38,19 +34,6 @@ from vllm.utils import random_uuid
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@final # So that it should be compatible with Dict[str, str]
|
|
||||||
class ConversationMessage(TypedDict):
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ChatMessageParseResult:
|
|
||||||
messages: List[ConversationMessage]
|
|
||||||
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
|
|
||||||
default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingChat(OpenAIServing):
|
class OpenAIServingChat(OpenAIServing):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -66,131 +49,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
lora_modules=lora_modules)
|
lora_modules=lora_modules)
|
||||||
|
|
||||||
self.response_role = response_role
|
self.response_role = response_role
|
||||||
self._load_chat_template(chat_template)
|
load_chat_template(self, chat_template)
|
||||||
|
|
||||||
def _load_chat_template(self, chat_template: Optional[str]):
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
|
|
||||||
if chat_template is not None:
|
|
||||||
try:
|
|
||||||
with open(chat_template, "r") as f:
|
|
||||||
tokenizer.chat_template = f.read()
|
|
||||||
except OSError as e:
|
|
||||||
JINJA_CHARS = "{}\n"
|
|
||||||
if not any(c in chat_template for c in JINJA_CHARS):
|
|
||||||
msg = (f"The supplied chat template ({chat_template}) "
|
|
||||||
f"looks like a file path, but it failed to be "
|
|
||||||
f"opened. Reason: {e}")
|
|
||||||
raise ValueError(msg) from e
|
|
||||||
|
|
||||||
# If opening a file fails, set chat template to be args to
|
|
||||||
# ensure we decode so our escape are interpreted correctly
|
|
||||||
tokenizer.chat_template = codecs.decode(
|
|
||||||
chat_template, "unicode_escape")
|
|
||||||
|
|
||||||
logger.info("Using supplied chat template:\n%s",
|
|
||||||
tokenizer.chat_template)
|
|
||||||
elif tokenizer.chat_template is not None:
|
|
||||||
logger.info("Using default chat template:\n%s",
|
|
||||||
tokenizer.chat_template)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"No chat template provided. Chat API will not work.")
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def image_token_str(self) -> Optional[str]:
|
|
||||||
# TODO: Let user specify how to insert image tokens into prompt
|
|
||||||
# (similar to chat template)
|
|
||||||
model_type = self.model_config.hf_config.model_type
|
|
||||||
if model_type == "phi3_v":
|
|
||||||
# Workaround since this token is not defined in the tokenizer
|
|
||||||
return "<|image_1|>"
|
|
||||||
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv",
|
|
||||||
"paligemma"):
|
|
||||||
# These models do not use image tokens in the prompt
|
|
||||||
return None
|
|
||||||
if model_type.startswith("llava"):
|
|
||||||
return self.tokenizer.decode(
|
|
||||||
self.model_config.hf_config.image_token_index)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise TypeError("Unknown model type: {model_type}")
|
|
||||||
|
|
||||||
# TODO: Let user specify how to insert image tokens into prompt
|
|
||||||
# (similar to chat template)
|
|
||||||
def _get_full_image_text_prompt(self, image_token_str: str,
|
|
||||||
text_prompt: str) -> str:
|
|
||||||
"""Combine image and text prompts for vision language model"""
|
|
||||||
|
|
||||||
# NOTE: For now we assume all model architectures use the same
|
|
||||||
# image + text prompt format. This may change in the future.
|
|
||||||
return f"{image_token_str}\n{text_prompt}"
|
|
||||||
|
|
||||||
def _parse_chat_message_content_parts(
|
|
||||||
self,
|
|
||||||
role: str,
|
|
||||||
parts: Iterable[ChatCompletionContentPartParam],
|
|
||||||
) -> ChatMessageParseResult:
|
|
||||||
texts: List[str] = []
|
|
||||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
|
||||||
|
|
||||||
for part in parts:
|
|
||||||
part_type = part["type"]
|
|
||||||
if part_type == "text":
|
|
||||||
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
|
||||||
texts.append(text)
|
|
||||||
elif part_type == "image_url":
|
|
||||||
if len(mm_futures) > 0:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Multiple 'image_url' input is currently not supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
image_url = cast(ChatCompletionContentPartImageParam,
|
|
||||||
part)["image_url"]
|
|
||||||
|
|
||||||
if image_url.get("detail", "auto") != "auto":
|
|
||||||
logger.warning(
|
|
||||||
"'image_url.detail' is currently not supported and "
|
|
||||||
"will be ignored.")
|
|
||||||
|
|
||||||
image_future = async_get_and_parse_image(image_url["url"])
|
|
||||||
mm_futures.append(image_future)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
|
||||||
|
|
||||||
text_prompt = "\n".join(texts)
|
|
||||||
|
|
||||||
if mm_futures:
|
|
||||||
image_token_str = self.image_token_str
|
|
||||||
if image_token_str is not None:
|
|
||||||
if image_token_str in text_prompt:
|
|
||||||
logger.warning(
|
|
||||||
"Detected image token string in the text prompt. "
|
|
||||||
"Skipping prompt formatting.")
|
|
||||||
else:
|
|
||||||
text_prompt = self._get_full_image_text_prompt(
|
|
||||||
image_token_str=image_token_str,
|
|
||||||
text_prompt=text_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [ConversationMessage(role=role, content=text_prompt)]
|
|
||||||
|
|
||||||
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
|
|
||||||
|
|
||||||
def _parse_chat_message_content(
|
|
||||||
self,
|
|
||||||
message: ChatCompletionMessageParam,
|
|
||||||
) -> ChatMessageParseResult:
|
|
||||||
role = message["role"]
|
|
||||||
content = message.get("content")
|
|
||||||
|
|
||||||
if content is None:
|
|
||||||
return ChatMessageParseResult(messages=[], mm_futures=[])
|
|
||||||
if isinstance(content, str):
|
|
||||||
messages = [ConversationMessage(role=role, content=content)]
|
|
||||||
return ChatMessageParseResult(messages=messages, mm_futures=[])
|
|
||||||
|
|
||||||
return self._parse_chat_message_content_parts(role, content)
|
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self,
|
self,
|
||||||
@ -216,7 +75,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||||
|
|
||||||
for msg in request.messages:
|
for msg in request.messages:
|
||||||
chat_parsed_result = self._parse_chat_message_content(msg)
|
chat_parsed_result = parse_chat_message_content(self, msg)
|
||||||
|
|
||||||
conversation.extend(chat_parsed_result.messages)
|
conversation.extend(chat_parsed_result.messages)
|
||||||
mm_futures.extend(chat_parsed_result.mm_futures)
|
mm_futures.extend(chat_parsed_result.mm_futures)
|
||||||
|
@ -16,10 +16,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|||||||
CompletionResponseChoice,
|
CompletionResponseChoice,
|
||||||
CompletionResponseStreamChoice,
|
CompletionResponseStreamChoice,
|
||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
DetokenizeRequest,
|
UsageInfo)
|
||||||
DetokenizeResponse,
|
|
||||||
TokenizeRequest,
|
|
||||||
TokenizeResponse, UsageInfo)
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
OpenAIServing,
|
OpenAIServing,
|
||||||
@ -457,29 +454,3 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
tokens=out_tokens,
|
tokens=out_tokens,
|
||||||
top_logprobs=out_top_logprobs,
|
top_logprobs=out_top_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_tokenize(self,
|
|
||||||
request: TokenizeRequest) -> TokenizeResponse:
|
|
||||||
error_check_ret = await self._check_model(request)
|
|
||||||
if error_check_ret is not None:
|
|
||||||
return error_check_ret
|
|
||||||
|
|
||||||
(input_ids, input_text) = self._validate_prompt_and_tokenize(
|
|
||||||
request,
|
|
||||||
prompt=request.prompt,
|
|
||||||
add_special_tokens=request.add_special_tokens)
|
|
||||||
|
|
||||||
return TokenizeResponse(tokens=input_ids,
|
|
||||||
count=len(input_ids),
|
|
||||||
max_model_len=self.max_model_len)
|
|
||||||
|
|
||||||
async def create_detokenize(
|
|
||||||
self, request: DetokenizeRequest) -> DetokenizeResponse:
|
|
||||||
error_check_ret = await self._check_model(request)
|
|
||||||
if error_check_ret is not None:
|
|
||||||
return error_check_ret
|
|
||||||
|
|
||||||
(input_ids, input_text) = self._validate_prompt_and_tokenize(
|
|
||||||
request, prompt_ids=request.tokens)
|
|
||||||
|
|
||||||
return DetokenizeResponse(prompt=input_text)
|
|
||||||
|
73
vllm/entrypoints/openai/serving_tokenization.py
Normal file
73
vllm/entrypoints/openai/serving_tokenization.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.entrypoints.openai.chat_utils import (ConversationMessage,
|
||||||
|
load_chat_template,
|
||||||
|
parse_chat_message_content)
|
||||||
|
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||||
|
DetokenizeResponse,
|
||||||
|
TokenizeRequest,
|
||||||
|
TokenizeResponse)
|
||||||
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingTokenization(OpenAIServing):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
engine: AsyncLLMEngine,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
served_model_names: List[str],
|
||||||
|
chat_template: Optional[str] = None):
|
||||||
|
super().__init__(engine=engine,
|
||||||
|
model_config=model_config,
|
||||||
|
served_model_names=served_model_names,
|
||||||
|
lora_modules=None)
|
||||||
|
|
||||||
|
load_chat_template(self, chat_template)
|
||||||
|
|
||||||
|
async def create_tokenize(self,
|
||||||
|
request: TokenizeRequest) -> TokenizeResponse:
|
||||||
|
error_check_ret = await self._check_model(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
if not (request.prompt or request.messages):
|
||||||
|
return self.create_error_response(
|
||||||
|
"Either `prompt` or `messages` should be provided.")
|
||||||
|
|
||||||
|
if (request.prompt and request.messages):
|
||||||
|
return self.create_error_response(
|
||||||
|
"Only one of `prompt` or `messages` should be provided.")
|
||||||
|
|
||||||
|
if request.messages:
|
||||||
|
conversation: List[ConversationMessage] = []
|
||||||
|
|
||||||
|
for message in request.messages:
|
||||||
|
conversation.extend(
|
||||||
|
parse_chat_message_content(self, message).messages)
|
||||||
|
|
||||||
|
request.prompt = self.tokenizer.apply_chat_template(
|
||||||
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
conversation=conversation,
|
||||||
|
tokenize=False)
|
||||||
|
|
||||||
|
(input_ids, input_text) = self._validate_prompt_and_tokenize(
|
||||||
|
request,
|
||||||
|
prompt=request.prompt,
|
||||||
|
add_special_tokens=request.add_special_tokens)
|
||||||
|
|
||||||
|
return TokenizeResponse(tokens=input_ids,
|
||||||
|
count=len(input_ids),
|
||||||
|
max_model_len=self.max_model_len)
|
||||||
|
|
||||||
|
async def create_detokenize(
|
||||||
|
self, request: DetokenizeRequest) -> DetokenizeResponse:
|
||||||
|
error_check_ret = await self._check_model(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
(input_ids, input_text) = self._validate_prompt_and_tokenize(
|
||||||
|
request, prompt_ids=request.tokens)
|
||||||
|
|
||||||
|
return DetokenizeResponse(prompt=input_text)
|
Loading…
x
Reference in New Issue
Block a user