[mypy] Add mypy type annotation part 1 (#4006)
This commit is contained in:
parent
d4ec9ffb95
commit
09473ee41c
50
.github/workflows/mypy.yaml
vendored
Normal file
50
.github/workflows/mypy.yaml
vendored
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
name: mypy
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Trigger the workflow on push or pull request,
|
||||||
|
# but only for the main branch
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
ruff:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.8"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install mypy==1.9.0
|
||||||
|
pip install types-setuptools
|
||||||
|
pip install types-PyYAML
|
||||||
|
pip install types-requests
|
||||||
|
pip install types-setuptools
|
||||||
|
- name: Mypy
|
||||||
|
run: |
|
||||||
|
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
|
||||||
|
# TODO(sang): Follow up
|
||||||
|
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
|
22
format.sh
22
format.sh
@ -93,9 +93,23 @@ fi
|
|||||||
echo 'vLLM yapf: Done'
|
echo 'vLLM yapf: Done'
|
||||||
|
|
||||||
# Run mypy
|
# Run mypy
|
||||||
# TODO(zhuohan): Enable mypy
|
echo 'vLLM mypy:'
|
||||||
# echo 'vLLM mypy:'
|
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
# mypy
|
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
|
||||||
|
# TODO(sang): Follow up
|
||||||
|
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
||||||
|
|
||||||
|
|
||||||
CODESPELL_EXCLUDES=(
|
CODESPELL_EXCLUDES=(
|
||||||
'--skip' '*docs/source/_build/**'
|
'--skip' '*docs/source/_build/**'
|
||||||
@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then
|
|||||||
|
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,10 +46,13 @@ ignore = [
|
|||||||
python_version = "3.8"
|
python_version = "3.8"
|
||||||
|
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
check_untyped_defs = true
|
||||||
|
|
||||||
files = "vllm"
|
files = "vllm"
|
||||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||||
exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
|
exclude = [
|
||||||
|
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
[tool.codespell]
|
[tool.codespell]
|
||||||
|
@ -11,4 +11,5 @@ uvicorn[standard]
|
|||||||
pydantic >= 2.0 # Required for OpenAI server.
|
pydantic >= 2.0 # Required for OpenAI server.
|
||||||
prometheus_client >= 0.18.0
|
prometheus_client >= 0.18.0
|
||||||
tiktoken == 0.6.0 # Required for DBRX tokenizer
|
tiktoken == 0.6.0 # Required for DBRX tokenizer
|
||||||
outlines == 0.0.34 # Requires torch >= 2.1.0
|
outlines == 0.0.34 # Requires torch >= 2.1.0
|
||||||
|
typing_extensions
|
@ -7,7 +7,7 @@ codespell==2.2.6
|
|||||||
isort==5.13.2
|
isort==5.13.2
|
||||||
|
|
||||||
# type checking
|
# type checking
|
||||||
mypy==0.991
|
mypy==1.9.0
|
||||||
types-PyYAML
|
types-PyYAML
|
||||||
types-requests
|
types-requests
|
||||||
types-setuptools
|
types-setuptools
|
||||||
|
@ -2,7 +2,7 @@ import enum
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import TYPE_CHECKING, ClassVar, Optional, Union
|
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
@ -141,7 +141,7 @@ class ModelConfig:
|
|||||||
supported_load_format = [
|
supported_load_format = [
|
||||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||||
]
|
]
|
||||||
rocm_not_supported_load_format = []
|
rocm_not_supported_load_format: List[str] = []
|
||||||
if load_format not in supported_load_format:
|
if load_format not in supported_load_format:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown load format: {self.load_format}. Must be one of "
|
f"Unknown load format: {self.load_format}. Must be one of "
|
||||||
@ -679,6 +679,9 @@ class SpeculativeConfig:
|
|||||||
"num_speculative_tokens to be provided, but found "
|
"num_speculative_tokens to be provided, but found "
|
||||||
f"{speculative_model=} and {num_speculative_tokens=}.")
|
f"{speculative_model=} and {num_speculative_tokens=}.")
|
||||||
|
|
||||||
|
assert (speculative_model is not None
|
||||||
|
and num_speculative_tokens is not None)
|
||||||
|
|
||||||
# TODO: The user should be able to specify revision/quantization/max
|
# TODO: The user should be able to specify revision/quantization/max
|
||||||
# model len for the draft model. It is not currently supported.
|
# model len for the draft model. It is not currently supported.
|
||||||
draft_revision = None
|
draft_revision = None
|
||||||
@ -993,7 +996,7 @@ def _get_and_verify_max_len(
|
|||||||
derived_max_model_len *= scaling_factor
|
derived_max_model_len *= scaling_factor
|
||||||
|
|
||||||
if max_model_len is None:
|
if max_model_len is None:
|
||||||
max_model_len = derived_max_model_len
|
max_model_len = int(derived_max_model_len)
|
||||||
elif max_model_len > derived_max_model_len:
|
elif max_model_len > derived_max_model_len:
|
||||||
# Some models might have a separate key for specifying model_max_length
|
# Some models might have a separate key for specifying model_max_length
|
||||||
# that will be bigger than derived_max_model_len. We compare user input
|
# that will be bigger than derived_max_model_len. We compare user input
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""A block manager that manages token blocks."""
|
"""A block manager that manages token blocks."""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence as GenericSequence
|
||||||
from itertools import count, takewhile
|
from itertools import count, takewhile
|
||||||
from os.path import commonprefix
|
from os.path import commonprefix
|
||||||
from typing import Dict, List, Optional, Set
|
from typing import Dict, List, Optional, Set
|
||||||
@ -231,10 +232,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
|
|
||||||
if self.enable_caching:
|
if self.enable_caching:
|
||||||
logger.info("Automatic prefix caching is enabled.")
|
logger.info("Automatic prefix caching is enabled.")
|
||||||
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
|
self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
|
||||||
num_gpu_blocks)
|
Device.GPU, block_size, num_gpu_blocks)
|
||||||
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
|
self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
|
||||||
num_cpu_blocks)
|
Device.CPU, block_size, num_cpu_blocks)
|
||||||
else:
|
else:
|
||||||
self.gpu_allocator = UncachedBlockAllocator(
|
self.gpu_allocator = UncachedBlockAllocator(
|
||||||
Device.GPU, block_size, num_gpu_blocks)
|
Device.GPU, block_size, num_gpu_blocks)
|
||||||
@ -588,7 +589,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
for b in takewhile(lambda b: b.computed, block_table[:-1])
|
for b in takewhile(lambda b: b.computed, block_table[:-1])
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
|
def get_common_computed_block_ids(
|
||||||
|
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||||
"""Return the block ids that are common for a given sequence group.
|
"""Return the block ids that are common for a given sequence group.
|
||||||
|
|
||||||
Used in prefill (can skip prefill of some blocks).
|
Used in prefill (can skip prefill of some blocks).
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""A block manager that manages token blocks."""
|
"""A block manager that manages token blocks."""
|
||||||
|
from collections.abc import Sequence as GenericSequence
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from vllm.core.block.block_table import BlockTable
|
from vllm.core.block.block_table import BlockTable
|
||||||
@ -205,7 +206,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
|||||||
# as computed.
|
# as computed.
|
||||||
self.block_allocator.mark_blocks_as_computed()
|
self.block_allocator.mark_blocks_as_computed()
|
||||||
|
|
||||||
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
|
def get_common_computed_block_ids(
|
||||||
|
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||||
"""Determine which blocks for which we skip prefill.
|
"""Determine which blocks for which we skip prefill.
|
||||||
|
|
||||||
With prefix caching we can skip prefill for previously-generated blocks.
|
With prefix caching we can skip prefill for previously-generated blocks.
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import enum
|
import enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence as GenericSequence
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from vllm.sequence import Sequence, SequenceGroup
|
from vllm.sequence import Sequence, SequenceGroup
|
||||||
@ -103,7 +104,8 @@ class BlockSpaceManager(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
|
def get_common_computed_block_ids(
|
||||||
|
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -42,8 +42,8 @@ class SchedulingBudget:
|
|||||||
"""
|
"""
|
||||||
token_budget: int
|
token_budget: int
|
||||||
max_num_seqs: int
|
max_num_seqs: int
|
||||||
_requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set)
|
_requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
|
||||||
_requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set)
|
_requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
|
||||||
_num_batched_tokens: int = 0
|
_num_batched_tokens: int = 0
|
||||||
_num_curr_seqs: int = 0
|
_num_curr_seqs: int = 0
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ class SchedulerOutputs:
|
|||||||
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
|
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
|
||||||
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
||||||
|
|
||||||
def _sort_by_lora_ids(self) -> bool:
|
def _sort_by_lora_ids(self):
|
||||||
self.scheduled_seq_groups = sorted(
|
self.scheduled_seq_groups = sorted(
|
||||||
self.scheduled_seq_groups,
|
self.scheduled_seq_groups,
|
||||||
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
|
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
|
||||||
@ -337,7 +337,8 @@ class Scheduler:
|
|||||||
self.free_seq(seq)
|
self.free_seq(seq)
|
||||||
|
|
||||||
def has_unfinished_seqs(self) -> bool:
|
def has_unfinished_seqs(self) -> bool:
|
||||||
return self.waiting or self.running or self.swapped
|
return len(self.waiting) != 0 or len(self.running) != 0 or len(
|
||||||
|
self.swapped) != 0
|
||||||
|
|
||||||
def get_num_unfinished_seq_groups(self) -> int:
|
def get_num_unfinished_seq_groups(self) -> int:
|
||||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||||
@ -404,7 +405,7 @@ class Scheduler:
|
|||||||
budget.subtract_num_seqs(seq_group.request_id,
|
budget.subtract_num_seqs(seq_group.request_id,
|
||||||
num_running_seqs)
|
num_running_seqs)
|
||||||
if curr_loras is not None and seq_group.lora_int_id > 0:
|
if curr_loras is not None and seq_group.lora_int_id > 0:
|
||||||
curr_loras.pop(seq_group.lora_int_id)
|
curr_loras.remove(seq_group.lora_int_id)
|
||||||
|
|
||||||
if running_queue:
|
if running_queue:
|
||||||
# Preempt the lowest-priority sequence groups.
|
# Preempt the lowest-priority sequence groups.
|
||||||
@ -496,7 +497,7 @@ class Scheduler:
|
|||||||
now = time.time()
|
now = time.time()
|
||||||
swapped_queue = policy.sort_by_priority(now, swapped_queue)
|
swapped_queue = policy.sort_by_priority(now, swapped_queue)
|
||||||
|
|
||||||
leftover_swapped = deque()
|
leftover_swapped: Deque[SequenceGroup] = deque()
|
||||||
while swapped_queue:
|
while swapped_queue:
|
||||||
seq_group = swapped_queue[0]
|
seq_group = swapped_queue[0]
|
||||||
|
|
||||||
@ -507,7 +508,9 @@ class Scheduler:
|
|||||||
lora_int_id = 0
|
lora_int_id = 0
|
||||||
if self.lora_enabled:
|
if self.lora_enabled:
|
||||||
lora_int_id = seq_group.lora_int_id
|
lora_int_id = seq_group.lora_int_id
|
||||||
if (lora_int_id > 0 and lora_int_id not in curr_loras
|
assert curr_loras is not None
|
||||||
|
assert self.lora_config is not None
|
||||||
|
if (lora_int_id > 0 and (lora_int_id not in curr_loras)
|
||||||
and len(curr_loras) >= self.lora_config.max_loras):
|
and len(curr_loras) >= self.lora_config.max_loras):
|
||||||
# We don't have a space for another LoRA, so
|
# We don't have a space for another LoRA, so
|
||||||
# we ignore this request for now.
|
# we ignore this request for now.
|
||||||
@ -593,7 +596,7 @@ class Scheduler:
|
|||||||
# Copy the queue so that the input queue is not modified.
|
# Copy the queue so that the input queue is not modified.
|
||||||
waiting_queue = deque([s for s in waiting_queue])
|
waiting_queue = deque([s for s in waiting_queue])
|
||||||
|
|
||||||
leftover_waiting_sequences = deque()
|
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
|
||||||
while self._passed_delay(time.time()) and waiting_queue:
|
while self._passed_delay(time.time()) and waiting_queue:
|
||||||
seq_group = waiting_queue[0]
|
seq_group = waiting_queue[0]
|
||||||
|
|
||||||
@ -635,6 +638,8 @@ class Scheduler:
|
|||||||
lora_int_id = 0
|
lora_int_id = 0
|
||||||
if self.lora_enabled:
|
if self.lora_enabled:
|
||||||
lora_int_id = seq_group.lora_int_id
|
lora_int_id = seq_group.lora_int_id
|
||||||
|
assert curr_loras is not None
|
||||||
|
assert self.lora_config is not None
|
||||||
if (self.lora_enabled and lora_int_id > 0
|
if (self.lora_enabled and lora_int_id > 0
|
||||||
and lora_int_id not in curr_loras
|
and lora_int_id not in curr_loras
|
||||||
and len(curr_loras) >= self.lora_config.max_loras):
|
and len(curr_loras) >= self.lora_config.max_loras):
|
||||||
@ -780,7 +785,7 @@ class Scheduler:
|
|||||||
token_budget=self.scheduler_config.max_num_batched_tokens,
|
token_budget=self.scheduler_config.max_num_batched_tokens,
|
||||||
max_num_seqs=self.scheduler_config.max_num_seqs,
|
max_num_seqs=self.scheduler_config.max_num_seqs,
|
||||||
)
|
)
|
||||||
curr_loras = set()
|
curr_loras: Set[int] = set()
|
||||||
|
|
||||||
remaining_waiting, prefills = (self.waiting,
|
remaining_waiting, prefills = (self.waiting,
|
||||||
SchedulerPrefillOutputs.create_empty())
|
SchedulerPrefillOutputs.create_empty())
|
||||||
@ -1087,7 +1092,7 @@ class Scheduler:
|
|||||||
|
|
||||||
def _get_num_new_tokens(self, seq_group: SequenceGroup,
|
def _get_num_new_tokens(self, seq_group: SequenceGroup,
|
||||||
status: SequenceStatus, enable_chunking: bool,
|
status: SequenceStatus, enable_chunking: bool,
|
||||||
budget: SchedulingBudget) -> Tuple[int, bool]:
|
budget: SchedulingBudget) -> int:
|
||||||
"""Get the next new tokens to compute for a given sequence group
|
"""Get the next new tokens to compute for a given sequence group
|
||||||
that's in a given `status`.
|
that's in a given `status`.
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
@ -144,7 +144,7 @@ def broadcast_tensor_dict(
|
|||||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
||||||
src: int = 0,
|
src: int = 0,
|
||||||
group: Optional[ProcessGroup] = None,
|
group: Optional[ProcessGroup] = None,
|
||||||
) -> Dict[Any, Union[torch.Tensor, Any]]:
|
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
||||||
"""Broadcast the input tensor dictionary."""
|
"""Broadcast the input tensor dictionary."""
|
||||||
group = group or torch.distributed.group.WORLD
|
group = group or torch.distributed.group.WORLD
|
||||||
ranks = torch.distributed.get_process_group_ranks(group)
|
ranks = torch.distributed.get_process_group_ranks(group)
|
||||||
@ -157,10 +157,10 @@ def broadcast_tensor_dict(
|
|||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
if rank == src:
|
if rank == src:
|
||||||
|
metadata_list: List[Tuple[Any, Any]] = []
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
tensor_dict,
|
tensor_dict,
|
||||||
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
||||||
metadata_list = []
|
|
||||||
for key, value in tensor_dict.items():
|
for key, value in tensor_dict.items():
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
assert value.is_cuda, (
|
assert value.is_cuda, (
|
||||||
@ -190,10 +190,10 @@ def broadcast_tensor_dict(
|
|||||||
torch.distributed.broadcast_object_list(recv_metadata_list,
|
torch.distributed.broadcast_object_list(recv_metadata_list,
|
||||||
src=src,
|
src=src,
|
||||||
group=group)
|
group=group)
|
||||||
metadata_list = recv_metadata_list[0]
|
assert recv_metadata_list[0] is not None
|
||||||
tensor_dict = {}
|
tensor_dict = {}
|
||||||
async_handles = []
|
async_handles = []
|
||||||
for key, value in metadata_list:
|
for key, value in recv_metadata_list[0]:
|
||||||
if isinstance(value, TensorMetadata):
|
if isinstance(value, TensorMetadata):
|
||||||
tensor = torch.empty(value.size,
|
tensor = torch.empty(value.size,
|
||||||
dtype=value.dtype,
|
dtype=value.dtype,
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import pickle
|
import pickle
|
||||||
from typing import List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
|
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -18,15 +19,20 @@ try:
|
|||||||
if init_cached_hf_modules:
|
if init_cached_hf_modules:
|
||||||
from transformers.dynamic_module_utils import init_hf_modules
|
from transformers.dynamic_module_utils import init_hf_modules
|
||||||
init_hf_modules()
|
init_hf_modules()
|
||||||
self.worker = None
|
self._worker: Optional[Worker] = None
|
||||||
# Since the compiled DAG runs a main execution
|
# Since the compiled DAG runs a main execution
|
||||||
# in a different thread that calls cuda.set_device.
|
# in a different thread that calls cuda.set_device.
|
||||||
# The flag indicates is set_device is called on
|
# The flag indicates is set_device is called on
|
||||||
# that thread.
|
# that thread.
|
||||||
self.compiled_dag_cuda_device_set = False
|
self.compiled_dag_cuda_device_set = False
|
||||||
|
|
||||||
def init_worker(self, worker_init_fn):
|
def init_worker(self, worker_init_fn: Callable[[], Worker]):
|
||||||
self.worker = worker_init_fn()
|
self._worker = worker_init_fn()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def worker(self) -> Worker:
|
||||||
|
assert self._worker is not None
|
||||||
|
return self._worker
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self.worker, name)
|
return getattr(self.worker, name)
|
||||||
@ -70,8 +76,8 @@ except ImportError as e:
|
|||||||
logger.warning(f"Failed to import Ray with {e!r}. "
|
logger.warning(f"Failed to import Ray with {e!r}. "
|
||||||
"For distributed inference, please install Ray with "
|
"For distributed inference, please install Ray with "
|
||||||
"`pip install ray`.")
|
"`pip install ray`.")
|
||||||
ray = None
|
ray = None # type: ignore
|
||||||
RayWorkerVllm = None
|
RayWorkerVllm = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def initialize_ray_cluster(
|
def initialize_ray_cluster(
|
||||||
|
@ -47,6 +47,7 @@ async def generate(request: Request) -> Response:
|
|||||||
sampling_params = SamplingParams(**request_dict)
|
sampling_params = SamplingParams(**request_dict)
|
||||||
request_id = random_uuid()
|
request_id = random_uuid()
|
||||||
|
|
||||||
|
assert engine is not None
|
||||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||||
|
|
||||||
# Streaming case
|
# Streaming case
|
||||||
|
@ -170,8 +170,12 @@ class LLM:
|
|||||||
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
|
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
|
||||||
|
|
||||||
# Add requests to the engine.
|
# Add requests to the engine.
|
||||||
num_requests = len(prompts) if prompts is not None else len(
|
if prompts is not None:
|
||||||
prompt_token_ids)
|
num_requests = len(prompts)
|
||||||
|
else:
|
||||||
|
assert prompt_token_ids is not None
|
||||||
|
num_requests = len(prompt_token_ids)
|
||||||
|
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
prompt = prompts[i] if prompts is not None else None
|
prompt = prompts[i] if prompts is not None else None
|
||||||
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ class CPUExecutor(ExecutorBase):
|
|||||||
self.driver_worker.init_device()
|
self.driver_worker.init_device()
|
||||||
self.driver_worker.load_model()
|
self.driver_worker.load_model()
|
||||||
|
|
||||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
"""Determine the number of available KV blocks by invoking the
|
"""Determine the number of available KV blocks by invoking the
|
||||||
underlying worker.
|
underlying worker.
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
@ -66,7 +66,7 @@ class GPUExecutor(ExecutorBase):
|
|||||||
self.driver_worker.init_device()
|
self.driver_worker.init_device()
|
||||||
self.driver_worker.load_model()
|
self.driver_worker.load_model()
|
||||||
|
|
||||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
"""Determine the number of available KV blocks by invoking the
|
"""Determine the number of available KV blocks by invoking the
|
||||||
underlying worker.
|
underlying worker.
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
@ -47,7 +47,7 @@ class NeuronExecutor(ExecutorBase):
|
|||||||
self.driver_worker.init_device()
|
self.driver_worker.init_device()
|
||||||
self.driver_worker.load_model()
|
self.driver_worker.load_model()
|
||||||
|
|
||||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
"""Determine the number of available KV blocks by invoking the
|
"""Determine the number of available KV blocks by invoking the
|
||||||
underlying worker.
|
underlying worker.
|
||||||
"""
|
"""
|
||||||
|
@ -3,7 +3,7 @@ import copy
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
@ -197,7 +197,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
max_parallel_loading_workers,
|
max_parallel_loading_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
"""Determine the number of available KV blocks.
|
"""Determine the number of available KV blocks.
|
||||||
|
|
||||||
This invokes `determine_num_available_blocks` on each worker and takes
|
This invokes `determine_num_available_blocks` on each worker and takes
|
||||||
@ -205,7 +205,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
compatible with all workers.
|
compatible with all workers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- tuple[num_gpu_blocks, num_cpu_blocks]
|
- Tuple[num_gpu_blocks, num_cpu_blocks]
|
||||||
"""
|
"""
|
||||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||||
num_blocks = self._run_workers("determine_num_available_blocks", )
|
num_blocks = self._run_workers("determine_num_available_blocks", )
|
||||||
@ -276,7 +276,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
*args,
|
*args,
|
||||||
driver_args: Optional[List[Any]] = None,
|
driver_args: Optional[Tuple[Any, ...]] = None,
|
||||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
max_concurrent_workers: Optional[int] = None,
|
max_concurrent_workers: Optional[int] = None,
|
||||||
use_ray_compiled_dag: bool = False,
|
use_ray_compiled_dag: bool = False,
|
||||||
@ -291,6 +291,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
if use_ray_compiled_dag:
|
if use_ray_compiled_dag:
|
||||||
# Right now, compiled DAG can only accept a single
|
# Right now, compiled DAG can only accept a single
|
||||||
# input. TODO(sang): Fix it.
|
# input. TODO(sang): Fix it.
|
||||||
|
assert self.forward_dag is not None
|
||||||
output_channels = self.forward_dag.execute(1)
|
output_channels = self.forward_dag.execute(1)
|
||||||
else:
|
else:
|
||||||
# Start the ray workers first.
|
# Start the ray workers first.
|
||||||
@ -369,7 +370,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
|
|||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
*args,
|
*args,
|
||||||
driver_args: Optional[List[Any]] = None,
|
driver_args: Optional[Tuple[Any, ...]] = None,
|
||||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
@ -5,7 +5,8 @@ from functools import cached_property
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import conint
|
from pydantic import Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
@ -127,7 +128,7 @@ class SamplingParams:
|
|||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
spaces_between_special_tokens: bool = True,
|
spaces_between_special_tokens: bool = True,
|
||||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None,
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.n = n
|
self.n = n
|
||||||
self.best_of = best_of if best_of is not None else n
|
self.best_of = best_of if best_of is not None else n
|
||||||
|
@ -171,10 +171,10 @@ class SequenceData:
|
|||||||
return self.prompt_token_ids[-1]
|
return self.prompt_token_ids[-1]
|
||||||
return self.output_token_ids[-1]
|
return self.output_token_ids[-1]
|
||||||
|
|
||||||
def get_prompt_token_ids(self) -> int:
|
def get_prompt_token_ids(self) -> List[int]:
|
||||||
return self.prompt_token_ids
|
return self.prompt_token_ids
|
||||||
|
|
||||||
def get_output_token_ids(self) -> int:
|
def get_output_token_ids(self) -> List[int]:
|
||||||
return self.output_token_ids
|
return self.output_token_ids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -370,7 +370,7 @@ class SequenceGroupState:
|
|||||||
"""Mutable state tied to a specific sequence group"""
|
"""Mutable state tied to a specific sequence group"""
|
||||||
|
|
||||||
# torch.Generator used in seeded sampling
|
# torch.Generator used in seeded sampling
|
||||||
generator: Optional = None
|
generator: Optional = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class MultiModalData:
|
class MultiModalData:
|
||||||
@ -599,7 +599,7 @@ class SequenceGroupMetadata:
|
|||||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def token_chunk_size(self) -> int:
|
def token_chunk_size(self) -> Optional[int]:
|
||||||
"""Return the number of tokens to be processed (chunk size)."""
|
"""Return the number of tokens to be processed (chunk size)."""
|
||||||
return self._token_chunk_size
|
return self._token_chunk_size
|
||||||
|
|
||||||
|
@ -2,7 +2,8 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
from vllm.transformers_utils.configs import *
|
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||||
|
JAISConfig, MPTConfig, RWConfig)
|
||||||
|
|
||||||
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
|
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
|
||||||
"chatglm": ChatGLMConfig,
|
"chatglm": ChatGLMConfig,
|
||||||
|
@ -168,8 +168,8 @@ def _convert_tokens_to_string_with_added_encoders(
|
|||||||
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
||||||
# the output_tokens. In Python, running a for loop over a list can be slow
|
# the output_tokens. In Python, running a for loop over a list can be slow
|
||||||
# even when the loop body is very simple.
|
# even when the loop body is very simple.
|
||||||
sub_texts = []
|
sub_texts: List[str] = []
|
||||||
current_sub_text = []
|
current_sub_text: List[str] = []
|
||||||
all_special_tokens = set(tokenizer.all_special_tokens)
|
all_special_tokens = set(tokenizer.all_special_tokens)
|
||||||
for token in output_tokens:
|
for token in output_tokens:
|
||||||
if skip_special_tokens and token in all_special_tokens:
|
if skip_special_tokens and token in all_special_tokens:
|
||||||
@ -263,6 +263,7 @@ def detokenize_incrementally(
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
all_input_ids[:-1],
|
all_input_ids[:-1],
|
||||||
skip_special_tokens=skip_special_tokens)
|
skip_special_tokens=skip_special_tokens)
|
||||||
|
assert prev_tokens is not None
|
||||||
|
|
||||||
# If the new token id is out of bounds, return an empty string.
|
# If the new token id is out of bounds, return an empty string.
|
||||||
if new_token_id >= len(tokenizer):
|
if new_token_id >= len(tokenizer):
|
||||||
@ -271,6 +272,8 @@ def detokenize_incrementally(
|
|||||||
# Put new_token_id in a list so skip_special_tokens is respected
|
# Put new_token_id in a list so skip_special_tokens is respected
|
||||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||||
|
if isinstance(new_tokens, str):
|
||||||
|
new_tokens = [new_tokens]
|
||||||
output_tokens = prev_tokens + new_tokens
|
output_tokens = prev_tokens + new_tokens
|
||||||
|
|
||||||
# If this is the first iteration, return all tokens.
|
# If this is the first iteration, return all tokens.
|
||||||
|
@ -5,7 +5,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.transformers_utils.tokenizers import *
|
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
|
||||||
from vllm.utils import make_async
|
from vllm.utils import make_async
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -28,7 +28,7 @@ def get_cached_tokenizer(
|
|||||||
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
|
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
|
||||||
tokenizer_len = len(tokenizer)
|
tokenizer_len = len(tokenizer)
|
||||||
|
|
||||||
class CachedTokenizer(tokenizer.__class__):
|
class CachedTokenizer(tokenizer.__class__): # type: ignore
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_special_ids(self):
|
def all_special_ids(self):
|
||||||
|
@ -7,7 +7,7 @@ import time
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import cpuinfo
|
import cpuinfo
|
||||||
@ -124,7 +124,7 @@ class UsageMessage:
|
|||||||
def report_usage(self,
|
def report_usage(self,
|
||||||
model_architecture: str,
|
model_architecture: str,
|
||||||
usage_context: UsageContext,
|
usage_context: UsageContext,
|
||||||
extra_kvs: Dict[str, any] = None) -> None:
|
extra_kvs: Optional[Dict[str, Any]] = None) -> None:
|
||||||
t = Thread(target=self._report_usage_worker,
|
t = Thread(target=self._report_usage_worker,
|
||||||
args=(model_architecture, usage_context, extra_kvs or {}),
|
args=(model_architecture, usage_context, extra_kvs or {}),
|
||||||
daemon=True)
|
daemon=True)
|
||||||
@ -132,13 +132,13 @@ class UsageMessage:
|
|||||||
|
|
||||||
def _report_usage_worker(self, model_architecture: str,
|
def _report_usage_worker(self, model_architecture: str,
|
||||||
usage_context: UsageContext,
|
usage_context: UsageContext,
|
||||||
extra_kvs: Dict[str, any]) -> None:
|
extra_kvs: Dict[str, Any]) -> None:
|
||||||
self._report_usage_once(model_architecture, usage_context, extra_kvs)
|
self._report_usage_once(model_architecture, usage_context, extra_kvs)
|
||||||
self._report_continous_usage()
|
self._report_continous_usage()
|
||||||
|
|
||||||
def _report_usage_once(self, model_architecture: str,
|
def _report_usage_once(self, model_architecture: str,
|
||||||
usage_context: UsageContext,
|
usage_context: UsageContext,
|
||||||
extra_kvs: Dict[str, any]) -> None:
|
extra_kvs: Dict[str, Any]) -> None:
|
||||||
# Platform information
|
# Platform information
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device_property = torch.cuda.get_device_properties(0)
|
device_property = torch.cuda.get_device_properties(0)
|
||||||
|
@ -60,7 +60,7 @@ class LRUCache(Generic[T]):
|
|||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.cache)
|
return len(self.cache)
|
||||||
|
|
||||||
def __getitem__(self, key: Hashable) -> T:
|
def __getitem__(self, key: Hashable) -> Optional[T]:
|
||||||
return self.get(key)
|
return self.get(key)
|
||||||
|
|
||||||
def __setitem__(self, key: Hashable, value: T) -> None:
|
def __setitem__(self, key: Hashable, value: T) -> None:
|
||||||
@ -76,7 +76,7 @@ class LRUCache(Generic[T]):
|
|||||||
key: Hashable,
|
key: Hashable,
|
||||||
default_value: Optional[T] = None) -> Optional[T]:
|
default_value: Optional[T] = None) -> Optional[T]:
|
||||||
if key in self.cache:
|
if key in self.cache:
|
||||||
value = self.cache[key]
|
value: Optional[T] = self.cache[key]
|
||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
else:
|
else:
|
||||||
value = default_value
|
value = default_value
|
||||||
@ -87,7 +87,7 @@ class LRUCache(Generic[T]):
|
|||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
self._remove_old_if_needed()
|
self._remove_old_if_needed()
|
||||||
|
|
||||||
def _on_remove(self, key: Hashable, value: T):
|
def _on_remove(self, key: Hashable, value: Optional[T]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def remove_oldest(self):
|
def remove_oldest(self):
|
||||||
@ -100,9 +100,11 @@ class LRUCache(Generic[T]):
|
|||||||
while len(self.cache) > self.capacity:
|
while len(self.cache) > self.capacity:
|
||||||
self.remove_oldest()
|
self.remove_oldest()
|
||||||
|
|
||||||
def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
|
def pop(self,
|
||||||
|
key: Hashable,
|
||||||
|
default_value: Optional[T] = None) -> Optional[T]:
|
||||||
run_on_remove = key in self.cache
|
run_on_remove = key in self.cache
|
||||||
value = self.cache.pop(key, default_value)
|
value: Optional[T] = self.cache.pop(key, default_value)
|
||||||
if run_on_remove:
|
if run_on_remove:
|
||||||
self._on_remove(key, value)
|
self._on_remove(key, value)
|
||||||
return value
|
return value
|
||||||
|
Loading…
x
Reference in New Issue
Block a user