[Doc] Add typing hints / mypy types cleanup (#3816)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
parent
e46a60aa4c
commit
c2b4a1bce9
@ -27,8 +27,8 @@ class RequestFuncInput:
|
|||||||
class RequestFuncOutput:
|
class RequestFuncOutput:
|
||||||
generated_text: str = ""
|
generated_text: str = ""
|
||||||
success: bool = False
|
success: bool = False
|
||||||
latency: float = 0
|
latency: float = 0.0
|
||||||
ttft: float = 0 # Time to first token
|
ttft: float = 0.0 # Time to first token
|
||||||
itl: List[float] = field(
|
itl: List[float] = field(
|
||||||
default_factory=list) # List of inter-token latencies
|
default_factory=list) # List of inter-token latencies
|
||||||
prompt_len: int = 0
|
prompt_len: int = 0
|
||||||
@ -58,23 +58,24 @@ async def async_request_tgi(
|
|||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
ttft = 0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload) as response:
|
async with session.post(url=api_url, json=payload) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk = chunk.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||||
|
"data:")
|
||||||
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = time.perf_counter() - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
@ -119,23 +120,24 @@ async def async_request_trt_llm(
|
|||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
ttft = 0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload) as response:
|
async with session.post(url=api_url, json=payload) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk = chunk.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||||
|
"data:")
|
||||||
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = time.perf_counter() - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
@ -151,7 +153,7 @@ async def async_request_trt_llm(
|
|||||||
output.success = True
|
output.success = True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
output.error = response.reason
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
except Exception:
|
except Exception:
|
||||||
output.success = False
|
output.success = False
|
||||||
@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
|
|||||||
output.generated_text = parsed_resp["text"][0]
|
output.generated_text = parsed_resp["text"][0]
|
||||||
output.success = True
|
output.success = True
|
||||||
else:
|
else:
|
||||||
output.error = response.reason
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
except Exception:
|
except Exception:
|
||||||
output.success = False
|
output.success = False
|
||||||
@ -234,19 +236,20 @@ async def async_request_openai_completions(
|
|||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
ttft = 0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload,
|
async with session.post(url=api_url, json=payload,
|
||||||
headers=headers) as response:
|
headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk = chunk.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||||
|
"data: ")
|
||||||
if chunk == "[DONE]":
|
if chunk == "[DONE]":
|
||||||
latency = time.perf_counter() - st
|
latency = time.perf_counter() - st
|
||||||
else:
|
else:
|
||||||
@ -255,7 +258,7 @@ async def async_request_openai_completions(
|
|||||||
if data["choices"][0]["text"]:
|
if data["choices"][0]["text"]:
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = time.perf_counter() - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
@ -315,19 +318,20 @@ async def async_request_openai_chat_completions(
|
|||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
ttft = 0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload,
|
async with session.post(url=api_url, json=payload,
|
||||||
headers=headers) as response:
|
headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk = chunk.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||||
|
"data: ")
|
||||||
if chunk == "[DONE]":
|
if chunk == "[DONE]":
|
||||||
latency = time.perf_counter() - st
|
latency = time.perf_counter() - st
|
||||||
else:
|
else:
|
||||||
@ -337,7 +341,7 @@ async def async_request_openai_chat_completions(
|
|||||||
delta = data["choices"][0]["delta"]
|
delta = data["choices"][0]["delta"]
|
||||||
if delta.get("content", None):
|
if delta.get("content", None):
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = time.perf_counter() - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
|
|||||||
output.success = True
|
output.success = True
|
||||||
output.latency = latency
|
output.latency = latency
|
||||||
else:
|
else:
|
||||||
output.error = response.reason
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
except Exception:
|
except Exception:
|
||||||
output.success = False
|
output.success = False
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from sphinx.ext import autodoc
|
from sphinx.ext import autodoc
|
||||||
|
|
||||||
@ -45,7 +46,7 @@ templates_path = ['_templates']
|
|||||||
# List of patterns, relative to source directory, that match files and
|
# List of patterns, relative to source directory, that match files and
|
||||||
# directories to ignore when looking for source files.
|
# directories to ignore when looking for source files.
|
||||||
# This pattern also affects html_static_path and html_extra_path.
|
# This pattern also affects html_static_path and html_extra_path.
|
||||||
exclude_patterns = []
|
exclude_patterns: List[str] = []
|
||||||
|
|
||||||
# Exclude the prompt "$" when copying code
|
# Exclude the prompt "$" when copying code
|
||||||
copybutton_prompt_text = r"\$ "
|
copybutton_prompt_text = r"\$ "
|
||||||
|
5
setup.py
5
setup.py
@ -5,7 +5,7 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from shutil import which
|
from shutil import which
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging.version import Version, parse
|
from packaging.version import Version, parse
|
||||||
@ -52,7 +52,7 @@ class CMakeExtension(Extension):
|
|||||||
|
|
||||||
class cmake_build_ext(build_ext):
|
class cmake_build_ext(build_ext):
|
||||||
# A dict of extension directories that have been configured.
|
# A dict of extension directories that have been configured.
|
||||||
did_config = {}
|
did_config: Dict[str, bool] = {}
|
||||||
|
|
||||||
#
|
#
|
||||||
# Determine number of compilation jobs and optionally nvcc compile threads.
|
# Determine number of compilation jobs and optionally nvcc compile threads.
|
||||||
@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version:
|
|||||||
|
|
||||||
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
|
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
|
||||||
"""
|
"""
|
||||||
|
assert CUDA_HOME is not None, "CUDA_HOME is not set"
|
||||||
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
|
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
|
||||||
universal_newlines=True)
|
universal_newlines=True)
|
||||||
output = nvcc_output.split()
|
output = nvcc_output.split()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod, abstractproperty
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional, Protocol
|
from typing import Dict, FrozenSet, List, Optional, Protocol
|
||||||
|
|
||||||
from vllm.utils import Device
|
from vllm.utils import Device
|
||||||
|
|
||||||
@ -10,23 +10,28 @@ class Block(ABC):
|
|||||||
def append_token_ids(self, token_ids: List[int]) -> None:
|
def append_token_ids(self, token_ids: List[int]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractproperty
|
@property
|
||||||
|
@abstractmethod
|
||||||
def block_id(self) -> Optional[int]:
|
def block_id(self) -> Optional[int]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractproperty
|
@property
|
||||||
|
@abstractmethod
|
||||||
def token_ids(self) -> List[int]:
|
def token_ids(self) -> List[int]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractproperty
|
@property
|
||||||
|
@abstractmethod
|
||||||
def num_empty_slots(self) -> int:
|
def num_empty_slots(self) -> int:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractproperty
|
@property
|
||||||
|
@abstractmethod
|
||||||
def is_full(self) -> bool:
|
def is_full(self) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractproperty
|
@property
|
||||||
|
@abstractmethod
|
||||||
def prev_block(self) -> Optional["Block"]:
|
def prev_block(self) -> Optional["Block"]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -47,12 +52,13 @@ class Block(ABC):
|
|||||||
class BlockAllocator(ABC):
|
class BlockAllocator(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
|
def allocate_mutable(self, prev_block: Optional[Block],
|
||||||
|
device: Device) -> Block:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def allocate_immutable(self, prev_block: Optional[Block],
|
def allocate_immutable(self, prev_block: Optional[Block],
|
||||||
token_ids: List[int]) -> Block:
|
token_ids: List[int], device: Device) -> Block:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -64,11 +70,12 @@ class BlockAllocator(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_num_free_blocks(self) -> int:
|
def get_num_free_blocks(self, device: Device) -> int:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractproperty
|
@property
|
||||||
def all_block_ids(self) -> frozenset[int]:
|
@abstractmethod
|
||||||
|
def all_block_ids(self) -> FrozenSet[int]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
|
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
|
||||||
@ -119,6 +119,12 @@ class Stats:
|
|||||||
time_e2e_requests: List[float]
|
time_e2e_requests: List[float]
|
||||||
|
|
||||||
|
|
||||||
|
class SupportsMetricsInfo(Protocol):
|
||||||
|
|
||||||
|
def metrics_info(self) -> Dict[str, str]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class StatLogger:
|
class StatLogger:
|
||||||
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
|
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
|
||||||
|
|
||||||
@ -135,7 +141,7 @@ class StatLogger:
|
|||||||
self.labels = labels
|
self.labels = labels
|
||||||
self.metrics = Metrics(labelnames=list(labels.keys()))
|
self.metrics = Metrics(labelnames=list(labels.keys()))
|
||||||
|
|
||||||
def info(self, type: str, obj: object) -> None:
|
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||||
if type == "cache_config":
|
if type == "cache_config":
|
||||||
self.metrics.info_cache_config.info(obj.metrics_info())
|
self.metrics.info_cache_config.info(obj.metrics_info())
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
|
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
|
||||||
|
|
||||||
@ -26,7 +27,7 @@ class NewLineFormatter(logging.Formatter):
|
|||||||
|
|
||||||
|
|
||||||
_root_logger = logging.getLogger("vllm")
|
_root_logger = logging.getLogger("vllm")
|
||||||
_default_handler = None
|
_default_handler: Optional[logging.Handler] = None
|
||||||
|
|
||||||
|
|
||||||
def _setup_logger():
|
def _setup_logger():
|
||||||
@ -55,7 +56,12 @@ def init_logger(name: str):
|
|||||||
# Use the same settings as above for root logger
|
# Use the same settings as above for root logger
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
|
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
|
||||||
|
|
||||||
if VLLM_CONFIGURE_LOGGING:
|
if VLLM_CONFIGURE_LOGGING:
|
||||||
|
if _default_handler is None:
|
||||||
|
raise ValueError(
|
||||||
|
"_default_handler is not set up. This should never happen!"
|
||||||
|
" Please open an issue on Github.")
|
||||||
logger.addHandler(_default_handler)
|
logger.addHandler(_default_handler)
|
||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
return logger
|
return logger
|
||||||
|
@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int,
|
|||||||
|
|
||||||
|
|
||||||
# Find dim range bounds based on rotations
|
# Find dim range bounds based on rotations
|
||||||
def _yarn_find_correction_range(low_rot: int,
|
def _yarn_find_correction_range(
|
||||||
|
low_rot: int,
|
||||||
high_rot: int,
|
high_rot: int,
|
||||||
dim: int,
|
dim: int,
|
||||||
base: float = 10000,
|
base: float = 10000,
|
||||||
max_position_embeddings: int = 2048) -> int:
|
max_position_embeddings: int = 2048) -> Tuple[int, int]:
|
||||||
low = math.floor(
|
low = math.floor(
|
||||||
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||||
high = math.ceil(
|
high = math.ceil(
|
||||||
@ -293,8 +294,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
*,
|
*,
|
||||||
extrapolation_factor: float = 1,
|
extrapolation_factor: float = 1,
|
||||||
attn_factor: float = 1,
|
attn_factor: float = 1,
|
||||||
beta_fast: float = 32,
|
beta_fast: int = 32,
|
||||||
beta_slow: float = 1,
|
beta_slow: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.extrapolation_factor = extrapolation_factor
|
self.extrapolation_factor = extrapolation_factor
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
from vllm.transformers_utils.configs import *
|
from vllm.transformers_utils.configs import *
|
||||||
|
|
||||||
_CONFIG_REGISTRY = {
|
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
|
||||||
"chatglm": ChatGLMConfig,
|
"chatglm": ChatGLMConfig,
|
||||||
"dbrx": DbrxConfig,
|
"dbrx": DbrxConfig,
|
||||||
"mpt": MPTConfig,
|
"mpt": MPTConfig,
|
||||||
|
@ -12,7 +12,7 @@ from transformers.utils import logging
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class DbrxAttentionConfig(PretrainedConfig):
|
class DbrxAttentionConfig(PretrainedConfig):
|
||||||
|
@ -16,11 +16,11 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
||||||
|
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = { # type: ignore
|
||||||
"vocab_file": {},
|
"vocab_file": {},
|
||||||
"tokenizer_file": {},
|
"tokenizer_file": {},
|
||||||
}
|
}
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class BaichuanTokenizer(PreTrainedTokenizer):
|
class BaichuanTokenizer(PreTrainedTokenizer):
|
||||||
@ -148,9 +148,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
|
|||||||
`Tuple(str)`: Paths to the files saved.
|
`Tuple(str)`: Paths to the files saved.
|
||||||
"""
|
"""
|
||||||
if not os.path.isdir(save_directory):
|
if not os.path.isdir(save_directory):
|
||||||
logger.error(f"Vocabulary path ({save_directory}) "
|
raise ValueError(f"Vocabulary path ({save_directory}) "
|
||||||
"should be a directory")
|
"should be a directory")
|
||||||
return
|
|
||||||
out_vocab_file = os.path.join(
|
out_vocab_file = os.path.join(
|
||||||
save_directory,
|
save_directory,
|
||||||
(filename_prefix + "-" if filename_prefix else "") +
|
(filename_prefix + "-" if filename_prefix else "") +
|
||||||
|
@ -294,7 +294,7 @@ def create_kv_caches_with_random(
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
cache_dtype: Optional[Union[str, torch.dtype]],
|
cache_dtype: Optional[Union[str, torch.dtype]],
|
||||||
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
seed: Optional[int] = 0,
|
seed: int = 0,
|
||||||
device: Optional[str] = "cuda",
|
device: Optional[str] = "cuda",
|
||||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
@ -400,7 +400,7 @@ class CudaMemoryProfiler:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
def str_to_int_tuple(s: str) -> Tuple[int]:
|
def str_to_int_tuple(s: str) -> Tuple[int, ...]:
|
||||||
"""Convert a string to a tuple of integers."""
|
"""Convert a string to a tuple of integers."""
|
||||||
try:
|
try:
|
||||||
return tuple(map(int, s.split(",")))
|
return tuple(map(int, s.split(",")))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user