[Typing] Mypy typing part 2 (#4043)
Co-authored-by: SangBin Cho <sangcho@sangcho-LT93GQWG9C.local>
This commit is contained in:
parent
a53222544c
commit
533d2a1f39
8
.github/workflows/mypy.yaml
vendored
8
.github/workflows/mypy.yaml
vendored
@ -41,10 +41,10 @@ jobs:
|
|||||||
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
|
||||||
|
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
# TODO(sang): Follow up
|
# TODO(sang): Follow up
|
||||||
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
|
||||||
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
|
||||||
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
|
|
||||||
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
|
||||||
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
|
||||||
|
@ -104,10 +104,10 @@ mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
|||||||
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
|
||||||
# TODO(sang): Follow up
|
# TODO(sang): Follow up
|
||||||
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,8 +2,8 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
|
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
|
||||||
Set, Tuple, Type, Union)
|
Optional, Set, Tuple, Type, Union)
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ class AsyncStream:
|
|||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: str) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self._queue = asyncio.Queue()
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
self._finished = False
|
self._finished = False
|
||||||
|
|
||||||
def put(self, item: Union[RequestOutput, Exception]) -> None:
|
def put(self, item: Union[RequestOutput, Exception]) -> None:
|
||||||
@ -312,15 +312,17 @@ class AsyncLLMEngine:
|
|||||||
self.max_log_len = max_log_len
|
self.max_log_len = max_log_len
|
||||||
self.engine = self._init_engine(*args, **kwargs)
|
self.engine = self._init_engine(*args, **kwargs)
|
||||||
|
|
||||||
self.background_loop = None
|
self.background_loop: Optional[asyncio.Future] = None
|
||||||
# We need to keep a reference to unshielded
|
# We need to keep a reference to unshielded
|
||||||
# task as well to prevent it from being garbage
|
# task as well to prevent it from being garbage
|
||||||
# collected
|
# collected
|
||||||
self._background_loop_unshielded = None
|
self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
|
||||||
self.start_engine_loop = start_engine_loop
|
self.start_engine_loop = start_engine_loop
|
||||||
self._request_tracker: Optional[RequestTracker] = None
|
|
||||||
self._errored_with: Optional[BaseException] = None
|
self._errored_with: Optional[BaseException] = None
|
||||||
|
|
||||||
|
# Lazy initialized fields
|
||||||
|
self._request_tracker: RequestTracker
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_args(
|
def from_engine_args(
|
||||||
cls,
|
cls,
|
||||||
@ -361,11 +363,13 @@ class AsyncLLMEngine:
|
|||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
return (self.background_loop is not None
|
return (self.background_loop is not None
|
||||||
|
and self._background_loop_unshielded is not None
|
||||||
and not self._background_loop_unshielded.done())
|
and not self._background_loop_unshielded.done())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_stopped(self) -> bool:
|
def is_stopped(self) -> bool:
|
||||||
return self.errored or (self.background_loop is not None
|
return self.errored or (self.background_loop is not None and
|
||||||
|
self._background_loop_unshielded is not None
|
||||||
and self._background_loop_unshielded.done())
|
and self._background_loop_unshielded.done())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -381,7 +385,7 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
async def get_tokenizer(self) -> "PreTrainedTokenizer":
|
async def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
return await self.engine.get_tokenizer.remote()
|
return await self.engine.get_tokenizer.remote() # type: ignore
|
||||||
else:
|
else:
|
||||||
return self.engine.get_tokenizer()
|
return self.engine.get_tokenizer()
|
||||||
|
|
||||||
@ -434,7 +438,8 @@ class AsyncLLMEngine:
|
|||||||
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
||||||
try:
|
try:
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
await self.engine.add_request.remote(**new_request)
|
await self.engine.add_request.remote( # type: ignore
|
||||||
|
**new_request)
|
||||||
else:
|
else:
|
||||||
await self.engine.add_request_async(**new_request)
|
await self.engine.add_request_async(**new_request)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -449,7 +454,7 @@ class AsyncLLMEngine:
|
|||||||
await self._engine_abort(finished_requests)
|
await self._engine_abort(finished_requests)
|
||||||
|
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
request_outputs = await self.engine.step.remote()
|
request_outputs = await self.engine.step.remote() # type: ignore
|
||||||
else:
|
else:
|
||||||
request_outputs = await self.engine.step_async()
|
request_outputs = await self.engine.step_async()
|
||||||
|
|
||||||
@ -462,7 +467,7 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
await self.engine.abort_request.remote(request_ids)
|
await self.engine.abort_request.remote(request_ids) # type: ignore
|
||||||
else:
|
else:
|
||||||
self.engine.abort_request(request_ids)
|
self.engine.abort_request(request_ids)
|
||||||
|
|
||||||
@ -525,11 +530,12 @@ class AsyncLLMEngine:
|
|||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
prompt_token_ids = await self.engine.encode_request_async.remote(
|
prompt_token_ids = await (
|
||||||
|
self.engine.encode_request_async.remote( # type: ignore
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
lora_request=lora_request)
|
lora_request=lora_request))
|
||||||
else:
|
else:
|
||||||
prompt_token_ids = await self.engine.encode_request_async(
|
prompt_token_ids = await self.engine.encode_request_async(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
@ -676,13 +682,13 @@ class AsyncLLMEngine:
|
|||||||
async def get_model_config(self) -> ModelConfig:
|
async def get_model_config(self) -> ModelConfig:
|
||||||
"""Get the model configuration of the vLLM engine."""
|
"""Get the model configuration of the vLLM engine."""
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
return await self.engine.get_model_config.remote()
|
return await self.engine.get_model_config.remote() # type: ignore
|
||||||
else:
|
else:
|
||||||
return self.engine.get_model_config()
|
return self.engine.get_model_config()
|
||||||
|
|
||||||
async def do_log_stats(self) -> None:
|
async def do_log_stats(self) -> None:
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
await self.engine.do_log_stats.remote()
|
await self.engine.do_log_stats.remote() # type: ignore
|
||||||
else:
|
else:
|
||||||
self.engine.do_log_stats()
|
self.engine.do_log_stats()
|
||||||
|
|
||||||
@ -695,7 +701,7 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
try:
|
try:
|
||||||
await self.engine.check_health.remote()
|
await self.engine.check_health.remote() # type: ignore
|
||||||
except ray.exceptions.RayActorError as e:
|
except ray.exceptions.RayActorError as e:
|
||||||
raise RuntimeError("Engine is dead.") from e
|
raise RuntimeError("Engine is dead.") from e
|
||||||
else:
|
else:
|
||||||
|
@ -107,12 +107,12 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
|||||||
self._lora_manager: LoRAModelManager = lora_manager
|
self._lora_manager: LoRAModelManager = lora_manager
|
||||||
return lora_manager.model
|
return lora_manager.model
|
||||||
|
|
||||||
def set_active_loras(self, lora_requests: List[LoRARequest],
|
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||||
lora_mapping: LoRAMapping) -> None:
|
lora_mapping: LoRAMapping) -> None:
|
||||||
self._apply_loras(lora_requests)
|
self._apply_loras(lora_requests)
|
||||||
self._lora_manager.set_lora_mapping(lora_mapping)
|
self._lora_manager.set_lora_mapping(lora_mapping)
|
||||||
|
|
||||||
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
|
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
|
||||||
loras_that_exist = self.list_loras()
|
loras_that_exist = self.list_loras()
|
||||||
loras_map = {
|
loras_map = {
|
||||||
lora_request.lora_int_id: lora_request
|
lora_request.lora_int_id: lora_request
|
||||||
|
@ -55,7 +55,7 @@ global_thread_pool = None # used for generating logits processor fsm
|
|||||||
|
|
||||||
async def get_outlines_guided_decoding_logits_processor(
|
async def get_outlines_guided_decoding_logits_processor(
|
||||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||||
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
|
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
||||||
"""
|
"""
|
||||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||||
and get the necessary logits processor for the given guide.
|
and get the necessary logits processor for the given guide.
|
||||||
@ -84,7 +84,7 @@ async def get_outlines_guided_decoding_logits_processor(
|
|||||||
|
|
||||||
def _get_guide_and_mode(
|
def _get_guide_and_mode(
|
||||||
request: Union[CompletionRequest, ChatCompletionRequest]
|
request: Union[CompletionRequest, ChatCompletionRequest]
|
||||||
) -> Tuple[str, GuidedDecodingMode]:
|
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
|
||||||
|
|
||||||
if request.guided_json:
|
if request.guided_json:
|
||||||
json = request.guided_json
|
json = request.guided_json
|
||||||
|
@ -21,7 +21,7 @@ from functools import lru_cache
|
|||||||
from typing import Callable, DefaultDict, Dict, List, Optional, Union
|
from typing import Callable, DefaultDict, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from outlines.fsm.fsm import CFGFSM, RegexFSM
|
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
@ -29,6 +29,10 @@ from transformers import PreTrainedTokenizerBase
|
|||||||
|
|
||||||
class BaseLogitsProcessor:
|
class BaseLogitsProcessor:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Child class should use initialize in their init.
|
||||||
|
self.fsm: FSM
|
||||||
|
|
||||||
def init_state(self):
|
def init_state(self):
|
||||||
"""Initialize the FSM states."""
|
"""Initialize the FSM states."""
|
||||||
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
|
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Utilities for selecting and loading neuron models."""
|
"""Utilities for selecting and loading neuron models."""
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Type
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -27,7 +27,7 @@ TORCH_DTYPE_TO_NEURON_AMP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Models supported by Neuron.
|
# Models supported by Neuron.
|
||||||
_NEURON_SUPPORTED_MODELS = {
|
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
|
||||||
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
|
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
|
||||||
"LlamaForSampling", "LlamaForCausalLM"),
|
"LlamaForSampling", "LlamaForCausalLM"),
|
||||||
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
|
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
|
||||||
@ -43,11 +43,13 @@ class NeuronCasualLM(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = None
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||||
logits_as_input=True)
|
logits_as_input=True)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
# Lazy initialized
|
||||||
|
self.model: nn.Module
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -74,17 +76,17 @@ class NeuronCasualLM(nn.Module):
|
|||||||
|
|
||||||
def load_weights(self, model_name_or_path: str, **kwargs):
|
def load_weights(self, model_name_or_path: str, **kwargs):
|
||||||
arch = _get_model_architecture(self.config)
|
arch = _get_model_architecture(self.config)
|
||||||
neuronx_module_path, neuronx_model_cls, hf_model_cls = (
|
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
|
||||||
_NEURON_SUPPORTED_MODELS[arch])
|
_NEURON_SUPPORTED_MODELS[arch])
|
||||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
|
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||||
|
|
||||||
split_model_dir = f"{model_name_or_path}-split"
|
split_model_dir = f"{model_name_or_path}-split"
|
||||||
if os.path.isdir(os.path.join(model_name_or_path,
|
if os.path.isdir(os.path.join(model_name_or_path,
|
||||||
"pytorch_model.bin")):
|
"pytorch_model.bin")):
|
||||||
split_model_dir = model_name_or_path
|
split_model_dir = model_name_or_path
|
||||||
elif not os.path.exists(f"{model_name_or_path}-split"):
|
elif not os.path.exists(f"{model_name_or_path}-split"):
|
||||||
hf_model_cls = getattr(transformers, hf_model_cls)
|
hf_model_cls = getattr(transformers, hf_model_cls_name)
|
||||||
from transformers_neuronx.module import save_pretrained_split
|
from transformers_neuronx.module import save_pretrained_split
|
||||||
|
|
||||||
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
|
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
|
||||||
@ -96,7 +98,7 @@ class NeuronCasualLM(nn.Module):
|
|||||||
self.model.to_neuron()
|
self.model.to_neuron()
|
||||||
|
|
||||||
|
|
||||||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
def _get_model_architecture(config: PretrainedConfig) -> str:
|
||||||
architectures = getattr(config, "architectures", [])
|
architectures = getattr(config, "architectures", [])
|
||||||
for arch in architectures:
|
for arch in architectures:
|
||||||
if arch in _NEURON_SUPPORTED_MODELS:
|
if arch in _NEURON_SUPPORTED_MODELS:
|
||||||
|
@ -167,6 +167,7 @@ class TensorizerArgs:
|
|||||||
decryption_params = DecryptionParams.from_key(key)
|
decryption_params = DecryptionParams.from_key(key)
|
||||||
self.deserializer_params['encryption'] = decryption_params
|
self.deserializer_params['encryption'] = decryption_params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def add_cli_args(
|
def add_cli_args(
|
||||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||||
"""Tensorizer CLI arguments"""
|
"""Tensorizer CLI arguments"""
|
||||||
|
@ -113,6 +113,8 @@ class SamplingTensors:
|
|||||||
get_num_triton_sampler_splits(vocab_size))
|
get_num_triton_sampler_splits(vocab_size))
|
||||||
|
|
||||||
sample_indices_start_idx = 0
|
sample_indices_start_idx = 0
|
||||||
|
assert sampling_metadata.seq_groups is not None
|
||||||
|
assert sampling_metadata.seq_data is not None
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
temperature = sampling_params.temperature
|
temperature = sampling_params.temperature
|
||||||
@ -147,6 +149,7 @@ class SamplingTensors:
|
|||||||
and sampling_params.prompt_logprobs is not None):
|
and sampling_params.prompt_logprobs is not None):
|
||||||
# For tokens in the prompt that we only need to get
|
# For tokens in the prompt that we only need to get
|
||||||
# their logprobs
|
# their logprobs
|
||||||
|
assert sampling_metadata.prompt_lens is not None
|
||||||
prompt_len = sampling_metadata.prompt_lens[i]
|
prompt_len = sampling_metadata.prompt_lens[i]
|
||||||
temperatures += [temperature] * (prompt_len - 1)
|
temperatures += [temperature] * (prompt_len - 1)
|
||||||
top_ps += [top_p] * (prompt_len - 1)
|
top_ps += [top_p] * (prompt_len - 1)
|
||||||
@ -172,6 +175,7 @@ class SamplingTensors:
|
|||||||
is_prompt = i < sampling_metadata.num_prompts
|
is_prompt = i < sampling_metadata.num_prompts
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
prompt_best_of.append(sampling_params.best_of)
|
prompt_best_of.append(sampling_params.best_of)
|
||||||
|
assert sampling_metadata.prompt_lens is not None
|
||||||
prompt_len = sampling_metadata.prompt_lens[i]
|
prompt_len = sampling_metadata.prompt_lens[i]
|
||||||
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
@ -106,7 +106,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
def _expand_batch(
|
def _expand_batch(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
proposal_token_ids_list: List[TokenId],
|
proposal_token_ids_list: List[List[TokenId]],
|
||||||
proposal_lens_list: List[int],
|
proposal_lens_list: List[int],
|
||||||
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
|
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
|
||||||
"""Given the input sequences and potentially multiple corresponding
|
"""Given the input sequences and potentially multiple corresponding
|
||||||
@ -218,7 +218,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
def _create_target_seq_group_metadata(
|
def _create_target_seq_group_metadata(
|
||||||
self,
|
self,
|
||||||
input_seq_group_metadata: SequenceGroupMetadata,
|
input_seq_group_metadata: SequenceGroupMetadata,
|
||||||
proposal_token_ids: List[TokenId], # shape: [batch_size, k]
|
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||||
batch_index: int,
|
batch_index: int,
|
||||||
target_seq_ids_iter: Iterator[TargetSeqId],
|
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||||
) -> List[SequenceGroupMetadata]:
|
) -> List[SequenceGroupMetadata]:
|
||||||
@ -360,7 +360,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
[0, 1, 2]
|
[0, 1, 2]
|
||||||
[0, 1, 2, 3]
|
[0, 1, 2, 3]
|
||||||
"""
|
"""
|
||||||
empty_token_ids = []
|
empty_token_ids: List[TokenId] = []
|
||||||
|
|
||||||
token_ids_to_score = [empty_token_ids]
|
token_ids_to_score = [empty_token_ids]
|
||||||
token_ids_to_score.extend([
|
token_ids_to_score.extend([
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -73,5 +73,5 @@ class SpeculativeScorer(ABC):
|
|||||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||||
k: int,
|
k: int,
|
||||||
proposals: SpeculativeProposals,
|
proposals: SpeculativeProposals,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> SpeculativeScores:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -112,6 +112,7 @@ class AsyncMetricsCollector:
|
|||||||
|
|
||||||
Returns a CUDA event recording when the copy is complete.
|
Returns a CUDA event recording when the copy is complete.
|
||||||
"""
|
"""
|
||||||
|
assert self._copy_stream is not None
|
||||||
self._copy_stream.wait_stream(torch.cuda.current_stream())
|
self._copy_stream.wait_stream(torch.cuda.current_stream())
|
||||||
|
|
||||||
with torch.cuda.stream(self._copy_stream):
|
with torch.cuda.stream(self._copy_stream):
|
||||||
|
@ -26,7 +26,8 @@ class MultiStepWorker(Worker):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self._proposer: Optional[DraftModelTop1Proposer] = None
|
# Lazy initialization list.
|
||||||
|
self._proposer: DraftModelTop1Proposer
|
||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
super().init_device()
|
super().init_device()
|
||||||
@ -338,10 +339,10 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
|||||||
self._vocab_size,
|
self._vocab_size,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self._device)
|
device=self._device)
|
||||||
proposal_lens = torch.zeros(len(proposal_lens),
|
proposal_lens_tensor = torch.zeros(len(proposal_lens),
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=self._device)
|
device=self._device)
|
||||||
return proposal_tokens, proposal_probs, proposal_lens
|
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||||
|
|
||||||
sampler_output = maybe_sampler_output
|
sampler_output = maybe_sampler_output
|
||||||
|
|
||||||
@ -376,9 +377,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
|||||||
proposal_tokens, proposal_probs = (entire_proposal_tokens,
|
proposal_tokens, proposal_probs = (entire_proposal_tokens,
|
||||||
entire_proposal_probs)
|
entire_proposal_probs)
|
||||||
|
|
||||||
proposal_lens = torch.zeros(batch_size,
|
proposal_lens_tensor = torch.zeros(batch_size,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=self._device)
|
device=self._device)
|
||||||
proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
|
proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
|
||||||
|
|
||||||
return proposal_tokens, proposal_probs, proposal_lens
|
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||||
|
@ -89,7 +89,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.probs_dtype = self.rejection_sampler.probs_dtype
|
self.probs_dtype = self.rejection_sampler.probs_dtype
|
||||||
self.token_id_dtype = self.rejection_sampler.token_id_dtype
|
self.token_id_dtype = self.rejection_sampler.token_id_dtype
|
||||||
|
|
||||||
self.scorer: SpeculativeScorer = None
|
# Lazy initiazliation.
|
||||||
|
self.scorer: SpeculativeScorer
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
"""Initialize both scorer and proposer models.
|
"""Initialize both scorer and proposer models.
|
||||||
@ -233,6 +234,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
logger.info("get spec proposals")
|
logger.info("get spec proposals")
|
||||||
# Generate proposals using draft worker.
|
# Generate proposals using draft worker.
|
||||||
|
assert blocks_to_swap_in is not None
|
||||||
|
assert blocks_to_swap_out is not None
|
||||||
|
assert blocks_to_copy is not None
|
||||||
proposals = self.proposer_worker.get_spec_proposals(
|
proposals = self.proposer_worker.get_spec_proposals(
|
||||||
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||||
blocks_to_copy, k)
|
blocks_to_copy, k)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||||
@ -48,14 +49,15 @@ class CPUModelRunner:
|
|||||||
if device_config is not None else DeviceConfig())
|
if device_config is not None else DeviceConfig())
|
||||||
self.device = self.device_config.device
|
self.device = self.device_config.device
|
||||||
|
|
||||||
self.model = None
|
|
||||||
self.block_size = None # Set after initial profiling.
|
|
||||||
|
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.model_config.dtype if model_config is not None else None)
|
self.model_config.dtype if model_config is not None else None)
|
||||||
|
|
||||||
|
# Lazy initialization.
|
||||||
|
self.model: nn.Module # Set after init_Model
|
||||||
|
self.block_size: int # Set after initial profiling.
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.model = get_model(model_config=self.model_config,
|
self.model = get_model(model_config=self.model_config,
|
||||||
load_config=self.load_config,
|
load_config=self.load_config,
|
||||||
@ -245,7 +247,11 @@ class CPUModelRunner:
|
|||||||
selected_token_indices: List[int] = []
|
selected_token_indices: List[int] = []
|
||||||
generators: List[torch.Generator] = []
|
generators: List[torch.Generator] = []
|
||||||
selected_token_start_idx = 0
|
selected_token_start_idx = 0
|
||||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
categorized_sample_indices: Dict[SamplingType,
|
||||||
|
List[Tuple[int, int]]] = {
|
||||||
|
t: []
|
||||||
|
for t in SamplingType
|
||||||
|
}
|
||||||
categorized_sample_indices_start_idx = 0
|
categorized_sample_indices_start_idx = 0
|
||||||
categorized_sampled_token_indices_start_idx = 0
|
categorized_sampled_token_indices_start_idx = 0
|
||||||
|
|
||||||
@ -262,10 +268,9 @@ class CPUModelRunner:
|
|||||||
categorized_sample_indices_start_idx += subquery_len - 1
|
categorized_sample_indices_start_idx += subquery_len - 1
|
||||||
|
|
||||||
categorized_sample_indices[
|
categorized_sample_indices[
|
||||||
sampling_params.sampling_type].append([
|
sampling_params.sampling_type].append(
|
||||||
categorized_sample_indices_start_idx,
|
(categorized_sample_indices_start_idx,
|
||||||
categorized_sampled_token_indices_start_idx
|
categorized_sampled_token_indices_start_idx))
|
||||||
])
|
|
||||||
categorized_sample_indices_start_idx += 1
|
categorized_sample_indices_start_idx += 1
|
||||||
categorized_sampled_token_indices_start_idx += 1
|
categorized_sampled_token_indices_start_idx += 1
|
||||||
|
|
||||||
@ -328,7 +333,7 @@ class CPUModelRunner:
|
|||||||
|
|
||||||
def prepare_input_tensors(
|
def prepare_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
|
||||||
SamplingMetadata]:
|
SamplingMetadata]:
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
@ -381,7 +386,7 @@ class CPUModelRunner:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
(input_tokens, input_positions, attn_metadata, sampling_metadata
|
(input_tokens, input_positions, attn_metadata, sampling_metadata
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""A CPU worker class."""
|
"""A CPU worker class."""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -152,8 +152,8 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
is_driver_worker=is_driver_worker)
|
is_driver_worker=is_driver_worker)
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# initialize_cache.
|
# initialize_cache.
|
||||||
self.cache_engine = None
|
self.cache_engine: CPUCacheEngine
|
||||||
self.cpu_cache = None
|
self.cpu_cache: List[torch.Tensor]
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
self.init_distributed_environment()
|
self.init_distributed_environment()
|
||||||
@ -257,13 +257,13 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
) -> List[SamplerOutput]:
|
) -> List[SamplerOutput]:
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
assert seq_group_metadata_list is not None
|
assert seq_group_metadata_list is not None
|
||||||
num_seq_groups = len(seq_group_metadata_list)
|
num_seq_groups: int = len(seq_group_metadata_list)
|
||||||
assert blocks_to_swap_in is not None
|
assert blocks_to_swap_in is not None
|
||||||
assert blocks_to_swap_out is not None
|
assert blocks_to_swap_out is not None
|
||||||
assert blocks_to_copy is not None
|
assert blocks_to_copy is not None
|
||||||
assert len(blocks_to_swap_in) == 0
|
assert len(blocks_to_swap_in) == 0
|
||||||
assert len(blocks_to_swap_out) == 0
|
assert len(blocks_to_swap_out) == 0
|
||||||
data = {
|
data: Dict[str, Any] = {
|
||||||
"num_seq_groups": num_seq_groups,
|
"num_seq_groups": num_seq_groups,
|
||||||
"blocks_to_copy": blocks_to_copy,
|
"blocks_to_copy": blocks_to_copy,
|
||||||
}
|
}
|
||||||
@ -273,6 +273,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
num_seq_groups = data["num_seq_groups"]
|
num_seq_groups = data["num_seq_groups"]
|
||||||
blocks_to_copy = data["blocks_to_copy"]
|
blocks_to_copy = data["blocks_to_copy"]
|
||||||
|
|
||||||
|
assert blocks_to_copy is not None
|
||||||
self.cache_copy(blocks_to_copy)
|
self.cache_copy(blocks_to_copy)
|
||||||
|
|
||||||
# If there is no input, we don't need to execute the model.
|
# If there is no input, we don't need to execute the model.
|
||||||
|
@ -128,23 +128,17 @@ class ModelRunner:
|
|||||||
if device_config is not None else DeviceConfig())
|
if device_config is not None else DeviceConfig())
|
||||||
self.device = self.device_config.device
|
self.device = self.device_config.device
|
||||||
|
|
||||||
self.model = None
|
# Set after load_model.
|
||||||
self.block_size = None # Set after initial profiling.
|
self.lora_manager: LRUCacheWorkerLoRAManager = None
|
||||||
self.lora_manager = None
|
|
||||||
|
|
||||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||||
self.graph_memory_pool = None # Set during graph capture.
|
self.graph_memory_pool: Optional[Tuple[
|
||||||
|
int, int]] = None # Set during graph capture.
|
||||||
|
|
||||||
self.max_context_len_to_capture = (
|
self.max_context_len_to_capture = (
|
||||||
self.model_config.max_context_len_to_capture
|
self.model_config.max_context_len_to_capture
|
||||||
if self.model_config is not None else 0)
|
if self.model_config is not None else 0)
|
||||||
# When using CUDA graph, the input block tables must be padded to
|
|
||||||
# max_context_len_to_capture. However, creating the block table in
|
|
||||||
# Python can be expensive. To optimize this, we cache the block table
|
|
||||||
# in numpy and only copy the actual input content at every iteration.
|
|
||||||
# The shape of the cached block table will be
|
|
||||||
# (max batch size to capture, max context len to capture / block size).
|
|
||||||
self.graph_block_tables = None # Set after initial profiling.
|
|
||||||
self.pin_memory = is_pin_memory_available()
|
self.pin_memory = is_pin_memory_available()
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.vision_language_config = vision_language_config
|
self.vision_language_config = vision_language_config
|
||||||
@ -152,6 +146,17 @@ class ModelRunner:
|
|||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.model_config.dtype if model_config is not None else None)
|
self.model_config.dtype if model_config is not None else None)
|
||||||
|
|
||||||
|
# Lazy initialization
|
||||||
|
self.model: torch.nn.Module # Set after load_model
|
||||||
|
self.block_size: int # Set after initial profiling.
|
||||||
|
# When using CUDA graph, the input block tables must be padded to
|
||||||
|
# max_context_len_to_capture. However, creating the block table in
|
||||||
|
# Python can be expensive. To optimize this, we cache the block table
|
||||||
|
# in numpy and only copy the actual input content at every iteration.
|
||||||
|
# The shape of the cached block table will be
|
||||||
|
# (max batch size to capture, max context len to capture / block size).
|
||||||
|
self.graph_block_tables: torch.Tensor # Set after initial profiling.
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
with CudaMemoryProfiler() as m:
|
with CudaMemoryProfiler() as m:
|
||||||
self.model = get_model(
|
self.model = get_model(
|
||||||
@ -489,16 +494,16 @@ class ModelRunner:
|
|||||||
lora_index_mapping.append(0)
|
lora_index_mapping.append(0)
|
||||||
batch_size = graph_batch_size
|
batch_size = graph_batch_size
|
||||||
|
|
||||||
context_lens = torch.tensor(context_lens,
|
context_lens_tensor = torch.tensor(context_lens,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
if use_captured_graph:
|
if use_captured_graph:
|
||||||
# When using cuda-graph all these tensors should be
|
# When using cuda-graph all these tensors should be
|
||||||
# padded.
|
# padded.
|
||||||
assert context_lens.shape[0] == len(input_tokens)
|
assert context_lens_tensor.shape[0] == len(input_tokens)
|
||||||
assert context_lens.shape[0] == len(input_positions)
|
assert context_lens_tensor.shape[0] == len(input_positions)
|
||||||
assert context_lens.shape[0] == len(slot_mapping)
|
assert context_lens_tensor.shape[0] == len(slot_mapping)
|
||||||
|
|
||||||
# The shape of graph_block_tables is
|
# The shape of graph_block_tables is
|
||||||
# [max batch size, max context len // block size].
|
# [max batch size, max context len // block size].
|
||||||
@ -527,7 +532,7 @@ class ModelRunner:
|
|||||||
max_prompt_len=None,
|
max_prompt_len=None,
|
||||||
subquery_start_loc=None,
|
subquery_start_loc=None,
|
||||||
seq_start_loc=None,
|
seq_start_loc=None,
|
||||||
context_lens=context_lens,
|
context_lens=context_lens_tensor,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
use_cuda_graph=use_captured_graph,
|
use_cuda_graph=use_captured_graph,
|
||||||
)
|
)
|
||||||
@ -551,7 +556,11 @@ class ModelRunner:
|
|||||||
selected_token_indices: List[int] = []
|
selected_token_indices: List[int] = []
|
||||||
generators: List[torch.Generator] = []
|
generators: List[torch.Generator] = []
|
||||||
selected_token_start_idx = 0
|
selected_token_start_idx = 0
|
||||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
categorized_sample_indices: Dict[SamplingType,
|
||||||
|
List[Tuple[int, int]]] = {
|
||||||
|
t: []
|
||||||
|
for t in SamplingType
|
||||||
|
}
|
||||||
categorized_sample_indices_start_idx = 0
|
categorized_sample_indices_start_idx = 0
|
||||||
categorized_sampled_token_indices_start_idx = 0
|
categorized_sampled_token_indices_start_idx = 0
|
||||||
|
|
||||||
@ -569,10 +578,9 @@ class ModelRunner:
|
|||||||
categorized_sample_indices_start_idx += subquery_len - 1
|
categorized_sample_indices_start_idx += subquery_len - 1
|
||||||
|
|
||||||
categorized_sample_indices[
|
categorized_sample_indices[
|
||||||
sampling_params.sampling_type].append([
|
sampling_params.sampling_type].append(
|
||||||
categorized_sample_indices_start_idx,
|
(categorized_sample_indices_start_idx,
|
||||||
categorized_sampled_token_indices_start_idx
|
categorized_sampled_token_indices_start_idx))
|
||||||
])
|
|
||||||
categorized_sample_indices_start_idx += 1
|
categorized_sample_indices_start_idx += 1
|
||||||
categorized_sampled_token_indices_start_idx += 1
|
categorized_sampled_token_indices_start_idx += 1
|
||||||
|
|
||||||
@ -596,6 +604,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
categorized_sample_indices[
|
categorized_sample_indices[
|
||||||
sampling_params.sampling_type].extend(
|
sampling_params.sampling_type].extend(
|
||||||
|
list(
|
||||||
zip(
|
zip(
|
||||||
range(
|
range(
|
||||||
categorized_sample_indices_start_idx,
|
categorized_sample_indices_start_idx,
|
||||||
@ -603,8 +612,8 @@ class ModelRunner:
|
|||||||
num_seqs),
|
num_seqs),
|
||||||
range(
|
range(
|
||||||
categorized_sampled_token_indices_start_idx,
|
categorized_sampled_token_indices_start_idx,
|
||||||
categorized_sampled_token_indices_start_idx +
|
categorized_sampled_token_indices_start_idx
|
||||||
num_seqs)))
|
+ num_seqs))))
|
||||||
categorized_sample_indices_start_idx += num_seqs
|
categorized_sample_indices_start_idx += num_seqs
|
||||||
categorized_sampled_token_indices_start_idx += num_seqs
|
categorized_sampled_token_indices_start_idx += num_seqs
|
||||||
|
|
||||||
@ -641,9 +650,9 @@ class ModelRunner:
|
|||||||
|
|
||||||
def prepare_input_tensors(
|
def prepare_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||||
Set[int], LoRAMapping, torch.Tensor]:
|
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
prefill_reqs = []
|
prefill_reqs = []
|
||||||
decode_reqs = []
|
decode_reqs = []
|
||||||
@ -741,6 +750,7 @@ class ModelRunner:
|
|||||||
if prefill_attn_metadata is not None:
|
if prefill_attn_metadata is not None:
|
||||||
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
||||||
else:
|
else:
|
||||||
|
assert decode_attn_metadata is not None
|
||||||
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
||||||
broadcast_tensor_dict(metadata_dict, src=0)
|
broadcast_tensor_dict(metadata_dict, src=0)
|
||||||
|
|
||||||
@ -809,7 +819,7 @@ class ModelRunner:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
||||||
@ -923,7 +933,7 @@ class ModelRunner:
|
|||||||
raise RuntimeError("LoRA is not enabled.")
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
return self.lora_manager.remove_all_loras()
|
return self.lora_manager.remove_all_loras()
|
||||||
|
|
||||||
def set_active_loras(self, lora_requests: List[LoRARequest],
|
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||||
lora_mapping: LoRAMapping) -> None:
|
lora_mapping: LoRAMapping) -> None:
|
||||||
if not self.lora_manager:
|
if not self.lora_manager:
|
||||||
raise RuntimeError("LoRA is not enabled.")
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
@ -1065,10 +1075,16 @@ class CUDAGraphRunner:
|
|||||||
|
|
||||||
def __init__(self, model: nn.Module):
|
def __init__(self, model: nn.Module):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.graph = None
|
|
||||||
self.input_buffers: Dict[str, torch.Tensor] = {}
|
self.input_buffers: Dict[str, torch.Tensor] = {}
|
||||||
self.output_buffers: Dict[str, torch.Tensor] = {}
|
self.output_buffers: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def graph(self):
|
||||||
|
assert self._graph is not None
|
||||||
|
return self._graph
|
||||||
|
|
||||||
def capture(
|
def capture(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -1078,7 +1094,7 @@ class CUDAGraphRunner:
|
|||||||
memory_pool,
|
memory_pool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self.graph is None
|
assert self._graph is None
|
||||||
# Run the model once without capturing the graph.
|
# Run the model once without capturing the graph.
|
||||||
# This is to make sure that the captured graph does not include the
|
# This is to make sure that the captured graph does not include the
|
||||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||||
@ -1095,8 +1111,8 @@ class CUDAGraphRunner:
|
|||||||
# Capture the graph.
|
# Capture the graph.
|
||||||
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
|
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
|
||||||
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
self._graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
|
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
|
||||||
with _maybe_pynccl():
|
with _maybe_pynccl():
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
@ -34,9 +35,11 @@ class NeuronModelRunner:
|
|||||||
self.device_config = (device_config
|
self.device_config = (device_config
|
||||||
if device_config is not None else DeviceConfig())
|
if device_config is not None else DeviceConfig())
|
||||||
self.device = self.device_config.device
|
self.device = self.device_config.device
|
||||||
self.model = None
|
|
||||||
self.pin_memory = is_pin_memory_available()
|
self.pin_memory = is_pin_memory_available()
|
||||||
|
|
||||||
|
# Lazy initialization.
|
||||||
|
self.model: nn.Module # initialize after load_model.
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.model = get_neuron_model(self.model_config,
|
self.model = get_neuron_model(self.model_config,
|
||||||
parallel_config=self.parallel_config,
|
parallel_config=self.parallel_config,
|
||||||
@ -147,7 +150,11 @@ class NeuronModelRunner:
|
|||||||
selected_token_indices: List[int] = []
|
selected_token_indices: List[int] = []
|
||||||
generators: List[torch.Generator] = []
|
generators: List[torch.Generator] = []
|
||||||
selected_token_start_idx = 0
|
selected_token_start_idx = 0
|
||||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
categorized_sample_indices: Dict[SamplingType,
|
||||||
|
List[Tuple[int, int]]] = {
|
||||||
|
t: []
|
||||||
|
for t in SamplingType
|
||||||
|
}
|
||||||
categorized_sample_indices_start_idx = 0
|
categorized_sample_indices_start_idx = 0
|
||||||
categorized_sampled_token_indices_start_idx = 0
|
categorized_sampled_token_indices_start_idx = 0
|
||||||
|
|
||||||
@ -165,10 +172,9 @@ class NeuronModelRunner:
|
|||||||
categorized_sample_indices_start_idx += prompt_len - 1
|
categorized_sample_indices_start_idx += prompt_len - 1
|
||||||
|
|
||||||
categorized_sample_indices[
|
categorized_sample_indices[
|
||||||
sampling_params.sampling_type].append([
|
sampling_params.sampling_type].append(
|
||||||
categorized_sample_indices_start_idx,
|
(categorized_sample_indices_start_idx,
|
||||||
categorized_sampled_token_indices_start_idx
|
categorized_sampled_token_indices_start_idx))
|
||||||
])
|
|
||||||
categorized_sample_indices_start_idx += 1
|
categorized_sample_indices_start_idx += 1
|
||||||
categorized_sampled_token_indices_start_idx += 1
|
categorized_sampled_token_indices_start_idx += 1
|
||||||
|
|
||||||
@ -237,7 +243,7 @@ class NeuronModelRunner:
|
|||||||
|
|
||||||
def prepare_input_tensors(
|
def prepare_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
|
||||||
# NOTE: We assume that all sequences in the group are all prompts or
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
# all decodes.
|
# all decodes.
|
||||||
@ -259,7 +265,7 @@ class NeuronModelRunner:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
(input_tokens, input_positions, input_block_ids, sampling_metadata
|
(input_tokens, input_positions, input_block_ids, sampling_metadata
|
||||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""A GPU worker class."""
|
"""A GPU worker class."""
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -82,8 +82,8 @@ class Worker(WorkerBase):
|
|||||||
)
|
)
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# initialize_cache.
|
# initialize_cache.
|
||||||
self.cache_engine = None
|
self.cache_engine: CacheEngine
|
||||||
self.gpu_cache = None
|
self.gpu_cache: List[torch.Tensor]
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
if self.device_config.device.type == "cuda":
|
if self.device_config.device.type == "cuda":
|
||||||
@ -223,7 +223,7 @@ class Worker(WorkerBase):
|
|||||||
assert blocks_to_swap_in is not None
|
assert blocks_to_swap_in is not None
|
||||||
assert blocks_to_swap_out is not None
|
assert blocks_to_swap_out is not None
|
||||||
assert blocks_to_copy is not None
|
assert blocks_to_copy is not None
|
||||||
data = {
|
data: Dict[str, Any] = {
|
||||||
"num_seq_groups": num_seq_groups,
|
"num_seq_groups": num_seq_groups,
|
||||||
"blocks_to_swap_in": blocks_to_swap_in,
|
"blocks_to_swap_in": blocks_to_swap_in,
|
||||||
"blocks_to_swap_out": blocks_to_swap_out,
|
"blocks_to_swap_out": blocks_to_swap_out,
|
||||||
@ -237,6 +237,9 @@ class Worker(WorkerBase):
|
|||||||
blocks_to_swap_out = data["blocks_to_swap_out"]
|
blocks_to_swap_out = data["blocks_to_swap_out"]
|
||||||
blocks_to_copy = data["blocks_to_copy"]
|
blocks_to_copy = data["blocks_to_copy"]
|
||||||
|
|
||||||
|
assert blocks_to_swap_in is not None
|
||||||
|
assert blocks_to_swap_out is not None
|
||||||
|
assert blocks_to_copy is not None
|
||||||
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
||||||
|
|
||||||
# If there is no input, we don't need to execute the model.
|
# If there is no input, we don't need to execute the model.
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -56,7 +56,7 @@ class WorkerBase(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_cache_block_size_bytes() -> int:
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
"""Return the size of a single cache block, in bytes. Used in
|
"""Return the size of a single cache block, in bytes. Used in
|
||||||
speculative decoding.
|
speculative decoding.
|
||||||
"""
|
"""
|
||||||
@ -71,7 +71,7 @@ class WorkerBase(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_loras(self) -> List[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ class LoraNotSupportedWorkerBase(WorkerBase):
|
|||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
raise ValueError(f"{type(self)} does not support LoRA")
|
raise ValueError(f"{type(self)} does not support LoRA")
|
||||||
|
|
||||||
def list_loras(self) -> List[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
raise ValueError(f"{type(self)} does not support LoRA")
|
raise ValueError(f"{type(self)} does not support LoRA")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user