[misc] remove engine_use_ray (#8126)

This commit is contained in:
youkaichao 2024-09-11 18:23:36 -07:00 committed by GitHub
parent a65cb16067
commit f842a7aff1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 32 additions and 197 deletions

View File

@ -1,4 +1,3 @@
import os
import subprocess
import sys
import time
@ -26,8 +25,7 @@ def _query_server_long(prompt: str) -> dict:
@pytest.fixture
def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
worker_use_ray: bool):
def api_server(tokenizer_pool_size: int, worker_use_ray: bool):
script_path = Path(__file__).parent.joinpath(
"api_server_async_engine.py").absolute()
commands = [
@ -37,25 +35,17 @@ def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
str(tokenizer_pool_size)
]
# Copy the environment variables and append `VLLM_ALLOW_ENGINE_USE_RAY=1`
# to prevent `--engine-use-ray` raises an exception due to it deprecation
env_vars = os.environ.copy()
env_vars["VLLM_ALLOW_ENGINE_USE_RAY"] = "1"
if engine_use_ray:
commands.append("--engine-use-ray")
if worker_use_ray:
commands.append("--worker-use-ray")
uvicorn_process = subprocess.Popen(commands, env=env_vars)
uvicorn_process = subprocess.Popen(commands)
yield
uvicorn_process.terminate()
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
@pytest.mark.parametrize("worker_use_ray", [False, True])
@pytest.mark.parametrize("engine_use_ray", [False, True])
def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
engine_use_ray: bool):
def test_api_server(api_server, tokenizer_pool_size: int,
worker_use_ray: bool):
"""
Run the API server and test it.

View File

@ -1,5 +1,4 @@
import asyncio
import os
from asyncio import CancelledError
from dataclasses import dataclass
from typing import Optional
@ -72,14 +71,12 @@ class MockEngine:
class MockAsyncLLMEngine(AsyncLLMEngine):
def _init_engine(self, *args, **kwargs):
return MockEngine()
_engine_class = MockEngine
@pytest.mark.asyncio
async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
engine = MockAsyncLLMEngine(worker_use_ray=False)
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
@ -112,16 +109,11 @@ async def test_new_requests_event():
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1
# Allow deprecated engine_use_ray to not raise exception
os.environ["VLLM_ALLOW_ENGINE_USE_RAY"] = "1"
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
engine = MockAsyncLLMEngine(worker_use_ray=True)
assert engine.get_model_config() is not None
assert engine.get_tokenizer() is not None
assert engine.get_decoding_config() is not None
os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY")
def start_engine():
wait_for_gpu_memory_to_clear(

View File

@ -19,16 +19,11 @@ def server():
"--max-model-len",
"2048",
"--enforce-eager",
"--engine-use-ray",
"--chat-template",
str(chatml_jinja_path),
]
# Allow `--engine-use-ray`, otherwise the launch of the server throw
# an error due to try to use a deprecated feature
env_dict = {"VLLM_ALLOW_ENGINE_USE_RAY": "1"}
with RemoteOpenAIServer(MODEL_NAME, args,
env_dict=env_dict) as remote_server:
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server

View File

@ -1035,7 +1035,6 @@ class EngineArgs:
@dataclass
class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False
@staticmethod
@ -1043,16 +1042,6 @@ class AsyncEngineArgs(EngineArgs):
async_args_only: bool = False) -> FlexibleArgumentParser:
if not async_args_only:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray',
action='store_true',
help='Use Ray to start the LLM engine in a '
'separate process as the server process.'
'(DEPRECATED. This argument is deprecated '
'and will be removed in a future update. '
'Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to force '
'use it. See '
'https://github.com/vllm-project/vllm/issues/7045.'
')')
parser.add_argument('--disable-log-requests',
action='store_true',
help='Disable logging requests.')

View File

@ -16,7 +16,7 @@ from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents, SchedulerOutputState)
from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
@ -30,7 +30,6 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@ -590,9 +589,6 @@ class AsyncLLMEngine:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
@ -604,41 +600,23 @@ class AsyncLLMEngine:
def __init__(self,
worker_use_ray: bool,
engine_use_ray: bool,
*args,
log_requests: bool = True,
start_engine_loop: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray
self.log_requests = log_requests
self.engine = self._init_engine(*args, **kwargs)
self.engine = self._engine_class(*args, **kwargs)
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
# TODO: Currently, disabled for engine_use_ray, ask
# Cody/Will/Woosuk about this case.
self.use_process_request_outputs_callback = not self.engine_use_ray
self.use_process_request_outputs_callback = True
if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \
self.process_request_outputs
if self.engine_use_ray:
print_warning_once(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
"be removed in a future update. "
"See https://github.com/vllm-project/vllm/issues/7045.")
if envs.VLLM_ALLOW_ENGINE_USE_RAY:
print_warning_once(
"VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray")
else:
raise ValueError("`--engine-use-ray` is deprecated. "
"Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to "
"force use it")
self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
@ -725,16 +703,11 @@ class AsyncLLMEngine:
# Create the engine configs.
engine_config = engine_args.create_engine_config()
if engine_args.engine_use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
executor_class = cls._get_executor_cls(engine_config)
# Create the async LLM engine.
engine = cls(
executor_class.uses_ray,
engine_args.engine_use_ray,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
@ -777,10 +750,6 @@ class AsyncLLMEngine:
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote( # type: ignore
lora_request)
return await (self.engine.get_tokenizer_group().
get_lora_tokenizer_async(lora_request))
@ -814,26 +783,6 @@ class AsyncLLMEngine:
self._background_loop_unshielded = None
self.background_loop = None
def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
if not self.engine_use_ray:
engine_class = self._engine_class
elif self.worker_use_ray:
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
else:
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
# order of the arguments.
cache_config = kwargs["cache_config"]
parallel_config = kwargs["parallel_config"]
if (parallel_config.tensor_parallel_size == 1
and parallel_config.pipeline_parallel_size == 1):
num_gpus = cache_config.gpu_memory_utilization
else:
num_gpus = 1
engine_class = ray.remote(num_gpus=num_gpus)(
self._engine_class).remote
return engine_class(*args, **kwargs)
async def engine_step(self, virtual_engine: int) -> bool:
"""Kick the engine to process the waiting requests.
@ -844,13 +793,8 @@ class AsyncLLMEngine:
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead
try:
if self.engine_use_ray:
await self.engine.add_request.remote( # type: ignore
**new_request)
else:
await self.engine.add_request_async(**new_request)
await self.engine.add_request_async(**new_request)
except ValueError as e:
# TODO: use a vLLM specific error for failed validation
self._request_tracker.process_exception(
@ -862,10 +806,7 @@ class AsyncLLMEngine:
if aborted_requests:
await self._engine_abort(aborted_requests)
if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore
else:
request_outputs = await self.engine.step_async(virtual_engine)
request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams.
# If used as a callback, then already invoked inside
@ -891,16 +832,10 @@ class AsyncLLMEngine:
return all_finished
async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids) # type: ignore
else:
self.engine.abort_request(request_ids)
self.engine.abort_request(request_ids)
async def run_engine_loop(self):
if self.engine_use_ray:
pipeline_parallel_size = 1 # type: ignore
else:
pipeline_parallel_size = \
pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size
while True:
@ -912,12 +847,7 @@ class AsyncLLMEngine:
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
if self.engine_use_ray:
await (self.engine.stop_remote_worker_execution_loop.
remote() # type: ignore
)
else:
await self.engine.stop_remote_worker_execution_loop_async()
await self.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests()
logger.debug("Got new requests!")
requests_in_progress = [
@ -938,17 +868,9 @@ class AsyncLLMEngine:
for task in done:
result = task.result()
virtual_engine = requests_in_progress.index(task)
if self.engine_use_ray:
has_unfinished_requests = (
await (self.engine.
has_unfinished_requests_for_virtual_engine.
remote( # type: ignore
virtual_engine)))
else:
has_unfinished_requests = (
self.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine))
has_unfinished_requests = (
self.engine.has_unfinished_requests_for_virtual_engine(
virtual_engine))
if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = (
asyncio.create_task(
@ -1190,52 +1112,29 @@ class AsyncLLMEngine:
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_model_config.remote() # type: ignore
else:
return self.engine.get_model_config()
return self.engine.get_model_config()
async def get_parallel_config(self) -> ParallelConfig:
"""Get the parallel configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_parallel_config.remote( # type: ignore
)
else:
return self.engine.get_parallel_config()
return self.engine.get_parallel_config()
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_decoding_config.remote( # type: ignore
)
else:
return self.engine.get_decoding_config()
return self.engine.get_decoding_config()
async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_scheduler_config.remote( # type: ignore
)
else:
return self.engine.get_scheduler_config()
return self.engine.get_scheduler_config()
async def get_lora_config(self) -> LoRAConfig:
"""Get the lora configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_lora_config.remote( # type: ignore
)
else:
return self.engine.get_lora_config()
return self.engine.get_lora_config()
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None:
if self.engine_use_ray:
await self.engine.do_log_stats.remote( # type: ignore
scheduler_outputs, model_output)
else:
self.engine.do_log_stats()
self.engine.do_log_stats()
async def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
@ -1244,37 +1143,17 @@ class AsyncLLMEngine:
if self.is_stopped:
raise AsyncEngineDeadError("Background loop is stopped.")
if self.engine_use_ray:
try:
await self.engine.check_health.remote() # type: ignore
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else:
await self.engine.check_health_async()
await self.engine.check_health_async()
logger.debug("Health check took %fs", time.perf_counter() - t)
async def is_tracing_enabled(self) -> bool:
if self.engine_use_ray:
return await self.engine.is_tracing_enabled.remote( # type: ignore
)
else:
return self.engine.is_tracing_enabled()
return self.engine.is_tracing_enabled()
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if self.engine_use_ray:
ray.get(
self.engine.add_logger.remote( # type: ignore
logger_name=logger_name, logger=logger))
else:
self.engine.add_logger(logger_name=logger_name, logger=logger)
self.engine.add_logger(logger_name=logger_name, logger=logger)
def remove_logger(self, logger_name: str) -> None:
if self.engine_use_ray:
ray.get(
self.engine.remove_logger.remote( # type: ignore
logger_name=logger_name))
else:
self.engine.remove_logger(logger_name=logger_name)
self.engine.remove_logger(logger_name=logger_name)
async def start_profile(self) -> None:
self.engine.model_executor._run_workers("start_profile")

View File

@ -3,8 +3,8 @@ import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
Mapping, NamedTuple, Optional)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, Union
@ -397,7 +397,7 @@ class LLMEngine:
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self.process_request_outputs_callback = None
self.process_request_outputs_callback: Optional[Callable] = None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of

View File

@ -195,7 +195,6 @@ async def main(args):
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
# When using single vLLM without engine_use_ray
model_config = await engine.get_model_config()
if args.disable_log_requests:

View File

@ -58,7 +58,6 @@ if TYPE_CHECKING:
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
@ -391,14 +390,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
# If set, allow running the engine as a separate ray actor,
# which is a deprecated feature soon to be removed.
# See https://github.com/vllm-project/vllm/issues/7045
"VLLM_ALLOW_ENGINE_USE_RAY":
lambda:
(os.environ.get("VLLM_ALLOW_ENGINE_USE_RAY", "0").strip().lower() in
("1", "true")),
# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded