[Doc] Add typing hints / mypy types cleanup (#3816)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Michael Feil 2024-04-11 17:17:21 -07:00 committed by GitHub
parent e46a60aa4c
commit c2b4a1bce9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 90 additions and 64 deletions

View File

@ -27,8 +27,8 @@ class RequestFuncInput:
class RequestFuncOutput:
generated_text: str = ""
success: bool = False
latency: float = 0
ttft: float = 0 # Time to first token
latency: float = 0.0
ttft: float = 0.0 # Time to first token
itl: List[float] = field(
default_factory=list) # List of inter-token latencies
prompt_len: int = 0
@ -58,23 +58,24 @@ async def async_request_tgi(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
@ -119,23 +120,24 @@ async def async_request_trt_llm(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
@ -151,7 +153,7 @@ async def async_request_trt_llm(
output.success = True
else:
output.error = response.reason
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
output.generated_text = parsed_resp["text"][0]
output.success = True
else:
output.error = response.reason
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
@ -234,19 +236,20 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
@ -255,7 +258,7 @@ async def async_request_openai_completions(
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
@ -315,19 +318,20 @@ async def async_request_openai_chat_completions(
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
@ -337,7 +341,7 @@ async def async_request_openai_chat_completions(
delta = data["choices"][0]["delta"]
if delta.get("content", None):
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
output.success = True
output.latency = latency
else:
output.error = response.reason
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False

View File

@ -12,6 +12,7 @@
import logging
import sys
from typing import List
from sphinx.ext import autodoc
@ -45,7 +46,7 @@ templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
exclude_patterns: List[str] = []
# Exclude the prompt "$" when copying code
copybutton_prompt_text = r"\$ "

View File

@ -5,7 +5,7 @@ import re
import subprocess
import sys
from shutil import which
from typing import List
from typing import Dict, List
import torch
from packaging.version import Version, parse
@ -52,7 +52,7 @@ class CMakeExtension(Extension):
class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured.
did_config = {}
did_config: Dict[str, bool] = {}
#
# Determine number of compilation jobs and optionally nvcc compile threads.
@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version:
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
"""
assert CUDA_HOME is not None, "CUDA_HOME is not set"
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
universal_newlines=True)
output = nvcc_output.split()

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, List, Optional, Protocol
from abc import ABC, abstractmethod
from typing import Dict, FrozenSet, List, Optional, Protocol
from vllm.utils import Device
@ -10,23 +10,28 @@ class Block(ABC):
def append_token_ids(self, token_ids: List[int]) -> None:
pass
@abstractproperty
@property
@abstractmethod
def block_id(self) -> Optional[int]:
pass
@abstractproperty
@property
@abstractmethod
def token_ids(self) -> List[int]:
pass
@abstractproperty
@property
@abstractmethod
def num_empty_slots(self) -> int:
pass
@abstractproperty
@property
@abstractmethod
def is_full(self) -> bool:
pass
@abstractproperty
@property
@abstractmethod
def prev_block(self) -> Optional["Block"]:
pass
@ -47,12 +52,13 @@ class Block(ABC):
class BlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
token_ids: List[int], device: Device) -> Block:
pass
@abstractmethod
@ -64,11 +70,12 @@ class BlockAllocator(ABC):
pass
@abstractmethod
def get_num_free_blocks(self) -> int:
def get_num_free_blocks(self, device: Device) -> int:
pass
@abstractproperty
def all_block_ids(self) -> frozenset[int]:
@property
@abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
pass
@abstractmethod

View File

@ -1,6 +1,6 @@
import time
from dataclasses import dataclass
from typing import Dict, List
from typing import Dict, List, Protocol
import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
@ -119,6 +119,12 @@ class Stats:
time_e2e_requests: List[float]
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
@ -135,7 +141,7 @@ class StatLogger:
self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()))
def info(self, type: str, obj: object) -> None:
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())

View File

@ -4,6 +4,7 @@
import logging
import os
import sys
from typing import Optional
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
@ -26,7 +27,7 @@ class NewLineFormatter(logging.Formatter):
_root_logger = logging.getLogger("vllm")
_default_handler = None
_default_handler: Optional[logging.Handler] = None
def _setup_logger():
@ -55,7 +56,12 @@ def init_logger(name: str):
# Use the same settings as above for root logger
logger = logging.getLogger(name)
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
if VLLM_CONFIGURE_LOGGING:
if _default_handler is None:
raise ValueError(
"_default_handler is not set up. This should never happen!"
" Please open an issue on Github.")
logger.addHandler(_default_handler)
logger.propagate = False
return logger

View File

@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int,
# Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> int:
def _yarn_find_correction_range(
low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> Tuple[int, int]:
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
@ -293,8 +294,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: float = 32,
beta_slow: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor

View File

@ -1,10 +1,10 @@
from typing import Optional
from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import *
_CONFIG_REGISTRY = {
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
"mpt": MPTConfig,

View File

@ -12,7 +12,7 @@ from transformers.utils import logging
logger = logging.get_logger(__name__)
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
class DbrxAttentionConfig(PretrainedConfig):

View File

@ -16,11 +16,11 @@ logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
PRETRAINED_VOCAB_FILES_MAP = {
PRETRAINED_VOCAB_FILES_MAP = { # type: ignore
"vocab_file": {},
"tokenizer_file": {},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} # type: ignore
class BaichuanTokenizer(PreTrainedTokenizer):
@ -148,9 +148,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) "
"should be a directory")
return
raise ValueError(f"Vocabulary path ({save_directory}) "
"should be a directory")
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") +

View File

@ -294,7 +294,7 @@ def create_kv_caches_with_random(
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = 0,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
@ -400,7 +400,7 @@ class CudaMemoryProfiler:
gc.collect()
def str_to_int_tuple(s: str) -> Tuple[int]:
def str_to_int_tuple(s: str) -> Tuple[int, ...]:
"""Convert a string to a tuple of integers."""
try:
return tuple(map(int, s.split(",")))