[Typing] Mypy typing part 2 (#4043)

Co-authored-by: SangBin Cho <sangcho@sangcho-LT93GQWG9C.local>
This commit is contained in:
SangBin Cho 2024-04-18 09:28:43 +09:00 committed by GitHub
parent a53222544c
commit 533d2a1f39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 180 additions and 126 deletions

View File

@ -41,10 +41,10 @@ jobs:
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up # TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml

View File

@ -104,10 +104,10 @@ mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up # TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml

View File

@ -2,8 +2,8 @@ import asyncio
import os import os
import time import time
from functools import partial from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
Set, Tuple, Type, Union) Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -52,7 +52,7 @@ class AsyncStream:
def __init__(self, request_id: str) -> None: def __init__(self, request_id: str) -> None:
self.request_id = request_id self.request_id = request_id
self._queue = asyncio.Queue() self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False self._finished = False
def put(self, item: Union[RequestOutput, Exception]) -> None: def put(self, item: Union[RequestOutput, Exception]) -> None:
@ -312,15 +312,17 @@ class AsyncLLMEngine:
self.max_log_len = max_log_len self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs) self.engine = self._init_engine(*args, **kwargs)
self.background_loop = None self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded # We need to keep a reference to unshielded
# task as well to prevent it from being garbage # task as well to prevent it from being garbage
# collected # collected
self._background_loop_unshielded = None self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
self.start_engine_loop = start_engine_loop self.start_engine_loop = start_engine_loop
self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
# Lazy initialized fields
self._request_tracker: RequestTracker
@classmethod @classmethod
def from_engine_args( def from_engine_args(
cls, cls,
@ -361,11 +363,13 @@ class AsyncLLMEngine:
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
return (self.background_loop is not None return (self.background_loop is not None
and self._background_loop_unshielded is not None
and not self._background_loop_unshielded.done()) and not self._background_loop_unshielded.done())
@property @property
def is_stopped(self) -> bool: def is_stopped(self) -> bool:
return self.errored or (self.background_loop is not None return self.errored or (self.background_loop is not None and
self._background_loop_unshielded is not None
and self._background_loop_unshielded.done()) and self._background_loop_unshielded.done())
@property @property
@ -381,7 +385,7 @@ class AsyncLLMEngine:
async def get_tokenizer(self) -> "PreTrainedTokenizer": async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray: if self.engine_use_ray:
return await self.engine.get_tokenizer.remote() return await self.engine.get_tokenizer.remote() # type: ignore
else: else:
return self.engine.get_tokenizer() return self.engine.get_tokenizer()
@ -434,7 +438,8 @@ class AsyncLLMEngine:
# TODO: Maybe add add_request_batch to reduce Ray overhead # TODO: Maybe add add_request_batch to reduce Ray overhead
try: try:
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.add_request.remote(**new_request) await self.engine.add_request.remote( # type: ignore
**new_request)
else: else:
await self.engine.add_request_async(**new_request) await self.engine.add_request_async(**new_request)
except ValueError as e: except ValueError as e:
@ -449,7 +454,7 @@ class AsyncLLMEngine:
await self._engine_abort(finished_requests) await self._engine_abort(finished_requests)
if self.engine_use_ray: if self.engine_use_ray:
request_outputs = await self.engine.step.remote() request_outputs = await self.engine.step.remote() # type: ignore
else: else:
request_outputs = await self.engine.step_async() request_outputs = await self.engine.step_async()
@ -462,7 +467,7 @@ class AsyncLLMEngine:
async def _engine_abort(self, request_ids: Iterable[str]): async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids) await self.engine.abort_request.remote(request_ids) # type: ignore
else: else:
self.engine.abort_request(request_ids) self.engine.abort_request(request_ids)
@ -525,11 +530,12 @@ class AsyncLLMEngine:
arrival_time = time.time() arrival_time = time.time()
if self.engine_use_ray: if self.engine_use_ray:
prompt_token_ids = await self.engine.encode_request_async.remote( prompt_token_ids = await (
request_id=request_id, self.engine.encode_request_async.remote( # type: ignore
prompt=prompt, request_id=request_id,
prompt_token_ids=prompt_token_ids, prompt=prompt,
lora_request=lora_request) prompt_token_ids=prompt_token_ids,
lora_request=lora_request))
else: else:
prompt_token_ids = await self.engine.encode_request_async( prompt_token_ids = await self.engine.encode_request_async(
request_id=request_id, request_id=request_id,
@ -676,13 +682,13 @@ class AsyncLLMEngine:
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
if self.engine_use_ray: if self.engine_use_ray:
return await self.engine.get_model_config.remote() return await self.engine.get_model_config.remote() # type: ignore
else: else:
return self.engine.get_model_config() return self.engine.get_model_config()
async def do_log_stats(self) -> None: async def do_log_stats(self) -> None:
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.do_log_stats.remote() await self.engine.do_log_stats.remote() # type: ignore
else: else:
self.engine.do_log_stats() self.engine.do_log_stats()
@ -695,7 +701,7 @@ class AsyncLLMEngine:
if self.engine_use_ray: if self.engine_use_ray:
try: try:
await self.engine.check_health.remote() await self.engine.check_health.remote() # type: ignore
except ray.exceptions.RayActorError as e: except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e raise RuntimeError("Engine is dead.") from e
else: else:

View File

@ -107,12 +107,12 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
self._lora_manager: LoRAModelManager = lora_manager self._lora_manager: LoRAModelManager = lora_manager
return lora_manager.model return lora_manager.model
def set_active_loras(self, lora_requests: List[LoRARequest], def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None: lora_mapping: LoRAMapping) -> None:
self._apply_loras(lora_requests) self._apply_loras(lora_requests)
self._lora_manager.set_lora_mapping(lora_mapping) self._lora_manager.set_lora_mapping(lora_mapping)
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_that_exist = self.list_loras() loras_that_exist = self.list_loras()
loras_map = { loras_map = {
lora_request.lora_int_id: lora_request lora_request.lora_int_id: lora_request

View File

@ -55,7 +55,7 @@ global_thread_pool = None # used for generating logits processor fsm
async def get_outlines_guided_decoding_logits_processor( async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest], request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
""" """
Given an OpenAI-compatible request, check for guided decoding parameters Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide. and get the necessary logits processor for the given guide.
@ -84,7 +84,7 @@ async def get_outlines_guided_decoding_logits_processor(
def _get_guide_and_mode( def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest] request: Union[CompletionRequest, ChatCompletionRequest]
) -> Tuple[str, GuidedDecodingMode]: ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
if request.guided_json: if request.guided_json:
json = request.guided_json json = request.guided_json

View File

@ -21,7 +21,7 @@ from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union from typing import Callable, DefaultDict, Dict, List, Optional, Union
import torch import torch
from outlines.fsm.fsm import CFGFSM, RegexFSM from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@ -29,6 +29,10 @@ from transformers import PreTrainedTokenizerBase
class BaseLogitsProcessor: class BaseLogitsProcessor:
def __init__(self):
# Child class should use initialize in their init.
self.fsm: FSM
def init_state(self): def init_state(self):
"""Initialize the FSM states.""" """Initialize the FSM states."""
self.fsm_state: DefaultDict[int, int] = defaultdict(int) self.fsm_state: DefaultDict[int, int] = defaultdict(int)

View File

@ -1,7 +1,7 @@
"""Utilities for selecting and loading neuron models.""" """Utilities for selecting and loading neuron models."""
import importlib import importlib
import os import os
from typing import Optional, Type from typing import Dict, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -27,7 +27,7 @@ TORCH_DTYPE_TO_NEURON_AMP = {
} }
# Models supported by Neuron. # Models supported by Neuron.
_NEURON_SUPPORTED_MODELS = { _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
"LlamaForCausalLM": ("transformers_neuronx.llama.model", "LlamaForCausalLM": ("transformers_neuronx.llama.model",
"LlamaForSampling", "LlamaForCausalLM"), "LlamaForSampling", "LlamaForCausalLM"),
"MistralForCausalLM": ("transformers_neuronx.mistral.model", "MistralForCausalLM": ("transformers_neuronx.mistral.model",
@ -43,11 +43,13 @@ class NeuronCasualLM(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.model = None
self.logits_processor = LogitsProcessor(config.vocab_size, self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True) logits_as_input=True)
self.sampler = Sampler() self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -74,17 +76,17 @@ class NeuronCasualLM(nn.Module):
def load_weights(self, model_name_or_path: str, **kwargs): def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config) arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls, hf_model_cls = ( neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch]) _NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path) neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls) neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
split_model_dir = f"{model_name_or_path}-split" split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path, if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")): "pytorch_model.bin")):
split_model_dir = model_name_or_path split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"): elif not os.path.exists(f"{model_name_or_path}-split"):
hf_model_cls = getattr(transformers, hf_model_cls) hf_model_cls = getattr(transformers, hf_model_cls_name)
from transformers_neuronx.module import save_pretrained_split from transformers_neuronx.module import save_pretrained_split
hf_model = hf_model_cls.from_pretrained(model_name_or_path, hf_model = hf_model_cls.from_pretrained(model_name_or_path,
@ -96,7 +98,7 @@ class NeuronCasualLM(nn.Module):
self.model.to_neuron() self.model.to_neuron()
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
if arch in _NEURON_SUPPORTED_MODELS: if arch in _NEURON_SUPPORTED_MODELS:

View File

@ -167,6 +167,7 @@ class TensorizerArgs:
decryption_params = DecryptionParams.from_key(key) decryption_params = DecryptionParams.from_key(key)
self.deserializer_params['encryption'] = decryption_params self.deserializer_params['encryption'] = decryption_params
@staticmethod
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Tensorizer CLI arguments""" """Tensorizer CLI arguments"""

View File

@ -113,6 +113,8 @@ class SamplingTensors:
get_num_triton_sampler_splits(vocab_size)) get_num_triton_sampler_splits(vocab_size))
sample_indices_start_idx = 0 sample_indices_start_idx = 0
assert sampling_metadata.seq_groups is not None
assert sampling_metadata.seq_data is not None
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature temperature = sampling_params.temperature
@ -147,6 +149,7 @@ class SamplingTensors:
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get # For tokens in the prompt that we only need to get
# their logprobs # their logprobs
assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i] prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1) temperatures += [temperature] * (prompt_len - 1)
top_ps += [top_p] * (prompt_len - 1) top_ps += [top_p] * (prompt_len - 1)
@ -172,6 +175,7 @@ class SamplingTensors:
is_prompt = i < sampling_metadata.num_prompts is_prompt = i < sampling_metadata.num_prompts
if is_prompt: if is_prompt:
prompt_best_of.append(sampling_params.best_of) prompt_best_of.append(sampling_params.best_of)
assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i] prompt_len = sampling_metadata.prompt_lens[i]
if sampling_params.prompt_logprobs is not None: if sampling_params.prompt_logprobs is not None:

View File

@ -106,7 +106,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _expand_batch( def _expand_batch(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids_list: List[TokenId], proposal_token_ids_list: List[List[TokenId]],
proposal_lens_list: List[int], proposal_lens_list: List[int],
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]: ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
"""Given the input sequences and potentially multiple corresponding """Given the input sequences and potentially multiple corresponding
@ -218,7 +218,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _create_target_seq_group_metadata( def _create_target_seq_group_metadata(
self, self,
input_seq_group_metadata: SequenceGroupMetadata, input_seq_group_metadata: SequenceGroupMetadata,
proposal_token_ids: List[TokenId], # shape: [batch_size, k] proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
batch_index: int, batch_index: int,
target_seq_ids_iter: Iterator[TargetSeqId], target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]: ) -> List[SequenceGroupMetadata]:
@ -360,7 +360,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
[0, 1, 2] [0, 1, 2]
[0, 1, 2, 3] [0, 1, 2, 3]
""" """
empty_token_ids = [] empty_token_ids: List[TokenId] = []
token_ids_to_score = [empty_token_ids] token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([ token_ids_to_score.extend([

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional
import torch import torch
@ -73,5 +73,5 @@ class SpeculativeScorer(ABC):
blocks_to_copy: Optional[Dict[int, List[int]]], blocks_to_copy: Optional[Dict[int, List[int]]],
k: int, k: int,
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> SpeculativeScores:
raise NotImplementedError raise NotImplementedError

View File

@ -112,6 +112,7 @@ class AsyncMetricsCollector:
Returns a CUDA event recording when the copy is complete. Returns a CUDA event recording when the copy is complete.
""" """
assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.cuda.current_stream()) self._copy_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._copy_stream): with torch.cuda.stream(self._copy_stream):

View File

@ -26,7 +26,8 @@ class MultiStepWorker(Worker):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._proposer: Optional[DraftModelTop1Proposer] = None # Lazy initialization list.
self._proposer: DraftModelTop1Proposer
def init_device(self): def init_device(self):
super().init_device() super().init_device()
@ -338,10 +339,10 @@ class DraftModelTop1Proposer(SpeculativeProposer):
self._vocab_size, self._vocab_size,
dtype=torch.float32, dtype=torch.float32,
device=self._device) device=self._device)
proposal_lens = torch.zeros(len(proposal_lens), proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long, dtype=torch.long,
device=self._device) device=self._device)
return proposal_tokens, proposal_probs, proposal_lens return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output sampler_output = maybe_sampler_output
@ -376,9 +377,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
proposal_tokens, proposal_probs = (entire_proposal_tokens, proposal_tokens, proposal_probs = (entire_proposal_tokens,
entire_proposal_probs) entire_proposal_probs)
proposal_lens = torch.zeros(batch_size, proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long, dtype=torch.long,
device=self._device) device=self._device)
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
return proposal_tokens, proposal_probs, proposal_lens return proposal_tokens, proposal_probs, proposal_lens_tensor

View File

@ -89,7 +89,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.probs_dtype = self.rejection_sampler.probs_dtype self.probs_dtype = self.rejection_sampler.probs_dtype
self.token_id_dtype = self.rejection_sampler.token_id_dtype self.token_id_dtype = self.rejection_sampler.token_id_dtype
self.scorer: SpeculativeScorer = None # Lazy initiazliation.
self.scorer: SpeculativeScorer
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize both scorer and proposer models. """Initialize both scorer and proposer models.
@ -233,6 +234,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger.info("get spec proposals") logger.info("get spec proposals")
# Generate proposals using draft worker. # Generate proposals using draft worker.
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
proposals = self.proposer_worker.get_spec_proposals( proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k) blocks_to_copy, k)

View File

@ -1,6 +1,7 @@
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
@ -48,14 +49,15 @@ class CPUModelRunner:
if device_config is not None else DeviceConfig()) if device_config is not None else DeviceConfig())
self.device = self.device_config.device self.device = self.device_config.device
self.model = None
self.block_size = None # Set after initial profiling.
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None) self.model_config.dtype if model_config is not None else None)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
self.block_size: int # Set after initial profiling.
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(model_config=self.model_config, self.model = get_model(model_config=self.model_config,
load_config=self.load_config, load_config=self.load_config,
@ -245,7 +247,11 @@ class CPUModelRunner:
selected_token_indices: List[int] = [] selected_token_indices: List[int] = []
generators: List[torch.Generator] = [] generators: List[torch.Generator] = []
selected_token_start_idx = 0 selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0 categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0
@ -262,10 +268,9 @@ class CPUModelRunner:
categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[ categorized_sample_indices[
sampling_params.sampling_type].append([ sampling_params.sampling_type].append(
categorized_sample_indices_start_idx, (categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx categorized_sampled_token_indices_start_idx))
])
categorized_sample_indices_start_idx += 1 categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1 categorized_sampled_token_indices_start_idx += 1
@ -328,7 +333,7 @@ class CPUModelRunner:
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
SamplingMetadata]: SamplingMetadata]:
if self.is_driver_worker: if self.is_driver_worker:
@ -381,7 +386,7 @@ class CPUModelRunner:
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata (input_tokens, input_positions, attn_metadata, sampling_metadata

View File

@ -1,5 +1,5 @@
"""A CPU worker class.""" """A CPU worker class."""
from typing import Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import torch.distributed import torch.distributed
@ -152,8 +152,8 @@ class CPUWorker(LoraNotSupportedWorkerBase):
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine = None self.cache_engine: CPUCacheEngine
self.cpu_cache = None self.cpu_cache: List[torch.Tensor]
def init_device(self) -> None: def init_device(self) -> None:
self.init_distributed_environment() self.init_distributed_environment()
@ -257,13 +257,13 @@ class CPUWorker(LoraNotSupportedWorkerBase):
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list) num_seq_groups: int = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None assert blocks_to_swap_out is not None
assert blocks_to_copy is not None assert blocks_to_copy is not None
assert len(blocks_to_swap_in) == 0 assert len(blocks_to_swap_in) == 0
assert len(blocks_to_swap_out) == 0 assert len(blocks_to_swap_out) == 0
data = { data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups, "num_seq_groups": num_seq_groups,
"blocks_to_copy": blocks_to_copy, "blocks_to_copy": blocks_to_copy,
} }
@ -273,6 +273,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
num_seq_groups = data["num_seq_groups"] num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"] blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_copy is not None
self.cache_copy(blocks_to_copy) self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.

