[Core][Doc] Default to multiprocessing for single-node distributed case (#5230)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
c4bd03c7c5
commit
99dac099ab
@ -3,11 +3,9 @@
|
|||||||
Distributed Inference and Serving
|
Distributed Inference and Serving
|
||||||
=================================
|
=================================
|
||||||
|
|
||||||
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with `Ray <https://github.com/ray-project/ray>`_. To run distributed inference, install Ray with:
|
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with either `Ray <https://github.com/ray-project/ray>`_ or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inferencing currently requires Ray.
|
||||||
|
|
||||||
.. code-block:: console
|
Multiprocessing will be used by default when not running in a Ray placement group and if there are sufficient GPUs available on the same node for the configured :code:`tensor_parallel_size`, otherwise Ray will be used. This default can be overridden via the :code:`LLM` class :code:`distributed-executor-backend` argument or :code:`--distributed-executor-backend` API server argument. Set it to :code:`mp` for multiprocessing or :code:`ray` for Ray. It's not required for Ray to be installed for the multiprocessing case.
|
||||||
|
|
||||||
$ pip install ray
|
|
||||||
|
|
||||||
To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs:
|
To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs:
|
||||||
|
|
||||||
@ -25,10 +23,12 @@ To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument wh
|
|||||||
$ --model facebook/opt-13b \
|
$ --model facebook/opt-13b \
|
||||||
$ --tensor-parallel-size 4
|
$ --tensor-parallel-size 4
|
||||||
|
|
||||||
To scale vLLM beyond a single machine, start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
|
To scale vLLM beyond a single machine, install and start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ pip install ray
|
||||||
|
|
||||||
$ # On head node
|
$ # On head node
|
||||||
$ ray start --head
|
$ ray start --head
|
||||||
|
|
||||||
|
@ -77,7 +77,11 @@ class AsyncLLM:
|
|||||||
swap_space=swap_space,
|
swap_space=swap_space,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||||
|
# For now use ray for the distributed back-end, since
|
||||||
|
# we rely on the use of engine_use_ray=True to avoid
|
||||||
|
# reinitializing CUDA in the same process (driver worker)
|
||||||
engine_use_ray=True,
|
engine_use_ray=True,
|
||||||
|
distributed_executor_backend="ray",
|
||||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -603,9 +603,25 @@ class ParallelConfig:
|
|||||||
f"'{self.distributed_executor_backend}'.")
|
f"'{self.distributed_executor_backend}'.")
|
||||||
|
|
||||||
if self.distributed_executor_backend is None and self.world_size > 1:
|
if self.distributed_executor_backend is None and self.world_size > 1:
|
||||||
|
# We use multiprocessing by default if world_size fits on the
|
||||||
|
# current node and we aren't in a ray placement group.
|
||||||
|
from torch.cuda import device_count
|
||||||
|
|
||||||
from vllm.executor import ray_utils
|
from vllm.executor import ray_utils
|
||||||
|
backend = "mp"
|
||||||
ray_found = ray_utils.ray is not None
|
ray_found = ray_utils.ray is not None
|
||||||
self.distributed_executor_backend = "ray" if ray_found else "mp"
|
if device_count() < self.world_size:
|
||||||
|
if not ray_found:
|
||||||
|
raise ValueError("Unable to load Ray which is "
|
||||||
|
"required for multi-node inference")
|
||||||
|
backend = "ray"
|
||||||
|
elif ray_found:
|
||||||
|
from ray.util import get_current_placement_group
|
||||||
|
if self.placement_group or get_current_placement_group():
|
||||||
|
backend = "ray"
|
||||||
|
self.distributed_executor_backend = backend
|
||||||
|
logger.info("Defaulting to use %s for distributed inference",
|
||||||
|
backend)
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
|
@ -19,10 +19,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
|||||||
"""Python multiprocessing-based multi-GPU executor"""
|
"""Python multiprocessing-based multi-GPU executor"""
|
||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
assert (
|
|
||||||
not self.speculative_config
|
|
||||||
), "Speculative decoding not yet supported for MultiProcGPU backend."
|
|
||||||
|
|
||||||
# Create the parallel GPU workers.
|
# Create the parallel GPU workers.
|
||||||
world_size = self.parallel_config.tensor_parallel_size
|
world_size = self.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
|
@ -65,10 +65,11 @@ def _set_future_result(future: Union[ResultFuture, asyncio.Future],
|
|||||||
future.set_result(result)
|
future.set_result(result)
|
||||||
return
|
return
|
||||||
loop = future.get_loop()
|
loop = future.get_loop()
|
||||||
if result.exception is not None:
|
if not loop.is_closed():
|
||||||
loop.call_soon_threadsafe(future.set_exception, result.exception)
|
if result.exception is not None:
|
||||||
else:
|
loop.call_soon_threadsafe(future.set_exception, result.exception)
|
||||||
loop.call_soon_threadsafe(future.set_result, result.value)
|
else:
|
||||||
|
loop.call_soon_threadsafe(future.set_result, result.value)
|
||||||
|
|
||||||
|
|
||||||
class ResultHandler(threading.Thread):
|
class ResultHandler(threading.Thread):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user