[Bugfix][Core] Fix get decoding config from ray (#4335)
This commit is contained in:
parent
3da24c2df7
commit
7134303cbb
@ -91,4 +91,6 @@ async def test_new_requests_event():
|
|||||||
assert engine.engine.step_calls == old_step_calls + 1
|
assert engine.engine.step_calls == old_step_calls + 1
|
||||||
|
|
||||||
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
|
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
|
||||||
|
assert engine.get_model_config() is not None
|
||||||
assert engine.get_tokenizer() is not None
|
assert engine.get_tokenizer() is not None
|
||||||
|
assert engine.get_decoding_config() is not None
|
||||||
|
157
tests/async_engine/test_openapi_server_ray.py
Normal file
157
tests/async_engine/test_openapi_server_ray.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
# imports for guided decoding tests
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import openai # use the official client for correctness check
|
||||||
|
import pytest
|
||||||
|
# using Ray for overall ease of process management, parallel requests,
|
||||||
|
# and debugging.
|
||||||
|
import ray
|
||||||
|
import requests
|
||||||
|
|
||||||
|
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
||||||
|
# any model with a chat template should work here
|
||||||
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
class ServerRunner:
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["PYTHONUNBUFFERED"] = "1"
|
||||||
|
self.proc = subprocess.Popen(
|
||||||
|
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
|
||||||
|
env=env,
|
||||||
|
stdout=sys.stdout,
|
||||||
|
stderr=sys.stderr,
|
||||||
|
)
|
||||||
|
self._wait_for_server()
|
||||||
|
|
||||||
|
def ready(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _wait_for_server(self):
|
||||||
|
# run health check
|
||||||
|
start = time.time()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if requests.get(
|
||||||
|
"http://localhost:8000/health").status_code == 200:
|
||||||
|
break
|
||||||
|
except Exception as err:
|
||||||
|
if self.proc.poll() is not None:
|
||||||
|
raise RuntimeError("Server exited unexpectedly.") from err
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
if time.time() - start > MAX_SERVER_START_WAIT_S:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Server failed to start in time.") from err
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if hasattr(self, "proc"):
|
||||||
|
self.proc.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def server():
|
||||||
|
ray.init()
|
||||||
|
server_runner = ServerRunner.remote([
|
||||||
|
"--model",
|
||||||
|
MODEL_NAME,
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"float16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--engine-use-ray"
|
||||||
|
])
|
||||||
|
ray.get(server_runner.ready.remote())
|
||||||
|
yield server_runner
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def client():
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
base_url="http://localhost:8000/v1",
|
||||||
|
api_key="token-abc123",
|
||||||
|
)
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_models(server, client: openai.AsyncOpenAI):
|
||||||
|
models = await client.models.list()
|
||||||
|
models = models.data
|
||||||
|
served_model = models[0]
|
||||||
|
assert served_model.id == MODEL_NAME
|
||||||
|
assert all(model.root == MODEL_NAME for model in models)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_completion(server, client: openai.AsyncOpenAI):
|
||||||
|
completion = await client.completions.create(model=MODEL_NAME,
|
||||||
|
prompt="Hello, my name is",
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
assert completion.id is not None
|
||||||
|
assert completion.choices is not None and len(completion.choices) == 1
|
||||||
|
assert completion.choices[0].text is not None and len(
|
||||||
|
completion.choices[0].text) >= 5
|
||||||
|
assert completion.choices[0].finish_reason == "length"
|
||||||
|
assert completion.usage == openai.types.CompletionUsage(
|
||||||
|
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||||
|
|
||||||
|
# test using token IDs
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=[0, 0, 0, 0, 0],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
assert completion.choices[0].text is not None and len(
|
||||||
|
completion.choices[0].text) >= 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is 1+1?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
# test single completion
|
||||||
|
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=5)
|
||||||
|
assert chat_completion.id is not None
|
||||||
|
assert chat_completion.choices is not None and len(
|
||||||
|
chat_completion.choices) == 1
|
||||||
|
assert chat_completion.choices[0].message is not None
|
||||||
|
assert chat_completion.choices[0].logprobs is not None
|
||||||
|
assert chat_completion.choices[0].logprobs.top_logprobs is not None
|
||||||
|
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 10
|
||||||
|
assert message.role == "assistant"
|
||||||
|
messages.append({"role": "assistant", "content": message.content})
|
||||||
|
|
||||||
|
# test multi-turn dialogue
|
||||||
|
messages.append({"role": "user", "content": "express your result in json"})
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 0
|
@ -7,7 +7,7 @@ from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
|
|||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import DecodingConfig, ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||||
@ -697,6 +697,14 @@ class AsyncLLMEngine:
|
|||||||
else:
|
else:
|
||||||
return self.engine.get_model_config()
|
return self.engine.get_model_config()
|
||||||
|
|
||||||
|
async def get_decoding_config(self) -> DecodingConfig:
|
||||||
|
"""Get the decoding configuration of the vLLM engine."""
|
||||||
|
if self.engine_use_ray:
|
||||||
|
return await self.engine.get_decoding_config.remote( # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.engine.get_decoding_config()
|
||||||
|
|
||||||
async def do_log_stats(self) -> None:
|
async def do_log_stats(self) -> None:
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
await self.engine.do_log_stats.remote() # type: ignore
|
await self.engine.do_log_stats.remote() # type: ignore
|
||||||
|
@ -467,6 +467,10 @@ class LLMEngine:
|
|||||||
"""Gets the model configuration."""
|
"""Gets the model configuration."""
|
||||||
return self.model_config
|
return self.model_config
|
||||||
|
|
||||||
|
def get_decoding_config(self) -> DecodingConfig:
|
||||||
|
"""Gets the decoding configuration."""
|
||||||
|
return self.decoding_config
|
||||||
|
|
||||||
def get_num_unfinished_requests(self) -> int:
|
def get_num_unfinished_requests(self) -> int:
|
||||||
"""Gets the number of unfinished requests."""
|
"""Gets the number of unfinished requests."""
|
||||||
return self.scheduler.get_num_unfinished_seq_groups()
|
return self.scheduler.get_num_unfinished_seq_groups()
|
||||||
|
@ -101,7 +101,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
request, prompt=prompt)
|
request, prompt=prompt)
|
||||||
sampling_params = request.to_sampling_params()
|
sampling_params = request.to_sampling_params()
|
||||||
lora_request = self._maybe_get_lora(request)
|
lora_request = self._maybe_get_lora(request)
|
||||||
decoding_config = self.engine.engine.decoding_config
|
decoding_config = await self.engine.get_decoding_config()
|
||||||
guided_decoding_backend = request.guided_decoding_backend \
|
guided_decoding_backend = request.guided_decoding_backend \
|
||||||
or decoding_config.guided_decoding_backend
|
or decoding_config.guided_decoding_backend
|
||||||
guided_decode_logits_processor = (
|
guided_decode_logits_processor = (
|
||||||
|
@ -89,7 +89,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
try:
|
try:
|
||||||
sampling_params = request.to_sampling_params()
|
sampling_params = request.to_sampling_params()
|
||||||
lora_request = self._maybe_get_lora(request)
|
lora_request = self._maybe_get_lora(request)
|
||||||
decoding_config = self.engine.engine.decoding_config
|
decoding_config = await self.engine.get_decoding_config()
|
||||||
guided_decoding_backend = request.guided_decoding_backend \
|
guided_decoding_backend = request.guided_decoding_backend \
|
||||||
or decoding_config.guided_decoding_backend
|
or decoding_config.guided_decoding_backend
|
||||||
guided_decode_logit_processor = (
|
guided_decode_logit_processor = (
|
||||||
|
Loading…
x
Reference in New Issue
Block a user