[mypy] Add mypy type annotation part 1 (#4006)

This commit is contained in:
SangBin Cho 2024-04-13 06:35:50 +09:00 committed by GitHub
parent d4ec9ffb95
commit 09473ee41c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 171 additions and 72 deletions

50
.github/workflows/mypy.yaml vendored Normal file
View File

@ -0,0 +1,50 @@
name: mypy
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.9.0
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
pip install types-setuptools
- name: Mypy
run: |
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml

View File

@ -93,9 +93,23 @@ fi
echo 'vLLM yapf: Done'
# Run mypy
# TODO(zhuohan): Enable mypy
# echo 'vLLM mypy:'
# mypy
echo 'vLLM mypy:'
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**'
@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then
exit 1
fi

View File

@ -46,10 +46,13 @@ ignore = [
python_version = "3.8"
ignore_missing_imports = true
check_untyped_defs = true
files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
]
[tool.codespell]

View File

@ -11,4 +11,5 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
outlines == 0.0.34 # Requires torch >= 2.1.0
outlines == 0.0.34 # Requires torch >= 2.1.0
typing_extensions

View File

@ -7,7 +7,7 @@ codespell==2.2.6
isort==5.13.2
# type checking
mypy==0.991
mypy==1.9.0
types-PyYAML
types-requests
types-setuptools

View File

@ -2,7 +2,7 @@ import enum
import json
import os
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, Optional, Union
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
import torch
from packaging.version import Version
@ -141,7 +141,7 @@ class ModelConfig:
supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy"
]
rocm_not_supported_load_format = []
rocm_not_supported_load_format: List[str] = []
if load_format not in supported_load_format:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
@ -679,6 +679,9 @@ class SpeculativeConfig:
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")
assert (speculative_model is not None
and num_speculative_tokens is not None)
# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
draft_revision = None
@ -993,7 +996,7 @@ def _get_and_verify_max_len(
derived_max_model_len *= scaling_factor
if max_model_len is None:
max_model_len = derived_max_model_len
max_model_len = int(derived_max_model_len)
elif max_model_len > derived_max_model_len:
# Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input

View File

@ -1,5 +1,6 @@
"""A block manager that manages token blocks."""
from abc import ABC, abstractmethod
from collections.abc import Sequence as GenericSequence
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional, Set
@ -231,10 +232,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if self.enable_caching:
logger.info("Automatic prefix caching is enabled.")
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
num_gpu_blocks)
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
Device.CPU, block_size, num_cpu_blocks)
else:
self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
@ -588,7 +589,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for b in takewhile(lambda b: b.computed, block_table[:-1])
]
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).

View File

@ -1,4 +1,5 @@
"""A block manager that manages token blocks."""
from collections.abc import Sequence as GenericSequence
from typing import Dict, List, Optional
from vllm.core.block.block_table import BlockTable
@ -205,7 +206,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
# as computed.
self.block_allocator.mark_blocks_as_computed()
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks.

View File

@ -1,5 +1,6 @@
import enum
from abc import ABC, abstractmethod
from collections.abc import Sequence as GenericSequence
from typing import Dict, List
from vllm.sequence import Sequence, SequenceGroup
@ -103,7 +104,8 @@ class BlockSpaceManager(ABC):
pass
@abstractmethod
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
pass
@abstractmethod

View File

