[Speculative decoding] Adding configuration object for speculative decoding (#3706)
Co-authored-by: Lily Liu <lilyliupku@gmail.com>
This commit is contained in:
parent
a3c226e7eb
commit
5757d90e26
41
tests/spec_decode/e2e/conftest.py
Normal file
41
tests/spec_decode/e2e/conftest.py
Normal file
@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
|
||||
from tests.conftest import cleanup
|
||||
from vllm import LLM
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, seed):
|
||||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, seed)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed):
|
||||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed)
|
||||
|
||||
|
||||
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
distinct_llm_kwargs, seed):
|
||||
kwargs = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**distinct_llm_kwargs,
|
||||
}
|
||||
|
||||
def generator_inner():
|
||||
llm = LLM(**kwargs)
|
||||
|
||||
set_random_seed(seed)
|
||||
|
||||
yield llm
|
||||
del llm
|
||||
cleanup()
|
||||
|
||||
for llm in generator_inner():
|
||||
yield llm
|
||||
del llm
|
50
tests/spec_decode/e2e/test_correctness.py
Normal file
50
tests/spec_decode/e2e/test_correctness.py
Normal file
@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
"speculative_model": "facebook/opt-125m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_config(test_llm_generator):
|
||||
output_len = 1024
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Speculative decoding not yet supported for GPU backend"):
|
||||
get_token_ids_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
|
||||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
del llm
|
||||
|
||||
return token_ids
|
@ -107,18 +107,16 @@ def create_worker(cls: type,
|
||||
block_size=block_size,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
|
||||
(model_config, cache_config, parallel_config, scheduler_config,
|
||||
device_config, _, _) = engine_args.create_engine_configs()
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
|
||||
worker = cls(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
model_config=engine_config.model_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
@ -128,9 +126,9 @@ def create_worker(cls: type,
|
||||
worker.init_device()
|
||||
worker.load_model()
|
||||
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
cache_config.num_cpu_blocks = 0
|
||||
worker.init_cache_engine(cache_config)
|
||||
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
engine_config.cache_config.num_cpu_blocks = 0
|
||||
worker.init_cache_engine(engine_config.cache_config)
|
||||
worker.warm_up_model()
|
||||
|
||||
return worker
|
||||
|
@ -10,19 +10,18 @@ def test_swap() -> None:
|
||||
engine_args = EngineArgs(model="facebook/opt-125m",
|
||||
dtype="half",
|
||||
load_format="dummy")
|
||||
(model_config, cache_config, parallel_config, scheduler_config,
|
||||
device_config, _, _) = engine_args.create_engine_configs()
|
||||
cache_config.num_gpu_blocks = 100
|
||||
cache_config.num_cpu_blocks = 100
|
||||
engine_config = engine_args.create_engine_config()
|
||||
engine_config.cache_config.num_gpu_blocks = 100
|
||||
engine_config.cache_config.num_cpu_blocks = 100
|
||||
|
||||
# Create the worker.
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
worker = Worker(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
model_config=engine_config.model_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
@ -32,7 +31,7 @@ def test_swap() -> None:
|
||||
# Initialize the worker.
|
||||
worker.init_device()
|
||||
worker.load_model()
|
||||
worker.init_cache_engine(cache_config)
|
||||
worker.init_cache_engine(engine_config.cache_config)
|
||||
worker.warm_up_model()
|
||||
|
||||
# Randomly initialize the cache.
|
||||
|
188
vllm/config.py
188
vllm/config.py
@ -1,7 +1,7 @@
|
||||
import enum
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -617,6 +617,159 @@ class DeviceConfig:
|
||||
self.device = torch.device(self.device_type)
|
||||
|
||||
|
||||
class SpeculativeConfig:
|
||||
"""Configuration for speculative decoding.
|
||||
|
||||
The configuration is currently specialized to draft-model speculative
|
||||
decoding with top-1 proposals.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def maybe_create_spec_config(
|
||||
target_model_config: ModelConfig,
|
||||
target_parallel_config: ParallelConfig,
|
||||
target_dtype: str,
|
||||
speculative_model: Optional[str],
|
||||
num_speculative_tokens: Optional[int],
|
||||
) -> Optional["SpeculativeConfig"]:
|
||||
"""Create a SpeculativeConfig if possible, else return None.
|
||||
|
||||
This function attempts to create a SpeculativeConfig object based on the
|
||||
provided parameters. If the necessary conditions are met, it returns an
|
||||
instance of SpeculativeConfig. Otherwise, it returns None.
|
||||
|
||||
Args:
|
||||
target_model_config (ModelConfig): The configuration of the target
|
||||
model.
|
||||
target_parallel_config (ParallelConfig): The parallel configuration
|
||||
for the target model.
|
||||
target_dtype (str): The data type used for the target model.
|
||||
speculative_model (Optional[str]): The name of the speculative
|
||||
model, if provided.
|
||||
num_speculative_tokens (Optional[int]): The number of speculative
|
||||
tokens, if provided.
|
||||
|
||||
Returns:
|
||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||
the necessary conditions are met, else None.
|
||||
"""
|
||||
|
||||
if (speculative_model is None and num_speculative_tokens is None):
|
||||
return None
|
||||
|
||||
if speculative_model is not None and num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
"Expected both speculative_model and "
|
||||
"num_speculative_tokens to be provided, but found "
|
||||
f"{speculative_model=} and {num_speculative_tokens=}.")
|
||||
|
||||
# TODO: The user should be able to specify revision/quantization/max
|
||||
# model len for the draft model. It is not currently supported.
|
||||
draft_revision = None
|
||||
draft_code_revision = None
|
||||
draft_quantization = None
|
||||
draft_max_model_len = None
|
||||
|
||||
draft_model_config = ModelConfig(
|
||||
model=speculative_model,
|
||||
tokenizer=target_model_config.tokenizer,
|
||||
tokenizer_mode=target_model_config.tokenizer_mode,
|
||||
trust_remote_code=target_model_config.trust_remote_code,
|
||||
download_dir=target_model_config.download_dir,
|
||||
load_format=target_model_config.load_format,
|
||||
dtype=target_model_config.dtype,
|
||||
seed=target_model_config.seed,
|
||||
revision=draft_revision,
|
||||
code_revision=draft_code_revision,
|
||||
tokenizer_revision=target_model_config.tokenizer_revision,
|
||||
max_model_len=draft_max_model_len,
|
||||
quantization=draft_quantization,
|
||||
enforce_eager=target_model_config.enforce_eager,
|
||||
max_context_len_to_capture=target_model_config.
|
||||
max_context_len_to_capture,
|
||||
max_logprobs=target_model_config.max_logprobs,
|
||||
)
|
||||
|
||||
draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
target_parallel_config))
|
||||
|
||||
return SpeculativeConfig(
|
||||
draft_model_config,
|
||||
draft_parallel_config,
|
||||
num_speculative_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_draft_parallel_config(
|
||||
target_parallel_config: ParallelConfig) -> ParallelConfig:
|
||||
"""Create a parallel config for use by the draft worker.
|
||||
|
||||
This is mostly a copy of the target parallel config. In the future the
|
||||
draft worker can have a different parallel strategy, e.g. TP=1.
|
||||
"""
|
||||
draft_parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=target_parallel_config.
|
||||
pipeline_parallel_size,
|
||||
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
|
||||
worker_use_ray=target_parallel_config.worker_use_ray,
|
||||
max_parallel_loading_workers=target_parallel_config.
|
||||
max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=target_parallel_config.
|
||||
disable_custom_all_reduce,
|
||||
tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
|
||||
ray_workers_use_nsight=target_parallel_config.
|
||||
ray_workers_use_nsight,
|
||||
placement_group=target_parallel_config.placement_group,
|
||||
)
|
||||
|
||||
return draft_parallel_config
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
draft_model_config: ModelConfig,
|
||||
draft_parallel_config: ParallelConfig,
|
||||
num_speculative_tokens: int,
|
||||
):
|
||||
"""Create a SpeculativeConfig object.
|
||||
|
||||
Args:
|
||||
draft_model_config: ModelConfig for the draft model.
|
||||
draft_parallel_config: ParallelConfig for the draft model.
|
||||
num_speculative_tokens: The number of tokens to sample from the
|
||||
draft model before scoring with the target model.
|
||||
"""
|
||||
self.draft_model_config = draft_model_config
|
||||
self.draft_parallel_config = draft_parallel_config
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.num_speculative_tokens <= 0:
|
||||
raise ValueError("Expected num_speculative_tokens to be greater "
|
||||
f"than zero ({self.num_speculative_tokens}).")
|
||||
|
||||
if self.draft_model_config:
|
||||
self.draft_model_config.verify_with_parallel_config(
|
||||
self.draft_parallel_config)
|
||||
|
||||
@property
|
||||
def num_lookahead_slots(self) -> int:
|
||||
"""The number of additional slots the scheduler should allocate per
|
||||
step, in addition to the slots allocated for each known token.
|
||||
|
||||
This is equal to the number of speculative tokens, as each speculative
|
||||
token must be scored.
|
||||
"""
|
||||
return self.num_speculative_tokens
|
||||
|
||||
def __repr__(self) -> str:
|
||||
draft_model = self.draft_model_config.model
|
||||
num_spec_tokens = self.num_speculative_tokens
|
||||
return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
max_lora_rank: int
|
||||
@ -838,3 +991,36 @@ def _get_and_verify_max_len(
|
||||
"to incorrect model outputs or CUDA errors. Make sure the "
|
||||
"value is correct and within the model context size.")
|
||||
return int(max_model_len)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EngineConfig:
|
||||
"""Dataclass which contains all engine-related configuration. This
|
||||
simplifies passing around the distinct configurations in the codebase.
|
||||
"""
|
||||
|
||||
model_config: ModelConfig
|
||||
cache_config: CacheConfig
|
||||
parallel_config: ParallelConfig
|
||||
scheduler_config: SchedulerConfig
|
||||
device_config: DeviceConfig
|
||||
lora_config: Optional[LoRAConfig]
|
||||
vision_language_config: Optional[VisionLanguageConfig]
|
||||
speculative_config: Optional[SpeculativeConfig]
|
||||
|
||||
def __post_init__(self):
|
||||
"""Verify configs are valid & consistent with each other.
|
||||
"""
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
if self.lora_config:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
self.scheduler_config)
|
||||
|
||||
def to_dict(self):
|
||||
"""Return the configs as a dictionary, for use in **kwargs.
|
||||
"""
|
||||
return dict(
|
||||
(field.name, getattr(self, field.name)) for field in fields(self))
|
||||
|
@ -1,10 +1,11 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, TokenizerPoolConfig,
|
||||
from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TokenizerPoolConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.utils import str_to_int_tuple
|
||||
|
||||
@ -61,9 +62,14 @@ class EngineArgs:
|
||||
image_token_id: Optional[int] = None
|
||||
image_input_shape: Optional[str] = None
|
||||
image_feature_size: Optional[int] = None
|
||||
|
||||
scheduler_delay_factor: float = 0.0
|
||||
enable_chunked_prefill: bool = False
|
||||
|
||||
# Speculative decoding configuration.
|
||||
speculative_model: Optional[str] = None
|
||||
num_speculative_tokens: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = self.model
|
||||
@ -371,6 +377,20 @@ class EngineArgs:
|
||||
default=False,
|
||||
help='If True, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
|
||||
parser.add_argument(
|
||||
'--speculative-model',
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
'The name of the draft model to be used in speculative decoding.')
|
||||
|
||||
parser.add_argument(
|
||||
'--num-speculative-tokens',
|
||||
type=int,
|
||||
default=None,
|
||||
help='The number of speculative tokens to sample from '
|
||||
'the draft model in speculative decoding')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -381,11 +401,7 @@ class EngineArgs:
|
||||
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
return engine_args
|
||||
|
||||
def create_engine_configs(
|
||||
self,
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
|
||||
DeviceConfig, Optional[LoRAConfig],
|
||||
Optional[VisionLanguageConfig]]:
|
||||
def create_engine_config(self, ) -> EngineConfig:
|
||||
device_config = DeviceConfig(self.device)
|
||||
model_config = ModelConfig(
|
||||
self.model, self.tokenizer, self.tokenizer_mode,
|
||||
@ -409,12 +425,23 @@ class EngineArgs:
|
||||
self.tokenizer_pool_type,
|
||||
self.tokenizer_pool_extra_config,
|
||||
), self.ray_workers_use_nsight)
|
||||
|
||||
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=parallel_config,
|
||||
target_dtype=self.dtype,
|
||||
speculative_model=self.speculative_model,
|
||||
num_speculative_tokens=self.num_speculative_tokens,
|
||||
)
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
self.max_num_batched_tokens,
|
||||
self.max_num_seqs,
|
||||
model_config.max_model_len,
|
||||
self.use_v2_block_manager,
|
||||
num_lookahead_slots=self.num_lookahead_slots,
|
||||
num_lookahead_slots=(self.num_lookahead_slots
|
||||
if speculative_config is None else
|
||||
speculative_config.num_lookahead_slots),
|
||||
delay_factor=self.scheduler_delay_factor,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
)
|
||||
@ -442,8 +469,14 @@ class EngineArgs:
|
||||
else:
|
||||
vision_language_config = None
|
||||
|
||||
return (model_config, cache_config, parallel_config, scheduler_config,
|
||||
device_config, lora_config, vision_language_config)
|
||||
return EngineConfig(model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
speculative_config=speculative_config)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -328,28 +328,27 @@ class AsyncLLMEngine:
|
||||
) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
device_config = engine_configs[4]
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
if device_config.device_type == "neuron":
|
||||
if engine_config.device_config.device_type == "neuron":
|
||||
raise NotImplementedError("Neuron is not supported for "
|
||||
"async engine yet.")
|
||||
elif parallel_config.worker_use_ray or engine_args.engine_use_ray:
|
||||
initialize_ray_cluster(parallel_config)
|
||||
elif (engine_config.parallel_config.worker_use_ray
|
||||
or engine_args.engine_use_ray):
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||
executor_class = RayGPUExecutorAsync
|
||||
else:
|
||||
assert parallel_config.world_size == 1, (
|
||||
assert engine_config.parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||
executor_class = GPUExecutorAsync
|
||||
# Create the async LLM engine.
|
||||
engine = cls(
|
||||
parallel_config.worker_use_ray,
|
||||
engine_config.parallel_config.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
*engine_configs,
|
||||
executor_class,
|
||||
**engine_config.to_dict(),
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
max_log_len=engine_args.max_log_len,
|
||||
|
@ -5,7 +5,8 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
import vllm
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics import StatLogger, Stats
|
||||
@ -52,6 +53,11 @@ class LLMEngine:
|
||||
parallel_config: The configuration related to distributed execution.
|
||||
scheduler_config: The configuration related to the request scheduler.
|
||||
device_config: The configuration related to the device.
|
||||
lora_config (Optional): The configuration related to serving multi-LoRA.
|
||||
vision_language_config (Optional): The configuration related to vision
|
||||
language models.
|
||||
speculative_config (Optional): The configuration related to speculative
|
||||
decoding.
|
||||
executor_class: The model executor class for managing distributed
|
||||
execution.
|
||||
log_stats: Whether to log statistics.
|
||||
@ -66,7 +72,8 @@ class LLMEngine:
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional["VisionLanguageConfig"],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
executor_class: Type[ExecutorBase],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
@ -74,6 +81,7 @@ class LLMEngine:
|
||||
logger.info(
|
||||
f"Initializing an LLM engine (v{vllm.__version__}) with config: "
|
||||
f"model={model_config.model!r}, "
|
||||
f"speculative_config={speculative_config!r}, "
|
||||
f"tokenizer={model_config.tokenizer!r}, "
|
||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||
f"revision={model_config.revision}, "
|
||||
@ -100,17 +108,23 @@ class LLMEngine:
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.speculative_config = speculative_config
|
||||
self.log_stats = log_stats
|
||||
self._verify_args()
|
||||
|
||||
self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
self.seq_counter = Counter()
|
||||
|
||||
self.model_executor = executor_class(model_config, cache_config,
|
||||
parallel_config, scheduler_config,
|
||||
device_config, lora_config,
|
||||
vision_language_config)
|
||||
self.model_executor = executor_class(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
speculative_config=speculative_config,
|
||||
)
|
||||
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
if is_usage_stats_enabled():
|
||||
@ -171,30 +185,28 @@ class LLMEngine:
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
device_config = engine_configs[4]
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
# Initialize the cluster and specify the executor class.
|
||||
if device_config.device_type == "neuron":
|
||||
if engine_config.device_config.device_type == "neuron":
|
||||
from vllm.executor.neuron_executor import NeuronExecutor
|
||||
executor_class = NeuronExecutor
|
||||
elif device_config.device_type == "cpu":
|
||||
elif engine_config.device_config.device_type == "cpu":
|
||||
from vllm.executor.cpu_executor import CPUExecutor
|
||||
executor_class = CPUExecutor
|
||||
elif parallel_config.worker_use_ray:
|
||||
initialize_ray_cluster(parallel_config)
|
||||
elif engine_config.parallel_config.worker_use_ray:
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
||||
executor_class = RayGPUExecutor
|
||||
else:
|
||||
assert parallel_config.world_size == 1, (
|
||||
assert engine_config.parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
executor_class = GPUExecutor
|
||||
|
||||
# Create the LLM engine.
|
||||
engine = cls(
|
||||
*engine_configs,
|
||||
**engine_config.to_dict(),
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
|
@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
|
||||
@ -25,6 +26,7 @@ class ExecutorBase(ABC):
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.executor.utils import check_block_size_valid
|
||||
from vllm.logger import init_logger
|
||||
@ -24,6 +25,7 @@ class GPUExecutor(ExecutorBase):
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@ -33,6 +35,9 @@ class GPUExecutor(ExecutorBase):
|
||||
self.device_config = device_config
|
||||
self.vision_language_config = vision_language_config
|
||||
|
||||
assert (not speculative_config
|
||||
), "Speculative decoding not yet supported for GPU backend"
|
||||
|
||||
# Instantiate the worker and load the model to GPU.
|
||||
self._init_worker()
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -21,6 +22,7 @@ class NeuronExecutor(ExecutorBase):
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@ -28,6 +30,8 @@ class NeuronExecutor(ExecutorBase):
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
assert (not speculative_config
|
||||
), "Speculative decoding not yet supported for Neuron backend."
|
||||
|
||||
# Set the number of GPU blocks to be the same as the maximum number of
|
||||
# sequences that can be processed in a single batch. This is equivalent
|
||||
|
@ -6,7 +6,8 @@ from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.executor.utils import check_block_size_valid
|
||||
@ -41,6 +42,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@ -49,6 +51,8 @@ class RayGPUExecutor(ExecutorBase):
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.vision_language_config = vision_language_config
|
||||
assert (not speculative_config
|
||||
), "Speculative decoding not yet supported for RayGPU backend."
|
||||
|
||||
assert self.parallel_config.worker_use_ray
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
Loading…
x
Reference in New Issue
Block a user