[Doc] Add typing hints / mypy types cleanup (#3816)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Michael Feil 2024-04-11 17:17:21 -07:00 committed by GitHub
parent e46a60aa4c
commit c2b4a1bce9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 90 additions and 64 deletions

View File

@ -27,8 +27,8 @@ class RequestFuncInput:
class RequestFuncOutput: class RequestFuncOutput:
generated_text: str = "" generated_text: str = ""
success: bool = False success: bool = False
latency: float = 0 latency: float = 0.0
ttft: float = 0 # Time to first token ttft: float = 0.0 # Time to first token
itl: List[float] = field( itl: List[float] = field(
default_factory=list) # List of inter-token latencies default_factory=list) # List of inter-token latencies
prompt_len: int = 0 prompt_len: int = 0
@ -58,23 +58,24 @@ async def async_request_tgi(
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(url=api_url, json=payload) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk) data = json.loads(chunk)
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
@ -119,23 +120,24 @@ async def async_request_trt_llm(
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(url=api_url, json=payload) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk) data = json.loads(chunk)
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
@ -151,7 +153,7 @@ async def async_request_trt_llm(
output.success = True output.success = True
else: else:
output.error = response.reason output.error = response.reason or ""
output.success = False output.success = False
except Exception: except Exception:
output.success = False output.success = False
@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
output.generated_text = parsed_resp["text"][0] output.generated_text = parsed_resp["text"][0]
output.success = True output.success = True
else: else:
output.error = response.reason output.error = response.reason or ""
output.success = False output.success = False
except Exception: except Exception:
output.success = False output.success = False
@ -234,19 +236,20 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload,
headers=headers) as response: headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data: ") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st latency = time.perf_counter() - st
else: else:
@ -255,7 +258,7 @@ async def async_request_openai_completions(
if data["choices"][0]["text"]: if data["choices"][0]["text"]:
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
@ -315,19 +318,20 @@ async def async_request_openai_chat_completions(
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload,
headers=headers) as response: headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data: ") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st latency = time.perf_counter() - st
else: else:
@ -337,7 +341,7 @@ async def async_request_openai_chat_completions(
delta = data["choices"][0]["delta"] delta = data["choices"][0]["delta"]
if delta.get("content", None): if delta.get("content", None):
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
output.success = True output.success = True
output.latency = latency output.latency = latency
else: else:
output.error = response.reason output.error = response.reason or ""
output.success = False output.success = False
except Exception: except Exception:
output.success = False output.success = False

View File

@ -12,6 +12,7 @@
import logging import logging
import sys import sys
from typing import List
from sphinx.ext import autodoc from sphinx.ext import autodoc
@ -45,7 +46,7 @@ templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path. # This pattern also affects html_static_path and html_extra_path.
exclude_patterns = [] exclude_patterns: List[str] = []
# Exclude the prompt "$" when copying code # Exclude the prompt "$" when copying code
copybutton_prompt_text = r"\$ " copybutton_prompt_text = r"\$ "

View File

@ -5,7 +5,7 @@ import re
import subprocess import subprocess
import sys import sys
from shutil import which from shutil import which
from typing import List from typing import Dict, List
import torch import torch
from packaging.version import Version, parse from packaging.version import Version, parse
@ -52,7 +52,7 @@ class CMakeExtension(Extension):
class cmake_build_ext(build_ext): class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured. # A dict of extension directories that have been configured.
did_config = {} did_config: Dict[str, bool] = {}
# #
# Determine number of compilation jobs and optionally nvcc compile threads. # Determine number of compilation jobs and optionally nvcc compile threads.
@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version:
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
""" """
assert CUDA_HOME is not None, "CUDA_HOME is not set"
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
universal_newlines=True) universal_newlines=True)
output = nvcc_output.split() output = nvcc_output.split()

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Protocol from typing import Dict, FrozenSet, List, Optional, Protocol
from vllm.utils import Device from vllm.utils import Device
@ -10,23 +10,28 @@ class Block(ABC):
def append_token_ids(self, token_ids: List[int]) -> None: def append_token_ids(self, token_ids: List[int]) -> None:
pass pass
@abstractproperty @property
@abstractmethod
def block_id(self) -> Optional[int]: def block_id(self) -> Optional[int]:
pass pass
@abstractproperty @property
@abstractmethod
def token_ids(self) -> List[int]: def token_ids(self) -> List[int]:
pass pass
@abstractproperty @property
@abstractmethod
def num_empty_slots(self) -> int: def num_empty_slots(self) -> int:
pass pass
@abstractproperty @property
@abstractmethod
def is_full(self) -> bool: def is_full(self) -> bool:
pass pass
@abstractproperty @property
@abstractmethod
def prev_block(self) -> Optional["Block"]: def prev_block(self) -> Optional["Block"]:
pass pass
@ -47,12 +52,13 @@ class Block(ABC):
class BlockAllocator(ABC): class BlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block: def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
pass pass
@abstractmethod @abstractmethod
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block: token_ids: List[int], device: Device) -> Block:
pass pass
@abstractmethod @abstractmethod
@ -64,11 +70,12 @@ class BlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self, device: Device) -> int:
pass pass
@abstractproperty @property
def all_block_ids(self) -> frozenset[int]: @abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
pass pass
@abstractmethod @abstractmethod

