[Typing] Mypy typing part 2 (#4043)
Co-authored-by: SangBin Cho <sangcho@sangcho-LT93GQWG9C.local>
This commit is contained in:
parent
a53222544c
commit
533d2a1f39
8
.github/workflows/mypy.yaml
vendored
8
.github/workflows/mypy.yaml
vendored
@ -41,10 +41,10 @@ jobs:
|
||||
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# TODO(sang): Follow up
|
||||
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
|
@ -104,10 +104,10 @@ mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
# TODO(sang): Follow up
|
||||
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
|
||||
|
@ -2,8 +2,8 @@ import asyncio
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
|
||||
Set, Tuple, Type, Union)
|
||||
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
|
||||
Optional, Set, Tuple, Type, Union)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@ -52,7 +52,7 @@ class AsyncStream:
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
self._queue = asyncio.Queue()
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item: Union[RequestOutput, Exception]) -> None:
|
||||
@ -312,15 +312,17 @@ class AsyncLLMEngine:
|
||||
self.max_log_len = max_log_len
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
self.background_loop = None
|
||||
self.background_loop: Optional[asyncio.Future] = None
|
||||
# We need to keep a reference to unshielded
|
||||
# task as well to prevent it from being garbage
|
||||
# collected
|
||||
self._background_loop_unshielded = None
|
||||
self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
|
||||
self.start_engine_loop = start_engine_loop
|
||||
self._request_tracker: Optional[RequestTracker] = None
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
# Lazy initialized fields
|
||||
self._request_tracker: RequestTracker
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
@ -361,11 +363,13 @@ class AsyncLLMEngine:
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return (self.background_loop is not None
|
||||
and self._background_loop_unshielded is not None
|
||||
and not self._background_loop_unshielded.done())
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self.errored or (self.background_loop is not None
|
||||
return self.errored or (self.background_loop is not None and
|
||||
self._background_loop_unshielded is not None
|
||||
and self._background_loop_unshielded.done())
|
||||
|
||||
@property
|
||||
@ -381,7 +385,7 @@ class AsyncLLMEngine:
|
||||
|
||||
async def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_tokenizer.remote()
|
||||
return await self.engine.get_tokenizer.remote() # type: ignore
|
||||
else:
|
||||
return self.engine.get_tokenizer()
|
||||
|
||||
@ -434,7 +438,8 @@ class AsyncLLMEngine:
|
||||
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
||||
try:
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(**new_request)
|
||||
await self.engine.add_request.remote( # type: ignore
|
||||
**new_request)
|
||||
else:
|
||||
await self.engine.add_request_async(**new_request)
|
||||
except ValueError as e:
|
||||
@ -449,7 +454,7 @@ class AsyncLLMEngine:
|
||||
await self._engine_abort(finished_requests)
|
||||
|
||||
if self.engine_use_ray:
|
||||
request_outputs = await self.engine.step.remote()
|
||||
request_outputs = await self.engine.step.remote() # type: ignore
|
||||
else:
|
||||
request_outputs = await self.engine.step_async()
|
||||
|
||||
@ -462,7 +467,7 @@ class AsyncLLMEngine:
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
await self.engine.abort_request.remote(request_ids)
|
||||
await self.engine.abort_request.remote(request_ids) # type: ignore
|
||||
else:
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
@ -525,11 +530,12 @@ class AsyncLLMEngine:
|
||||
arrival_time = time.time()
|
||||
|
||||
if self.engine_use_ray:
|
||||
prompt_token_ids = await self.engine.encode_request_async.remote(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request)
|
||||
prompt_token_ids = await (
|
||||
self.engine.encode_request_async.remote( # type: ignore
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request))
|
||||
else:
|
||||
prompt_token_ids = await self.engine.encode_request_async(
|
||||
request_id=request_id,
|
||||
@ -676,13 +682,13 @@ class AsyncLLMEngine:
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_model_config.remote()
|
||||
return await self.engine.get_model_config.remote() # type: ignore
|
||||
else:
|
||||
return self.engine.get_model_config()
|
||||
|
||||
async def do_log_stats(self) -> None:
|
||||
if self.engine_use_ray:
|
||||
await self.engine.do_log_stats.remote()
|
||||
await self.engine.do_log_stats.remote() # type: ignore
|
||||
else:
|
||||
self.engine.do_log_stats()
|
||||
|
||||
@ -695,7 +701,7 @@ class AsyncLLMEngine:
|
||||
|
||||
if self.engine_use_ray:
|
||||
try:
|
||||
await self.engine.check_health.remote()
|
||||
await self.engine.check_health.remote() # type: ignore
|
||||
except ray.exceptions.RayActorError as e:
|
||||
raise RuntimeError("Engine is dead.") from e
|
||||
else:
|
||||
|
@ -107,12 +107,12 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
self._lora_manager: LoRAModelManager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
def set_active_loras(self, lora_requests: List[LoRARequest],
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
self._apply_loras(lora_requests)
|
||||
self._lora_manager.set_lora_mapping(lora_mapping)
|
||||
|
||||
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
|
||||
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
|
||||
loras_that_exist = self.list_loras()
|
||||
loras_map = {
|
||||
lora_request.lora_int_id: lora_request
|
||||
|
@ -55,7 +55,7 @@ global_thread_pool = None # used for generating logits processor fsm
|
||||
|
||||
async def get_outlines_guided_decoding_logits_processor(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
|
||||
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
@ -84,7 +84,7 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
|
||||
def _get_guide_and_mode(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest]
|
||||
) -> Tuple[str, GuidedDecodingMode]:
|
||||
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
|
||||
|
||||
if request.guided_json:
|
||||
json = request.guided_json
|
||||
|
@ -21,7 +21,7 @@ from functools import lru_cache
|
||||
from typing import Callable, DefaultDict, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from outlines.fsm.fsm import CFGFSM, RegexFSM
|
||||
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@ -29,6 +29,10 @@ from transformers import PreTrainedTokenizerBase
|
||||
|
||||
class BaseLogitsProcessor:
|
||||
|
||||
def __init__(self):
|
||||
# Child class should use initialize in their init.
|
||||
self.fsm: FSM
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize the FSM states."""
|
||||
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Utilities for selecting and loading neuron models."""
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional, Type
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -27,7 +27,7 @@ TORCH_DTYPE_TO_NEURON_AMP = {
|
||||
}
|
||||
|
||||
# Models supported by Neuron.
|
||||
_NEURON_SUPPORTED_MODELS = {
|
||||
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
|
||||
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
|
||||
"LlamaForSampling", "LlamaForCausalLM"),
|
||||
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
|
||||
@ -43,11 +43,13 @@ class NeuronCasualLM(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = None
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
logits_as_input=True)
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Lazy initialized
|
||||
self.model: nn.Module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -74,17 +76,17 @@ class NeuronCasualLM(nn.Module):
|
||||
|
||||
def load_weights(self, model_name_or_path: str, **kwargs):
|
||||
arch = _get_model_architecture(self.config)
|
||||
neuronx_module_path, neuronx_model_cls, hf_model_cls = (
|
||||
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
|
||||
_NEURON_SUPPORTED_MODELS[arch])
|
||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
|
||||
split_model_dir = f"{model_name_or_path}-split"
|
||||
if os.path.isdir(os.path.join(model_name_or_path,
|
||||
"pytorch_model.bin")):
|
||||
split_model_dir = model_name_or_path
|
||||
elif not os.path.exists(f"{model_name_or_path}-split"):
|
||||
hf_model_cls = getattr(transformers, hf_model_cls)
|
||||
hf_model_cls = getattr(transformers, hf_model_cls_name)
|
||||
from transformers_neuronx.module import save_pretrained_split
|
||||
|
||||
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
|
||||
@ -96,7 +98,7 @@ class NeuronCasualLM(nn.Module):
|
||||
self.model.to_neuron()
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
def _get_model_architecture(config: PretrainedConfig) -> str:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _NEURON_SUPPORTED_MODELS:
|
||||
|
@ -167,6 +167,7 @@ class TensorizerArgs:
|
||||
decryption_params = DecryptionParams.from_key(key)
|
||||
self.deserializer_params['encryption'] = decryption_params
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
"""Tensorizer CLI arguments"""
|
||||
|
@ -113,6 +113,8 @@ class SamplingTensors:
|
||||
get_num_triton_sampler_splits(vocab_size))
|
||||
|
||||
sample_indices_start_idx = 0
|
||||
assert sampling_metadata.seq_groups is not None
|
||||
assert sampling_metadata.seq_data is not None
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
temperature = sampling_params.temperature
|
||||
@ -147,6 +149,7 @@ class SamplingTensors:
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# For tokens in the prompt that we only need to get
|
||||
# their logprobs
|
||||
assert sampling_metadata.prompt_lens is not None
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
temperatures += [temperature] * (prompt_len - 1)
|
||||
top_ps += [top_p] * (prompt_len - 1)
|
||||
@ -172,6 +175,7 @@ class SamplingTensors:
|
||||
is_prompt = i < sampling_metadata.num_prompts
|
||||
if is_prompt:
|
||||
prompt_best_of.append(sampling_params.best_of)
|
||||
assert sampling_metadata.prompt_lens is not None
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
|
@ -106,7 +106,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
def _expand_batch(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids_list: List[TokenId],
|
||||
proposal_token_ids_list: List[List[TokenId]],
|
||||
proposal_lens_list: List[int],
|
||||
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
|
||||
"""Given the input sequences and potentially multiple corresponding
|
||||
@ -218,7 +218,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
def _create_target_seq_group_metadata(
|
||||
self,
|
||||
input_seq_group_metadata: SequenceGroupMetadata,
|
||||
proposal_token_ids: List[TokenId], # shape: [batch_size, k]
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
batch_index: int,
|
||||
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
@ -360,7 +360,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
[0, 1, 2]
|
||||
[0, 1, 2, 3]
|
||||
"""
|
||||
empty_token_ids = []
|
||||
empty_token_ids: List[TokenId] = []
|
||||
|
||||
token_ids_to_score = [empty_token_ids]
|
||||
token_ids_to_score.extend([
|
||||
|
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -73,5 +73,5 @@ class SpeculativeScorer(ABC):
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
k: int,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> SpeculativeScores:
|
||||
raise NotImplementedError
|
||||
|
@ -112,6 +112,7 @@ class AsyncMetricsCollector:
|
||||
|
||||
Returns a CUDA event recording when the copy is complete.
|
||||
"""
|
||||
assert self._copy_stream is not None
|
||||
self._copy_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
with torch.cuda.stream(self._copy_stream):
|
||||
|
@ -26,7 +26,8 @@ class MultiStepWorker(Worker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._proposer: Optional[DraftModelTop1Proposer] = None
|
||||
# Lazy initialization list.
|
||||
self._proposer: DraftModelTop1Proposer
|
||||
|
||||
def init_device(self):
|
||||
super().init_device()
|
||||
@ -338,10 +339,10 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
self._vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=self._device)
|
||||
proposal_lens = torch.zeros(len(proposal_lens),
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
return proposal_tokens, proposal_probs, proposal_lens
|
||||
proposal_lens_tensor = torch.zeros(len(proposal_lens),
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
|
||||
@ -376,9 +377,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
proposal_tokens, proposal_probs = (entire_proposal_tokens,
|
||||
entire_proposal_probs)
|
||||
|
||||
proposal_lens = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
|
||||
proposal_lens_tensor = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
@ -89,7 +89,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.probs_dtype = self.rejection_sampler.probs_dtype
|
||||
self.token_id_dtype = self.rejection_sampler.token_id_dtype
|
||||
|
||||
self.scorer: SpeculativeScorer = None
|
||||
# Lazy initiazliation.
|
||||
self.scorer: SpeculativeScorer
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize both scorer and proposer models.
|
||||
@ -233,6 +234,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
logger.info("get spec proposals")
|
||||
# Generate proposals using draft worker.
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
assert blocks_to_copy is not None
|
||||
proposals = self.proposer_worker.get_spec_proposals(
|
||||
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||
blocks_to_copy, k)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
@ -48,14 +49,15 @@ class CPUModelRunner:
|
||||
if device_config is not None else DeviceConfig())
|
||||
self.device = self.device_config.device
|
||||
|
||||
self.model = None
|
||||
self.block_size = None # Set after initial profiling.
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.dtype if model_config is not None else None)
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
self.block_size: int # Set after initial profiling.
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
@ -245,7 +247,11 @@ class CPUModelRunner:
|
||||
selected_token_indices: List[int] = []
|
||||
generators: List[torch.Generator] = []
|
||||
selected_token_start_idx = 0
|
||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices: Dict[SamplingType,
|
||||
List[Tuple[int, int]]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
categorized_sampled_token_indices_start_idx = 0
|
||||
|
||||
@ -262,10 +268,9 @@ class CPUModelRunner:
|
||||
categorized_sample_indices_start_idx += subquery_len - 1
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].append([
|
||||
categorized_sample_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx
|
||||
])
|
||||
sampling_params.sampling_type].append(
|
||||
(categorized_sample_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx))
|
||||
categorized_sample_indices_start_idx += 1
|
||||
categorized_sampled_token_indices_start_idx += 1
|
||||
|
||||
@ -328,7 +333,7 @@ class CPUModelRunner:
|
||||
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
|
||||
SamplingMetadata]:
|
||||
if self.is_driver_worker:
|
||||
@ -381,7 +386,7 @@ class CPUModelRunner:
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[SamplerOutput]:
|
||||
(input_tokens, input_positions, attn_metadata, sampling_metadata
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""A CPU worker class."""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -152,8 +152,8 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
is_driver_worker=is_driver_worker)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine = None
|
||||
self.cpu_cache = None
|
||||
self.cache_engine: CPUCacheEngine
|
||||
self.cpu_cache: List[torch.Tensor]
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.init_distributed_environment()
|
||||
@ -257,13 +257,13 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
) -> List[SamplerOutput]:
|
||||
if self.is_driver_worker:
|
||||
assert seq_group_metadata_list is not None
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
num_seq_groups: int = len(seq_group_metadata_list)
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
assert blocks_to_copy is not None
|
||||
assert len(blocks_to_swap_in) == 0
|
||||
assert len(blocks_to_swap_out) == 0
|
||||
data = {
|
||||
data: Dict[str, Any] = {
|
||||
"num_seq_groups": num_seq_groups,
|
||||
"blocks_to_copy": blocks_to_copy,
|
||||
}
|
||||
@ -273,6 +273,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
num_seq_groups = data["num_seq_groups"]
|
||||
blocks_to_copy = data["blocks_to_copy"]
|
||||
|
||||
assert blocks_to_copy is not None
|
||||
self.cache_copy(blocks_to_copy)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
|
@ -128,23 +128,17 @@ class ModelRunner:
|
||||
if device_config is not None else DeviceConfig())
|
||||
self.device = self.device_config.device
|
||||
|
||||
self.model = None
|
||||
self.block_size = None # Set after initial profiling.
|
||||
self.lora_manager = None
|
||||
# Set after load_model.
|
||||
self.lora_manager: LRUCacheWorkerLoRAManager = None
|
||||
|
||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||
self.graph_memory_pool = None # Set during graph capture.
|
||||
self.graph_memory_pool: Optional[Tuple[
|
||||
int, int]] = None # Set during graph capture.
|
||||
|
||||
self.max_context_len_to_capture = (
|
||||
self.model_config.max_context_len_to_capture
|
||||
if self.model_config is not None else 0)
|
||||
# When using CUDA graph, the input block tables must be padded to
|
||||
# max_context_len_to_capture. However, creating the block table in
|
||||
# Python can be expensive. To optimize this, we cache the block table
|
||||
# in numpy and only copy the actual input content at every iteration.
|
||||
# The shape of the cached block table will be
|
||||
# (max batch size to capture, max context len to capture / block size).
|
||||
self.graph_block_tables = None # Set after initial profiling.
|
||||
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.vision_language_config = vision_language_config
|
||||
@ -152,6 +146,17 @@ class ModelRunner:
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.dtype if model_config is not None else None)
|
||||
|
||||
# Lazy initialization
|
||||
self.model: torch.nn.Module # Set after load_model
|
||||
self.block_size: int # Set after initial profiling.
|
||||
# When using CUDA graph, the input block tables must be padded to
|
||||
# max_context_len_to_capture. However, creating the block table in
|
||||
# Python can be expensive. To optimize this, we cache the block table
|
||||
# in numpy and only copy the actual input content at every iteration.
|
||||
# The shape of the cached block table will be
|
||||
# (max batch size to capture, max context len to capture / block size).
|
||||
self.graph_block_tables: torch.Tensor # Set after initial profiling.
|
||||
|
||||
def load_model(self) -> None:
|
||||
with CudaMemoryProfiler() as m:
|
||||
self.model = get_model(
|
||||
@ -489,16 +494,16 @@ class ModelRunner:
|
||||
lora_index_mapping.append(0)
|
||||
batch_size = graph_batch_size
|
||||
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
if use_captured_graph:
|
||||
# When using cuda-graph all these tensors should be
|
||||
# padded.
|
||||
assert context_lens.shape[0] == len(input_tokens)
|
||||
assert context_lens.shape[0] == len(input_positions)
|
||||
assert context_lens.shape[0] == len(slot_mapping)
|
||||
assert context_lens_tensor.shape[0] == len(input_tokens)
|
||||
assert context_lens_tensor.shape[0] == len(input_positions)
|
||||
assert context_lens_tensor.shape[0] == len(slot_mapping)
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
@ -527,7 +532,7 @@ class ModelRunner:
|
||||
max_prompt_len=None,
|
||||
subquery_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens=context_lens,
|
||||
context_lens=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
@ -551,7 +556,11 @@ class ModelRunner:
|
||||
selected_token_indices: List[int] = []
|
||||
generators: List[torch.Generator] = []
|
||||
selected_token_start_idx = 0
|
||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices: Dict[SamplingType,
|
||||
List[Tuple[int, int]]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
categorized_sampled_token_indices_start_idx = 0
|
||||
|
||||
@ -569,10 +578,9 @@ class ModelRunner:
|
||||
categorized_sample_indices_start_idx += subquery_len - 1
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].append([
|
||||
categorized_sample_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx
|
||||
])
|
||||
sampling_params.sampling_type].append(
|
||||
(categorized_sample_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx))
|
||||
categorized_sample_indices_start_idx += 1
|
||||
categorized_sampled_token_indices_start_idx += 1
|
||||
|
||||
@ -596,15 +604,16 @@ class ModelRunner:
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].extend(
|
||||
zip(
|
||||
range(
|
||||
categorized_sample_indices_start_idx,
|
||||
categorized_sample_indices_start_idx +
|
||||
num_seqs),
|
||||
range(
|
||||
categorized_sampled_token_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx +
|
||||
num_seqs)))
|
||||
list(
|
||||
zip(
|
||||
range(
|
||||
categorized_sample_indices_start_idx,
|
||||
categorized_sample_indices_start_idx +
|
||||
num_seqs),
|
||||
range(
|
||||
categorized_sampled_token_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx
|
||||
+ num_seqs))))
|
||||
categorized_sample_indices_start_idx += num_seqs
|
||||
categorized_sampled_token_indices_start_idx += num_seqs
|
||||
|
||||
@ -641,9 +650,9 @@ class ModelRunner:
|
||||
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||
Set[int], LoRAMapping, torch.Tensor]:
|
||||
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
||||
if self.is_driver_worker:
|
||||
prefill_reqs = []
|
||||
decode_reqs = []
|
||||
@ -741,6 +750,7 @@ class ModelRunner:
|
||||
if prefill_attn_metadata is not None:
|
||||
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
||||
else:
|
||||
assert decode_attn_metadata is not None
|
||||
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
|
||||
@ -809,7 +819,7 @@ class ModelRunner:
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[SamplerOutput]:
|
||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
||||
@ -923,7 +933,7 @@ class ModelRunner:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.remove_all_loras()
|
||||
|
||||
def set_active_loras(self, lora_requests: List[LoRARequest],
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
@ -1065,10 +1075,16 @@ class CUDAGraphRunner:
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
self.model = model
|
||||
self.graph = None
|
||||
self.input_buffers: Dict[str, torch.Tensor] = {}
|
||||
self.output_buffers: Dict[str, torch.Tensor] = {}
|
||||
|
||||
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
assert self._graph is not None
|
||||
return self._graph
|
||||
|
||||
def capture(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1078,7 +1094,7 @@ class CUDAGraphRunner:
|
||||
memory_pool,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
assert self.graph is None
|
||||
assert self._graph is None
|
||||
# Run the model once without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||
@ -1095,8 +1111,8 @@ class CUDAGraphRunner:
|
||||
# Capture the graph.
|
||||
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
|
||||
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
|
||||
self._graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
|
||||
with _maybe_pynccl():
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
@ -34,9 +35,11 @@ class NeuronModelRunner:
|
||||
self.device_config = (device_config
|
||||
if device_config is not None else DeviceConfig())
|
||||
self.device = self.device_config.device
|
||||
self.model = None
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # initialize after load_model.
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_neuron_model(self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
@ -147,7 +150,11 @@ class NeuronModelRunner:
|
||||
selected_token_indices: List[int] = []
|
||||
generators: List[torch.Generator] = []
|
||||
selected_token_start_idx = 0
|
||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices: Dict[SamplingType,
|
||||
List[Tuple[int, int]]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
categorized_sampled_token_indices_start_idx = 0
|
||||
|
||||
@ -165,10 +172,9 @@ class NeuronModelRunner:
|
||||
categorized_sample_indices_start_idx += prompt_len - 1
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].append([
|
||||
categorized_sample_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx
|
||||
])
|
||||
sampling_params.sampling_type].append(
|
||||
(categorized_sample_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx))
|
||||
categorized_sample_indices_start_idx += 1
|
||||
categorized_sampled_token_indices_start_idx += 1
|
||||
|
||||
@ -237,7 +243,7 @@ class NeuronModelRunner:
|
||||
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
@ -259,7 +265,7 @@ class NeuronModelRunner:
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Optional[SamplerOutput]:
|
||||
(input_tokens, input_positions, input_block_ids, sampling_metadata
|
||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -82,8 +82,8 @@ class Worker(WorkerBase):
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine = None
|
||||
self.gpu_cache = None
|
||||
self.cache_engine: CacheEngine
|
||||
self.gpu_cache: List[torch.Tensor]
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "cuda":
|
||||
@ -223,7 +223,7 @@ class Worker(WorkerBase):
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
assert blocks_to_copy is not None
|
||||
data = {
|
||||
data: Dict[str, Any] = {
|
||||
"num_seq_groups": num_seq_groups,
|
||||
"blocks_to_swap_in": blocks_to_swap_in,
|
||||
"blocks_to_swap_out": blocks_to_swap_out,
|
||||
@ -237,6 +237,9 @@ class Worker(WorkerBase):
|
||||
blocks_to_swap_out = data["blocks_to_swap_out"]
|
||||
blocks_to_copy = data["blocks_to_copy"]
|
||||
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
assert blocks_to_copy is not None
|
||||
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
|
@ -1,7 +1,7 @@
|
||||
import importlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -56,7 +56,7 @@ class WorkerBase(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_cache_block_size_bytes() -> int:
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Return the size of a single cache block, in bytes. Used in
|
||||
speculative decoding.
|
||||
"""
|
||||
@ -71,7 +71,7 @@ class WorkerBase(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list_loras(self) -> List[int]:
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -86,7 +86,7 @@ class LoraNotSupportedWorkerBase(WorkerBase):
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user