View File

@ -128,23 +128,17 @@ class ModelRunner:
if device_config is not None else DeviceConfig()) if device_config is not None else DeviceConfig())
self.device = self.device_config.device self.device = self.device_config.device
self.model = None # Set after load_model.
self.block_size = None # Set after initial profiling. self.lora_manager: LRUCacheWorkerLoRAManager = None
self.lora_manager = None
self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture. self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture.
self.max_context_len_to_capture = ( self.max_context_len_to_capture = (
self.model_config.max_context_len_to_capture self.model_config.max_context_len_to_capture
if self.model_config is not None else 0) if self.model_config is not None else 0)
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = None # Set after initial profiling.
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.vision_language_config = vision_language_config self.vision_language_config = vision_language_config
@ -152,6 +146,17 @@ class ModelRunner:
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None) self.model_config.dtype if model_config is not None else None)
# Lazy initialization
self.model: torch.nn.Module # Set after load_model
self.block_size: int # Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables: torch.Tensor # Set after initial profiling.
def load_model(self) -> None: def load_model(self) -> None:
with CudaMemoryProfiler() as m: with CudaMemoryProfiler() as m:
self.model = get_model( self.model = get_model(
@ -489,16 +494,16 @@ class ModelRunner:
lora_index_mapping.append(0) lora_index_mapping.append(0)
batch_size = graph_batch_size batch_size = graph_batch_size
context_lens = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
if use_captured_graph: if use_captured_graph:
# When using cuda-graph all these tensors should be # When using cuda-graph all these tensors should be
# padded. # padded.
assert context_lens.shape[0] == len(input_tokens) assert context_lens_tensor.shape[0] == len(input_tokens)
assert context_lens.shape[0] == len(input_positions) assert context_lens_tensor.shape[0] == len(input_positions)
assert context_lens.shape[0] == len(slot_mapping) assert context_lens_tensor.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
@ -527,7 +532,7 @@ class ModelRunner:
max_prompt_len=None, max_prompt_len=None,
subquery_start_loc=None, subquery_start_loc=None,
seq_start_loc=None, seq_start_loc=None,
context_lens=context_lens, context_lens=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
) )
@ -551,7 +556,11 @@ class ModelRunner:
selected_token_indices: List[int] = [] selected_token_indices: List[int] = []
generators: List[torch.Generator] = [] generators: List[torch.Generator] = []
selected_token_start_idx = 0 selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0 categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0
@ -569,10 +578,9 @@ class ModelRunner:
categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[ categorized_sample_indices[
sampling_params.sampling_type].append([ sampling_params.sampling_type].append(
categorized_sample_indices_start_idx, (categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx categorized_sampled_token_indices_start_idx))
])
categorized_sample_indices_start_idx += 1 categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1 categorized_sampled_token_indices_start_idx += 1
@ -596,15 +604,16 @@ class ModelRunner:
categorized_sample_indices[ categorized_sample_indices[
sampling_params.sampling_type].extend( sampling_params.sampling_type].extend(
zip( list(
range( zip(
categorized_sample_indices_start_idx, range(
categorized_sample_indices_start_idx + categorized_sample_indices_start_idx,
num_seqs), categorized_sample_indices_start_idx +
range( num_seqs),
categorized_sampled_token_indices_start_idx, range(
categorized_sampled_token_indices_start_idx + categorized_sampled_token_indices_start_idx,
num_seqs))) categorized_sampled_token_indices_start_idx
+ num_seqs))))
categorized_sample_indices_start_idx += num_seqs categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs categorized_sampled_token_indices_start_idx += num_seqs
@ -641,9 +650,9 @@ class ModelRunner:
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[int], LoRAMapping, torch.Tensor]: Set[LoRARequest], LoRAMapping, torch.Tensor]:
if self.is_driver_worker: if self.is_driver_worker:
prefill_reqs = [] prefill_reqs = []
decode_reqs = [] decode_reqs = []
@ -741,6 +750,7 @@ class ModelRunner:
if prefill_attn_metadata is not None: if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else: else:
assert decode_attn_metadata is not None
metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0) broadcast_tensor_dict(metadata_dict, src=0)
@ -809,7 +819,7 @@ class ModelRunner:
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata, (input_tokens, input_positions, attn_metadata, sampling_metadata,
@ -923,7 +933,7 @@ class ModelRunner:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_all_loras() return self.lora_manager.remove_all_loras()
def set_active_loras(self, lora_requests: List[LoRARequest], def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None: lora_mapping: LoRAMapping) -> None:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
@ -1065,10 +1075,16 @@ class CUDAGraphRunner:
def __init__(self, model: nn.Module): def __init__(self, model: nn.Module):
self.model = model self.model = model
self.graph = None
self.input_buffers: Dict[str, torch.Tensor] = {} self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {}
self._graph: Optional[torch.cuda.CUDAGraph] = None
@property
def graph(self):
assert self._graph is not None
return self._graph
def capture( def capture(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -1078,7 +1094,7 @@ class CUDAGraphRunner:
memory_pool, memory_pool,
**kwargs, **kwargs,
) -> None: ) -> None:
assert self.graph is None assert self._graph is None
# Run the model once without capturing the graph. # Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the # This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune). # kernel launches for initial benchmarking (e.g., Triton autotune).
@ -1095,8 +1111,8 @@ class CUDAGraphRunner:
# Capture the graph. # Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements. # NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph() self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
with _maybe_pynccl(): with _maybe_pynccl():
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,

View File

@ -1,6 +1,7 @@
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from torch import nn
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
@ -34,9 +35,11 @@ class NeuronModelRunner:
self.device_config = (device_config self.device_config = (device_config
if device_config is not None else DeviceConfig()) if device_config is not None else DeviceConfig())
self.device = self.device_config.device self.device = self.device_config.device
self.model = None
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
# Lazy initialization.
self.model: nn.Module # initialize after load_model.
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_neuron_model(self.model_config, self.model = get_neuron_model(self.model_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
@ -147,7 +150,11 @@ class NeuronModelRunner:
selected_token_indices: List[int] = [] selected_token_indices: List[int] = []
generators: List[torch.Generator] = [] generators: List[torch.Generator] = []
selected_token_start_idx = 0 selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0 categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0
@ -165,10 +172,9 @@ class NeuronModelRunner:
categorized_sample_indices_start_idx += prompt_len - 1 categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[ categorized_sample_indices[
sampling_params.sampling_type].append([ sampling_params.sampling_type].append(
categorized_sample_indices_start_idx, (categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx categorized_sampled_token_indices_start_idx))
])
categorized_sample_indices_start_idx += 1 categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1 categorized_sampled_token_indices_start_idx += 1
@ -237,7 +243,7 @@ class NeuronModelRunner:
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
@ -259,7 +265,7 @@ class NeuronModelRunner:
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, input_block_ids, sampling_metadata (input_tokens, input_positions, input_block_ids, sampling_metadata
) = self.prepare_input_tensors(seq_group_metadata_list) ) = self.prepare_input_tensors(seq_group_metadata_list)

View File

@ -1,7 +1,7 @@
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import os import os
from typing import Dict, List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
import torch import torch
import torch.distributed import torch.distributed
@ -82,8 +82,8 @@ class Worker(WorkerBase):
) )
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine = None self.cache_engine: CacheEngine
self.gpu_cache = None self.gpu_cache: List[torch.Tensor]
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
@ -223,7 +223,7 @@ class Worker(WorkerBase):
assert blocks_to_swap_in is not None assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None assert blocks_to_swap_out is not None
assert blocks_to_copy is not None assert blocks_to_copy is not None
data = { data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups, "num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out, "blocks_to_swap_out": blocks_to_swap_out,
@ -237,6 +237,9 @@ class Worker(WorkerBase):
blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"] blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.

View File

@ -1,7 +1,7 @@
import importlib import importlib
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Tuple from typing import Dict, List, Set, Tuple
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -56,7 +56,7 @@ class WorkerBase(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_cache_block_size_bytes() -> int: def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in """Return the size of a single cache block, in bytes. Used in
speculative decoding. speculative decoding.
""" """
@ -71,7 +71,7 @@ class WorkerBase(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
raise NotImplementedError raise NotImplementedError
@ -86,7 +86,7 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA") raise ValueError(f"{type(self)} does not support LoRA")
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA") raise ValueError(f"{type(self)} does not support LoRA")