[V1] [Feature] Collective RPC (#15444)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
parent
4965ec42d2
commit
94744ba41a
@ -150,8 +150,8 @@ steps:
|
||||
# TODO: create a dedicated test section for multi-GPU example tests
|
||||
# when we have multiple distributed example tests
|
||||
- pushd ../examples/offline_inference
|
||||
- VLLM_ENABLE_V1_MULTIPROCESSING=0 python3 rlhf.py
|
||||
- VLLM_ENABLE_V1_MULTIPROCESSING=0 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||
- python3 rlhf.py
|
||||
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||
- popd
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
@ -520,7 +520,7 @@ steps:
|
||||
- vllm/v1/engine/
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
|
@ -7,8 +7,8 @@ from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
|
||||
List, Mapping, NamedTuple, Optional)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||
Iterable, List, Mapping, NamedTuple, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, cast, overload
|
||||
|
||||
@ -67,6 +67,7 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -2123,6 +2124,14 @@ class LLMEngine:
|
||||
|
||||
return sampling_params
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.model_executor.collective_rpc(method, timeout, args,
|
||||
kwargs)
|
||||
|
||||
|
||||
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||
|
@ -492,8 +492,8 @@ class LLM:
|
||||
It is recommended to use this API to only pass control messages,
|
||||
and set up data-plane communication to pass data.
|
||||
"""
|
||||
executor = self.llm_engine.model_executor
|
||||
return executor.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||
"""
|
||||
|
@ -8,7 +8,7 @@ import time
|
||||
from concurrent.futures import Future
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import msgspec
|
||||
import psutil
|
||||
@ -43,6 +43,8 @@ logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_S = 2.5
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
|
||||
class EngineCore:
|
||||
"""Inner loop of vLLM's Engine."""
|
||||
@ -280,6 +282,14 @@ class EngineCore:
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_executor.pin_lora(lora_id)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.model_executor.collective_rpc(method, timeout, args,
|
||||
kwargs)
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
|
@ -12,7 +12,7 @@ from collections.abc import Awaitable, Sequence
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@ -33,6 +33,8 @@ logger = init_logger(__name__)
|
||||
|
||||
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
|
||||
class EngineCoreClient(ABC):
|
||||
"""
|
||||
@ -117,6 +119,13 @@ class EngineCoreClient(ABC):
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_output_async(self) -> EngineCoreOutputs:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -153,6 +162,14 @@ class EngineCoreClient(ABC):
|
||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def collective_rpc_async(
|
||||
self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InprocClient(EngineCoreClient):
|
||||
"""
|
||||
@ -210,6 +227,13 @@ class InprocClient(EngineCoreClient):
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank."""
|
||||
@ -505,6 +529,14 @@ class SyncMPClient(MPClient):
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.call_utility("execute_dummy_batch")
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.call_utility("collective_rpc", method, timeout, args,
|
||||
kwargs)
|
||||
|
||||
|
||||
class AsyncMPClient(MPClient):
|
||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||
@ -636,6 +668,15 @@ class AsyncMPClient(MPClient):
|
||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||
return await self.call_utility_async("pin_lora", lora_id)
|
||||
|
||||
async def collective_rpc_async(
|
||||
self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return await self.call_utility_async("collective_rpc", method, timeout,
|
||||
args, kwargs)
|
||||
|
||||
|
||||
class DPAsyncMPClient(AsyncMPClient):
|
||||
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
@ -32,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
@ -282,6 +283,13 @@ class LLMEngine:
|
||||
"""Prevent an adapter from being evicted."""
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
def __del__(self):
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
@ -1,13 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pickle
|
||||
from types import FunctionType
|
||||
from typing import Any, Optional
|
||||
|
||||
import cloudpickle
|
||||
import torch
|
||||
from msgspec import msgpack
|
||||
|
||||
CUSTOM_TYPE_TENSOR = 1
|
||||
CUSTOM_TYPE_PICKLE = 2
|
||||
CUSTOM_TYPE_CLOUDPICKLE = 3
|
||||
|
||||
|
||||
class MsgpackEncoder:
|
||||
@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
|
||||
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
|
||||
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
|
||||
|
||||
if isinstance(obj, FunctionType):
|
||||
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
|
||||
|
||||
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
|
||||
|
||||
|
||||
@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
|
||||
return torch.from_numpy(pickle.loads(data))
|
||||
if code == CUSTOM_TYPE_PICKLE:
|
||||
return pickle.loads(data)
|
||||
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
||||
return cloudpickle.loads(data)
|
||||
|
||||
raise NotImplementedError(f"Extension type code {code} is not supported")
|
||||
|
Loading…
x
Reference in New Issue
Block a user