View File

@ -1,6 +1,6 @@
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List from typing import Dict, List, Protocol
import numpy as np import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
@ -119,6 +119,12 @@ class Stats:
time_e2e_requests: List[float] time_e2e_requests: List[float]
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
class StatLogger: class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout.""" """StatLogger is used LLMEngine to log to Promethus and Stdout."""
@ -135,7 +141,7 @@ class StatLogger:
self.labels = labels self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys())) self.metrics = Metrics(labelnames=list(labels.keys()))
def info(self, type: str, obj: object) -> None: def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config": if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info()) self.metrics.info_cache_config.info(obj.metrics_info())

View File

@ -4,6 +4,7 @@
import logging import logging
import os import os
import sys import sys
from typing import Optional
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
@ -26,7 +27,7 @@ class NewLineFormatter(logging.Formatter):
_root_logger = logging.getLogger("vllm") _root_logger = logging.getLogger("vllm")
_default_handler = None _default_handler: Optional[logging.Handler] = None
def _setup_logger(): def _setup_logger():
@ -55,7 +56,12 @@ def init_logger(name: str):
# Use the same settings as above for root logger # Use the same settings as above for root logger
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG")) logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
if VLLM_CONFIGURE_LOGGING: if VLLM_CONFIGURE_LOGGING:
if _default_handler is None:
raise ValueError(
"_default_handler is not set up. This should never happen!"
" Please open an issue on Github.")
logger.addHandler(_default_handler) logger.addHandler(_default_handler)
logger.propagate = False logger.propagate = False
return logger return logger

View File

@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int,
# Find dim range bounds based on rotations # Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot: int, def _yarn_find_correction_range(
low_rot: int,
high_rot: int, high_rot: int,
dim: int, dim: int,
base: float = 10000, base: float = 10000,
max_position_embeddings: int = 2048) -> int: max_position_embeddings: int = 2048) -> Tuple[int, int]:
low = math.floor( low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil( high = math.ceil(
@ -293,8 +294,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
*, *,
extrapolation_factor: float = 1, extrapolation_factor: float = 1,
attn_factor: float = 1, attn_factor: float = 1,
beta_fast: float = 32, beta_fast: int = 32,
beta_slow: float = 1, beta_slow: int = 1,
) -> None: ) -> None:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor self.extrapolation_factor = extrapolation_factor

View File

@ -1,10 +1,10 @@
from typing import Optional from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * from vllm.transformers_utils.configs import *
_CONFIG_REGISTRY = { _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"dbrx": DbrxConfig, "dbrx": DbrxConfig,
"mpt": MPTConfig, "mpt": MPTConfig,

View File

@ -12,7 +12,7 @@ from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
class DbrxAttentionConfig(PretrainedConfig): class DbrxAttentionConfig(PretrainedConfig):

View File

@ -16,11 +16,11 @@ logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = { # type: ignore
"vocab_file": {}, "vocab_file": {},
"tokenizer_file": {}, "tokenizer_file": {},
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} # type: ignore
class BaichuanTokenizer(PreTrainedTokenizer): class BaichuanTokenizer(PreTrainedTokenizer):
@ -148,9 +148,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
`Tuple(str)`: Paths to the files saved. `Tuple(str)`: Paths to the files saved.
""" """
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) " raise ValueError(f"Vocabulary path ({save_directory}) "
"should be a directory") "should be a directory")
return
out_vocab_file = os.path.join( out_vocab_file = os.path.join(
save_directory, save_directory,
(filename_prefix + "-" if filename_prefix else "") + (filename_prefix + "-" if filename_prefix else "") +

View File

@ -294,7 +294,7 @@ def create_kv_caches_with_random(
head_size: int, head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]], cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None, model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = 0, seed: int = 0,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
@ -400,7 +400,7 @@ class CudaMemoryProfiler:
gc.collect() gc.collect()
def str_to_int_tuple(s: str) -> Tuple[int]: def str_to_int_tuple(s: str) -> Tuple[int, ...]:
"""Convert a string to a tuple of integers.""" """Convert a string to a tuple of integers."""
try: try:
return tuple(map(int, s.split(","))) return tuple(map(int, s.split(",")))