[V1] [Feature] Collective RPC (#15444)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
wwl2755 2025-03-29 05:39:14 -05:00 committed by GitHub
parent 4965ec42d2
commit 94744ba41a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 86 additions and 10 deletions

View File

@ -150,8 +150,8 @@ steps:
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- pushd ../examples/offline_inference
- VLLM_ENABLE_V1_MULTIPROCESSING=0 python3 rlhf.py
- VLLM_ENABLE_V1_MULTIPROCESSING=0 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- python3 rlhf.py
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd
- label: Metrics, Tracing Test # 10min
@ -520,7 +520,7 @@ steps:
- vllm/v1/engine/
commands:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'

View File

@ -7,8 +7,8 @@ from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
@ -67,6 +67,7 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
_R = TypeVar("_R", default=Any)
@dataclass
@ -2123,6 +2124,14 @@ class LLMEngine:
return sampling_params
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine

View File

@ -492,8 +492,8 @@ class LLM:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
executor = self.llm_engine.model_executor
return executor.collective_rpc(method, timeout, args, kwargs)
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
"""

View File

@ -8,7 +8,7 @@ import time
from concurrent.futures import Future
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, Optional
from typing import Any, Callable, Optional, TypeVar, Union
import msgspec
import psutil
@ -43,6 +43,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCore:
"""Inner loop of vLLM's Engine."""
@ -280,6 +282,14 @@ class EngineCore:
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""

View File

@ -12,7 +12,7 @@ from collections.abc import Awaitable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass, field
from threading import Thread
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TypeVar, Union
import zmq
import zmq.asyncio
@ -33,6 +33,8 @@ logger = init_logger(__name__)
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCoreClient(ABC):
"""
@ -117,6 +119,13 @@ class EngineCoreClient(ABC):
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError
@ -153,6 +162,14 @@ class EngineCoreClient(ABC):
async def pin_lora_async(self, lora_id: int) -> bool:
raise NotImplementedError
async def collective_rpc_async(
self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
class InprocClient(EngineCoreClient):
"""
@ -210,6 +227,13 @@ class InprocClient(EngineCoreClient):
def pin_lora(self, lora_id: int) -> bool:
return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
class CoreEngine:
"""One per data parallel rank."""
@ -505,6 +529,14 @@ class SyncMPClient(MPClient):
def execute_dummy_batch(self) -> None:
self.call_utility("execute_dummy_batch")
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.call_utility("collective_rpc", method, timeout, args,
kwargs)
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
@ -636,6 +668,15 @@ class AsyncMPClient(MPClient):
async def pin_lora_async(self, lora_id: int) -> bool:
return await self.call_utility_async("pin_lora", lora_id)
async def collective_rpc_async(
self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return await self.call_utility_async("collective_rpc", method, timeout,
args, kwargs)
class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)

View File

@ -2,7 +2,7 @@
from collections.abc import Mapping
from copy import copy
from typing import Optional, Union
from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar
@ -32,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_R = TypeVar("_R", default=Any)
class LLMEngine:
@ -282,6 +283,13 @@ class LLMEngine:
"""Prevent an adapter from being evicted."""
return self.engine_core.pin_lora(lora_id)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def __del__(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)

View File

@ -1,13 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
import pickle
from types import FunctionType
from typing import Any, Optional
import cloudpickle
import torch
from msgspec import msgpack
CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2
CUSTOM_TYPE_CLOUDPICKLE = 3
class MsgpackEncoder:
@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
if isinstance(obj, FunctionType):
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
raise NotImplementedError(f"Extension type code {code} is not supported")