@ -42,8 +42,8 @@ class SchedulingBudget:
"""
token_budget: int
max_num_seqs: int
_requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set)
_requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_num_batched_tokens: int = 0
_num_curr_seqs: int = 0
@ -133,7 +133,7 @@ class SchedulerOutputs:
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self) -> bool:
def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
@ -337,7 +337,8 @@ class Scheduler:
self.free_seq(seq)
def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
@ -404,7 +405,7 @@ class Scheduler:
budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.pop(seq_group.lora_int_id)
curr_loras.remove(seq_group.lora_int_id)
if running_queue:
# Preempt the lowest-priority sequence groups.
@ -496,7 +497,7 @@ class Scheduler:
now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue)
leftover_swapped = deque()
leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue:
seq_group = swapped_queue[0]
@ -507,7 +508,9 @@ class Scheduler:
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if (lora_int_id > 0 and lora_int_id not in curr_loras
assert curr_loras is not None
assert self.lora_config is not None
if (lora_int_id > 0 and (lora_int_id not in curr_loras)
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
@ -593,7 +596,7 @@ class Scheduler:
# Copy the queue so that the input queue is not modified.
waiting_queue = deque([s for s in waiting_queue])
leftover_waiting_sequences = deque()
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
seq_group = waiting_queue[0]
@ -635,6 +638,8 @@ class Scheduler:
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
assert curr_loras is not None
assert self.lora_config is not None
if (self.lora_enabled and lora_int_id > 0
and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
@ -780,7 +785,7 @@ class Scheduler:
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
)
curr_loras = set()
curr_loras: Set[int] = set()
remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty())
@ -1087,7 +1092,7 @@ class Scheduler:
def _get_num_new_tokens(self, seq_group: SequenceGroup,
status: SequenceStatus, enable_chunking: bool,
budget: SchedulingBudget) -> Tuple[int, bool]:
budget: SchedulingBudget) -> int:
"""Get the next new tokens to compute for a given sequence group
that's in a given `status`.

View File

@ -1,5 +1,5 @@
from collections import namedtuple
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch.distributed import ProcessGroup
@ -144,7 +144,7 @@ def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]:
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
@ -157,10 +157,10 @@ def broadcast_tensor_dict(
rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
assert value.is_cuda, (
@ -190,10 +190,10 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=group)
metadata_list = recv_metadata_list[0]
assert recv_metadata_list[0] is not None
tensor_dict = {}
async_handles = []
for key, value in metadata_list:
for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,

View File

@ -1,9 +1,10 @@
import pickle
from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
from vllm.worker.worker import Worker
logger = init_logger(__name__)
@ -18,15 +19,20 @@ try:
if init_cached_hf_modules:
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self.worker = None
self._worker: Optional[Worker] = None
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False
def init_worker(self, worker_init_fn):
self.worker = worker_init_fn()
def init_worker(self, worker_init_fn: Callable[[], Worker]):
self._worker = worker_init_fn()
@property
def worker(self) -> Worker:
assert self._worker is not None
return self._worker
def __getattr__(self, name):
return getattr(self.worker, name)
@ -70,8 +76,8 @@ except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with "
"`pip install ray`.")
ray = None
RayWorkerVllm = None
ray = None # type: ignore
RayWorkerVllm = None # type: ignore
def initialize_ray_cluster(

View File

@ -47,6 +47,7 @@ async def generate(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case

View File

@ -170,8 +170,12 @@ class LLM:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len(
prompt_token_ids)
if prompts is not None:
num_requests = len(prompts)
else:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[

View File

@ -1,5 +1,5 @@
import os
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
import torch
@ -61,7 +61,7 @@ class CPUExecutor(ExecutorBase):
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]:
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
@ -66,7 +66,7 @@ class GPUExecutor(ExecutorBase):
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]:
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
@ -47,7 +47,7 @@ class NeuronExecutor(ExecutorBase):
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]:
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""

View File

@ -3,7 +3,7 @@ import copy
import os
import pickle
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
@ -197,7 +197,7 @@ class RayGPUExecutor(ExecutorBase):
max_parallel_loading_workers,
)
def determine_num_available_blocks(self) -> tuple[int, int]:
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
@ -205,7 +205,7 @@ class RayGPUExecutor(ExecutorBase):
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
- Tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", )
@ -276,7 +276,7 @@ class RayGPUExecutor(ExecutorBase):
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
@ -291,6 +291,7 @@ class RayGPUExecutor(ExecutorBase):
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
@ -369,7 +370,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:

View File

