[Core] Allow specifying custom Executor (#6557)

This commit is contained in:
Antoni Baum 2024-07-19 18:25:06 -07:00 committed by GitHub
parent 2e26564259
commit 7bd82002ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 310 additions and 92 deletions

View File

@ -564,6 +564,10 @@ def get_tokenizer_pool_config(tokenizer_group_type):
return TokenizerPoolConfig(pool_size=1,
pool_type="ray",
extra_config={})
if isinstance(tokenizer_group_type, type):
return TokenizerPoolConfig(pool_size=1,
pool_type=tokenizer_group_type,
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")

View File

@ -0,0 +1,91 @@
import asyncio
import os
import pytest
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync
from vllm.sampling_params import SamplingParams
class Mock:
...
class CustomGPUExecutor(GPUExecutor):
def execute_model(self, *args, **kwargs):
# Drop marker to show that this was ran
with open(".marker", "w"):
...
return super().execute_model(*args, **kwargs)
class CustomGPUExecutorAsync(GPUExecutorAsync):
async def execute_model_async(self, *args, **kwargs):
with open(".marker", "w"):
...
return await super().execute_model_async(*args, **kwargs)
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_custom_executor_type_checking(model):
with pytest.raises(ValueError):
engine_args = EngineArgs(model=model,
distributed_executor_backend=Mock)
LLMEngine.from_engine_args(engine_args)
with pytest.raises(ValueError):
engine_args = AsyncEngineArgs(model=model,
distributed_executor_backend=Mock)
AsyncLLMEngine.from_engine_args(engine_args)
with pytest.raises(TypeError):
engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutor)
AsyncLLMEngine.from_engine_args(engine_args)
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_custom_executor(model, tmpdir):
cwd = os.path.abspath(".")
os.chdir(tmpdir)
try:
assert not os.path.exists(".marker")
engine_args = EngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutor)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)
engine.add_request("0", "foo", sampling_params)
engine.step()
assert os.path.exists(".marker")
finally:
os.chdir(cwd)
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_custom_executor_async(model, tmpdir):
cwd = os.path.abspath(".")
os.chdir(tmpdir)
try:
assert not os.path.exists(".marker")
engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutorAsync)
engine = AsyncLLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)
async def t():
stream = await engine.add_request("0", "foo", sampling_params)
async for x in stream:
...
asyncio.run(t())
assert os.path.exists(".marker")
finally:
os.chdir(cwd)

View File

