[V1] [Feature] Collective RPC (#15444)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
parent
4965ec42d2
commit
94744ba41a
@ -150,8 +150,8 @@ steps:
|
|||||||
# TODO: create a dedicated test section for multi-GPU example tests
|
# TODO: create a dedicated test section for multi-GPU example tests
|
||||||
# when we have multiple distributed example tests
|
# when we have multiple distributed example tests
|
||||||
- pushd ../examples/offline_inference
|
- pushd ../examples/offline_inference
|
||||||
- VLLM_ENABLE_V1_MULTIPROCESSING=0 python3 rlhf.py
|
- python3 rlhf.py
|
||||||
- VLLM_ENABLE_V1_MULTIPROCESSING=0 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||||
- popd
|
- popd
|
||||||
|
|
||||||
- label: Metrics, Tracing Test # 10min
|
- label: Metrics, Tracing Test # 10min
|
||||||
@ -520,7 +520,7 @@ steps:
|
|||||||
- vllm/v1/engine/
|
- vllm/v1/engine/
|
||||||
commands:
|
commands:
|
||||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||||
- VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py
|
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||||
- pytest -v -s ./compile/test_basic_correctness.py
|
- pytest -v -s ./compile/test_basic_correctness.py
|
||||||
- pytest -v -s ./compile/test_wrapper.py
|
- pytest -v -s ./compile/test_wrapper.py
|
||||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||||
|
@ -7,8 +7,8 @@ from collections import deque
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||||
List, Mapping, NamedTuple, Optional)
|
Iterable, List, Mapping, NamedTuple, Optional)
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Set, Type, Union, cast, overload
|
from typing import Set, Type, Union, cast, overload
|
||||||
|
|
||||||
@ -67,6 +67,7 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5
|
|||||||
|
|
||||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||||
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
|
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
|
||||||
|
_R = TypeVar("_R", default=Any)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -2123,6 +2124,14 @@ class LLMEngine:
|
|||||||
|
|
||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
|
def collective_rpc(self,
|
||||||
|
method: Union[str, Callable[..., _R]],
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: tuple = (),
|
||||||
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||||
|
return self.model_executor.collective_rpc(method, timeout, args,
|
||||||
|
kwargs)
|
||||||
|
|
||||||
|
|
||||||
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
|
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||||
|
@ -492,8 +492,8 @@ class LLM:
|
|||||||
It is recommended to use this API to only pass control messages,
|
It is recommended to use this API to only pass control messages,
|
||||||
and set up data-plane communication to pass data.
|
and set up data-plane communication to pass data.
|
||||||
"""
|
"""
|
||||||
executor = self.llm_engine.model_executor
|
|
||||||
return executor.collective_rpc(method, timeout, args, kwargs)
|
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
|
||||||
|
|
||||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||||
"""
|
"""
|
||||||
|
@ -8,7 +8,7 @@ import time
|
|||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from inspect import isclass, signature
|
from inspect import isclass, signature
|
||||||
from logging import DEBUG
|
from logging import DEBUG
|
||||||
from typing import Any, Optional
|
from typing import Any, Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import psutil
|
import psutil
|
||||||
@ -43,6 +43,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
POLLING_TIMEOUT_S = 2.5
|
POLLING_TIMEOUT_S = 2.5
|
||||||
|
|
||||||
|
_R = TypeVar('_R') # Return type for collective_rpc
|
||||||
|
|
||||||
|
|
||||||
class EngineCore:
|
class EngineCore:
|
||||||
"""Inner loop of vLLM's Engine."""
|
"""Inner loop of vLLM's Engine."""
|
||||||
@ -280,6 +282,14 @@ class EngineCore:
|
|||||||
def pin_lora(self, lora_id: int) -> bool:
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
return self.model_executor.pin_lora(lora_id)
|
return self.model_executor.pin_lora(lora_id)
|
||||||
|
|
||||||
|
def collective_rpc(self,
|
||||||
|
method: Union[str, Callable[..., _R]],
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: tuple = (),
|
||||||
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||||
|
return self.model_executor.collective_rpc(method, timeout, args,
|
||||||
|
kwargs)
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreProc(EngineCore):
|
class EngineCoreProc(EngineCore):
|
||||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||||
|
@ -12,7 +12,7 @@ from collections.abc import Awaitable, Sequence
|
|||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
@ -33,6 +33,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
||||||
|
|
||||||
|
_R = TypeVar('_R') # Return type for collective_rpc
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreClient(ABC):
|
class EngineCoreClient(ABC):
|
||||||
"""
|
"""
|
||||||
@ -117,6 +119,13 @@ class EngineCoreClient(ABC):
|
|||||||
def pin_lora(self, lora_id: int) -> bool:
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def collective_rpc(self,
|
||||||
|
method: Union[str, Callable[..., _R]],
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: tuple = (),
|
||||||
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_output_async(self) -> EngineCoreOutputs:
|
async def get_output_async(self) -> EngineCoreOutputs:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -153,6 +162,14 @@ class EngineCoreClient(ABC):
|
|||||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def collective_rpc_async(
|
||||||
|
self,
|
||||||
|
method: Union[str, Callable[..., _R]],
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: tuple = (),
|
||||||
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class InprocClient(EngineCoreClient):
|
class InprocClient(EngineCoreClient):
|
||||||
"""
|
"""
|
||||||
@ -210,6 +227,13 @@ class InprocClient(EngineCoreClient):
|
|||||||
def pin_lora(self, lora_id: int) -> bool:
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
return self.engine_core.pin_lora(lora_id)
|
return self.engine_core.pin_lora(lora_id)
|
||||||
|
|
||||||
|
def collective_rpc(self,
|
||||||
|
method: Union[str, Callable[..., _R]],
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: tuple = (),
|
||||||
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||||
|
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
class CoreEngine:
|
class CoreEngine:
|
||||||
"""One per data parallel rank."""
|
"""One per data parallel rank."""
|
||||||
@ -505,6 +529,14 @@ class SyncMPClient(MPClient):
|
|||||||
def execute_dummy_batch(self) -> None:
|
def execute_dummy_batch(self) -> None:
|
||||||
self.call_utility("execute_dummy_batch")
|
self.call_utility("execute_dummy_batch")
|
||||||
|
|
||||||
|
def collective_rpc(self,
|
||||||
|
method: Union[str, Callable[..., _R]],
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: tuple = (),
|
||||||
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||||
|
return self.call_utility("collective_rpc", method, timeout, args,
|
||||||
|
kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AsyncMPClient(MPClient):
|
class AsyncMPClient(MPClient):
|
||||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||||
@ -636,6 +668,15 @@ class AsyncMPClient(MPClient):
|
|||||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||||
return await self.call_utility_async("pin_lora", lora_id)
|
return await self.call_utility_async("pin_lora", lora_id)
|
||||||
|
|
||||||
|
async def collective_rpc_async(
|
||||||
|
self,
|
||||||
|
method: Union[str, Callable[..., _R]],
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: tuple = (),
|
||||||
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||||
|
return await self.call_utility_async("collective_rpc", method, timeout,
|
||||||
|
args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
class DPAsyncMPClient(AsyncMPClient):
|
class DPAsyncMPClient(AsyncMPClient):
|
||||||
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
@ -32,6 +32,7 @@ from vllm.v1.executor.abstract import Executor
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||||
|
_R = TypeVar("_R", default=Any)
|
||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
@ -282,6 +283,13 @@ class LLMEngine:
|
|||||||
"""Prevent an adapter from being evicted."""
|
"""Prevent an adapter from being evicted."""
|
||||||
return self.engine_core.pin_lora(lora_id)
|
return self.engine_core.pin_lora(lora_id)
|
||||||
|
|
||||||
|
def collective_rpc(self,
|
||||||
|
method: Union[str, Callable[..., _R]],
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: tuple = (),
|
||||||
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
||||||
|
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if dp_group := getattr(self, "dp_group", None):
|
if dp_group := getattr(self, "dp_group", None):
|
||||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
|
from types import FunctionType
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import cloudpickle
|
||||||
import torch
|
import torch
|
||||||
from msgspec import msgpack
|
from msgspec import msgpack
|
||||||
|
|
||||||
CUSTOM_TYPE_TENSOR = 1
|
CUSTOM_TYPE_TENSOR = 1
|
||||||
CUSTOM_TYPE_PICKLE = 2
|
CUSTOM_TYPE_PICKLE = 2
|
||||||
|
CUSTOM_TYPE_CLOUDPICKLE = 3
|
||||||
|
|
||||||
|
|
||||||
class MsgpackEncoder:
|
class MsgpackEncoder:
|
||||||
@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any:
|
|||||||
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
|
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
|
||||||
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
|
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
|
||||||
|
|
||||||
|
if isinstance(obj, FunctionType):
|
||||||
|
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
|
||||||
|
|
||||||
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
|
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
|
||||||
|
|
||||||
|
|
||||||
@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any:
|
|||||||
return torch.from_numpy(pickle.loads(data))
|
return torch.from_numpy(pickle.loads(data))
|
||||||
if code == CUSTOM_TYPE_PICKLE:
|
if code == CUSTOM_TYPE_PICKLE:
|
||||||
return pickle.loads(data)
|
return pickle.loads(data)
|
||||||
|
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
||||||
|
return cloudpickle.loads(data)
|
||||||
|
|
||||||
raise NotImplementedError(f"Extension type code {code} is not supported")
|
raise NotImplementedError(f"Extension type code {code} is not supported")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user