[Typing] Mypy typing part 2 (#4043)

Co-authored-by: SangBin Cho <sangcho@sangcho-LT93GQWG9C.local>
This commit is contained in:
SangBin Cho 2024-04-18 09:28:43 +09:00 committed by GitHub
parent a53222544c
commit 533d2a1f39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 180 additions and 126 deletions

View File

@ -41,10 +41,10 @@ jobs:
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml

View File

@ -104,10 +104,10 @@ mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml

View File

@ -2,8 +2,8 @@ import asyncio
import os
import time
from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
Set, Tuple, Type, Union)
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
@ -52,7 +52,7 @@ class AsyncStream:
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._queue = asyncio.Queue()
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, Exception]) -> None:
@ -312,15 +312,17 @@ class AsyncLLMEngine:
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)
self.background_loop = None
self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self._background_loop_unshielded = None
self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
self.start_engine_loop = start_engine_loop
self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None
# Lazy initialized fields
self._request_tracker: RequestTracker
@classmethod
def from_engine_args(
cls,
@ -361,11 +363,13 @@ class AsyncLLMEngine:
@property
def is_running(self) -> bool:
return (self.background_loop is not None
and self._background_loop_unshielded is not None
and not self._background_loop_unshielded.done())
@property
def is_stopped(self) -> bool:
return self.errored or (self.background_loop is not None
return self.errored or (self.background_loop is not None and
self._background_loop_unshielded is not None
and self._background_loop_unshielded.done())
@property
@ -381,7 +385,7 @@ class AsyncLLMEngine:
async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote()
return await self.engine.get_tokenizer.remote() # type: ignore
else:
return self.engine.get_tokenizer()
@ -434,7 +438,8 @@ class AsyncLLMEngine:
# TODO: Maybe add add_request_batch to reduce Ray overhead
try:
if self.engine_use_ray:
await self.engine.add_request.remote(**new_request)
await self.engine.add_request.remote( # type: ignore
**new_request)
else:
await self.engine.add_request_async(**new_request)
except ValueError as e:
@ -449,7 +454,7 @@ class AsyncLLMEngine:
await self._engine_abort(finished_requests)
if self.engine_use_ray:
request_outputs = await self.engine.step.remote()
request_outputs = await self.engine.step.remote() # type: ignore
else:
request_outputs = await self.engine.step_async()
@ -462,7 +467,7 @@ class AsyncLLMEngine:
async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids)
await self.engine.abort_request.remote(request_ids) # type: ignore
else:
self.engine.abort_request(request_ids)
@ -525,11 +530,12 @@ class AsyncLLMEngine:
arrival_time = time.time()
if self.engine_use_ray:
prompt_token_ids = await self.engine.encode_request_async.remote(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
prompt_token_ids = await (
self.engine.encode_request_async.remote( # type: ignore
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request))
else:
prompt_token_ids = await self.engine.encode_request_async(
request_id=request_id,
@ -676,13 +682,13 @@ class AsyncLLMEngine:
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_model_config.remote()
return await self.engine.get_model_config.remote() # type: ignore
else:
return self.engine.get_model_config()
async def do_log_stats(self) -> None:
if self.engine_use_ray:
await self.engine.do_log_stats.remote()
await self.engine.do_log_stats.remote() # type: ignore
else:
self.engine.do_log_stats()
@ -695,7 +701,7 @@ class AsyncLLMEngine:
if self.engine_use_ray:
try:
await self.engine.check_health.remote()
await self.engine.check_health.remote() # type: ignore
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else:

View File

@ -107,12 +107,12 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
self._lora_manager: LoRAModelManager = lora_manager
return lora_manager.model
def set_active_loras(self, lora_requests: List[LoRARequest],
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
self._apply_loras(lora_requests)
self._lora_manager.set_lora_mapping(lora_mapping)
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_that_exist = self.list_loras()
loras_map = {
lora_request.lora_int_id: lora_request

View File

@ -55,7 +55,7 @@ global_thread_pool = None # used for generating logits processor fsm
async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
@ -84,7 +84,7 @@ async def get_outlines_guided_decoding_logits_processor(
def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest]
) -> Tuple[str, GuidedDecodingMode]:
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
if request.guided_json:
json = request.guided_json

View File

@ -21,7 +21,7 @@ from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union
import torch
from outlines.fsm.fsm import CFGFSM, RegexFSM
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
@ -29,6 +29,10 @@ from transformers import PreTrainedTokenizerBase
class BaseLogitsProcessor:
def __init__(self):
# Child class should use initialize in their init.
self.fsm: FSM
def init_state(self):
"""Initialize the FSM states."""
self.fsm_state: DefaultDict[int, int] = defaultdict(int)

View File

@ -1,7 +1,7 @@
"""Utilities for selecting and loading neuron models."""
import importlib
import os
from typing import Optional, Type
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
@ -27,7 +27,7 @@ TORCH_DTYPE_TO_NEURON_AMP = {
}
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS = {
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
"LlamaForSampling", "LlamaForCausalLM"),
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
@ -43,11 +43,13 @@ class NeuronCasualLM(nn.Module):
) -> None:
super().__init__()
self.config = config
self.model = None
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
def forward(
self,
input_ids: torch.Tensor,
@ -74,17 +76,17 @@ class NeuronCasualLM(nn.Module):
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls, hf_model_cls = (
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
hf_model_cls = getattr(transformers, hf_model_cls)
hf_model_cls = getattr(transformers, hf_model_cls_name)
from transformers_neuronx.module import save_pretrained_split
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
@ -96,7 +98,7 @@ class NeuronCasualLM(nn.Module):
self.model.to_neuron()
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _NEURON_SUPPORTED_MODELS:

View File

@ -167,6 +167,7 @@ class TensorizerArgs:
decryption_params = DecryptionParams.from_key(key)
self.deserializer_params['encryption'] = decryption_params
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Tensorizer CLI arguments"""

View File

@ -113,6 +113,8 @@ class SamplingTensors:
get_num_triton_sampler_splits(vocab_size))
sample_indices_start_idx = 0
assert sampling_metadata.seq_groups is not None
assert sampling_metadata.seq_data is not None
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
@ -147,6 +149,7 @@ class SamplingTensors:
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get
# their logprobs
assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
top_ps += [top_p] * (prompt_len - 1)
@ -172,6 +175,7 @@ class SamplingTensors:
is_prompt = i < sampling_metadata.num_prompts
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i]
if sampling_params.prompt_logprobs is not None:

View File

@ -106,7 +106,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _expand_batch(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids_list: List[TokenId],
proposal_token_ids_list: List[List[TokenId]],
proposal_lens_list: List[int],
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
"""Given the input sequences and potentially multiple corresponding
@ -218,7 +218,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _create_target_seq_group_metadata(
self,
input_seq_group_metadata: SequenceGroupMetadata,
proposal_token_ids: List[TokenId], # shape: [batch_size, k]
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
batch_index: int,
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
@ -360,7 +360,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
[0, 1, 2]
[0, 1, 2, 3]
"""
empty_token_ids = []
empty_token_ids: List[TokenId] = []
token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
import torch
@ -73,5 +73,5 @@ class SpeculativeScorer(ABC):
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> SpeculativeScores:
raise NotImplementedError

View File

@ -112,6 +112,7 @@ class AsyncMetricsCollector:
Returns a CUDA event recording when the copy is complete.
"""
assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._copy_stream):

View File

@ -26,7 +26,8 @@ class MultiStepWorker(Worker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._proposer: Optional[DraftModelTop1Proposer] = None
# Lazy initialization list.
self._proposer: DraftModelTop1Proposer
def init_device(self):
super().init_device()
@ -338,10 +339,10 @@ class DraftModelTop1Proposer(SpeculativeProposer):
self._vocab_size,
dtype=torch.float32,
device=self._device)
proposal_lens = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens
proposal_lens_tensor = torch.zeros(len(proposal_lens),
dtype=torch.long,
device=self._device)
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
@ -376,9 +377,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
proposal_tokens, proposal_probs = (entire_proposal_tokens,
entire_proposal_probs)
proposal_lens = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
return proposal_tokens, proposal_probs, proposal_lens
return proposal_tokens, proposal_probs, proposal_lens_tensor

View File

@ -89,7 +89,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.probs_dtype = self.rejection_sampler.probs_dtype
self.token_id_dtype = self.rejection_sampler.token_id_dtype
self.scorer: SpeculativeScorer = None
# Lazy initiazliation.
self.scorer: SpeculativeScorer
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
@ -233,6 +234,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger.info("get spec proposals")
# Generate proposals using draft worker.
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)

View File

@ -1,6 +1,7 @@
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
@ -48,14 +49,15 @@ class CPUModelRunner:
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.model = None
self.block_size = None # Set after initial profiling.
self.kv_cache_dtype = kv_cache_dtype
self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
self.block_size: int # Set after initial profiling.
def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
@ -245,7 +247,11 @@ class CPUModelRunner:
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
@ -262,10 +268,9 @@ class CPUModelRunner:
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append([
categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx
])
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
@ -328,7 +333,7 @@ class CPUModelRunner:
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
SamplingMetadata]:
if self.is_driver_worker:
@ -381,7 +386,7 @@ class CPUModelRunner:
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata

View File

@ -1,5 +1,5 @@
"""A CPU worker class."""
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.distributed
@ -152,8 +152,8 @@ class CPUWorker(LoraNotSupportedWorkerBase):
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine = None
self.cpu_cache = None
self.cache_engine: CPUCacheEngine
self.cpu_cache: List[torch.Tensor]
def init_device(self) -> None:
self.init_distributed_environment()
@ -257,13 +257,13 @@ class CPUWorker(LoraNotSupportedWorkerBase):
) -> List[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
num_seq_groups: int = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
assert len(blocks_to_swap_in) == 0
assert len(blocks_to_swap_out) == 0
data = {
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_copy": blocks_to_copy,
}
@ -273,6 +273,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_copy is not None
self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model.

View File

@ -128,23 +128,17 @@ class ModelRunner:
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.model = None
self.block_size = None # Set after initial profiling.
self.lora_manager = None
# Set after load_model.
self.lora_manager: LRUCacheWorkerLoRAManager = None
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture.
self.max_context_len_to_capture = (
self.model_config.max_context_len_to_capture
if self.model_config is not None else 0)
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = None # Set after initial profiling.
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype
self.vision_language_config = vision_language_config
@ -152,6 +146,17 @@ class ModelRunner:
self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None)
# Lazy initialization
self.model: torch.nn.Module # Set after load_model
self.block_size: int # Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables: torch.Tensor # Set after initial profiling.
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(
@ -489,16 +494,16 @@ class ModelRunner:
lora_index_mapping.append(0)
batch_size = graph_batch_size
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
if use_captured_graph:
# When using cuda-graph all these tensors should be
# padded.
assert context_lens.shape[0] == len(input_tokens)
assert context_lens.shape[0] == len(input_positions)
assert context_lens.shape[0] == len(slot_mapping)
assert context_lens_tensor.shape[0] == len(input_tokens)
assert context_lens_tensor.shape[0] == len(input_positions)
assert context_lens_tensor.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
@ -527,7 +532,7 @@ class ModelRunner:
max_prompt_len=None,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens,
context_lens=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
@ -551,7 +556,11 @@ class ModelRunner:
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
@ -569,10 +578,9 @@ class ModelRunner:
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append([
categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx
])
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
@ -596,15 +604,16 @@ class ModelRunner:
categorized_sample_indices[
sampling_params.sampling_type].extend(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx +
num_seqs)))
list(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx
+ num_seqs))))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
@ -641,9 +650,9 @@ class ModelRunner:
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[int], LoRAMapping, torch.Tensor]:
Set[LoRARequest], LoRAMapping, torch.Tensor]:
if self.is_driver_worker:
prefill_reqs = []
decode_reqs = []
@ -741,6 +750,7 @@ class ModelRunner:
if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else:
assert decode_attn_metadata is not None
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
@ -809,7 +819,7 @@ class ModelRunner:
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
@ -923,7 +933,7 @@ class ModelRunner:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_all_loras()
def set_active_loras(self, lora_requests: List[LoRARequest],
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
@ -1065,10 +1075,16 @@ class CUDAGraphRunner:
def __init__(self, model: nn.Module):
self.model = model
self.graph = None
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self._graph: Optional[torch.cuda.CUDAGraph] = None
@property
def graph(self):
assert self._graph is not None
return self._graph
def capture(
self,
input_ids: torch.Tensor,
@ -1078,7 +1094,7 @@ class CUDAGraphRunner:
memory_pool,
**kwargs,
) -> None:
assert self.graph is None
assert self._graph is None
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
@ -1095,8 +1111,8 @@ class CUDAGraphRunner:
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
with _maybe_pynccl():
hidden_states = self.model(
input_ids,

View File

@ -1,6 +1,7 @@
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
@ -34,9 +35,11 @@ class NeuronModelRunner:
self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.model = None
self.pin_memory = is_pin_memory_available()
# Lazy initialization.
self.model: nn.Module # initialize after load_model.
def load_model(self) -> None:
self.model = get_neuron_model(self.model_config,
parallel_config=self.parallel_config,
@ -147,7 +150,11 @@ class NeuronModelRunner:
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
@ -165,10 +172,9 @@ class NeuronModelRunner:
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append([
categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx
])
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
@ -237,7 +243,7 @@ class NeuronModelRunner:
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
@ -259,7 +265,7 @@ class NeuronModelRunner:
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, input_block_ids, sampling_metadata
) = self.prepare_input_tensors(seq_group_metadata_list)

View File

@ -1,7 +1,7 @@
"""A GPU worker class."""
import gc
import os
from typing import Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple
import torch
import torch.distributed
@ -82,8 +82,8 @@ class Worker(WorkerBase):
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine = None
self.gpu_cache = None
self.cache_engine: CacheEngine
self.gpu_cache: List[torch.Tensor]
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
@ -223,7 +223,7 @@ class Worker(WorkerBase):
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
data = {
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
@ -237,6 +237,9 @@ class Worker(WorkerBase):
blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.

View File

@ -1,7 +1,7 @@
import importlib
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
from typing import Dict, List, Set, Tuple
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -56,7 +56,7 @@ class WorkerBase(ABC):
raise NotImplementedError
@abstractmethod
def get_cache_block_size_bytes() -> int:
def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
@ -71,7 +71,7 @@ class WorkerBase(ABC):
raise NotImplementedError
@abstractmethod
def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
raise NotImplementedError
@ -86,7 +86,7 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")