[bugfix] respect distributed_executor_backend in world_size=1 (#12934)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-02-08 16:17:08 +08:00 committed by GitHub
parent d01f66b039
commit 91dd8f7aa6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 32 deletions

View File

@ -55,6 +55,7 @@ def test_custom_executor(model, tmp_path):
engine_args = EngineArgs( engine_args = EngineArgs(
model=model, model=model,
distributed_executor_backend=CustomUniExecutor, distributed_executor_backend=CustomUniExecutor,
enforce_eager=True, # reduce test time
) )
engine = LLMEngine.from_engine_args(engine_args) engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1) sampling_params = SamplingParams(max_tokens=1)
@ -75,7 +76,10 @@ def test_custom_executor_async(model, tmp_path):
assert not os.path.exists(".marker") assert not os.path.exists(".marker")
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomUniExecutorAsync) model=model,
distributed_executor_backend=CustomUniExecutorAsync,
enforce_eager=True, # reduce test time
)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1) sampling_params = SamplingParams(max_tokens=1)
@ -89,3 +93,18 @@ def test_custom_executor_async(model, tmp_path):
assert os.path.exists(".marker") assert os.path.exists(".marker")
finally: finally:
os.chdir(cwd) os.chdir(cwd)
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_respect_ray(model):
# even for TP=1 and PP=1,
# if users specify ray, we should use ray.
# users might do this if they want to manage the
# resources using ray.
engine_args = EngineArgs(
model=model,
distributed_executor_backend="ray",
enforce_eager=True, # reduce test time
)
engine = LLMEngine.from_engine_args(engine_args)
assert engine.model_executor.uses_ray

View File

@ -1401,6 +1401,9 @@ class ParallelConfig:
logger.info("Defaulting to use %s for distributed inference", logger.info("Defaulting to use %s for distributed inference",
backend) backend)
if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"
self._verify_args() self._verify_args()
@property @property

View File

@ -434,6 +434,7 @@ class LLMEngine:
@classmethod @classmethod
def _get_executor_cls(cls, def _get_executor_cls(cls,
engine_config: VllmConfig) -> Type[ExecutorBase]: engine_config: VllmConfig) -> Type[ExecutorBase]:
# distributed_executor_backend must be set in VllmConfig.__post_init__
distributed_executor_backend = ( distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend) engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class. # Initialize the cluster and specify the executor class.
@ -443,8 +444,7 @@ class LLMEngine:
"distributed_executor_backend must be a subclass of " "distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}.") f"ExecutorBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend executor_class = distributed_executor_backend
elif engine_config.parallel_config.world_size > 1: elif distributed_executor_backend == "ray":
if distributed_executor_backend == "ray":
from vllm.executor.ray_distributed_executor import ( from vllm.executor.ray_distributed_executor import (
RayDistributedExecutor) RayDistributedExecutor)
executor_class = RayDistributedExecutor executor_class = RayDistributedExecutor
@ -465,8 +465,8 @@ class LLMEngine:
ExecutorWithExternalLauncher) ExecutorWithExternalLauncher)
executor_class = ExecutorWithExternalLauncher executor_class = ExecutorWithExternalLauncher
else: else:
from vllm.executor.uniproc_executor import UniProcExecutor raise ValueError("unrecognized distributed_executor_backend: "
executor_class = UniProcExecutor f"{distributed_executor_backend}")
return executor_class return executor_class
@classmethod @classmethod

View File

@ -25,15 +25,14 @@ class Executor(ExecutorBase):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
distributed_executor_backend = ( distributed_executor_backend = (
parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend)
if distributed_executor_backend is None: # distributed_executor_backend must be set in VllmConfig.__post_init__
# If the user does not specify the distributed executor backend, if isinstance(distributed_executor_backend, type):
# we will choose the backend based on the world size. if not issubclass(distributed_executor_backend, ExecutorBase):
if parallel_config.world_size > 1: raise TypeError(
distributed_executor_backend = "mp" "distributed_executor_backend must be a subclass of "
else: f"ExecutorBase. Got {distributed_executor_backend}.")
distributed_executor_backend = "uni" executor_class = distributed_executor_backend
elif distributed_executor_backend == "ray":
if distributed_executor_backend == "ray":
executor_class = RayDistributedExecutor executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp": elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor from vllm.v1.executor.multiproc_executor import MultiprocExecutor