@ -7,17 +7,28 @@ from unittest.mock import patch
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizer_group import (TokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from ..conftest import get_tokenizer_pool_config
class CustomTokenizerGroup(TokenizerGroup):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._i = 0
def encode(self, *args, **kwargs):
self._i += 1
return super().encode(*args, **kwargs)
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
@pytest.mark.parametrize("tokenizer_group_type",
[None, "ray", CustomTokenizerGroup])
async def test_tokenizer_group(tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer_group = get_tokenizer_group(
@ -36,6 +47,8 @@ async def test_tokenizer_group(tokenizer_group_type):
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
None) == await tokenizer_group.get_lora_tokenizer_async(None)
if tokenizer_group_type is CustomTokenizerGroup:
assert tokenizer_group._i > 0
@pytest.mark.asyncio

View File

@ -1,7 +1,7 @@
import enum
import json
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
import torch
from transformers import PretrainedConfig
@ -18,7 +18,10 @@ from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.model_loader.loader import BaseModelLoader
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
logger = init_logger(__name__)
@ -527,11 +530,12 @@ class TokenizerPoolConfig:
pool type.
"""
pool_size: int
pool_type: str
pool_type: Union[str, Type["BaseTokenizerGroup"]]
extra_config: dict
def __post_init__(self):
if self.pool_type not in ("ray", ):
if self.pool_type not in ("ray", ) and not isinstance(
self.pool_type, type):
raise ValueError(f"Unknown pool type: {self.pool_type}")
if not isinstance(self.extra_config, dict):
raise ValueError("extra_config must be a dictionary.")
@ -661,7 +665,8 @@ class ParallelConfig:
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
ray_workers_use_nsight: bool = False,
placement_group: Optional["PlacementGroup"] = None,
distributed_executor_backend: Optional[str] = None,
distributed_executor_backend: Optional[Union[
str, Type["ExecutorBase"]]] = None,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
@ -676,7 +681,7 @@ class ParallelConfig:
if worker_use_ray:
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray"
elif self.distributed_executor_backend != "ray":
elif not self.use_ray:
raise ValueError(f"worker-use-ray can't be used with "
f"distributed executor backend "
f"'{self.distributed_executor_backend}'.")
@ -711,12 +716,25 @@ class ParallelConfig:
self._verify_args()
self.rank = 0
@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
isinstance(self.distributed_executor_backend, type)
and self.distributed_executor_backend.uses_ray)
def _verify_args(self) -> None:
if self.distributed_executor_backend not in ("ray", "mp", None):
# Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase
if self.distributed_executor_backend not in (
"ray", "mp", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass(
self.distributed_executor_backend, ExecutorBase)):
raise ValueError(
"Unrecognized distributed executor backend. Supported values "
"are 'ray' or 'mp'.")
if self.distributed_executor_backend == "ray":
"Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' or custom ExecutorBase subclass.")
if self.use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
if is_hip():
@ -724,8 +742,7 @@ class ParallelConfig:
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
if self.ray_workers_use_nsight and (
not self.distributed_executor_backend == "ray"):
if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")

View File

@ -2,16 +2,21 @@ import argparse
import dataclasses
import json
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
def nullable_str(val: str):
if not val or val == "None":
@ -36,7 +41,11 @@ class EngineArgs:
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
distributed_executor_backend: Optional[str] = None
# Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# notice.
distributed_executor_backend: Optional[Union[str,
Type[ExecutorBase]]] = None
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
@ -62,7 +71,10 @@ class EngineArgs:
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0
tokenizer_pool_type: str = "ray"
# Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[dict] = None
enable_lora: bool = False
max_loras: int = 1

View File

@ -7,12 +7,13 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
from transformers import PreTrainedTokenizer
import vllm.envs as envs
from vllm.config import DecodingConfig, ModelConfig
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.metrics import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
@ -385,25 +386,19 @@ class AsyncLLMEngine:
self._request_tracker: RequestTracker
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
if engine_args.engine_use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
def _get_executor_cls(
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
if engine_config.device_config.device_type == "neuron":
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
@ -442,9 +437,29 @@ class AsyncLLMEngine:
else:
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
return executor_class
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
if engine_args.engine_use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
executor_class = cls._get_executor_cls(engine_config)
# Create the async LLM engine.
engine = cls(
distributed_executor_backend == "ray",
executor_class.uses_ray,
engine_args.engine_use_ray,
**engine_config.to_dict(),
executor_class=executor_class,

View File

@ -7,9 +7,9 @@ from typing import Set, Type, TypeVar, Union
from transformers import PreTrainedTokenizer
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, MultiModalConfig,
ObservabilityConfig, ParallelConfig,
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
@ -376,19 +376,20 @@ class LLMEngine:
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
def _get_executor_cls(cls,
engine_config: EngineConfig) -> Type[ExecutorBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron":
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}.")
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu":
@ -422,6 +423,19 @@ class LLMEngine:
else:
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
return executor_class
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
**engine_config.to_dict(),

View File

@ -17,6 +17,8 @@ logger = init_logger(__name__)
class CPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
assert self.lora_config is None, "cpu backend doesn't support LoRA"

View File

@ -18,6 +18,8 @@ class ExecutorBase(ABC):
that can execute the model on multiple devices.
"""
uses_ray: bool # whether the executor uses Ray for orchestration.
def __init__(
self,
model_config: ModelConfig,

View File

@ -23,6 +23,8 @@ def create_worker(worker_module_name, worker_class_name, **kwargs):
class GPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""

View File

@ -25,6 +25,8 @@ logger = init_logger(__name__)
class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"""Python multiprocessing-based multi-GPU executor"""
uses_ray: bool = False
def _init_executor(self) -> None:
# Create the parallel GPU workers.
world_size = self.parallel_config.world_size

View File

@ -11,6 +11,8 @@ logger = init_logger(__name__)
class NeuronExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
assert (self.lora_config is
None), "LoRA is not supported for Neuron backend."

View File

@ -18,6 +18,8 @@ logger = init_logger(__name__)
class OpenVINOExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
assert self.device_config.device_type == "openvino"
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"

View File

@ -26,6 +26,8 @@ logger = init_logger(__name__)
class RayGPUExecutor(DistributedGPUExecutor):
uses_ray: bool = True
def _init_executor(self) -> None:
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
@ -47,7 +49,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_COMPILED_DAG=1")
assert self.parallel_config.distributed_executor_backend == "ray"
assert self.uses_ray
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
@ -75,6 +77,20 @@ class RayGPUExecutor(DistributedGPUExecutor):
return ray_remote_kwargs
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
return dict(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if (self.parallel_config.tensor_parallel_size == 1
@ -97,6 +113,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the workers.
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
@ -106,23 +123,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_bundle_index=bundle_id,
)
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
if self.use_ray_spmd_worker:
self.workers.append(worker)
@ -133,10 +139,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)
@ -378,7 +381,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.distributed_executor_backend == "ray"
assert self.parallel_config.use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.

View File

@ -35,6 +35,8 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayXPUExecutor(DistributedGPUExecutor):
uses_ray: bool = True
def __init__(
self,
model_config: ModelConfig,
@ -107,6 +109,13 @@ class RayXPUExecutor(DistributedGPUExecutor):
return num_gpu_blocks, num_cpu_blocks
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
return dict(
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
@ -124,6 +133,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
# Create the workers.
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
@ -137,22 +147,14 @@ class RayXPUExecutor(DistributedGPUExecutor):
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)
@ -337,7 +339,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.distributed_executor_backend == "ray"
assert self.parallel_config.use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.

View File

@ -14,6 +14,8 @@ logger = init_logger(__name__)
class TPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
assert not self.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend")

View File

@ -18,6 +18,8 @@ logger = init_logger(__name__)
class XPUExecutor(GPUExecutor):
uses_ray: bool = False
def __init__(
self,
model_config: ModelConfig,

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Type
from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray
@ -16,18 +16,22 @@ else:
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> BaseTokenizerGroup:
tokenizer_cls: Type[BaseTokenizerGroup]
if tokenizer_pool_config is None:
return TokenizerGroup(**init_kwargs)
if tokenizer_pool_config.pool_type == "ray":
tokenizer_cls = TokenizerGroup
elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass(
tokenizer_pool_config.pool_type, BaseTokenizerGroup):
tokenizer_cls = tokenizer_pool_config.pool_type
elif tokenizer_pool_config.pool_type == "ray":
if RayTokenizerGroupPool is None:
raise ImportError(
"RayTokenizerGroupPool is not available. Please install "
"the ray package to use the Ray tokenizer group pool.")
return RayTokenizerGroupPool.from_config(tokenizer_pool_config,
**init_kwargs)
tokenizer_cls = RayTokenizerGroupPool
else:
raise ValueError(
f"Unknown pool type: {tokenizer_pool_config.pool_type}")
return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]

View File

@ -3,12 +3,19 @@ from typing import List, Optional
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
class BaseTokenizerGroup(ABC):
"""A group of tokenizers that can be used for LoRA adapters."""
@classmethod
@abstractmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "BaseTokenizerGroup":
pass
@abstractmethod
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""

View File

@ -29,8 +29,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
_worker_cls = TokenizerGroup
@classmethod
def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig,
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "RayTokenizerGroupPool":
if not tokenizer_pool_config:
raise ValueError("tokenizer_pool_config must not be None.")
ray_actor_options = (tokenizer_pool_config.extra_config or {
"num_cpus": 0
})

View File

@ -2,6 +2,7 @@ from typing import List, Optional
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
get_lora_tokenizer_async,
@ -24,6 +25,11 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
capacity=max_num_seqs) if enable_lora else None
@classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> "TokenizerGroup":
return cls(**init_kwargs)
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
return True

View File

@ -2,7 +2,7 @@ import dataclasses
import importlib
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import torch
@ -315,14 +315,23 @@ class WorkerWrapperBase:
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
If worker_class_fn is specified, it will be executed to get the worker
class.
Otherwise, the worker class will be obtained by dynamically importing it
using worker_module_name and worker_class_name.
"""
def __init__(self,
worker_module_name: str,
worker_class_name: str,
trust_remote_code: bool = False) -> None:
def __init__(
self,
worker_module_name: str,
worker_class_name: str,
trust_remote_code: bool = False,
worker_class_fn: Optional[Callable[[],
Type[WorkerBase]]] = None) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
self.worker_class_fn = worker_class_fn
self.worker: Optional[WorkerBase] = None
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
@ -348,8 +357,11 @@ class WorkerWrapperBase:
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
if self.worker_class_fn:
worker_class = self.worker_class_fn()
else:
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
assert self.worker is not None