@ -5,7 +5,8 @@ from functools import cached_property
from typing import Callable, List, Optional, Union
import torch
from pydantic import conint
from pydantic import Field
from typing_extensions import Annotated
_SAMPLING_EPS = 1e-5
@ -127,7 +128,7 @@ class SamplingParams:
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n

View File

@ -171,10 +171,10 @@ class SequenceData:
return self.prompt_token_ids[-1]
return self.output_token_ids[-1]
def get_prompt_token_ids(self) -> int:
def get_prompt_token_ids(self) -> List[int]:
return self.prompt_token_ids
def get_output_token_ids(self) -> int:
def get_output_token_ids(self) -> List[int]:
return self.output_token_ids
@property
@ -370,7 +370,7 @@ class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator: Optional = None
generator: Optional = None # type: ignore
class MultiModalData:
@ -599,7 +599,7 @@ class SequenceGroupMetadata:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def token_chunk_size(self) -> int:
def token_chunk_size(self) -> Optional[int]:
"""Return the number of tokens to be processed (chunk size)."""
return self._token_chunk_size

View File

@ -2,7 +2,8 @@ from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import *
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig)
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig,

View File

@ -168,8 +168,8 @@ def _convert_tokens_to_string_with_added_encoders(
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts = []
current_sub_text = []
sub_texts: List[str] = []
current_sub_text: List[str] = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
@ -263,6 +263,7 @@ def detokenize_incrementally(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
assert prev_tokens is not None
# If the new token id is out of bounds, return an empty string.
if new_token_id >= len(tokenizer):
@ -271,6 +272,8 @@ def detokenize_incrementally(
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens.

View File

@ -5,7 +5,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import *
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async
logger = init_logger(__name__)
@ -28,7 +28,7 @@ def get_cached_tokenizer(
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_len = len(tokenizer)
class CachedTokenizer(tokenizer.__class__):
class CachedTokenizer(tokenizer.__class__): # type: ignore
@property
def all_special_ids(self):

View File

@ -7,7 +7,7 @@ import time
from enum import Enum
from pathlib import Path
from threading import Thread
from typing import Dict, Optional
from typing import Any, Dict, Optional
from uuid import uuid4
import cpuinfo
@ -124,7 +124,7 @@ class UsageMessage:
def report_usage(self,
model_architecture: str,
usage_context: UsageContext,
extra_kvs: Dict[str, any] = None) -> None:
extra_kvs: Optional[Dict[str, Any]] = None) -> None:
t = Thread(target=self._report_usage_worker,
args=(model_architecture, usage_context, extra_kvs or {}),
daemon=True)
@ -132,13 +132,13 @@ class UsageMessage:
def _report_usage_worker(self, model_architecture: str,
usage_context: UsageContext,
extra_kvs: Dict[str, any]) -> None:
extra_kvs: Dict[str, Any]) -> None:
self._report_usage_once(model_architecture, usage_context, extra_kvs)
self._report_continous_usage()
def _report_usage_once(self, model_architecture: str,
usage_context: UsageContext,
extra_kvs: Dict[str, any]) -> None:
extra_kvs: Dict[str, Any]) -> None:
# Platform information
if torch.cuda.is_available():
device_property = torch.cuda.get_device_properties(0)

View File

@ -60,7 +60,7 @@ class LRUCache(Generic[T]):
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> T:
def __getitem__(self, key: Hashable) -> Optional[T]:
return self.get(key)
def __setitem__(self, key: Hashable, value: T) -> None:
@ -76,7 +76,7 @@ class LRUCache(Generic[T]):
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
if key in self.cache:
value = self.cache[key]
value: Optional[T] = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
@ -87,7 +87,7 @@ class LRUCache(Generic[T]):
self.cache.move_to_end(key)
self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: T):
def _on_remove(self, key: Hashable, value: Optional[T]):
pass
def remove_oldest(self):
@ -100,9 +100,11 @@ class LRUCache(Generic[T]):
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
def pop(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache
value = self.cache.pop(key, default_value)
value: Optional[T] = self.cache.pop(key, default_value)
if run_on_remove:
self._on_remove(key, value)
return value