[Speculative decoding] Adding configuration object for speculative decoding (#3706)

Co-authored-by: Lily Liu <lilyliupku@gmail.com>
This commit is contained in:
Cade Daniel 2024-04-02 17:40:57 -07:00 committed by GitHub
parent a3c226e7eb
commit 5757d90e26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 394 additions and 61 deletions

View File

@ -0,0 +1,41 @@
import pytest
from tests.conftest import cleanup
from vllm import LLM
from vllm.model_executor.utils import set_random_seed
@pytest.fixture
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, seed):
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, seed)
@pytest.fixture
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed):
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed)
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
distinct_llm_kwargs, seed):
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**distinct_llm_kwargs,
}
def generator_inner():
llm = LLM(**kwargs)
set_random_seed(seed)
yield llm
del llm
cleanup()
for llm in generator_inner():
yield llm
del llm

View File

@ -0,0 +1,50 @@
import pytest
from vllm import SamplingParams
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
"speculative_model": "facebook/opt-125m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_config(test_llm_generator):
output_len = 1024
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(
AssertionError,
match="Speculative decoding not yet supported for GPU backend"):
get_token_ids_from_llm_generator(test_llm_generator, prompts,
sampling_params)
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
del llm
return token_ids

View File

@ -107,18 +107,16 @@ def create_worker(cls: type,
block_size=block_size, block_size=block_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
) )
engine_config = engine_args.create_engine_config()
(model_config, cache_config, parallel_config, scheduler_config,
device_config, _, _) = engine_args.create_engine_configs()
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
worker = cls( worker = cls(
model_config=model_config, model_config=engine_config.model_config,
parallel_config=parallel_config, parallel_config=engine_config.parallel_config,
scheduler_config=scheduler_config, scheduler_config=engine_config.scheduler_config,
device_config=device_config, device_config=engine_config.device_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
@ -128,9 +126,9 @@ def create_worker(cls: type,
worker.init_device() worker.init_device()
worker.load_model() worker.load_model()
cache_config.num_gpu_blocks = num_gpu_blocks engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
cache_config.num_cpu_blocks = 0 engine_config.cache_config.num_cpu_blocks = 0
worker.init_cache_engine(cache_config) worker.init_cache_engine(engine_config.cache_config)
worker.warm_up_model() worker.warm_up_model()
return worker return worker

View File

@ -10,19 +10,18 @@ def test_swap() -> None:
engine_args = EngineArgs(model="facebook/opt-125m", engine_args = EngineArgs(model="facebook/opt-125m",
dtype="half", dtype="half",
load_format="dummy") load_format="dummy")
(model_config, cache_config, parallel_config, scheduler_config, engine_config = engine_args.create_engine_config()
device_config, _, _) = engine_args.create_engine_configs() engine_config.cache_config.num_gpu_blocks = 100
cache_config.num_gpu_blocks = 100 engine_config.cache_config.num_cpu_blocks = 100
cache_config.num_cpu_blocks = 100
# Create the worker. # Create the worker.
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())
worker = Worker( worker = Worker(
model_config=model_config, model_config=engine_config.model_config,
parallel_config=parallel_config, parallel_config=engine_config.parallel_config,
scheduler_config=scheduler_config, scheduler_config=engine_config.scheduler_config,
device_config=device_config, device_config=engine_config.device_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
@ -32,7 +31,7 @@ def test_swap() -> None:
# Initialize the worker. # Initialize the worker.
worker.init_device() worker.init_device()
worker.load_model() worker.load_model()
worker.init_cache_engine(cache_config) worker.init_cache_engine(engine_config.cache_config)
worker.warm_up_model() worker.warm_up_model()
# Randomly initialize the cache. # Randomly initialize the cache.

View File

@ -1,7 +1,7 @@
import enum import enum
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, Optional, Union from typing import TYPE_CHECKING, ClassVar, Optional, Union
import torch import torch
@ -617,6 +617,159 @@ class DeviceConfig:
self.device = torch.device(self.device_type) self.device = torch.device(self.device_type)
class SpeculativeConfig:
"""Configuration for speculative decoding.
The configuration is currently specialized to draft-model speculative
decoding with top-1 proposals.
"""
@staticmethod
def maybe_create_spec_config(
target_model_config: ModelConfig,
target_parallel_config: ParallelConfig,
target_dtype: str,
speculative_model: Optional[str],
num_speculative_tokens: Optional[int],
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.
This function attempts to create a SpeculativeConfig object based on the
provided parameters. If the necessary conditions are met, it returns an
instance of SpeculativeConfig. Otherwise, it returns None.
Args:
target_model_config (ModelConfig): The configuration of the target
model.
target_parallel_config (ParallelConfig): The parallel configuration
for the target model.
target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative
model, if provided.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
"""
if (speculative_model is None and num_speculative_tokens is None):
return None
if speculative_model is not None and num_speculative_tokens is None:
raise ValueError(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")
# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
draft_revision = None
draft_code_revision = None
draft_quantization = None
draft_max_model_len = None
draft_model_config = ModelConfig(
model=speculative_model,
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
download_dir=target_model_config.download_dir,
load_format=target_model_config.load_format,
dtype=target_model_config.dtype,
seed=target_model_config.seed,
revision=draft_revision,
code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision,
max_model_len=draft_max_model_len,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_context_len_to_capture=target_model_config.
max_context_len_to_capture,
max_logprobs=target_model_config.max_logprobs,
)
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))
return SpeculativeConfig(
draft_model_config,
draft_parallel_config,
num_speculative_tokens,
)
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config. In the future the
draft worker can have a different parallel strategy, e.g. TP=1.
"""
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
worker_use_ray=target_parallel_config.worker_use_ray,
max_parallel_loading_workers=target_parallel_config.
max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.
disable_custom_all_reduce,
tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
ray_workers_use_nsight=target_parallel_config.
ray_workers_use_nsight,
placement_group=target_parallel_config.placement_group,
)
return draft_parallel_config
def __init__(
self,
draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig,
num_speculative_tokens: int,
):
"""Create a SpeculativeConfig object.
Args:
draft_model_config: ModelConfig for the draft model.
draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens
self._verify_args()
def _verify_args(self) -> None:
if self.num_speculative_tokens <= 0:
raise ValueError("Expected num_speculative_tokens to be greater "
f"than zero ({self.num_speculative_tokens}).")
if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config)
@property
def num_lookahead_slots(self) -> int:
"""The number of additional slots the scheduler should allocate per
step, in addition to the slots allocated for each known token.
This is equal to the number of speculative tokens, as each speculative
token must be scored.
"""
return self.num_speculative_tokens
def __repr__(self) -> str:
draft_model = self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
@dataclass @dataclass
class LoRAConfig: class LoRAConfig:
max_lora_rank: int max_lora_rank: int
@ -838,3 +991,36 @@ def _get_and_verify_max_len(
"to incorrect model outputs or CUDA errors. Make sure the " "to incorrect model outputs or CUDA errors. Make sure the "
"value is correct and within the model context size.") "value is correct and within the model context size.")
return int(max_model_len) return int(max_model_len)
@dataclass(frozen=True)
class EngineConfig:
"""Dataclass which contains all engine-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
model_config: ModelConfig
cache_config: CacheConfig
parallel_config: ParallelConfig
scheduler_config: SchedulerConfig
device_config: DeviceConfig
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
speculative_config: Optional[SpeculativeConfig]
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs.
"""
return dict(
(field.name, getattr(self, field.name)) for field in fields(self))

View File

@ -1,10 +1,11 @@
import argparse import argparse
import dataclasses import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
ParallelConfig, SchedulerConfig, TokenizerPoolConfig, ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.utils import str_to_int_tuple from vllm.utils import str_to_int_tuple
@ -61,9 +62,14 @@ class EngineArgs:
image_token_id: Optional[int] = None image_token_id: Optional[int] = None
image_input_shape: Optional[str] = None image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None image_feature_size: Optional[int] = None
scheduler_delay_factor: float = 0.0 scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False enable_chunked_prefill: bool = False
# Speculative decoding configuration.
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
self.tokenizer = self.model self.tokenizer = self.model
@ -371,6 +377,20 @@ class EngineArgs:
default=False, default=False,
help='If True, the prefill requests can be chunked based on the ' help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens') 'max_num_batched_tokens')
parser.add_argument(
'--speculative-model',
type=str,
default=None,
help=
'The name of the draft model to be used in speculative decoding.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
default=None,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding')
return parser return parser
@classmethod @classmethod
@ -381,11 +401,7 @@ class EngineArgs:
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args return engine_args
def create_engine_configs( def create_engine_config(self, ) -> EngineConfig:
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
DeviceConfig, Optional[LoRAConfig],
Optional[VisionLanguageConfig]]:
device_config = DeviceConfig(self.device) device_config = DeviceConfig(self.device)
model_config = ModelConfig( model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode, self.model, self.tokenizer, self.tokenizer_mode,
@ -409,12 +425,23 @@ class EngineArgs:
self.tokenizer_pool_type, self.tokenizer_pool_type,
self.tokenizer_pool_extra_config, self.tokenizer_pool_extra_config,
), self.ray_workers_use_nsight) ), self.ray_workers_use_nsight)
speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens,
)
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
self.max_num_batched_tokens, self.max_num_batched_tokens,
self.max_num_seqs, self.max_num_seqs,
model_config.max_model_len, model_config.max_model_len,
self.use_v2_block_manager, self.use_v2_block_manager,
num_lookahead_slots=self.num_lookahead_slots, num_lookahead_slots=(self.num_lookahead_slots
if speculative_config is None else
speculative_config.num_lookahead_slots),
delay_factor=self.scheduler_delay_factor, delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
) )
@ -442,8 +469,14 @@ class EngineArgs:
else: else:
vision_language_config = None vision_language_config = None
return (model_config, cache_config, parallel_config, scheduler_config, return EngineConfig(model_config=model_config,
device_config, lora_config, vision_language_config) cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config)
@dataclass @dataclass

View File

@ -328,28 +328,27 @@ class AsyncLLMEngine:
) -> "AsyncLLMEngine": ) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_configs = engine_args.create_engine_configs() engine_config = engine_args.create_engine_config()
parallel_config = engine_configs[2]
device_config = engine_configs[4]
if device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
raise NotImplementedError("Neuron is not supported for " raise NotImplementedError("Neuron is not supported for "
"async engine yet.") "async engine yet.")
elif parallel_config.worker_use_ray or engine_args.engine_use_ray: elif (engine_config.parallel_config.worker_use_ray
initialize_ray_cluster(parallel_config) or engine_args.engine_use_ray):
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync executor_class = RayGPUExecutorAsync
else: else:
assert parallel_config.world_size == 1, ( assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.") "Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync executor_class = GPUExecutorAsync
# Create the async LLM engine. # Create the async LLM engine.
engine = cls( engine = cls(
parallel_config.worker_use_ray, engine_config.parallel_config.worker_use_ray,
engine_args.engine_use_ray, engine_args.engine_use_ray,
*engine_configs, **engine_config.to_dict(),
executor_class, executor_class=executor_class,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len, max_log_len=engine_args.max_log_len,

View File

@ -5,7 +5,8 @@ from transformers import PreTrainedTokenizer
import vllm import vllm
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats from vllm.engine.metrics import StatLogger, Stats
@ -52,6 +53,11 @@ class LLMEngine:
parallel_config: The configuration related to distributed execution. parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler. scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device. device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA.
vision_language_config (Optional): The configuration related to vision
language models.
speculative_config (Optional): The configuration related to speculative
decoding.
executor_class: The model executor class for managing distributed executor_class: The model executor class for managing distributed
execution. execution.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
@ -66,7 +72,8 @@ class LLMEngine:
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional["VisionLanguageConfig"], vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
executor_class: Type[ExecutorBase], executor_class: Type[ExecutorBase],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@ -74,6 +81,7 @@ class LLMEngine:
logger.info( logger.info(
f"Initializing an LLM engine (v{vllm.__version__}) with config: " f"Initializing an LLM engine (v{vllm.__version__}) with config: "
f"model={model_config.model!r}, " f"model={model_config.model!r}, "
f"speculative_config={speculative_config!r}, "
f"tokenizer={model_config.tokenizer!r}, " f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, " f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, " f"revision={model_config.revision}, "
@ -100,17 +108,23 @@ class LLMEngine:
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.speculative_config = speculative_config
self.log_stats = log_stats self.log_stats = log_stats
self._verify_args()
self._init_tokenizer() self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer) self.detokenizer = Detokenizer(self.tokenizer)
self.seq_counter = Counter() self.seq_counter = Counter()
self.model_executor = executor_class(model_config, cache_config, self.model_executor = executor_class(
parallel_config, scheduler_config, model_config=model_config,
device_config, lora_config, cache_config=cache_config,
vision_language_config) parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
)
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled(): if is_usage_stats_enabled():
@ -171,30 +185,28 @@ class LLMEngine:
) -> "LLMEngine": ) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments.""" """Creates an LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_configs = engine_args.create_engine_configs() engine_config = engine_args.create_engine_config()
parallel_config = engine_configs[2]
device_config = engine_configs[4]
# Initialize the cluster and specify the executor class. # Initialize the cluster and specify the executor class.
if device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor executor_class = NeuronExecutor
elif device_config.device_type == "cpu": elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor executor_class = CPUExecutor
elif parallel_config.worker_use_ray: elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(parallel_config) initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor executor_class = RayGPUExecutor
else: else:
assert parallel_config.world_size == 1, ( assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.") "Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor executor_class = GPUExecutor
# Create the LLM engine. # Create the LLM engine.
engine = cls( engine = cls(
*engine_configs, **engine_config.to_dict(),
executor_class=executor_class, executor_class=executor_class,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
usage_context=usage_context, usage_context=usage_context,

View File

@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@ -25,6 +26,7 @@ class ExecutorBase(ABC):
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError

View File

@ -1,7 +1,8 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid from vllm.executor.utils import check_block_size_valid
from vllm.logger import init_logger from vllm.logger import init_logger
@ -24,6 +25,7 @@ class GPUExecutor(ExecutorBase):
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
@ -33,6 +35,9 @@ class GPUExecutor(ExecutorBase):
self.device_config = device_config self.device_config = device_config
self.vision_language_config = vision_language_config self.vision_language_config = vision_language_config
assert (not speculative_config
), "Speculative decoding not yet supported for GPU backend"
# Instantiate the worker and load the model to GPU. # Instantiate the worker and load the model to GPU.
self._init_worker() self._init_worker()

View File

@ -1,7 +1,8 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -21,6 +22,7 @@ class NeuronExecutor(ExecutorBase):
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
@ -28,6 +30,8 @@ class NeuronExecutor(ExecutorBase):
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
assert (not speculative_config
), "Speculative decoding not yet supported for Neuron backend."
# Set the number of GPU blocks to be the same as the maximum number of # Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent # sequences that can be processed in a single batch. This is equivalent

View File

@ -6,7 +6,8 @@ from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid from vllm.executor.utils import check_block_size_valid
@ -41,6 +42,7 @@ class RayGPUExecutor(ExecutorBase):
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
@ -49,6 +51,8 @@ class RayGPUExecutor(ExecutorBase):
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.vision_language_config = vision_language_config self.vision_language_config = vision_language_config
assert (not speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.worker_use_ray assert self.parallel_config.worker_use_ray
placement_group = self.parallel_config.placement_group placement_group = self.parallel_config.placement_group