[misc] remove engine_use_ray (#8126)
This commit is contained in:
parent
a65cb16067
commit
f842a7aff1
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
@ -26,8 +25,7 @@ def _query_server_long(prompt: str) -> dict:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
|
||||
worker_use_ray: bool):
|
||||
def api_server(tokenizer_pool_size: int, worker_use_ray: bool):
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
commands = [
|
||||
@ -37,25 +35,17 @@ def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
|
||||
str(tokenizer_pool_size)
|
||||
]
|
||||
|
||||
# Copy the environment variables and append `VLLM_ALLOW_ENGINE_USE_RAY=1`
|
||||
# to prevent `--engine-use-ray` raises an exception due to it deprecation
|
||||
env_vars = os.environ.copy()
|
||||
env_vars["VLLM_ALLOW_ENGINE_USE_RAY"] = "1"
|
||||
|
||||
if engine_use_ray:
|
||||
commands.append("--engine-use-ray")
|
||||
if worker_use_ray:
|
||||
commands.append("--worker-use-ray")
|
||||
uvicorn_process = subprocess.Popen(commands, env=env_vars)
|
||||
uvicorn_process = subprocess.Popen(commands)
|
||||
yield
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
|
||||
@pytest.mark.parametrize("worker_use_ray", [False, True])
|
||||
@pytest.mark.parametrize("engine_use_ray", [False, True])
|
||||
def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
|
||||
engine_use_ray: bool):
|
||||
def test_api_server(api_server, tokenizer_pool_size: int,
|
||||
worker_use_ray: bool):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
from asyncio import CancelledError
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
@ -72,14 +71,12 @@ class MockEngine:
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||
|
||||
def _init_engine(self, *args, **kwargs):
|
||||
return MockEngine()
|
||||
_engine_class = MockEngine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False)
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
@ -112,16 +109,11 @@ async def test_new_requests_event():
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == old_step_calls + 1
|
||||
|
||||
# Allow deprecated engine_use_ray to not raise exception
|
||||
os.environ["VLLM_ALLOW_ENGINE_USE_RAY"] = "1"
|
||||
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=True)
|
||||
assert engine.get_model_config() is not None
|
||||
assert engine.get_tokenizer() is not None
|
||||
assert engine.get_decoding_config() is not None
|
||||
|
||||
os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY")
|
||||
|
||||
|
||||
def start_engine():
|
||||
wait_for_gpu_memory_to_clear(
|
||||
|
@ -19,16 +19,11 @@ def server():
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--enforce-eager",
|
||||
"--engine-use-ray",
|
||||
"--chat-template",
|
||||
str(chatml_jinja_path),
|
||||
]
|
||||
|
||||
# Allow `--engine-use-ray`, otherwise the launch of the server throw
|
||||
# an error due to try to use a deprecated feature
|
||||
env_dict = {"VLLM_ALLOW_ENGINE_USE_RAY": "1"}
|
||||
with RemoteOpenAIServer(MODEL_NAME, args,
|
||||
env_dict=env_dict) as remote_server:
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
@ -1035,7 +1035,6 @@ class EngineArgs:
|
||||
@dataclass
|
||||
class AsyncEngineArgs(EngineArgs):
|
||||
"""Arguments for asynchronous vLLM engine."""
|
||||
engine_use_ray: bool = False
|
||||
disable_log_requests: bool = False
|
||||
|
||||
@staticmethod
|
||||
@ -1043,16 +1042,6 @@ class AsyncEngineArgs(EngineArgs):
|
||||
async_args_only: bool = False) -> FlexibleArgumentParser:
|
||||
if not async_args_only:
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument('--engine-use-ray',
|
||||
action='store_true',
|
||||
help='Use Ray to start the LLM engine in a '
|
||||
'separate process as the server process.'
|
||||
'(DEPRECATED. This argument is deprecated '
|
||||
'and will be removed in a future update. '
|
||||
'Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to force '
|
||||
'use it. See '
|
||||
'https://github.com/vllm-project/vllm/issues/7045.'
|
||||
')')
|
||||
parser.add_argument('--disable-log-requests',
|
||||
action='store_true',
|
||||
help='Disable logging requests.')
|
||||
|
@ -16,7 +16,7 @@ from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
|
||||
PromptComponents, SchedulerOutputState)
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
|
||||
SingletonPromptInputs)
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
@ -30,7 +30,6 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||
@ -590,9 +589,6 @@ class AsyncLLMEngine:
|
||||
worker_use_ray: Whether to use Ray for model workers. Required for
|
||||
distributed execution. Should be the same as
|
||||
`parallel_config.worker_use_ray`.
|
||||
engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
|
||||
async frontend will be executed in a separate process as the
|
||||
model workers.
|
||||
log_requests: Whether to log the requests.
|
||||
start_engine_loop: If True, the background task to run the engine
|
||||
will be automatically started in the generate call.
|
||||
@ -604,41 +600,23 @@ class AsyncLLMEngine:
|
||||
|
||||
def __init__(self,
|
||||
worker_use_ray: bool,
|
||||
engine_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = True,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.engine_use_ray = engine_use_ray
|
||||
self.log_requests = log_requests
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
self.engine = self._engine_class(*args, **kwargs)
|
||||
|
||||
# This ensures quick processing of request outputs
|
||||
# so the append to asyncio queues is not delayed,
|
||||
# especially for multi-step.
|
||||
#
|
||||
# TODO: Currently, disabled for engine_use_ray, ask
|
||||
# Cody/Will/Woosuk about this case.
|
||||
self.use_process_request_outputs_callback = not self.engine_use_ray
|
||||
self.use_process_request_outputs_callback = True
|
||||
if self.use_process_request_outputs_callback:
|
||||
self.engine.process_request_outputs_callback = \
|
||||
self.process_request_outputs
|
||||
|
||||
if self.engine_use_ray:
|
||||
print_warning_once(
|
||||
"DEPRECATED. `--engine-use-ray` is deprecated and will "
|
||||
"be removed in a future update. "
|
||||
"See https://github.com/vllm-project/vllm/issues/7045.")
|
||||
|
||||
if envs.VLLM_ALLOW_ENGINE_USE_RAY:
|
||||
print_warning_once(
|
||||
"VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray")
|
||||
else:
|
||||
raise ValueError("`--engine-use-ray` is deprecated. "
|
||||
"Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to "
|
||||
"force use it")
|
||||
|
||||
self.background_loop: Optional[asyncio.Future] = None
|
||||
# We need to keep a reference to unshielded
|
||||
# task as well to prevent it from being garbage
|
||||
@ -725,16 +703,11 @@ class AsyncLLMEngine:
|
||||
# Create the engine configs.
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
if engine_args.engine_use_ray:
|
||||
from vllm.executor import ray_utils
|
||||
ray_utils.assert_ray_available()
|
||||
|
||||
executor_class = cls._get_executor_cls(engine_config)
|
||||
|
||||
# Create the async LLM engine.
|
||||
engine = cls(
|
||||
executor_class.uses_ray,
|
||||
engine_args.engine_use_ray,
|
||||
**engine_config.to_dict(),
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
@ -777,10 +750,6 @@ class AsyncLLMEngine:
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_tokenizer.remote( # type: ignore
|
||||
lora_request)
|
||||
|
||||
return await (self.engine.get_tokenizer_group().
|
||||
get_lora_tokenizer_async(lora_request))
|
||||
|
||||
@ -814,26 +783,6 @@ class AsyncLLMEngine:
|
||||
self._background_loop_unshielded = None
|
||||
self.background_loop = None
|
||||
|
||||
def _init_engine(self, *args,
|
||||
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||
if not self.engine_use_ray:
|
||||
engine_class = self._engine_class
|
||||
elif self.worker_use_ray:
|
||||
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
|
||||
else:
|
||||
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
|
||||
# order of the arguments.
|
||||
cache_config = kwargs["cache_config"]
|
||||
parallel_config = kwargs["parallel_config"]
|
||||
if (parallel_config.tensor_parallel_size == 1
|
||||
and parallel_config.pipeline_parallel_size == 1):
|
||||
num_gpus = cache_config.gpu_memory_utilization
|
||||
else:
|
||||
num_gpus = 1
|
||||
engine_class = ray.remote(num_gpus=num_gpus)(
|
||||
self._engine_class).remote
|
||||
return engine_class(*args, **kwargs)
|
||||
|
||||
async def engine_step(self, virtual_engine: int) -> bool:
|
||||
"""Kick the engine to process the waiting requests.
|
||||
|
||||
@ -844,13 +793,8 @@ class AsyncLLMEngine:
|
||||
|
||||
for new_request in new_requests:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
||||
try:
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote( # type: ignore
|
||||
**new_request)
|
||||
else:
|
||||
await self.engine.add_request_async(**new_request)
|
||||
await self.engine.add_request_async(**new_request)
|
||||
except ValueError as e:
|
||||
# TODO: use a vLLM specific error for failed validation
|
||||
self._request_tracker.process_exception(
|
||||
@ -862,10 +806,7 @@ class AsyncLLMEngine:
|
||||
if aborted_requests:
|
||||
await self._engine_abort(aborted_requests)
|
||||
|
||||
if self.engine_use_ray:
|
||||
request_outputs = await self.engine.step.remote() # type: ignore
|
||||
else:
|
||||
request_outputs = await self.engine.step_async(virtual_engine)
|
||||
request_outputs = await self.engine.step_async(virtual_engine)
|
||||
|
||||
# Put the outputs into the corresponding streams.
|
||||
# If used as a callback, then already invoked inside
|
||||
@ -891,16 +832,10 @@ class AsyncLLMEngine:
|
||||
return all_finished
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
await self.engine.abort_request.remote(request_ids) # type: ignore
|
||||
else:
|
||||
self.engine.abort_request(request_ids)
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
if self.engine_use_ray:
|
||||
pipeline_parallel_size = 1 # type: ignore
|
||||
else:
|
||||
pipeline_parallel_size = \
|
||||
pipeline_parallel_size = \
|
||||
self.engine.parallel_config.pipeline_parallel_size
|
||||
has_requests_in_progress = [False] * pipeline_parallel_size
|
||||
while True:
|
||||
@ -912,12 +847,7 @@ class AsyncLLMEngine:
|
||||
# timeout, and unblocks the RPC thread in the workers so that
|
||||
# they can process any other queued control plane messages,
|
||||
# such as add/remove lora adapters.
|
||||
if self.engine_use_ray:
|
||||
await (self.engine.stop_remote_worker_execution_loop.
|
||||
remote() # type: ignore
|
||||
)
|
||||
else:
|
||||
await self.engine.stop_remote_worker_execution_loop_async()
|
||||
await self.engine.stop_remote_worker_execution_loop_async()
|
||||
await self._request_tracker.wait_for_new_requests()
|
||||
logger.debug("Got new requests!")
|
||||
requests_in_progress = [
|
||||
@ -938,17 +868,9 @@ class AsyncLLMEngine:
|
||||
for task in done:
|
||||
result = task.result()
|
||||
virtual_engine = requests_in_progress.index(task)
|
||||
if self.engine_use_ray:
|
||||
has_unfinished_requests = (
|
||||
await (self.engine.
|
||||
has_unfinished_requests_for_virtual_engine.
|
||||
remote( # type: ignore
|
||||
virtual_engine)))
|
||||
else:
|
||||
has_unfinished_requests = (
|
||||
self.engine.
|
||||
has_unfinished_requests_for_virtual_engine(
|
||||
virtual_engine))
|
||||
has_unfinished_requests = (
|
||||
self.engine.has_unfinished_requests_for_virtual_engine(
|
||||
virtual_engine))
|
||||
if result or has_unfinished_requests:
|
||||
requests_in_progress[virtual_engine] = (
|
||||
asyncio.create_task(
|
||||
@ -1190,52 +1112,29 @@ class AsyncLLMEngine:
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_model_config.remote() # type: ignore
|
||||
else:
|
||||
return self.engine.get_model_config()
|
||||
return self.engine.get_model_config()
|
||||
|
||||
async def get_parallel_config(self) -> ParallelConfig:
|
||||
"""Get the parallel configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_parallel_config.remote( # type: ignore
|
||||
)
|
||||
else:
|
||||
return self.engine.get_parallel_config()
|
||||
return self.engine.get_parallel_config()
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
"""Get the decoding configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_decoding_config.remote( # type: ignore
|
||||
)
|
||||
else:
|
||||
return self.engine.get_decoding_config()
|
||||
return self.engine.get_decoding_config()
|
||||
|
||||
async def get_scheduler_config(self) -> SchedulerConfig:
|
||||
"""Get the scheduling configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_scheduler_config.remote( # type: ignore
|
||||
)
|
||||
else:
|
||||
return self.engine.get_scheduler_config()
|
||||
return self.engine.get_scheduler_config()
|
||||
|
||||
async def get_lora_config(self) -> LoRAConfig:
|
||||
"""Get the lora configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_lora_config.remote( # type: ignore
|
||||
)
|
||||
else:
|
||||
return self.engine.get_lora_config()
|
||||
return self.engine.get_lora_config()
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[List[SamplerOutput]] = None) -> None:
|
||||
if self.engine_use_ray:
|
||||
await self.engine.do_log_stats.remote( # type: ignore
|
||||
scheduler_outputs, model_output)
|
||||
else:
|
||||
self.engine.do_log_stats()
|
||||
self.engine.do_log_stats()
|
||||
|
||||
async def check_health(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
@ -1244,37 +1143,17 @@ class AsyncLLMEngine:
|
||||
if self.is_stopped:
|
||||
raise AsyncEngineDeadError("Background loop is stopped.")
|
||||
|
||||
if self.engine_use_ray:
|
||||
try:
|
||||
await self.engine.check_health.remote() # type: ignore
|
||||
except ray.exceptions.RayActorError as e:
|
||||
raise RuntimeError("Engine is dead.") from e
|
||||
else:
|
||||
await self.engine.check_health_async()
|
||||
await self.engine.check_health_async()
|
||||
logger.debug("Health check took %fs", time.perf_counter() - t)
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.is_tracing_enabled.remote( # type: ignore
|
||||
)
|
||||
else:
|
||||
return self.engine.is_tracing_enabled()
|
||||
return self.engine.is_tracing_enabled()
|
||||
|
||||
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
|
||||
if self.engine_use_ray:
|
||||
ray.get(
|
||||
self.engine.add_logger.remote( # type: ignore
|
||||
logger_name=logger_name, logger=logger))
|
||||
else:
|
||||
self.engine.add_logger(logger_name=logger_name, logger=logger)
|
||||
self.engine.add_logger(logger_name=logger_name, logger=logger)
|
||||
|
||||
def remove_logger(self, logger_name: str) -> None:
|
||||
if self.engine_use_ray:
|
||||
ray.get(
|
||||
self.engine.remove_logger.remote( # type: ignore
|
||||
logger_name=logger_name))
|
||||
else:
|
||||
self.engine.remove_logger(logger_name=logger_name)
|
||||
self.engine.remove_logger(logger_name=logger_name)
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
self.engine.model_executor._run_workers("start_profile")
|
||||
|
@ -3,8 +3,8 @@ import time
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List,
|
||||
Mapping, NamedTuple, Optional)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||
Iterable, List, Mapping, NamedTuple, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Type, Union
|
||||
|
||||
@ -397,7 +397,7 @@ class LLMEngine:
|
||||
|
||||
# Currently used by AsyncLLMEngine to ensure quick append
|
||||
# of request outputs to asyncio queues
|
||||
self.process_request_outputs_callback = None
|
||||
self.process_request_outputs_callback: Optional[Callable] = None
|
||||
|
||||
# Create the scheduler.
|
||||
# NOTE: the cache_config here have been updated with the numbers of
|
||||
|
@ -195,7 +195,6 @@ async def main(args):
|
||||
engine = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
|
||||
|
||||
# When using single vLLM without engine_use_ray
|
||||
model_config = await engine.get_model_config()
|
||||
|
||||
if args.disable_log_requests:
|
||||
|
@ -58,7 +58,6 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
||||
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
|
||||
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
|
||||
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
|
||||
VLLM_PLUGINS: Optional[List[str]] = None
|
||||
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
@ -391,14 +390,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
|
||||
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
|
||||
|
||||
# If set, allow running the engine as a separate ray actor,
|
||||
# which is a deprecated feature soon to be removed.
|
||||
# See https://github.com/vllm-project/vllm/issues/7045
|
||||
"VLLM_ALLOW_ENGINE_USE_RAY":
|
||||
lambda:
|
||||
(os.environ.get("VLLM_ALLOW_ENGINE_USE_RAY", "0").strip().lower() in
|
||||
("1", "true")),
|
||||
|
||||
# a list of plugin names to load, separated by commas.
|
||||
# if this is not set, it means all plugins will be loaded
|
||||
# if this is set to an empty string, no plugins will be loaded
|
||||
|
Loading…
x
Reference in New Issue
Block a user