Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
parent
bf33700ecd
commit
cf069aa8aa
@ -6,7 +6,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import huggingface_hub.constants
|
import huggingface_hub.constants
|
||||||
@ -41,8 +41,8 @@ class RequestFuncOutput:
|
|||||||
latency: float = 0.0
|
latency: float = 0.0
|
||||||
output_tokens: int = 0
|
output_tokens: int = 0
|
||||||
ttft: float = 0.0 # Time to first token
|
ttft: float = 0.0 # Time to first token
|
||||||
itl: List[float] = field(
|
itl: list[float] = field(
|
||||||
default_factory=list) # List of inter-token latencies
|
default_factory=list) # list of inter-token latencies
|
||||||
tpot: float = 0.0 # avg next-token latencies
|
tpot: float = 0.0 # avg next-token latencies
|
||||||
prompt_len: int = 0
|
prompt_len: int = 0
|
||||||
error: str = ""
|
error: str = ""
|
||||||
|
@ -6,7 +6,6 @@ import json
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -39,7 +38,7 @@ class SampleRequest:
|
|||||||
completion: str = None
|
completion: str = None
|
||||||
|
|
||||||
|
|
||||||
def run_vllm(requests: List[SampleRequest],
|
def run_vllm(requests: list[SampleRequest],
|
||||||
engine_args: EngineArgs,
|
engine_args: EngineArgs,
|
||||||
n: int,
|
n: int,
|
||||||
guided_decoding_rate: float = 1.0,
|
guided_decoding_rate: float = 1.0,
|
||||||
@ -54,8 +53,8 @@ def run_vllm(requests: List[SampleRequest],
|
|||||||
" prompt_len and expected_output_len for all requests.")
|
" prompt_len and expected_output_len for all requests.")
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts: List[str] = []
|
prompts: list[str] = []
|
||||||
sampling_params: List[SamplingParams] = []
|
sampling_params: list[SamplingParams] = []
|
||||||
# create a list containing random selected true or false
|
# create a list containing random selected true or false
|
||||||
guided_decoding_req_idx = random.sample(
|
guided_decoding_req_idx = random.sample(
|
||||||
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
||||||
@ -110,7 +109,7 @@ def run_vllm(requests: List[SampleRequest],
|
|||||||
|
|
||||||
|
|
||||||
async def run_vllm_async(
|
async def run_vllm_async(
|
||||||
requests: List[SampleRequest],
|
requests: list[SampleRequest],
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
n: int,
|
n: int,
|
||||||
guided_decoding_rate: float = 1.0,
|
guided_decoding_rate: float = 1.0,
|
||||||
@ -129,8 +128,8 @@ async def run_vllm_async(
|
|||||||
" prompt_len and expected_output_len for all requests.")
|
" prompt_len and expected_output_len for all requests.")
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts: List[str] = []
|
prompts: list[str] = []
|
||||||
sampling_params: List[SamplingParams] = []
|
sampling_params: list[SamplingParams] = []
|
||||||
guided_decoding_req_idx = random.sample(
|
guided_decoding_req_idx = random.sample(
|
||||||
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
||||||
|
|
||||||
@ -203,7 +202,7 @@ async def run_vllm_async(
|
|||||||
|
|
||||||
|
|
||||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||||
args: argparse.Namespace) -> List[SampleRequest]:
|
args: argparse.Namespace) -> list[SampleRequest]:
|
||||||
if args.dataset == 'json':
|
if args.dataset == 'json':
|
||||||
if args.json_schema_path is None:
|
if args.json_schema_path is None:
|
||||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
@ -287,7 +286,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
|||||||
|
|
||||||
elif args.dataset == "xgrammar_bench":
|
elif args.dataset == "xgrammar_bench":
|
||||||
args.warmup = False
|
args.warmup = False
|
||||||
requests: List[SampleRequest] = []
|
requests: list[SampleRequest] = []
|
||||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||||
split="train")
|
split="train")
|
||||||
print(f"dataset has {len(dataset)} entries")
|
print(f"dataset has {len(dataset)} entries")
|
||||||
|
@ -7,7 +7,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -22,7 +22,7 @@ from vllm.utils import FlexibleArgumentParser
|
|||||||
|
|
||||||
|
|
||||||
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||||
results: Dict[str, Any]) -> None:
|
results: dict[str, Any]) -> None:
|
||||||
pt_records = convert_to_pytorch_benchmark_format(
|
pt_records = convert_to_pytorch_benchmark_format(
|
||||||
args=args,
|
args=args,
|
||||||
metrics={"latency": results["latencies"]},
|
metrics={"latency": results["latencies"]},
|
||||||
@ -57,7 +57,7 @@ def main(args: argparse.Namespace):
|
|||||||
dummy_prompt_token_ids = np.random.randint(10000,
|
dummy_prompt_token_ids = np.random.randint(10000,
|
||||||
size=(args.batch_size,
|
size=(args.batch_size,
|
||||||
args.input_len))
|
args.input_len))
|
||||||
dummy_prompts: List[PromptType] = [{
|
dummy_prompts: list[PromptType] = [{
|
||||||
"prompt_token_ids": batch
|
"prompt_token_ids": batch
|
||||||
} for batch in dummy_prompt_token_ids.tolist()]
|
} for batch in dummy_prompt_token_ids.tolist()]
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ import dataclasses
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
@ -77,9 +77,9 @@ def sample_requests_from_dataset(
|
|||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
num_requests: int,
|
num_requests: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
input_length_range: Tuple[int, int],
|
input_length_range: tuple[int, int],
|
||||||
fixed_output_len: Optional[int],
|
fixed_output_len: Optional[int],
|
||||||
) -> List[Request]:
|
) -> list[Request]:
|
||||||
if fixed_output_len is not None and fixed_output_len < 4:
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
raise ValueError("output_len too small")
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
@ -99,7 +99,7 @@ def sample_requests_from_dataset(
|
|||||||
assert min_len >= 0 and max_len >= min_len, "input_length_range too small"
|
assert min_len >= 0 and max_len >= min_len, "input_length_range too small"
|
||||||
|
|
||||||
# Filter out sequences that are too long or too short
|
# Filter out sequences that are too long or too short
|
||||||
filtered_requests: List[Request] = []
|
filtered_requests: list[Request] = []
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
if len(filtered_requests) == num_requests:
|
if len(filtered_requests) == num_requests:
|
||||||
@ -122,10 +122,10 @@ def sample_requests_from_dataset(
|
|||||||
def sample_requests_from_random(
|
def sample_requests_from_random(
|
||||||
num_requests: int,
|
num_requests: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
input_length_range: Tuple[int, int],
|
input_length_range: tuple[int, int],
|
||||||
fixed_output_len: Optional[int],
|
fixed_output_len: Optional[int],
|
||||||
prefix_len: int,
|
prefix_len: int,
|
||||||
) -> List[Request]:
|
) -> list[Request]:
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
|
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
|
||||||
@ -144,9 +144,9 @@ def sample_requests_from_random(
|
|||||||
return requests
|
return requests
|
||||||
|
|
||||||
|
|
||||||
def repeat_and_sort_requests(requests: List[Request],
|
def repeat_and_sort_requests(requests: list[Request],
|
||||||
repeat_count: int,
|
repeat_count: int,
|
||||||
sort: bool = False) -> List[str]:
|
sort: bool = False) -> list[str]:
|
||||||
repeated_requests = requests * repeat_count
|
repeated_requests = requests * repeat_count
|
||||||
if sort:
|
if sort:
|
||||||
repeated_requests.sort(key=lambda x: x[1])
|
repeated_requests.sort(key=lambda x: x[1])
|
||||||
|
@ -5,7 +5,7 @@ import dataclasses
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ def sample_requests(
|
|||||||
num_requests: int,
|
num_requests: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
fixed_output_len: Optional[int],
|
fixed_output_len: Optional[int],
|
||||||
) -> List[Tuple[str, int, int]]:
|
) -> list[tuple[str, int, int]]:
|
||||||
if fixed_output_len is not None and fixed_output_len < 4:
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
raise ValueError("output_len too small")
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ def sample_requests(
|
|||||||
random.shuffle(dataset)
|
random.shuffle(dataset)
|
||||||
|
|
||||||
# Filter out sequences that are too long or too short
|
# Filter out sequences that are too long or too short
|
||||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
filtered_dataset: list[tuple[str, int, int]] = []
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
if len(filtered_dataset) == num_requests:
|
if len(filtered_dataset) == num_requests:
|
||||||
break
|
break
|
||||||
@ -68,7 +68,7 @@ def sample_requests(
|
|||||||
|
|
||||||
|
|
||||||
def run_vllm(
|
def run_vllm(
|
||||||
requests: List[Tuple[str, int, int]],
|
requests: list[tuple[str, int, int]],
|
||||||
n: int,
|
n: int,
|
||||||
engine_args: EngineArgs,
|
engine_args: EngineArgs,
|
||||||
) -> float:
|
) -> float:
|
||||||
|
@ -33,9 +33,10 @@ import os
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import AsyncGenerator, Collection
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -73,22 +74,22 @@ class BenchmarkMetrics:
|
|||||||
mean_ttft_ms: float
|
mean_ttft_ms: float
|
||||||
median_ttft_ms: float
|
median_ttft_ms: float
|
||||||
std_ttft_ms: float
|
std_ttft_ms: float
|
||||||
percentiles_ttft_ms: List[Tuple[float, float]]
|
percentiles_ttft_ms: list[tuple[float, float]]
|
||||||
mean_tpot_ms: float
|
mean_tpot_ms: float
|
||||||
median_tpot_ms: float
|
median_tpot_ms: float
|
||||||
std_tpot_ms: float
|
std_tpot_ms: float
|
||||||
percentiles_tpot_ms: List[Tuple[float, float]]
|
percentiles_tpot_ms: list[tuple[float, float]]
|
||||||
mean_itl_ms: float
|
mean_itl_ms: float
|
||||||
median_itl_ms: float
|
median_itl_ms: float
|
||||||
std_itl_ms: float
|
std_itl_ms: float
|
||||||
percentiles_itl_ms: List[Tuple[float, float]]
|
percentiles_itl_ms: list[tuple[float, float]]
|
||||||
# E2EL stands for end-to-end latency per request.
|
# E2EL stands for end-to-end latency per request.
|
||||||
# It is the time taken on the client side from sending
|
# It is the time taken on the client side from sending
|
||||||
# a request to receiving a complete response.
|
# a request to receiving a complete response.
|
||||||
mean_e2el_ms: float
|
mean_e2el_ms: float
|
||||||
median_e2el_ms: float
|
median_e2el_ms: float
|
||||||
std_e2el_ms: float
|
std_e2el_ms: float
|
||||||
percentiles_e2el_ms: List[Tuple[float, float]]
|
percentiles_e2el_ms: list[tuple[float, float]]
|
||||||
|
|
||||||
|
|
||||||
def sample_sharegpt_requests(
|
def sample_sharegpt_requests(
|
||||||
@ -96,7 +97,7 @@ def sample_sharegpt_requests(
|
|||||||
num_requests: int,
|
num_requests: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
fixed_output_len: Optional[int] = None,
|
fixed_output_len: Optional[int] = None,
|
||||||
) -> List[Tuple[str, int, int, None]]:
|
) -> list[tuple[str, int, int, None]]:
|
||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
with open(dataset_path, encoding='utf-8') as f:
|
with open(dataset_path, encoding='utf-8') as f:
|
||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
@ -110,7 +111,7 @@ def sample_sharegpt_requests(
|
|||||||
random.shuffle(dataset)
|
random.shuffle(dataset)
|
||||||
|
|
||||||
# Filter out sequences that are too long or too short
|
# Filter out sequences that are too long or too short
|
||||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
filtered_dataset: list[tuple[str, int, int]] = []
|
||||||
for i in range(len(dataset)):
|
for i in range(len(dataset)):
|
||||||
if len(filtered_dataset) == num_requests:
|
if len(filtered_dataset) == num_requests:
|
||||||
break
|
break
|
||||||
@ -139,7 +140,7 @@ def sample_burstgpt_requests(
|
|||||||
num_requests: int,
|
num_requests: int,
|
||||||
random_seed: int,
|
random_seed: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> List[Tuple[str, int, int, None]]:
|
) -> list[tuple[str, int, int, None]]:
|
||||||
df = pd.read_csv(dataset_path)
|
df = pd.read_csv(dataset_path)
|
||||||
gpt4_df = df[df["Model"] == "GPT-4"]
|
gpt4_df = df[df["Model"] == "GPT-4"]
|
||||||
# Remove the failed requests (i.e., response length is 0)
|
# Remove the failed requests (i.e., response length is 0)
|
||||||
@ -170,7 +171,7 @@ def sample_sonnet_requests(
|
|||||||
output_len: int,
|
output_len: int,
|
||||||
prefix_len: int,
|
prefix_len: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> List[Tuple[str, str, int, int, None]]:
|
) -> list[tuple[str, str, int, int, None]]:
|
||||||
assert (
|
assert (
|
||||||
input_len > prefix_len
|
input_len > prefix_len
|
||||||
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
|
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
|
||||||
@ -211,7 +212,7 @@ def sample_sonnet_requests(
|
|||||||
prefix_lines = poem_lines[:num_prefix_lines]
|
prefix_lines = poem_lines[:num_prefix_lines]
|
||||||
|
|
||||||
# Sample the rest of lines per request.
|
# Sample the rest of lines per request.
|
||||||
sampled_requests: List[Tuple[str, int, int]] = []
|
sampled_requests: list[tuple[str, int, int]] = []
|
||||||
for _ in range(num_requests):
|
for _ in range(num_requests):
|
||||||
num_lines_needed = num_input_lines - num_prefix_lines
|
num_lines_needed = num_input_lines - num_prefix_lines
|
||||||
sampled_lines = "".join(prefix_lines +
|
sampled_lines = "".join(prefix_lines +
|
||||||
@ -238,8 +239,8 @@ def sample_vision_arena_requests(
|
|||||||
num_requests: int,
|
num_requests: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
fixed_output_len: Optional[int] = None,
|
fixed_output_len: Optional[int] = None,
|
||||||
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
|
) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]:
|
||||||
sampled_requests: List[Tuple[str, int, int, Dict[str,
|
sampled_requests: list[tuple[str, int, int, dict[str,
|
||||||
Collection[str]]]] = []
|
Collection[str]]]] = []
|
||||||
for data in dataset:
|
for data in dataset:
|
||||||
if len(sampled_requests) == num_requests:
|
if len(sampled_requests) == num_requests:
|
||||||
@ -285,7 +286,7 @@ def sample_hf_requests(
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
random_seed: int,
|
random_seed: int,
|
||||||
fixed_output_len: Optional[int] = None,
|
fixed_output_len: Optional[int] = None,
|
||||||
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
|
) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]:
|
||||||
|
|
||||||
# Special case for vision_arena dataset
|
# Special case for vision_arena dataset
|
||||||
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \
|
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \
|
||||||
@ -307,7 +308,7 @@ def sample_hf_requests(
|
|||||||
"HF Dataset must have 'conversations' column.")
|
"HF Dataset must have 'conversations' column.")
|
||||||
filter_func = lambda x: len(x["conversations"]) >= 2
|
filter_func = lambda x: len(x["conversations"]) >= 2
|
||||||
filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
|
filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
|
||||||
sampled_requests: List[Tuple[str, int, int, Dict[str,
|
sampled_requests: list[tuple[str, int, int, dict[str,
|
||||||
Collection[str]]]] = []
|
Collection[str]]]] = []
|
||||||
for data in filtered_dataset:
|
for data in filtered_dataset:
|
||||||
if len(sampled_requests) == num_requests:
|
if len(sampled_requests) == num_requests:
|
||||||
@ -370,7 +371,7 @@ def sample_random_requests(
|
|||||||
num_prompts: int,
|
num_prompts: int,
|
||||||
range_ratio: float,
|
range_ratio: float,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> List[Tuple[str, int, int]]:
|
) -> list[tuple[str, int, int]]:
|
||||||
prefix_token_ids = np.random.randint(0,
|
prefix_token_ids = np.random.randint(0,
|
||||||
tokenizer.vocab_size,
|
tokenizer.vocab_size,
|
||||||
size=prefix_len).tolist()
|
size=prefix_len).tolist()
|
||||||
@ -399,10 +400,10 @@ def sample_random_requests(
|
|||||||
|
|
||||||
|
|
||||||
async def get_request(
|
async def get_request(
|
||||||
input_requests: List[Tuple[str, int, int]],
|
input_requests: list[tuple[str, int, int]],
|
||||||
request_rate: float,
|
request_rate: float,
|
||||||
burstiness: float = 1.0,
|
burstiness: float = 1.0,
|
||||||
) -> AsyncGenerator[Tuple[str, int, int], None]:
|
) -> AsyncGenerator[tuple[str, int, int], None]:
|
||||||
"""
|
"""
|
||||||
Asynchronously generates requests at a specified rate
|
Asynchronously generates requests at a specified rate
|
||||||
with OPTIONAL burstiness.
|
with OPTIONAL burstiness.
|
||||||
@ -443,23 +444,23 @@ async def get_request(
|
|||||||
|
|
||||||
|
|
||||||
def calculate_metrics(
|
def calculate_metrics(
|
||||||
input_requests: List[Tuple[str, int, int]],
|
input_requests: list[tuple[str, int, int]],
|
||||||
outputs: List[RequestFuncOutput],
|
outputs: list[RequestFuncOutput],
|
||||||
dur_s: float,
|
dur_s: float,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
selected_percentile_metrics: List[str],
|
selected_percentile_metrics: list[str],
|
||||||
selected_percentiles: List[float],
|
selected_percentiles: list[float],
|
||||||
goodput_config_dict: Dict[str, float],
|
goodput_config_dict: dict[str, float],
|
||||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
) -> tuple[BenchmarkMetrics, list[int]]:
|
||||||
actual_output_lens: List[int] = []
|
actual_output_lens: list[int] = []
|
||||||
total_input = 0
|
total_input = 0
|
||||||
completed = 0
|
completed = 0
|
||||||
good_completed = 0
|
good_completed = 0
|
||||||
itls: List[float] = []
|
itls: list[float] = []
|
||||||
tpots: List[float] = []
|
tpots: list[float] = []
|
||||||
all_tpots: List[float] = []
|
all_tpots: list[float] = []
|
||||||
ttfts: List[float] = []
|
ttfts: list[float] = []
|
||||||
e2els: List[float] = []
|
e2els: list[float] = []
|
||||||
for i in range(len(outputs)):
|
for i in range(len(outputs)):
|
||||||
if outputs[i].success:
|
if outputs[i].success:
|
||||||
output_len = outputs[i].output_tokens
|
output_len = outputs[i].output_tokens
|
||||||
@ -557,19 +558,19 @@ async def benchmark(
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
input_requests: List[Tuple[str, int, int]],
|
input_requests: list[tuple[str, int, int]],
|
||||||
logprobs: Optional[int],
|
logprobs: Optional[int],
|
||||||
best_of: int,
|
best_of: int,
|
||||||
request_rate: float,
|
request_rate: float,
|
||||||
burstiness: float,
|
burstiness: float,
|
||||||
disable_tqdm: bool,
|
disable_tqdm: bool,
|
||||||
profile: bool,
|
profile: bool,
|
||||||
selected_percentile_metrics: List[str],
|
selected_percentile_metrics: list[str],
|
||||||
selected_percentiles: List[str],
|
selected_percentiles: list[str],
|
||||||
ignore_eos: bool,
|
ignore_eos: bool,
|
||||||
goodput_config_dict: Dict[str, float],
|
goodput_config_dict: dict[str, float],
|
||||||
max_concurrency: Optional[int],
|
max_concurrency: Optional[int],
|
||||||
lora_modules: Optional[List[str]],
|
lora_modules: Optional[list[str]],
|
||||||
):
|
):
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
@ -652,7 +653,7 @@ async def benchmark(
|
|||||||
pbar=pbar)
|
pbar=pbar)
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
tasks: List[asyncio.Task] = []
|
tasks: list[asyncio.Task] = []
|
||||||
async for request in get_request(input_requests, request_rate, burstiness):
|
async for request in get_request(input_requests, request_rate, burstiness):
|
||||||
prompt, prompt_len, output_len, mm_content = request
|
prompt, prompt_len, output_len, mm_content = request
|
||||||
req_model_id, req_model_name = model_id, model_name
|
req_model_id, req_model_name = model_id, model_name
|
||||||
@ -674,7 +675,7 @@ async def benchmark(
|
|||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
limited_request_func(request_func_input=request_func_input,
|
limited_request_func(request_func_input=request_func_input,
|
||||||
pbar=pbar)))
|
pbar=pbar)))
|
||||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
print("Stopping profiler...")
|
print("Stopping profiler...")
|
||||||
@ -820,7 +821,7 @@ def parse_goodput(slo_pairs):
|
|||||||
|
|
||||||
|
|
||||||
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||||
results: Dict[str, Any],
|
results: dict[str, Any],
|
||||||
file_name: str) -> None:
|
file_name: str) -> None:
|
||||||
metrics = [
|
metrics = [
|
||||||
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
|
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
|
||||||
@ -974,7 +975,7 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
# Save config and results to json
|
# Save config and results to json
|
||||||
if args.save_result:
|
if args.save_result:
|
||||||
result_json: Dict[str, Any] = {}
|
result_json: dict[str, Any] = {}
|
||||||
|
|
||||||
# Setup
|
# Setup
|
||||||
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
@ -30,8 +30,9 @@ import os
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -66,22 +67,22 @@ class BenchmarkMetrics:
|
|||||||
mean_ttft_ms: float
|
mean_ttft_ms: float
|
||||||
median_ttft_ms: float
|
median_ttft_ms: float
|
||||||
std_ttft_ms: float
|
std_ttft_ms: float
|
||||||
percentiles_ttft_ms: List[Tuple[float, float]]
|
percentiles_ttft_ms: list[tuple[float, float]]
|
||||||
mean_tpot_ms: float
|
mean_tpot_ms: float
|
||||||
median_tpot_ms: float
|
median_tpot_ms: float
|
||||||
std_tpot_ms: float
|
std_tpot_ms: float
|
||||||
percentiles_tpot_ms: List[Tuple[float, float]]
|
percentiles_tpot_ms: list[tuple[float, float]]
|
||||||
mean_itl_ms: float
|
mean_itl_ms: float
|
||||||
median_itl_ms: float
|
median_itl_ms: float
|
||||||
std_itl_ms: float
|
std_itl_ms: float
|
||||||
percentiles_itl_ms: List[Tuple[float, float]]
|
percentiles_itl_ms: list[tuple[float, float]]
|
||||||
# E2EL stands for end-to-end latency per request.
|
# E2EL stands for end-to-end latency per request.
|
||||||
# It is the time taken on the client side from sending
|
# It is the time taken on the client side from sending
|
||||||
# a request to receiving a complete response.
|
# a request to receiving a complete response.
|
||||||
mean_e2el_ms: float
|
mean_e2el_ms: float
|
||||||
median_e2el_ms: float
|
median_e2el_ms: float
|
||||||
std_e2el_ms: float
|
std_e2el_ms: float
|
||||||
percentiles_e2el_ms: List[Tuple[float, float]]
|
percentiles_e2el_ms: list[tuple[float, float]]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@ -104,7 +105,7 @@ class SampleRequest:
|
|||||||
|
|
||||||
|
|
||||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||||
args: argparse.Namespace) -> List[SampleRequest]:
|
args: argparse.Namespace) -> list[SampleRequest]:
|
||||||
if args.dataset == 'json':
|
if args.dataset == 'json':
|
||||||
if args.json_schema_path is None:
|
if args.json_schema_path is None:
|
||||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
@ -187,7 +188,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
|||||||
]
|
]
|
||||||
|
|
||||||
elif args.dataset == "xgrammar_bench":
|
elif args.dataset == "xgrammar_bench":
|
||||||
requests: List[SampleRequest] = []
|
requests: list[SampleRequest] = []
|
||||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||||
split="train")
|
split="train")
|
||||||
print(f"dataset has {len(dataset)} entries")
|
print(f"dataset has {len(dataset)} entries")
|
||||||
@ -214,10 +215,10 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
|||||||
|
|
||||||
|
|
||||||
async def get_request(
|
async def get_request(
|
||||||
input_requests: List[SampleRequest],
|
input_requests: list[SampleRequest],
|
||||||
request_rate: float,
|
request_rate: float,
|
||||||
burstiness: float = 1.0,
|
burstiness: float = 1.0,
|
||||||
) -> AsyncGenerator[Tuple[int, SampleRequest], None]:
|
) -> AsyncGenerator[tuple[int, SampleRequest], None]:
|
||||||
"""
|
"""
|
||||||
Asynchronously generates requests at a specified rate
|
Asynchronously generates requests at a specified rate
|
||||||
with OPTIONAL burstiness.
|
with OPTIONAL burstiness.
|
||||||
@ -258,23 +259,23 @@ async def get_request(
|
|||||||
|
|
||||||
|
|
||||||
def calculate_metrics(
|
def calculate_metrics(
|
||||||
input_requests: List[Tuple[str, int, int]],
|
input_requests: list[tuple[str, int, int]],
|
||||||
outputs: List[RequestFuncOutput],
|
outputs: list[RequestFuncOutput],
|
||||||
dur_s: float,
|
dur_s: float,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
selected_percentile_metrics: List[str],
|
selected_percentile_metrics: list[str],
|
||||||
selected_percentiles: List[float],
|
selected_percentiles: list[float],
|
||||||
goodput_config_dict: Optional[Dict[str, float]] = None,
|
goodput_config_dict: Optional[dict[str, float]] = None,
|
||||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
) -> tuple[BenchmarkMetrics, list[int]]:
|
||||||
actual_output_lens: List[int] = []
|
actual_output_lens: list[int] = []
|
||||||
total_input = 0
|
total_input = 0
|
||||||
completed = 0
|
completed = 0
|
||||||
good_completed = 0
|
good_completed = 0
|
||||||
itls: List[float] = []
|
itls: list[float] = []
|
||||||
tpots: List[float] = []
|
tpots: list[float] = []
|
||||||
all_tpots: List[float] = []
|
all_tpots: list[float] = []
|
||||||
ttfts: List[float] = []
|
ttfts: list[float] = []
|
||||||
e2els: List[float] = []
|
e2els: list[float] = []
|
||||||
for i in range(len(outputs)):
|
for i in range(len(outputs)):
|
||||||
if outputs[i].success:
|
if outputs[i].success:
|
||||||
# We use the tokenizer to count the number of output tokens for all
|
# We use the tokenizer to count the number of output tokens for all
|
||||||
@ -368,18 +369,18 @@ async def benchmark(
|
|||||||
base_url: str,
|
base_url: str,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
input_requests: List[SampleRequest],
|
input_requests: list[SampleRequest],
|
||||||
request_rate: float,
|
request_rate: float,
|
||||||
burstiness: float,
|
burstiness: float,
|
||||||
disable_tqdm: bool,
|
disable_tqdm: bool,
|
||||||
profile: bool,
|
profile: bool,
|
||||||
selected_percentile_metrics: List[str],
|
selected_percentile_metrics: list[str],
|
||||||
selected_percentiles: List[str],
|
selected_percentiles: list[str],
|
||||||
ignore_eos: bool,
|
ignore_eos: bool,
|
||||||
max_concurrency: Optional[int],
|
max_concurrency: Optional[int],
|
||||||
guided_decoding_ratio: float,
|
guided_decoding_ratio: float,
|
||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
goodput_config_dict: Optional[Dict[str, float]] = None,
|
goodput_config_dict: Optional[dict[str, float]] = None,
|
||||||
):
|
):
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
@ -459,8 +460,8 @@ async def benchmark(
|
|||||||
pbar=pbar)
|
pbar=pbar)
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
tasks: List[asyncio.Task] = []
|
tasks: list[asyncio.Task] = []
|
||||||
expected: List[str] = []
|
expected: list[str] = []
|
||||||
async for i, request in get_request(input_requests, request_rate,
|
async for i, request in get_request(input_requests, request_rate,
|
||||||
burstiness):
|
burstiness):
|
||||||
extra_body = prepare_extra_body(
|
extra_body = prepare_extra_body(
|
||||||
@ -479,7 +480,7 @@ async def benchmark(
|
|||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
limited_request_func(request_func_input=request_func_input,
|
limited_request_func(request_func_input=request_func_input,
|
||||||
pbar=pbar)))
|
pbar=pbar)))
|
||||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
print("Stopping profiler...")
|
print("Stopping profiler...")
|
||||||
|
@ -7,7 +7,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvloop
|
import uvloop
|
||||||
@ -74,12 +74,12 @@ def lora_path_on_disk(lora_path: str) -> str:
|
|||||||
return get_adapter_absolute_path(lora_path)
|
return get_adapter_absolute_path(lora_path)
|
||||||
|
|
||||||
|
|
||||||
lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}
|
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_random_lora_request(
|
def get_random_lora_request(
|
||||||
args: argparse.Namespace
|
args: argparse.Namespace
|
||||||
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
|
) -> tuple[LoRARequest, Optional[AnyTokenizer]]:
|
||||||
global lora_tokenizer_cache
|
global lora_tokenizer_cache
|
||||||
lora_id = random.randint(1, args.max_loras)
|
lora_id = random.randint(1, args.max_loras)
|
||||||
lora_request = LoRARequest(lora_name=str(lora_id),
|
lora_request = LoRARequest(lora_name=str(lora_id),
|
||||||
@ -91,7 +91,7 @@ def get_random_lora_request(
|
|||||||
|
|
||||||
|
|
||||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||||
args: argparse.Namespace) -> List[SampleRequest]:
|
args: argparse.Namespace) -> list[SampleRequest]:
|
||||||
|
|
||||||
dataset_path: str = args.dataset
|
dataset_path: str = args.dataset
|
||||||
num_requests: int = args.num_prompts
|
num_requests: int = args.num_prompts
|
||||||
@ -109,7 +109,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
|||||||
random.shuffle(dataset)
|
random.shuffle(dataset)
|
||||||
|
|
||||||
# Filter out sequences that are too long or too short
|
# Filter out sequences that are too long or too short
|
||||||
filtered_dataset: List[SampleRequest] = []
|
filtered_dataset: list[SampleRequest] = []
|
||||||
for data in tqdm(dataset,
|
for data in tqdm(dataset,
|
||||||
total=len(filtered_dataset),
|
total=len(filtered_dataset),
|
||||||
desc="sampling requests"):
|
desc="sampling requests"):
|
||||||
@ -165,7 +165,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
|||||||
|
|
||||||
|
|
||||||
def run_vllm(
|
def run_vllm(
|
||||||
requests: List[SampleRequest],
|
requests: list[SampleRequest],
|
||||||
n: int,
|
n: int,
|
||||||
engine_args: EngineArgs,
|
engine_args: EngineArgs,
|
||||||
) -> float:
|
) -> float:
|
||||||
@ -178,8 +178,8 @@ def run_vllm(
|
|||||||
"Please ensure that max_model_len is greater than the sum of"
|
"Please ensure that max_model_len is greater than the sum of"
|
||||||
" prompt_len and expected_output_len for all requests.")
|
" prompt_len and expected_output_len for all requests.")
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts: List[TextPrompt] = []
|
prompts: list[TextPrompt] = []
|
||||||
sampling_params: List[SamplingParams] = []
|
sampling_params: list[SamplingParams] = []
|
||||||
for request in requests:
|
for request in requests:
|
||||||
prompts.append(
|
prompts.append(
|
||||||
TextPrompt(prompt=request.prompt,
|
TextPrompt(prompt=request.prompt,
|
||||||
@ -192,7 +192,7 @@ def run_vllm(
|
|||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
max_tokens=request.expected_output_len,
|
max_tokens=request.expected_output_len,
|
||||||
))
|
))
|
||||||
lora_requests: Optional[List[LoRARequest]] = None
|
lora_requests: Optional[list[LoRARequest]] = None
|
||||||
if engine_args.enable_lora:
|
if engine_args.enable_lora:
|
||||||
lora_requests = [request.lora_request for request in requests]
|
lora_requests = [request.lora_request for request in requests]
|
||||||
|
|
||||||
@ -225,7 +225,7 @@ def run_vllm(
|
|||||||
|
|
||||||
|
|
||||||
async def run_vllm_async(
|
async def run_vllm_async(
|
||||||
requests: List[SampleRequest],
|
requests: list[SampleRequest],
|
||||||
n: int,
|
n: int,
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
disable_frontend_multiprocessing: bool = False,
|
disable_frontend_multiprocessing: bool = False,
|
||||||
@ -242,9 +242,9 @@ async def run_vllm_async(
|
|||||||
" prompt_len and expected_output_len for all requests.")
|
" prompt_len and expected_output_len for all requests.")
|
||||||
|
|
||||||
# Add the requests to the engine.
|
# Add the requests to the engine.
|
||||||
prompts: List[TextPrompt] = []
|
prompts: list[TextPrompt] = []
|
||||||
sampling_params: List[SamplingParams] = []
|
sampling_params: list[SamplingParams] = []
|
||||||
lora_requests: List[Optional[LoRARequest]] = []
|
lora_requests: list[Optional[LoRARequest]] = []
|
||||||
for request in requests:
|
for request in requests:
|
||||||
prompts.append(
|
prompts.append(
|
||||||
TextPrompt(prompt=request.prompt,
|
TextPrompt(prompt=request.prompt,
|
||||||
@ -276,7 +276,7 @@ async def run_vllm_async(
|
|||||||
|
|
||||||
|
|
||||||
def run_hf(
|
def run_hf(
|
||||||
requests: List[SampleRequest],
|
requests: list[SampleRequest],
|
||||||
model: str,
|
model: str,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
n: int,
|
n: int,
|
||||||
@ -292,7 +292,7 @@ def run_hf(
|
|||||||
|
|
||||||
pbar = tqdm(total=len(requests))
|
pbar = tqdm(total=len(requests))
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
batch: List[str] = []
|
batch: list[str] = []
|
||||||
max_prompt_len = 0
|
max_prompt_len = 0
|
||||||
max_output_len = 0
|
max_output_len = 0
|
||||||
for i in range(len(requests)):
|
for i in range(len(requests)):
|
||||||
@ -334,7 +334,7 @@ def run_hf(
|
|||||||
|
|
||||||
|
|
||||||
def run_mii(
|
def run_mii(
|
||||||
requests: List[SampleRequest],
|
requests: list[SampleRequest],
|
||||||
model: str,
|
model: str,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
output_len: int,
|
output_len: int,
|
||||||
@ -352,7 +352,7 @@ def run_mii(
|
|||||||
|
|
||||||
|
|
||||||
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||||
results: Dict[str, Any]) -> None:
|
results: dict[str, Any]) -> None:
|
||||||
pt_records = convert_to_pytorch_benchmark_format(
|
pt_records = convert_to_pytorch_benchmark_format(
|
||||||
args=args,
|
args=args,
|
||||||
metrics={
|
metrics={
|
||||||
@ -479,8 +479,8 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Path to the dataset. The dataset is expected to "
|
help="Path to the dataset. The dataset is expected to "
|
||||||
"be a json in form of List[Dict[..., conversations: "
|
"be a json in form of list[dict[..., conversations: "
|
||||||
"List[Dict[..., value: <prompt_or_response>]]]]")
|
"list[dict[..., value: <prompt_or_response>]]]]")
|
||||||
parser.add_argument("--input-len",
|
parser.add_argument("--input-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -4,12 +4,12 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||||
metrics: Dict[str, List],
|
metrics: dict[str, list],
|
||||||
extra_info: Dict[str, Any]) -> List:
|
extra_info: dict[str, Any]) -> list:
|
||||||
"""
|
"""
|
||||||
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
||||||
on metric per record
|
on metric per record
|
||||||
@ -64,6 +64,6 @@ class InfEncoder(json.JSONEncoder):
|
|||||||
return super().iterencode(self.clear_inf(o), *args, **kwargs)
|
return super().iterencode(self.clear_inf(o), *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def write_to_json(filename: str, records: List) -> None:
|
def write_to_json(filename: str, records: list) -> None:
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
json.dump(records, f, cls=InfEncoder)
|
json.dump(records, f, cls=InfEncoder)
|
||||||
|
@ -5,7 +5,8 @@ import copy
|
|||||||
import itertools
|
import itertools
|
||||||
import pickle as pkl
|
import pickle as pkl
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Iterable, List, Tuple
|
from collections.abc import Iterable
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.benchmark as TBenchmark
|
import torch.utils.benchmark as TBenchmark
|
||||||
@ -228,7 +229,7 @@ def print_timers(timers: Iterable[TMeasurement]):
|
|||||||
|
|
||||||
|
|
||||||
def run(dtype: torch.dtype,
|
def run(dtype: torch.dtype,
|
||||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||||
results = []
|
results = []
|
||||||
for m, k, n in MKNs:
|
for m, k, n in MKNs:
|
||||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||||
@ -241,7 +242,7 @@ def run(dtype: torch.dtype,
|
|||||||
|
|
||||||
# output makers
|
# output makers
|
||||||
def make_output(data: Iterable[TMeasurement],
|
def make_output(data: Iterable[TMeasurement],
|
||||||
MKNs: Iterable[Tuple[int, int, int]],
|
MKNs: Iterable[tuple[int, int, int]],
|
||||||
base_description: str,
|
base_description: str,
|
||||||
timestamp=None):
|
timestamp=None):
|
||||||
print(f"== All Results {base_description} ====")
|
print(f"== All Results {base_description} ====")
|
||||||
@ -282,7 +283,7 @@ def run_model_bench(args):
|
|||||||
for i, model in enumerate(args.models):
|
for i, model in enumerate(args.models):
|
||||||
print(f"[{i}] {model}")
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
|
||||||
KNs = []
|
KNs = []
|
||||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
# Cutlass bench utils
|
# Cutlass bench utils
|
||||||
from typing import Iterable, Tuple
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
a = torch.randn((m, k), device='cuda') * 5
|
a = torch.randn((m, k), device='cuda') * 5
|
||||||
b = torch.randn((n, k), device='cuda').t() * 5
|
b = torch.randn((n, k), device='cuda').t() * 5
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ def prune_to_2_4(tensor):
|
|||||||
|
|
||||||
|
|
||||||
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
|
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
|
||||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
a = torch.randn((m, k), device='cuda') * 5
|
a = torch.randn((m, k), device='cuda') * 5
|
||||||
b = torch.randn((n, k), device='cuda').t() * 5
|
b = torch.randn((n, k), device='cuda').t() * 5
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
|
|||||||
|
|
||||||
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
|
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
|
||||||
m: int, n: int, k: int) -> \
|
m: int, n: int, k: int) -> \
|
||||||
Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
||||||
ABs = []
|
ABs = []
|
||||||
for _ in range(num_tensors):
|
for _ in range(num_tensors):
|
||||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||||
|
@ -5,7 +5,8 @@ import copy
|
|||||||
import itertools
|
import itertools
|
||||||
import pickle as pkl
|
import pickle as pkl
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Iterable, List, Optional, Tuple
|
from collections.abc import Iterable
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.benchmark as TBenchmark
|
import torch.utils.benchmark as TBenchmark
|
||||||
@ -49,7 +50,7 @@ def bench_int8(
|
|||||||
n: int,
|
n: int,
|
||||||
label: str,
|
label: str,
|
||||||
sub_label: str,
|
sub_label: str,
|
||||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||||
"""Benchmark INT8-based kernels."""
|
"""Benchmark INT8-based kernels."""
|
||||||
assert dtype == torch.int8
|
assert dtype == torch.int8
|
||||||
a, b = make_rand_tensors(torch.int8, m, n, k)
|
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||||
@ -101,7 +102,7 @@ def bench_fp8(
|
|||||||
n: int,
|
n: int,
|
||||||
label: str,
|
label: str,
|
||||||
sub_label: str,
|
sub_label: str,
|
||||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||||
"""Benchmark FP8-based kernels."""
|
"""Benchmark FP8-based kernels."""
|
||||||
assert dtype == torch.float8_e4m3fn
|
assert dtype == torch.float8_e4m3fn
|
||||||
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||||
@ -180,7 +181,7 @@ def bench(dtype: torch.dtype,
|
|||||||
n: int,
|
n: int,
|
||||||
label: str,
|
label: str,
|
||||||
sub_label: str,
|
sub_label: str,
|
||||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||||
if dtype == torch.int8:
|
if dtype == torch.int8:
|
||||||
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
|
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||||
if dtype == torch.float8_e4m3fn:
|
if dtype == torch.float8_e4m3fn:
|
||||||
@ -195,8 +196,8 @@ def print_timers(timers: Iterable[TMeasurement]):
|
|||||||
|
|
||||||
|
|
||||||
def run(dtype: torch.dtype,
|
def run(dtype: torch.dtype,
|
||||||
MKNs: Iterable[Tuple[int, int, int]],
|
MKNs: Iterable[tuple[int, int, int]],
|
||||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||||
results = []
|
results = []
|
||||||
for m, k, n in MKNs:
|
for m, k, n in MKNs:
|
||||||
timers = bench(dtype,
|
timers = bench(dtype,
|
||||||
@ -212,7 +213,7 @@ def run(dtype: torch.dtype,
|
|||||||
|
|
||||||
|
|
||||||
def make_output(data: Iterable[TMeasurement],
|
def make_output(data: Iterable[TMeasurement],
|
||||||
MKNs: Iterable[Tuple[int, int, int]],
|
MKNs: Iterable[tuple[int, int, int]],
|
||||||
base_description: str,
|
base_description: str,
|
||||||
timestamp=None):
|
timestamp=None):
|
||||||
print(f"== All Results {base_description} ====")
|
print(f"== All Results {base_description} ====")
|
||||||
@ -248,7 +249,7 @@ def run_model_bench(args):
|
|||||||
for i, model in enumerate(args.models):
|
for i, model in enumerate(args.models):
|
||||||
print(f"[{i}] {model}")
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
|
||||||
KNs = []
|
KNs = []
|
||||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
import pickle as pkl
|
import pickle as pkl
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Callable, Iterable, List, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.benchmark as TBenchmark
|
import torch.utils.benchmark as TBenchmark
|
||||||
@ -29,7 +30,7 @@ class bench_params_t:
|
|||||||
f'x DT {self.dtype}')
|
f'x DT {self.dtype}')
|
||||||
|
|
||||||
|
|
||||||
def get_bench_params() -> List[bench_params_t]:
|
def get_bench_params() -> list[bench_params_t]:
|
||||||
## Test Fixtures
|
## Test Fixtures
|
||||||
NUM_TOKENS = [2**x for x in range(11)]
|
NUM_TOKENS = [2**x for x in range(11)]
|
||||||
HIDDEN_SIZES = list(range(1024, 8129, 1024))
|
HIDDEN_SIZES = list(range(1024, 8129, 1024))
|
||||||
|
@ -9,7 +9,7 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.benchmark as TBenchmark
|
import torch.utils.benchmark as TBenchmark
|
||||||
@ -61,15 +61,15 @@ def make_rand_lora_weight_tensor(k: int,
|
|||||||
|
|
||||||
|
|
||||||
def make_rand_tensors(
|
def make_rand_tensors(
|
||||||
a_shape: Tuple[int],
|
a_shape: tuple[int],
|
||||||
b_shape: Tuple[int],
|
b_shape: tuple[int],
|
||||||
c_shape: Tuple[int],
|
c_shape: tuple[int],
|
||||||
a_dtype: torch.dtype,
|
a_dtype: torch.dtype,
|
||||||
b_dtype: torch.dtype,
|
b_dtype: torch.dtype,
|
||||||
c_dtype: torch.dtype,
|
c_dtype: torch.dtype,
|
||||||
num_slices: int,
|
num_slices: int,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Make LoRA input/output matrices.
|
Make LoRA input/output matrices.
|
||||||
"""
|
"""
|
||||||
@ -135,7 +135,7 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int,
|
|||||||
|
|
||||||
|
|
||||||
def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
|
def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
|
||||||
lora_weights: List[torch.Tensor],
|
lora_weights: list[torch.Tensor],
|
||||||
seq_lens_cpu: torch.Tensor,
|
seq_lens_cpu: torch.Tensor,
|
||||||
prompt_lora_mapping_cpu: torch.Tensor, scaling: float,
|
prompt_lora_mapping_cpu: torch.Tensor, scaling: float,
|
||||||
add_inputs: Optional[bool]):
|
add_inputs: Optional[bool]):
|
||||||
@ -204,7 +204,7 @@ class OpType(Enum):
|
|||||||
def is_expand_slice_fn(self) -> bool:
|
def is_expand_slice_fn(self) -> bool:
|
||||||
return self in [OpType.BGMV_EXPAND_SLICE]
|
return self in [OpType.BGMV_EXPAND_SLICE]
|
||||||
|
|
||||||
def num_slices(self) -> List[int]:
|
def num_slices(self) -> list[int]:
|
||||||
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
|
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
|
||||||
# SGMV kernels supports slices
|
# SGMV kernels supports slices
|
||||||
return [1, 2, 3]
|
return [1, 2, 3]
|
||||||
@ -215,7 +215,7 @@ class OpType(Enum):
|
|||||||
raise ValueError(f"Unrecognized OpType {self}")
|
raise ValueError(f"Unrecognized OpType {self}")
|
||||||
|
|
||||||
def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
|
def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
|
||||||
lora_rank: int) -> Tuple[int, int, int]:
|
lora_rank: int) -> tuple[int, int, int]:
|
||||||
num_tokens = batch_size * seq_length
|
num_tokens = batch_size * seq_length
|
||||||
if self.is_shrink_fn():
|
if self.is_shrink_fn():
|
||||||
m = num_tokens
|
m = num_tokens
|
||||||
@ -230,7 +230,7 @@ class OpType(Enum):
|
|||||||
|
|
||||||
def matmul_dtypes(
|
def matmul_dtypes(
|
||||||
self, op_dtype: torch.dtype
|
self, op_dtype: torch.dtype
|
||||||
) -> Tuple[torch.dtype, torch.dtype, torch.dtype]:
|
) -> tuple[torch.dtype, torch.dtype, torch.dtype]:
|
||||||
"""
|
"""
|
||||||
return a type, b type and c type for A x B = C
|
return a type, b type and c type for A x B = C
|
||||||
"""
|
"""
|
||||||
@ -243,7 +243,7 @@ class OpType(Enum):
|
|||||||
def matmul_shapes(
|
def matmul_shapes(
|
||||||
self, batch_size: int, seq_length: int, hidden_size: int,
|
self, batch_size: int, seq_length: int, hidden_size: int,
|
||||||
lora_rank: int, num_loras: int,
|
lora_rank: int, num_loras: int,
|
||||||
num_slices: int) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]:
|
num_slices: int) -> tuple[tuple[int], tuple[int], tuple[int]]:
|
||||||
"""
|
"""
|
||||||
Given num_slices, return the shapes of the A, B, and C matrices
|
Given num_slices, return the shapes of the A, B, and C matrices
|
||||||
in A x B = C, for the op_type
|
in A x B = C, for the op_type
|
||||||
@ -268,7 +268,7 @@ class OpType(Enum):
|
|||||||
|
|
||||||
def bench_fn(self) -> Callable:
|
def bench_fn(self) -> Callable:
|
||||||
|
|
||||||
def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]):
|
def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]):
|
||||||
for x in kwargs_list:
|
for x in kwargs_list:
|
||||||
bgmv_expand_slice(**x)
|
bgmv_expand_slice(**x)
|
||||||
|
|
||||||
@ -285,7 +285,7 @@ class OpType(Enum):
|
|||||||
raise ValueError(f"Unrecognized optype {self}")
|
raise ValueError(f"Unrecognized optype {self}")
|
||||||
|
|
||||||
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
|
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
|
||||||
lora_weights: List[torch.Tensor],
|
lora_weights: list[torch.Tensor],
|
||||||
**kwargs) -> Callable:
|
**kwargs) -> Callable:
|
||||||
"""Each benchmark operation expected the input, lora_weights and outputs
|
"""Each benchmark operation expected the input, lora_weights and outputs
|
||||||
in a slightly different format. Refer to self.matmul_shapes().
|
in a slightly different format. Refer to self.matmul_shapes().
|
||||||
@ -384,7 +384,7 @@ class BenchmarkTensors:
|
|||||||
"""
|
"""
|
||||||
# matmul tensors
|
# matmul tensors
|
||||||
input: torch.Tensor
|
input: torch.Tensor
|
||||||
lora_weights_lst: List[torch.Tensor]
|
lora_weights_lst: list[torch.Tensor]
|
||||||
output: torch.Tensor
|
output: torch.Tensor
|
||||||
# metadata tensors
|
# metadata tensors
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
@ -469,7 +469,7 @@ class BenchmarkTensors:
|
|||||||
for i in range(len(self.lora_weights_lst)):
|
for i in range(len(self.lora_weights_lst)):
|
||||||
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
|
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
|
||||||
|
|
||||||
def metadata(self) -> Tuple[int, int, int]:
|
def metadata(self) -> tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
Return num_seqs, num_tokens and max_seq_len
|
Return num_seqs, num_tokens and max_seq_len
|
||||||
"""
|
"""
|
||||||
@ -505,7 +505,7 @@ class BenchmarkTensors:
|
|||||||
self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
|
self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
|
||||||
self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)
|
self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)
|
||||||
|
|
||||||
def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]:
|
def as_sgmv_shrink_kwargs(self) -> dict[str, Any]:
|
||||||
self.convert_to_sgmv_benchmark_tensors()
|
self.convert_to_sgmv_benchmark_tensors()
|
||||||
self.sanity_check()
|
self.sanity_check()
|
||||||
self.to_device(self.input.device)
|
self.to_device(self.input.device)
|
||||||
@ -540,7 +540,7 @@ class BenchmarkTensors:
|
|||||||
'scaling': 1.0,
|
'scaling': 1.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
|
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||||
|
|
||||||
self.convert_to_sgmv_benchmark_tensors()
|
self.convert_to_sgmv_benchmark_tensors()
|
||||||
self.sanity_check()
|
self.sanity_check()
|
||||||
@ -578,7 +578,7 @@ class BenchmarkTensors:
|
|||||||
'add_inputs': add_inputs,
|
'add_inputs': add_inputs,
|
||||||
}
|
}
|
||||||
|
|
||||||
def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]:
|
def as_bgmv_shrink_kwargs(self) -> dict[str, Any]:
|
||||||
assert len(self.lora_weights_lst) == 1
|
assert len(self.lora_weights_lst) == 1
|
||||||
self.to_device(self.input.device)
|
self.to_device(self.input.device)
|
||||||
|
|
||||||
@ -634,7 +634,7 @@ class BenchmarkTensors:
|
|||||||
'add_inputs': add_inputs
|
'add_inputs': add_inputs
|
||||||
}
|
}
|
||||||
|
|
||||||
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
|
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||||
|
|
||||||
_, num_tokens, _, num_slices = self.metadata()
|
_, num_tokens, _, num_slices = self.metadata()
|
||||||
# Sanity check shapes
|
# Sanity check shapes
|
||||||
@ -670,7 +670,7 @@ class BenchmarkTensors:
|
|||||||
|
|
||||||
def bench_fn_kwargs(self,
|
def bench_fn_kwargs(self,
|
||||||
op_type: OpType,
|
op_type: OpType,
|
||||||
add_inputs: Optional[bool] = None) -> Dict[str, Any]:
|
add_inputs: Optional[bool] = None) -> dict[str, Any]:
|
||||||
if op_type.is_shrink_fn():
|
if op_type.is_shrink_fn():
|
||||||
assert add_inputs is None
|
assert add_inputs is None
|
||||||
else:
|
else:
|
||||||
@ -734,7 +734,7 @@ def bench_optype(ctx: BenchmarkContext,
|
|||||||
assert expand_fn_add_inputs is not None
|
assert expand_fn_add_inputs is not None
|
||||||
|
|
||||||
# BenchmarkContext -> BenchmarkTensors
|
# BenchmarkContext -> BenchmarkTensors
|
||||||
bench_tensors : List[BenchmarkTensors] = \
|
bench_tensors : list[BenchmarkTensors] = \
|
||||||
[BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)]
|
[BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)]
|
||||||
for bt in bench_tensors:
|
for bt in bench_tensors:
|
||||||
bt.sanity_check()
|
bt.sanity_check()
|
||||||
@ -746,7 +746,7 @@ def bench_optype(ctx: BenchmarkContext,
|
|||||||
for bt in bench_tensors
|
for bt in bench_tensors
|
||||||
])
|
])
|
||||||
|
|
||||||
# BenchmarkTensors -> Dict (kwargs)
|
# BenchmarkTensors -> dict (kwargs)
|
||||||
kwargs_list = [
|
kwargs_list = [
|
||||||
bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
|
bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
|
||||||
for bt in bench_tensors
|
for bt in bench_tensors
|
||||||
@ -841,7 +841,7 @@ def use_cuda_graph_recommendation() -> str:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def print_timers(timers: List[TMeasurement],
|
def print_timers(timers: list[TMeasurement],
|
||||||
args: Optional[argparse.Namespace] = None):
|
args: Optional[argparse.Namespace] = None):
|
||||||
compare = TBenchmark.Compare(timers)
|
compare = TBenchmark.Compare(timers)
|
||||||
compare.print()
|
compare.print()
|
||||||
@ -861,7 +861,7 @@ def print_timers(timers: List[TMeasurement],
|
|||||||
"small num_loras the goal should be to match the torch.mm numbers.")
|
"small num_loras the goal should be to match the torch.mm numbers.")
|
||||||
|
|
||||||
|
|
||||||
def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
|
def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
|
||||||
|
|
||||||
if args.cuda_graph_nops is not None:
|
if args.cuda_graph_nops is not None:
|
||||||
assert args.cuda_graph_nops > 0
|
assert args.cuda_graph_nops > 0
|
||||||
@ -873,7 +873,7 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
|
|||||||
timers = []
|
timers = []
|
||||||
for bench_ctx in bench_ctxs:
|
for bench_ctx in bench_ctxs:
|
||||||
for seq_len in args.seq_lengths:
|
for seq_len in args.seq_lengths:
|
||||||
bench_ops: List[OpType] = []
|
bench_ops: list[OpType] = []
|
||||||
if seq_len == 1:
|
if seq_len == 1:
|
||||||
# bench all decode ops
|
# bench all decode ops
|
||||||
bench_ops = [op for op in args.op_types if op.is_decode_op()]
|
bench_ops = [op for op in args.op_types if op.is_decode_op()]
|
||||||
@ -921,10 +921,10 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
|
|||||||
pickle.dump(timers, f)
|
pickle.dump(timers, f)
|
||||||
|
|
||||||
|
|
||||||
def as_benchmark_contexts(hidden_sizes: List[int], lora_ranks: List[int],
|
def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int],
|
||||||
args: argparse.Namespace) -> List[BenchmarkContext]:
|
args: argparse.Namespace) -> list[BenchmarkContext]:
|
||||||
|
|
||||||
ctxs: List[BenchmarkContext] = []
|
ctxs: list[BenchmarkContext] = []
|
||||||
for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa
|
for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa
|
||||||
args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras,
|
args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras,
|
||||||
args.sort_by_lora_id):
|
args.sort_by_lora_id):
|
||||||
@ -954,7 +954,7 @@ def run_list_bench(args: argparse.Namespace):
|
|||||||
f" LoRA Ranks {args.lora_ranks}")
|
f" LoRA Ranks {args.lora_ranks}")
|
||||||
|
|
||||||
# Get all benchmarking contexts
|
# Get all benchmarking contexts
|
||||||
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts(
|
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
|
||||||
hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args)
|
hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args)
|
||||||
|
|
||||||
run(args, bench_contexts)
|
run(args, bench_contexts)
|
||||||
@ -975,7 +975,7 @@ def run_range_bench(args: argparse.Namespace):
|
|||||||
f" LoRA Ranks {lora_ranks}")
|
f" LoRA Ranks {lora_ranks}")
|
||||||
|
|
||||||
# Get all benchmarking contexts
|
# Get all benchmarking contexts
|
||||||
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts(
|
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
|
||||||
hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args)
|
hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args)
|
||||||
|
|
||||||
run(args, bench_contexts)
|
run(args, bench_contexts)
|
||||||
@ -1002,7 +1002,7 @@ def run_model_bench(args: argparse.Namespace):
|
|||||||
f" LoRA Ranks {args.lora_ranks}")
|
f" LoRA Ranks {args.lora_ranks}")
|
||||||
|
|
||||||
# Get all benchmarking contexts
|
# Get all benchmarking contexts
|
||||||
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts(
|
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
|
||||||
hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args)
|
hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args)
|
||||||
|
|
||||||
run(args, bench_contexts)
|
run(args, bench_contexts)
|
||||||
|
@ -7,9 +7,10 @@ import math
|
|||||||
import os
|
import os
|
||||||
import pickle as pkl
|
import pickle as pkl
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Callable, Iterable, List, Optional, Tuple
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -102,8 +103,8 @@ def quantize_and_pack(atype: torch.dtype,
|
|||||||
return w_ref, w_q, w_s, w_zp
|
return w_ref, w_q, w_s, w_zp
|
||||||
|
|
||||||
|
|
||||||
def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig,
|
def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
|
||||||
group_size: Optional[int]) -> List[BenchmarkTensors]:
|
group_size: Optional[int]) -> list[BenchmarkTensors]:
|
||||||
m, n, k = shape
|
m, n, k = shape
|
||||||
|
|
||||||
# we want to make sure that weights don't fit into L2 cache between runs so
|
# we want to make sure that weights don't fit into L2 cache between runs so
|
||||||
@ -114,7 +115,7 @@ def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig,
|
|||||||
|
|
||||||
a = rand_data((m, k), types.act_type, scale=5)
|
a = rand_data((m, k), types.act_type, scale=5)
|
||||||
|
|
||||||
benchmark_tensors: List[BenchmarkTensors] = []
|
benchmark_tensors: list[BenchmarkTensors] = []
|
||||||
for _ in range(num_weights):
|
for _ in range(num_weights):
|
||||||
w = rand_data((k, n), types.act_type, scale=5)
|
w = rand_data((k, n), types.act_type, scale=5)
|
||||||
|
|
||||||
@ -276,7 +277,7 @@ def machete_create_bench_fn(bt: BenchmarkTensors,
|
|||||||
|
|
||||||
|
|
||||||
def bench_fns(label: str, sub_label: str, description: str,
|
def bench_fns(label: str, sub_label: str, description: str,
|
||||||
fns: List[Callable]):
|
fns: list[Callable]):
|
||||||
|
|
||||||
min_run_time = 1 if not NVTX_PROFILE else 0.1
|
min_run_time = 1 if not NVTX_PROFILE else 0.1
|
||||||
res = TBenchmark.Timer(
|
res = TBenchmark.Timer(
|
||||||
@ -311,7 +312,7 @@ def bench(types: TypeConfig,
|
|||||||
n: int,
|
n: int,
|
||||||
label: str,
|
label: str,
|
||||||
sub_label: str,
|
sub_label: str,
|
||||||
sweep_schedules: bool = True) -> List[TMeasurement]:
|
sweep_schedules: bool = True) -> list[TMeasurement]:
|
||||||
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
|
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
|
||||||
sub_label += f", L={len(benchmark_tensors)}"
|
sub_label += f", L={len(benchmark_tensors)}"
|
||||||
|
|
||||||
@ -414,12 +415,12 @@ def bench(types: TypeConfig,
|
|||||||
|
|
||||||
|
|
||||||
# runner
|
# runner
|
||||||
def print_timers(timers: List[TMeasurement]):
|
def print_timers(timers: list[TMeasurement]):
|
||||||
compare = TBenchmark.Compare(timers)
|
compare = TBenchmark.Compare(timers)
|
||||||
compare.print()
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||||
types = TypeConfig(
|
types = TypeConfig(
|
||||||
act_type=args.act_type,
|
act_type=args.act_type,
|
||||||
weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
|
weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
|
||||||
@ -431,7 +432,7 @@ def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
|||||||
token_scale_type=args.token_scale_type,
|
token_scale_type=args.token_scale_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
results: List[TMeasurement] = []
|
results: list[TMeasurement] = []
|
||||||
for m, k, n in MKNs:
|
for m, k, n in MKNs:
|
||||||
timers = bench(types,
|
timers = bench(types,
|
||||||
args.group_size,
|
args.group_size,
|
||||||
@ -449,8 +450,8 @@ def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
|||||||
|
|
||||||
# output makers
|
# output makers
|
||||||
def make_output(
|
def make_output(
|
||||||
data: List[TMeasurement],
|
data: list[TMeasurement],
|
||||||
MKNs: Iterable[Tuple[int, int, int]],
|
MKNs: Iterable[tuple[int, int, int]],
|
||||||
base_description: str,
|
base_description: str,
|
||||||
timestamp=None,
|
timestamp=None,
|
||||||
):
|
):
|
||||||
@ -497,7 +498,7 @@ def run_model_bench(args):
|
|||||||
for i, model in enumerate(args.models):
|
for i, model in enumerate(args.models):
|
||||||
print(f"[{i}] {model}")
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
|
||||||
KNs = []
|
KNs = []
|
||||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.benchmark as benchmark
|
import torch.utils.benchmark as benchmark
|
||||||
from benchmark_shapes import WEIGHT_SHAPES
|
from benchmark_shapes import WEIGHT_SHAPES
|
||||||
@ -31,7 +29,7 @@ ACT_ORDER_OPTS = [False, True]
|
|||||||
K_FULL_OPTS = [False, True]
|
K_FULL_OPTS = [False, True]
|
||||||
|
|
||||||
|
|
||||||
def bench_run(results: List[benchmark.Measurement], model: str,
|
def bench_run(results: list[benchmark.Measurement], model: str,
|
||||||
act_order: bool, is_k_full: bool, quant_type: ScalarType,
|
act_order: bool, is_k_full: bool, quant_type: ScalarType,
|
||||||
group_size: int, size_m: int, size_k: int, size_n: int):
|
group_size: int, size_m: int, size_k: int, size_n: int):
|
||||||
label = "Quant Matmul"
|
label = "Quant Matmul"
|
||||||
@ -221,7 +219,7 @@ def main(args):
|
|||||||
for i, model in enumerate(args.models):
|
for i, model in enumerate(args.models):
|
||||||
print(f"[{i}] {model}")
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
results: List[benchmark.Measurement] = []
|
results: list[benchmark.Measurement] = []
|
||||||
|
|
||||||
for model in args.models:
|
for model in args.models:
|
||||||
for layer in WEIGHT_SHAPES[model]:
|
for layer in WEIGHT_SHAPES[model]:
|
||||||
|
@ -4,7 +4,7 @@ import argparse
|
|||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Any, Dict, List, Tuple, TypedDict
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
@ -132,7 +132,7 @@ def benchmark_config(
|
|||||||
start_event = torch.cuda.Event(enable_timing=True)
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
latencies: List[float] = []
|
latencies: list[float] = []
|
||||||
for i in range(num_iters):
|
for i in range(num_iters):
|
||||||
prepare(i)
|
prepare(i)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -175,8 +175,8 @@ def get_rocm_tuning_space(use_fp16):
|
|||||||
return param_ranges
|
return param_ranges
|
||||||
|
|
||||||
|
|
||||||
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
|
def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
|
||||||
configs: List[BenchmarkConfig] = []
|
configs: list[BenchmarkConfig] = []
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
param_ranges = get_rocm_tuning_space(use_fp16)
|
param_ranges = get_rocm_tuning_space(use_fp16)
|
||||||
@ -335,7 +335,7 @@ class BenchmarkWorker:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
) -> Tuple[Dict[str, int], float]:
|
) -> tuple[dict[str, int], float]:
|
||||||
current_platform.seed_everything(self.seed)
|
current_platform.seed_everything(self.seed)
|
||||||
dtype_str = get_config_dtype_str(dtype,
|
dtype_str = get_config_dtype_str(dtype,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
@ -371,8 +371,8 @@ class BenchmarkWorker:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
search_space: List[Dict[str, int]],
|
search_space: list[dict[str, int]],
|
||||||
) -> Dict[str, int]:
|
) -> dict[str, int]:
|
||||||
best_config = None
|
best_config = None
|
||||||
best_time = float("inf")
|
best_time = float("inf")
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
@ -434,7 +434,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
|
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
||||||
shard_intermediate_size: int, hidden_size: int, topk: int,
|
shard_intermediate_size: int, hidden_size: int, topk: int,
|
||||||
dtype: torch.dtype, use_fp8_w8a8: bool,
|
dtype: torch.dtype, use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool) -> None:
|
use_int8_w8a16: bool) -> None:
|
||||||
@ -498,7 +498,7 @@ def main(args: argparse.Namespace):
|
|||||||
num_gpus = int(ray.available_resources()["GPU"])
|
num_gpus = int(ray.available_resources()["GPU"])
|
||||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||||
|
|
||||||
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
|
||||||
outputs = []
|
outputs = []
|
||||||
worker_idx = 0
|
worker_idx = 0
|
||||||
for input_args in inputs:
|
for input_args in inputs:
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ def main(
|
|||||||
|
|
||||||
# Create the block tables.
|
# Create the block tables.
|
||||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||||
block_tables_lst: List[List[int]] = []
|
block_tables_lst: list[list[int]] = []
|
||||||
for _ in range(num_seqs):
|
for _ in range(num_seqs):
|
||||||
block_table = [
|
block_table = [
|
||||||
random.randint(0, NUM_BLOCKS - 1)
|
random.randint(0, NUM_BLOCKS - 1)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@ -22,7 +22,7 @@ class HuggingFaceRMSNorm(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor] = None,
|
residual: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import nvtx
|
import nvtx
|
||||||
import torch
|
import torch
|
||||||
@ -39,7 +39,7 @@ def benchmark_rope_kernels_multi_lora(
|
|||||||
})
|
})
|
||||||
# non-batched RoPE takes only one scaling factor, we create multiple
|
# non-batched RoPE takes only one scaling factor, we create multiple
|
||||||
# instances to simulate the same behavior
|
# instances to simulate the same behavior
|
||||||
non_batched_ropes: List[RotaryEmbedding] = []
|
non_batched_ropes: list[RotaryEmbedding] = []
|
||||||
for scaling_factor in scaling_factors:
|
for scaling_factor in scaling_factors:
|
||||||
non_batched_ropes.append(
|
non_batched_ropes.append(
|
||||||
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
|
@ -4,7 +4,6 @@ import math
|
|||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -23,7 +22,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
with open(args.filename, 'rb') as f:
|
with open(args.filename, 'rb') as f:
|
||||||
data = pickle.load(f)
|
data = pickle.load(f)
|
||||||
raw_results: List[TMeasurement] = data["results"]
|
raw_results: list[TMeasurement] = data["results"]
|
||||||
|
|
||||||
results = defaultdict(lambda: list())
|
results = defaultdict(lambda: list())
|
||||||
for v in raw_results:
|
for v in raw_results:
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Callable, Iterable, Optional
|
from collections.abc import Iterable
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.benchmark as TBenchmark
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
from typing import Dict, Union
|
from typing import Union
|
||||||
|
|
||||||
from cutlass_library import *
|
from cutlass_library import *
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ class MixedInputKernelScheduleType(enum.Enum):
|
|||||||
TmaWarpSpecializedCooperative = enum_auto()
|
TmaWarpSpecializedCooperative = enum_auto()
|
||||||
|
|
||||||
|
|
||||||
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
|
VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = {
|
||||||
**DataTypeNames, # type: ignore
|
**DataTypeNames, # type: ignore
|
||||||
**{
|
**{
|
||||||
VLLMDataType.u4b8: "u4b8",
|
VLLMDataType.u4b8: "u4b8",
|
||||||
@ -29,7 +29,7 @@ VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||||
**DataTypeTag, # type: ignore
|
**DataTypeTag, # type: ignore
|
||||||
**{
|
**{
|
||||||
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
|
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
|
||||||
@ -37,7 +37,7 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
|
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
||||||
**DataTypeSize, # type: ignore
|
**DataTypeSize, # type: ignore
|
||||||
**{
|
**{
|
||||||
VLLMDataType.u4b8: 4,
|
VLLMDataType.u4b8: 4,
|
||||||
@ -45,7 +45,7 @@ VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||||
VLLMDataType.u4b8: "vllm::kU4B8",
|
VLLMDataType.u4b8: "vllm::kU4B8",
|
||||||
VLLMDataType.u8b128: "vllm::kU8B128",
|
VLLMDataType.u8b128: "vllm::kU8B128",
|
||||||
DataType.u4: "vllm::kU4",
|
DataType.u4: "vllm::kU4",
|
||||||
@ -56,7 +56,7 @@ VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
DataType.bf16: "vllm::kBfloat16",
|
DataType.bf16: "vllm::kBfloat16",
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||||
DataType.u8: "at::ScalarType::Byte",
|
DataType.u8: "at::ScalarType::Byte",
|
||||||
DataType.s8: "at::ScalarType::Char",
|
DataType.s8: "at::ScalarType::Char",
|
||||||
DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
|
DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
|
||||||
@ -66,7 +66,7 @@ VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
DataType.f32: "at::ScalarType::Float",
|
DataType.f32: "at::ScalarType::Float",
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMKernelScheduleTag: Dict[Union[
|
VLLMKernelScheduleTag: dict[Union[
|
||||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||||
**KernelScheduleTag, # type: ignore
|
**KernelScheduleTag, # type: ignore
|
||||||
**{
|
**{
|
||||||
|
@ -8,7 +8,7 @@ from collections.abc import Iterable
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
@ -247,8 +247,8 @@ TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ScheduleConfig:
|
class ScheduleConfig:
|
||||||
tile_shape_mn: Tuple[int, int]
|
tile_shape_mn: tuple[int, int]
|
||||||
cluster_shape_mnk: Tuple[int, int, int]
|
cluster_shape_mnk: tuple[int, int, int]
|
||||||
kernel_schedule: MixedInputKernelScheduleType
|
kernel_schedule: MixedInputKernelScheduleType
|
||||||
epilogue_schedule: EpilogueScheduleType
|
epilogue_schedule: EpilogueScheduleType
|
||||||
tile_scheduler: TileSchedulerType
|
tile_scheduler: TileSchedulerType
|
||||||
@ -277,8 +277,8 @@ class PrepackTypeConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ImplConfig:
|
class ImplConfig:
|
||||||
types: TypeConfig
|
types: TypeConfig
|
||||||
schedules: List[ScheduleConfig]
|
schedules: list[ScheduleConfig]
|
||||||
heuristic: List[Tuple[Optional[str], ScheduleConfig]]
|
heuristic: list[tuple[Optional[str], ScheduleConfig]]
|
||||||
|
|
||||||
|
|
||||||
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||||
@ -333,7 +333,7 @@ def is_power_of_two(n):
|
|||||||
return (n != 0) and (n & (n - 1) == 0)
|
return (n != 0) and (n & (n - 1) == 0)
|
||||||
|
|
||||||
|
|
||||||
def to_cute_constant(value: List[int]):
|
def to_cute_constant(value: list[int]):
|
||||||
|
|
||||||
def _to_cute_constant(value: int):
|
def _to_cute_constant(value: int):
|
||||||
if is_power_of_two(value):
|
if is_power_of_two(value):
|
||||||
@ -347,7 +347,7 @@ def to_cute_constant(value: List[int]):
|
|||||||
return _to_cute_constant(value)
|
return _to_cute_constant(value)
|
||||||
|
|
||||||
|
|
||||||
def unique_schedules(impl_configs: List[ImplConfig]):
|
def unique_schedules(impl_configs: list[ImplConfig]):
|
||||||
return list(
|
return list(
|
||||||
set(sch for impl_config in impl_configs
|
set(sch for impl_config in impl_configs
|
||||||
for sch in impl_config.schedules))
|
for sch in impl_config.schedules))
|
||||||
@ -391,7 +391,7 @@ mm_impl_template = create_template(IMPL_TEMPLATE)
|
|||||||
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
||||||
|
|
||||||
|
|
||||||
def create_sources(impl_configs: List[ImplConfig], num_impl_files=8):
|
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
||||||
sources = []
|
sources = []
|
||||||
|
|
||||||
sources.append((
|
sources.append((
|
||||||
@ -435,7 +435,7 @@ def create_sources(impl_configs: List[ImplConfig], num_impl_files=8):
|
|||||||
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
||||||
num_impls_per_file = math.ceil(num_impls / num_impl_files)
|
num_impls_per_file = math.ceil(num_impls / num_impl_files)
|
||||||
|
|
||||||
files_impls: List[List[ImplConfig]] = [[]]
|
files_impls: list[list[ImplConfig]] = [[]]
|
||||||
|
|
||||||
curr_num_impls_assigned = 0
|
curr_num_impls_assigned = 0
|
||||||
curr_impl_in_file = 0
|
curr_impl_in_file = 0
|
||||||
@ -515,7 +515,7 @@ def generate():
|
|||||||
for cond, tile_config in default_tile_heuristic_config.items()
|
for cond, tile_config in default_tile_heuristic_config.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]):
|
def get_unique_schedules(heuristic: dict[str, ScheduleConfig]):
|
||||||
# Do not use schedules = list(set(...)) because we need to make sure
|
# Do not use schedules = list(set(...)) because we need to make sure
|
||||||
# the output list is deterministic; otherwise the generated kernel file
|
# the output list is deterministic; otherwise the generated kernel file
|
||||||
# will be non-deterministic and causes ccache miss.
|
# will be non-deterministic and causes ccache miss.
|
||||||
|
@ -17,7 +17,6 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from sphinx.ext import autodoc
|
from sphinx.ext import autodoc
|
||||||
@ -58,7 +57,7 @@ templates_path = ['_templates']
|
|||||||
# List of patterns, relative to source directory, that match files and
|
# List of patterns, relative to source directory, that match files and
|
||||||
# directories to ignore when looking for source files.
|
# directories to ignore when looking for source files.
|
||||||
# This pattern also affects html_static_path and html_extra_path.
|
# This pattern also affects html_static_path and html_extra_path.
|
||||||
exclude_patterns: List[str] = ["**/*.template.md", "**/*.inc.md"]
|
exclude_patterns: list[str] = ["**/*.template.md", "**/*.inc.md"]
|
||||||
|
|
||||||
# Exclude the prompt "$" when copying code
|
# Exclude the prompt "$" when copying code
|
||||||
copybutton_prompt_text = r"\$ "
|
copybutton_prompt_text = r"\$ "
|
||||||
|
@ -123,7 +123,7 @@ class ExampleParser(ReasoningParser):
|
|||||||
|
|
||||||
def extract_reasoning_content(
|
def extract_reasoning_content(
|
||||||
self, model_output: str, request: ChatCompletionRequest
|
self, model_output: str, request: ChatCompletionRequest
|
||||||
) -> Tuple[Optional[str], Optional[str]]:
|
) -> tuple[Optional[str], Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Extract reasoning content from a complete model-generated string.
|
Extract reasoning content from a complete model-generated string.
|
||||||
|
|
||||||
@ -138,7 +138,7 @@ class ExampleParser(ReasoningParser):
|
|||||||
The request object that was used to generate the model_output.
|
The request object that was used to generate the model_output.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Optional[str], Optional[str]]
|
tuple[Optional[str], Optional[str]]
|
||||||
A tuple containing the reasoning content and the content.
|
A tuple containing the reasoning content and the content.
|
||||||
"""
|
"""
|
||||||
```
|
```
|
||||||
|
@ -193,7 +193,7 @@ class Step(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class MathResponse(BaseModel):
|
class MathResponse(BaseModel):
|
||||||
steps: List[Step]
|
steps: list[Step]
|
||||||
final_answer: str
|
final_answer: str
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ class Example:
|
|||||||
path (Path): The path to the main directory or file.
|
path (Path): The path to the main directory or file.
|
||||||
category (str): The category of the document.
|
category (str): The category of the document.
|
||||||
main_file (Path): The main file in the directory.
|
main_file (Path): The main file in the directory.
|
||||||
other_files (list[Path]): List of other files in the directory.
|
other_files (list[Path]): list of other files in the directory.
|
||||||
title (str): The title of the document.
|
title (str): The title of the document.
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
|
@ -6,7 +6,7 @@ distributively on a multi-nodes cluster.
|
|||||||
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
|
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ray
|
import ray
|
||||||
@ -36,13 +36,13 @@ class LLMPredictor:
|
|||||||
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
|
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
|
||||||
tensor_parallel_size=tensor_parallel_size)
|
tensor_parallel_size=tensor_parallel_size)
|
||||||
|
|
||||||
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
|
def __call__(self, batch: dict[str, np.ndarray]) -> dict[str, list]:
|
||||||
# Generate texts from the prompts.
|
# Generate texts from the prompts.
|
||||||
# The output is a list of RequestOutput objects that contain the prompt,
|
# The output is a list of RequestOutput objects that contain the prompt,
|
||||||
# generated text, and other information.
|
# generated text, and other information.
|
||||||
outputs = self.llm.generate(batch["text"], sampling_params)
|
outputs = self.llm.generate(batch["text"], sampling_params)
|
||||||
prompt: List[str] = []
|
prompt: list[str] = []
|
||||||
generated_text: List[str] = []
|
generated_text: list[str] = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
prompt.append(output.prompt)
|
prompt.append(output.prompt)
|
||||||
generated_text.append(' '.join([o.text for o in output.outputs]))
|
generated_text.append(' '.join([o.text for o in output.outputs]))
|
||||||
@ -72,7 +72,7 @@ def scheduling_strategy_fn():
|
|||||||
pg, placement_group_capture_child_tasks=True))
|
pg, placement_group_capture_child_tasks=True))
|
||||||
|
|
||||||
|
|
||||||
resources_kwarg: Dict[str, Any] = {}
|
resources_kwarg: dict[str, Any] = {}
|
||||||
if tensor_parallel_size == 1:
|
if tensor_parallel_size == 1:
|
||||||
# For tensor_parallel_size == 1, we simply set num_gpus=1.
|
# For tensor_parallel_size == 1, we simply set num_gpus=1.
|
||||||
resources_kwarg["num_gpus"] = 1
|
resources_kwarg["num_gpus"] = 1
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
|
def create_test_prompts() -> list[tuple[str, SamplingParams]]:
|
||||||
"""Create a list of test prompts with their sampling parameters."""
|
"""Create a list of test prompts with their sampling parameters."""
|
||||||
return [
|
return [
|
||||||
("A robot may not injure a human being",
|
("A robot may not injure a human being",
|
||||||
@ -24,7 +23,7 @@ def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
|
|||||||
|
|
||||||
|
|
||||||
def process_requests(engine: LLMEngine,
|
def process_requests(engine: LLMEngine,
|
||||||
test_prompts: List[Tuple[str, SamplingParams]]):
|
test_prompts: list[tuple[str, SamplingParams]]):
|
||||||
"""Continuously process a list of prompts and handle the outputs."""
|
"""Continuously process a list of prompts and handle the outputs."""
|
||||||
request_id = 0
|
request_id = 0
|
||||||
|
|
||||||
@ -34,7 +33,7 @@ def process_requests(engine: LLMEngine,
|
|||||||
engine.add_request(str(request_id), prompt, sampling_params)
|
engine.add_request(str(request_id), prompt, sampling_params)
|
||||||
request_id += 1
|
request_id += 1
|
||||||
|
|
||||||
request_outputs: List[RequestOutput] = engine.step()
|
request_outputs: list[RequestOutput] = engine.step()
|
||||||
|
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
if request_output.finished:
|
if request_output.finished:
|
||||||
|
@ -7,7 +7,7 @@ Requires HuggingFace credentials for access.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
@ -18,7 +18,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
|
|
||||||
def create_test_prompts(
|
def create_test_prompts(
|
||||||
lora_path: str
|
lora_path: str
|
||||||
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
|
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
|
||||||
return [
|
return [
|
||||||
# this is an example of using quantization without LoRA
|
# this is an example of using quantization without LoRA
|
||||||
("My name is",
|
("My name is",
|
||||||
@ -49,7 +49,7 @@ def create_test_prompts(
|
|||||||
|
|
||||||
|
|
||||||
def process_requests(engine: LLMEngine,
|
def process_requests(engine: LLMEngine,
|
||||||
test_prompts: List[Tuple[str, SamplingParams,
|
test_prompts: list[tuple[str, SamplingParams,
|
||||||
Optional[LoRARequest]]]):
|
Optional[LoRARequest]]]):
|
||||||
"""Continuously process a list of prompts and handle the outputs."""
|
"""Continuously process a list of prompts and handle the outputs."""
|
||||||
request_id = 0
|
request_id = 0
|
||||||
@ -63,7 +63,7 @@ def process_requests(engine: LLMEngine,
|
|||||||
lora_request=lora_request)
|
lora_request=lora_request)
|
||||||
request_id += 1
|
request_id += 1
|
||||||
|
|
||||||
request_outputs: List[RequestOutput] = engine.step()
|
request_outputs: list[RequestOutput] = engine.step()
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
if request_output.finished:
|
if request_output.finished:
|
||||||
print("----------------------------------------------------")
|
print("----------------------------------------------------")
|
||||||
|
@ -2,12 +2,11 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
def time_generation(llm: LLM, prompts: List[str],
|
def time_generation(llm: LLM, prompts: list[str],
|
||||||
sampling_params: SamplingParams):
|
sampling_params: SamplingParams):
|
||||||
# Generate texts from the prompts. The output is a list of RequestOutput
|
# Generate texts from the prompts. The output is a list of RequestOutput
|
||||||
# objects that contain the prompt, generated text, and other information.
|
# objects that contain the prompt, generated text, and other information.
|
||||||
|
@ -6,7 +6,7 @@ for offline inference.
|
|||||||
Requires HuggingFace credentials for access to Llama2.
|
Requires HuggingFace credentials for access to Llama2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
|
|
||||||
def create_test_prompts(
|
def create_test_prompts(
|
||||||
lora_path: str
|
lora_path: str
|
||||||
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
|
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
|
||||||
"""Create a list of test prompts with their sampling parameters.
|
"""Create a list of test prompts with their sampling parameters.
|
||||||
|
|
||||||
2 requests for base model, 4 requests for the LoRA. We define 2
|
2 requests for base model, 4 requests for the LoRA. We define 2
|
||||||
@ -56,7 +56,7 @@ def create_test_prompts(
|
|||||||
|
|
||||||
|
|
||||||
def process_requests(engine: LLMEngine,
|
def process_requests(engine: LLMEngine,
|
||||||
test_prompts: List[Tuple[str, SamplingParams,
|
test_prompts: list[tuple[str, SamplingParams,
|
||||||
Optional[LoRARequest]]]):
|
Optional[LoRARequest]]]):
|
||||||
"""Continuously process a list of prompts and handle the outputs."""
|
"""Continuously process a list of prompts and handle the outputs."""
|
||||||
request_id = 0
|
request_id = 0
|
||||||
@ -70,7 +70,7 @@ def process_requests(engine: LLMEngine,
|
|||||||
lora_request=lora_request)
|
lora_request=lora_request)
|
||||||
request_id += 1
|
request_id += 1
|
||||||
|
|
||||||
request_outputs: List[RequestOutput] = engine.step()
|
request_outputs: list[RequestOutput] = engine.step()
|
||||||
|
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
if request_output.finished:
|
if request_output.finished:
|
||||||
|
@ -21,7 +21,7 @@ import argparse
|
|||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import List, Union
|
from typing import Union
|
||||||
|
|
||||||
import albumentations
|
import albumentations
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -260,9 +260,9 @@ def _convert_np_uint8(float_image: torch.Tensor):
|
|||||||
|
|
||||||
|
|
||||||
def load_example(
|
def load_example(
|
||||||
file_paths: List[str],
|
file_paths: list[str],
|
||||||
mean: List[float] = None,
|
mean: list[float] = None,
|
||||||
std: List[float] = None,
|
std: list[float] = None,
|
||||||
indices: Union[list[int], None] = None,
|
indices: Union[list[int], None] = None,
|
||||||
):
|
):
|
||||||
"""Build an input example by loading images in *file_paths*.
|
"""Build an input example by loading images in *file_paths*.
|
||||||
|
@ -5,8 +5,9 @@ import json
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from argparse import RawTextHelpFormatter
|
from argparse import RawTextHelpFormatter
|
||||||
|
from collections.abc import Generator
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Any, Dict, Generator, List, Optional, TypeAlias
|
from typing import Any, Optional, TypeAlias
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
@ -42,8 +43,8 @@ def get_dtype(dtype: str):
|
|||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
|
|
||||||
OutputLen_NumReqs_Map: TypeAlias = Dict[int, int]
|
OutputLen_NumReqs_Map: TypeAlias = dict[int, int]
|
||||||
def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
|
def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \
|
||||||
-> OutputLen_NumReqs_Map:
|
-> OutputLen_NumReqs_Map:
|
||||||
"""
|
"""
|
||||||
Given the number of requests, batch_size, and the number of requests
|
Given the number of requests, batch_size, and the number of requests
|
||||||
@ -63,7 +64,7 @@ def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
|
|||||||
Args:
|
Args:
|
||||||
batch_size (int): Number of requests submitted for profile. This is
|
batch_size (int): Number of requests submitted for profile. This is
|
||||||
args.batch_size.
|
args.batch_size.
|
||||||
step_requests (List[int]): step_requests[i] is the number of requests
|
step_requests (list[int]): step_requests[i] is the number of requests
|
||||||
that the ith engine step should process.
|
that the ith engine step should process.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -114,7 +115,7 @@ def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
|
|||||||
return ol_nr
|
return ol_nr
|
||||||
|
|
||||||
|
|
||||||
def determine_requests_per_step(context: ProfileContext) -> List[int]:
|
def determine_requests_per_step(context: ProfileContext) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Determine number of requests each engine step should process.
|
Determine number of requests each engine step should process.
|
||||||
If context.num_steps is set, then all engine steps process the
|
If context.num_steps is set, then all engine steps process the
|
||||||
@ -130,7 +131,7 @@ def determine_requests_per_step(context: ProfileContext) -> List[int]:
|
|||||||
context: ProfileContext object.
|
context: ProfileContext object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[int]: Number of requests to process for all engine-steps.
|
list[int]: Number of requests to process for all engine-steps.
|
||||||
output[i], contains the number of requests that the ith step
|
output[i], contains the number of requests that the ith step
|
||||||
should process.
|
should process.
|
||||||
"""
|
"""
|
||||||
@ -170,7 +171,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
|||||||
for key, value in asdict(context).items():
|
for key, value in asdict(context).items():
|
||||||
print(f" {key} = {value}")
|
print(f" {key} = {value}")
|
||||||
|
|
||||||
requests_per_step: List[int] = determine_requests_per_step(context)
|
requests_per_step: list[int] = determine_requests_per_step(context)
|
||||||
|
|
||||||
ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
|
ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
|
||||||
context.batch_size, requests_per_step)
|
context.batch_size, requests_per_step)
|
||||||
|
@ -4,7 +4,6 @@ import argparse
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch_xla.debug.profiler as xp
|
import torch_xla.debug.profiler as xp
|
||||||
@ -35,7 +34,7 @@ def main(args: argparse.Namespace):
|
|||||||
dummy_prompt_token_ids = np.random.randint(10000,
|
dummy_prompt_token_ids = np.random.randint(10000,
|
||||||
size=(args.batch_size,
|
size=(args.batch_size,
|
||||||
args.input_len))
|
args.input_len))
|
||||||
dummy_prompts: List[PromptType] = [{
|
dummy_prompts: list[PromptType] = [{
|
||||||
"prompt_token_ids": batch
|
"prompt_token_ids": batch
|
||||||
} for batch in dummy_prompt_token_ids.tolist()]
|
} for batch in dummy_prompt_token_ids.tolist()]
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ multi-image input on vision language models for text generation,
|
|||||||
using the chat template defined by the model.
|
using the chat template defined by the model.
|
||||||
"""
|
"""
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from typing import List, NamedTuple, Optional
|
from typing import NamedTuple, Optional
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from transformers import AutoProcessor, AutoTokenizer
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
@ -24,8 +24,8 @@ IMAGE_URLS = [
|
|||||||
class ModelRequestData(NamedTuple):
|
class ModelRequestData(NamedTuple):
|
||||||
llm: LLM
|
llm: LLM
|
||||||
prompt: str
|
prompt: str
|
||||||
stop_token_ids: Optional[List[int]]
|
stop_token_ids: Optional[list[int]]
|
||||||
image_data: List[Image]
|
image_data: list[Image]
|
||||||
chat_template: Optional[str]
|
chat_template: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ class ModelRequestData(NamedTuple):
|
|||||||
# Unless specified, these settings have been tested to work on a single L4.
|
# Unless specified, these settings have been tested to work on a single L4.
|
||||||
|
|
||||||
|
|
||||||
def load_aria(question, image_urls: List[str]) -> ModelRequestData:
|
def load_aria(question, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "rhymes-ai/Aria"
|
model_name = "rhymes-ai/Aria"
|
||||||
llm = LLM(model=model_name,
|
llm = LLM(model=model_name,
|
||||||
tokenizer_mode="slow",
|
tokenizer_mode="slow",
|
||||||
@ -55,7 +55,7 @@ def load_aria(question, image_urls: List[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_deepseek_vl2(question: str, image_urls: List[str]):
|
def load_deepseek_vl2(question: str, image_urls: list[str]):
|
||||||
model_name = "deepseek-ai/deepseek-vl2-tiny"
|
model_name = "deepseek-ai/deepseek-vl2-tiny"
|
||||||
|
|
||||||
llm = LLM(model=model_name,
|
llm = LLM(model=model_name,
|
||||||
@ -77,7 +77,7 @@ def load_deepseek_vl2(question: str, image_urls: List[str]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_h2ovl(question: str, image_urls: List[str]) -> ModelRequestData:
|
def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "h2oai/h2ovl-mississippi-800m"
|
model_name = "h2oai/h2ovl-mississippi-800m"
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@ -111,7 +111,7 @@ def load_h2ovl(question: str, image_urls: List[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_idefics3(question, image_urls: List[str]) -> ModelRequestData:
|
def load_idefics3(question, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
|
model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
|
||||||
|
|
||||||
# The configuration below has been confirmed to launch on a single L40 GPU.
|
# The configuration below has been confirmed to launch on a single L40 GPU.
|
||||||
@ -142,7 +142,7 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
|
def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "OpenGVLab/InternVL2-2B"
|
model_name = "OpenGVLab/InternVL2-2B"
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@ -179,7 +179,7 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
|
def load_mllama(question, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||||
|
|
||||||
# The configuration below has been confirmed to launch on a single L40 GPU.
|
# The configuration below has been confirmed to launch on a single L40 GPU.
|
||||||
@ -201,7 +201,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_nvlm_d(question: str, image_urls: List[str]):
|
def load_nvlm_d(question: str, image_urls: list[str]):
|
||||||
model_name = "nvidia/NVLM-D-72B"
|
model_name = "nvidia/NVLM-D-72B"
|
||||||
|
|
||||||
# Adjust this as necessary to fit in GPU
|
# Adjust this as necessary to fit in GPU
|
||||||
@ -234,7 +234,7 @@ def load_nvlm_d(question: str, image_urls: List[str]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_pixtral_hf(question: str, image_urls: List[str]) -> ModelRequestData:
|
def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "mistral-community/pixtral-12b"
|
model_name = "mistral-community/pixtral-12b"
|
||||||
|
|
||||||
# Adjust this as necessary to fit in GPU
|
# Adjust this as necessary to fit in GPU
|
||||||
@ -259,7 +259,7 @@ def load_pixtral_hf(question: str, image_urls: List[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
|
def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
# num_crops is an override kwarg to the multimodal image processor;
|
# num_crops is an override kwarg to the multimodal image processor;
|
||||||
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
|
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
|
||||||
# to use 16 for single frame scenarios, and 4 for multi-frame.
|
# to use 16 for single frame scenarios, and 4 for multi-frame.
|
||||||
@ -295,7 +295,7 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
|
|||||||
|
|
||||||
|
|
||||||
def load_qwen_vl_chat(question: str,
|
def load_qwen_vl_chat(question: str,
|
||||||
image_urls: List[str]) -> ModelRequestData:
|
image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "Qwen/Qwen-VL-Chat"
|
model_name = "Qwen/Qwen-VL-Chat"
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@ -336,7 +336,7 @@ def load_qwen_vl_chat(question: str,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
|
def load_qwen2_vl(question, image_urls: list[str]) -> ModelRequestData:
|
||||||
try:
|
try:
|
||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
@ -393,7 +393,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_qwen2_5_vl(question, image_urls: List[str]) -> ModelRequestData:
|
def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
|
||||||
try:
|
try:
|
||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
@ -466,7 +466,7 @@ model_example_map = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def run_generate(model, question: str, image_urls: List[str]):
|
def run_generate(model, question: str, image_urls: list[str]):
|
||||||
req_data = model_example_map[model](question, image_urls)
|
req_data = model_example_map[model](question, image_urls)
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.0,
|
sampling_params = SamplingParams(temperature=0.0,
|
||||||
@ -487,7 +487,7 @@ def run_generate(model, question: str, image_urls: List[str]):
|
|||||||
print(generated_text)
|
print(generated_text)
|
||||||
|
|
||||||
|
|
||||||
def run_chat(model: str, question: str, image_urls: List[str]):
|
def run_chat(model: str, question: str, image_urls: list[str]):
|
||||||
req_data = model_example_map[model](question, image_urls)
|
req_data = model_example_map[model](question, image_urls)
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.0,
|
sampling_params = SamplingParams(temperature=0.0,
|
||||||
|
@ -7,7 +7,7 @@ For production use, we recommend `vllm serve` and the OpenAI client API.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
from typing import Iterable, List
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ def post_http_request(prompt: str,
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
|
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
|
||||||
for chunk in response.iter_lines(chunk_size=8192,
|
for chunk in response.iter_lines(chunk_size=8192,
|
||||||
decode_unicode=False,
|
decode_unicode=False,
|
||||||
delimiter=b"\0"):
|
delimiter=b"\0"):
|
||||||
@ -49,7 +49,7 @@ def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
|
|||||||
yield output
|
yield output
|
||||||
|
|
||||||
|
|
||||||
def get_response(response: requests.Response) -> List[str]:
|
def get_response(response: requests.Response) -> list[str]:
|
||||||
data = json.loads(response.content)
|
data = json.loads(response.content)
|
||||||
output = data["text"]
|
output = data["text"]
|
||||||
return output
|
return output
|
||||||
|
@ -24,4 +24,4 @@ responses = client.embeddings.create(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for data in responses.data:
|
for data in responses.data:
|
||||||
print(data.embedding) # list of float of len 4096
|
print(data.embedding) # List of float of len 4096
|
||||||
|
@ -65,6 +65,32 @@ exclude = [
|
|||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"vllm/version.py" = ["F401"]
|
"vllm/version.py" = ["F401"]
|
||||||
"vllm/_version.py" = ["ALL"]
|
"vllm/_version.py" = ["ALL"]
|
||||||
|
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
|
||||||
|
"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/compilation/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/core/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/distributed/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/inputs/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/logging_utils/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/lora/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/model_executor/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/multimodal/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/profiler/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/third_party/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/triton_utils/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/usage/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/vllm_flash_attn/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/assets/**/*.py" = ["UP006", "UP035"]
|
||||||
|
"vllm/worker/**/*.py" = ["UP006", "UP035"]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
@ -91,8 +117,6 @@ ignore = [
|
|||||||
"B007",
|
"B007",
|
||||||
# f-string format
|
# f-string format
|
||||||
"UP032",
|
"UP032",
|
||||||
# Python 3.8 typing
|
|
||||||
"UP006", "UP035",
|
|
||||||
# Can remove once 3.10+ is the minimum Python version
|
# Can remove once 3.10+ is the minimum Python version
|
||||||
"UP007",
|
"UP007",
|
||||||
]
|
]
|
||||||
|
7
setup.py
7
setup.py
@ -9,7 +9,6 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import which
|
from shutil import which
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging.version import Version, parse
|
from packaging.version import Version, parse
|
||||||
@ -78,7 +77,7 @@ class CMakeExtension(Extension):
|
|||||||
|
|
||||||
class cmake_build_ext(build_ext):
|
class cmake_build_ext(build_ext):
|
||||||
# A dict of extension directories that have been configured.
|
# A dict of extension directories that have been configured.
|
||||||
did_config: Dict[str, bool] = {}
|
did_config: dict[str, bool] = {}
|
||||||
|
|
||||||
#
|
#
|
||||||
# Determine number of compilation jobs and optionally nvcc compile threads.
|
# Determine number of compilation jobs and optionally nvcc compile threads.
|
||||||
@ -548,10 +547,10 @@ def get_vllm_version() -> str:
|
|||||||
return version
|
return version
|
||||||
|
|
||||||
|
|
||||||
def get_requirements() -> List[str]:
|
def get_requirements() -> list[str]:
|
||||||
"""Get Python package dependencies from requirements.txt."""
|
"""Get Python package dependencies from requirements.txt."""
|
||||||
|
|
||||||
def _read_requirements(filename: str) -> List[str]:
|
def _read_requirements(filename: str) -> list[str]:
|
||||||
with open(get_path(filename)) as f:
|
with open(get_path(filename)) as f:
|
||||||
requirements = f.read().strip().split("\n")
|
requirements = f.read().strip().split("\n")
|
||||||
resolved_requirements = []
|
resolved_requirements = []
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
||||||
from typing import Any, Dict, Iterable
|
from collections.abc import Iterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi.responses import JSONResponse, Response
|
from fastapi.responses import JSONResponse, Response
|
||||||
@ -24,7 +25,7 @@ class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
|||||||
self._num_aborts += len(ids)
|
self._num_aborts += len(ids)
|
||||||
await super()._engine_abort(ids)
|
await super()._engine_abort(ids)
|
||||||
|
|
||||||
def testing_stats(self) -> Dict[str, Any]:
|
def testing_stats(self) -> dict[str, Any]:
|
||||||
return {"num_aborted_requests": self._num_aborts}
|
return {"num_aborted_requests": self._num_aborts}
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import uuid
|
|||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
@ -254,7 +254,7 @@ async def test_output_kinds(async_engine, stop):
|
|||||||
params.output_kind = RequestOutputKind.DELTA
|
params.output_kind = RequestOutputKind.DELTA
|
||||||
|
|
||||||
prompt_tokens = None
|
prompt_tokens = None
|
||||||
output_tokens: List[int] = []
|
output_tokens: list[int] = []
|
||||||
output_text = ""
|
output_text = ""
|
||||||
output_count = 0
|
output_count = 0
|
||||||
final_output = None
|
final_output = None
|
||||||
|
@ -8,7 +8,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are
|
|||||||
initialized randomly with a fixed seed.
|
initialized randomly with a fixed seed.
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -56,7 +56,7 @@ class LlamaConfig:
|
|||||||
random_seed: int = 0
|
random_seed: int = 0
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
factors: List[Any] = []
|
factors: list[Any] = []
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if k == "random_seed":
|
if k == "random_seed":
|
||||||
continue
|
continue
|
||||||
@ -174,7 +174,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
For tractable computation:
|
For tractable computation:
|
||||||
- if residual is None, the outputs are:
|
- if residual is None, the outputs are:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Dict, List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ from ..utils import compare_all_settings
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class TestSetting:
|
class TestSetting:
|
||||||
model: str
|
model: str
|
||||||
model_args: List[str]
|
model_args: list[str]
|
||||||
pp_size: int
|
pp_size: int
|
||||||
tp_size: int
|
tp_size: int
|
||||||
attn_backend: str
|
attn_backend: str
|
||||||
@ -108,8 +108,8 @@ def test_compile_correctness(test_setting: TestSetting):
|
|||||||
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
|
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
|
||||||
["-tp", str(tp_size)]
|
["-tp", str(tp_size)]
|
||||||
|
|
||||||
all_args: List[List[str]] = []
|
all_args: list[list[str]] = []
|
||||||
all_envs: List[Optional[Dict[str, str]]] = []
|
all_envs: list[Optional[dict[str, str]]] = []
|
||||||
|
|
||||||
for level in [
|
for level in [
|
||||||
CompilationLevel.NO_COMPILATION,
|
CompilationLevel.NO_COMPILATION,
|
||||||
|
@ -5,8 +5,7 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
from collections import UserList
|
from collections import UserList
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
|
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
|
||||||
TypedDict, TypeVar, Union)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -47,14 +46,14 @@ _SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt")
|
|||||||
|
|
||||||
_M = TypeVar("_M")
|
_M = TypeVar("_M")
|
||||||
|
|
||||||
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
|
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
|
||||||
|
|
||||||
PromptImageInput = _PromptMultiModalInput[Image.Image]
|
PromptImageInput = _PromptMultiModalInput[Image.Image]
|
||||||
PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
|
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
|
||||||
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
|
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
|
||||||
|
|
||||||
|
|
||||||
def _read_prompts(filename: str) -> List[str]:
|
def _read_prompts(filename: str) -> list[str]:
|
||||||
with open(filename) as f:
|
with open(filename) as f:
|
||||||
prompts = f.readlines()
|
prompts = f.readlines()
|
||||||
return prompts
|
return prompts
|
||||||
@ -77,7 +76,7 @@ class _ImageAssets(_ImageAssetsBase):
|
|||||||
ImageAsset("cherry_blossom"),
|
ImageAsset("cherry_blossom"),
|
||||||
])
|
])
|
||||||
|
|
||||||
def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
|
def prompts(self, prompts: _ImageAssetPrompts) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Convenience method to define the prompt for each test image.
|
Convenience method to define the prompt for each test image.
|
||||||
|
|
||||||
@ -102,7 +101,7 @@ class _VideoAssets(_VideoAssetsBase):
|
|||||||
VideoAsset("sample_demo_1.mp4"),
|
VideoAsset("sample_demo_1.mp4"),
|
||||||
])
|
])
|
||||||
|
|
||||||
def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
|
def prompts(self, prompts: _VideoAssetPrompts) -> list[str]:
|
||||||
return [prompts["sample_demo_1"]]
|
return [prompts["sample_demo_1"]]
|
||||||
|
|
||||||
|
|
||||||
@ -175,7 +174,7 @@ def dynamo_reset():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def example_prompts() -> List[str]:
|
def example_prompts() -> list[str]:
|
||||||
prompts = []
|
prompts = []
|
||||||
for filename in _TEST_PROMPTS:
|
for filename in _TEST_PROMPTS:
|
||||||
prompts += _read_prompts(filename)
|
prompts += _read_prompts(filename)
|
||||||
@ -197,7 +196,7 @@ class DecoderPromptType(Enum):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def example_encoder_decoder_prompts(
|
def example_encoder_decoder_prompts(
|
||||||
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
|
) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]:
|
||||||
'''
|
'''
|
||||||
Returns an encoder prompt list and a decoder prompt list, wherein each pair
|
Returns an encoder prompt list and a decoder prompt list, wherein each pair
|
||||||
of same-index entries in both lists corresponds to an (encoder prompt,
|
of same-index entries in both lists corresponds to an (encoder prompt,
|
||||||
@ -229,7 +228,7 @@ def example_encoder_decoder_prompts(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def example_long_prompts() -> List[str]:
|
def example_long_prompts() -> list[str]:
|
||||||
prompts = []
|
prompts = []
|
||||||
for filename in _LONG_PROMPTS:
|
for filename in _LONG_PROMPTS:
|
||||||
prompts += _read_prompts(filename)
|
prompts += _read_prompts(filename)
|
||||||
@ -273,11 +272,11 @@ class HfRunner:
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
dtype: str = "half",
|
dtype: str = "half",
|
||||||
*,
|
*,
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
model_kwargs: Optional[dict[str, Any]] = None,
|
||||||
is_sentence_transformer: bool = False,
|
is_sentence_transformer: bool = False,
|
||||||
is_cross_encoder: bool = False,
|
is_cross_encoder: bool = False,
|
||||||
skip_tokenizer_init: bool = False,
|
skip_tokenizer_init: bool = False,
|
||||||
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
||||||
postprocess_inputs: Callable[..., BatchEncoding] = identity,
|
postprocess_inputs: Callable[..., BatchEncoding] = identity,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
@ -334,11 +333,11 @@ class HfRunner:
|
|||||||
|
|
||||||
def get_inputs(
|
def get_inputs(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
) -> List[BatchEncoding]:
|
) -> list[BatchEncoding]:
|
||||||
if images is not None:
|
if images is not None:
|
||||||
assert len(prompts) == len(images)
|
assert len(prompts) == len(images)
|
||||||
|
|
||||||
@ -348,9 +347,9 @@ class HfRunner:
|
|||||||
if audios is not None:
|
if audios is not None:
|
||||||
assert len(prompts) == len(audios)
|
assert len(prompts) == len(audios)
|
||||||
|
|
||||||
all_inputs: List[BatchEncoding] = []
|
all_inputs: list[BatchEncoding] = []
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
processor_kwargs: Dict[str, Any] = {
|
processor_kwargs: dict[str, Any] = {
|
||||||
"text": prompt,
|
"text": prompt,
|
||||||
"return_tensors": "pt",
|
"return_tensors": "pt",
|
||||||
}
|
}
|
||||||
@ -370,7 +369,7 @@ class HfRunner:
|
|||||||
|
|
||||||
return all_inputs
|
return all_inputs
|
||||||
|
|
||||||
def classify(self, prompts: List[str]) -> List[str]:
|
def classify(self, prompts: list[str]) -> list[str]:
|
||||||
# output is final logits
|
# output is final logits
|
||||||
all_inputs = self.get_inputs(prompts)
|
all_inputs = self.get_inputs(prompts)
|
||||||
outputs = []
|
outputs = []
|
||||||
@ -383,18 +382,18 @@ class HfRunner:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
all_inputs = self.get_inputs(prompts,
|
all_inputs = self.get_inputs(prompts,
|
||||||
images=images,
|
images=images,
|
||||||
videos=videos,
|
videos=videos,
|
||||||
audios=audios)
|
audios=audios)
|
||||||
|
|
||||||
outputs: List[Tuple[List[List[int]], List[str]]] = []
|
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||||
for inputs in all_inputs:
|
for inputs in all_inputs:
|
||||||
output_ids = self.model.generate(
|
output_ids = self.model.generate(
|
||||||
**self.wrap_device(inputs, device=self.model.device.type),
|
**self.wrap_device(inputs, device=self.model.device.type),
|
||||||
@ -412,13 +411,13 @@ class HfRunner:
|
|||||||
|
|
||||||
def generate_greedy(
|
def generate_greedy(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> list[tuple[list[int], str]]:
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(prompts,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
@ -432,10 +431,10 @@ class HfRunner:
|
|||||||
|
|
||||||
def generate_beam_search(
|
def generate_beam_search(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
beam_width: int,
|
beam_width: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(prompts,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
@ -453,19 +452,19 @@ class HfRunner:
|
|||||||
|
|
||||||
def generate_greedy_logprobs(
|
def generate_greedy_logprobs(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[List[torch.Tensor]]:
|
) -> list[list[torch.Tensor]]:
|
||||||
all_inputs = self.get_inputs(prompts,
|
all_inputs = self.get_inputs(prompts,
|
||||||
images=images,
|
images=images,
|
||||||
videos=videos,
|
videos=videos,
|
||||||
audios=audios)
|
audios=audios)
|
||||||
|
|
||||||
all_logprobs: List[List[torch.Tensor]] = []
|
all_logprobs: list[list[torch.Tensor]] = []
|
||||||
for inputs in all_inputs:
|
for inputs in all_inputs:
|
||||||
output = self.model.generate(
|
output = self.model.generate(
|
||||||
**self.wrap_device(inputs, device=self.model.device.type),
|
**self.wrap_device(inputs, device=self.model.device.type),
|
||||||
@ -483,11 +482,11 @@ class HfRunner:
|
|||||||
|
|
||||||
def _hidden_states_to_seq_logprobs(
|
def _hidden_states_to_seq_logprobs(
|
||||||
self,
|
self,
|
||||||
hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
|
hidden_states: tuple[tuple[torch.Tensor, ...], ...],
|
||||||
) -> List[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
output_embeddings = self.model.get_output_embeddings()
|
output_embeddings = self.model.get_output_embeddings()
|
||||||
|
|
||||||
seq_logprobs: List[torch.Tensor] = []
|
seq_logprobs: list[torch.Tensor] = []
|
||||||
for _, hidden_state in enumerate(hidden_states):
|
for _, hidden_state in enumerate(hidden_states):
|
||||||
last_hidden_states = hidden_state[-1][0]
|
last_hidden_states = hidden_state[-1][0]
|
||||||
logits = torch.matmul(
|
logits = torch.matmul(
|
||||||
@ -503,14 +502,14 @@ class HfRunner:
|
|||||||
|
|
||||||
def _hidden_states_to_logprobs(
|
def _hidden_states_to_logprobs(
|
||||||
self,
|
self,
|
||||||
hidden_states: Tuple[Tuple[torch.Tensor, ...], ...],
|
hidden_states: tuple[tuple[torch.Tensor, ...], ...],
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
) -> Tuple[List[Dict[int, float]], int]:
|
) -> tuple[list[dict[int, float]], int]:
|
||||||
seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
|
seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states)
|
||||||
output_len = len(hidden_states)
|
output_len = len(hidden_states)
|
||||||
|
|
||||||
# convert to dict
|
# convert to dict
|
||||||
seq_logprobs_lst: List[Dict[int, float]] = []
|
seq_logprobs_lst: list[dict[int, float]] = []
|
||||||
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
|
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
|
||||||
# drop prompt logprobs
|
# drop prompt logprobs
|
||||||
if tok_idx == 0:
|
if tok_idx == 0:
|
||||||
@ -530,22 +529,22 @@ class HfRunner:
|
|||||||
|
|
||||||
def generate_greedy_logprobs_limit(
|
def generate_greedy_logprobs_limit(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[TokensTextLogprobs]:
|
) -> list[TokensTextLogprobs]:
|
||||||
all_inputs = self.get_inputs(prompts,
|
all_inputs = self.get_inputs(prompts,
|
||||||
images=images,
|
images=images,
|
||||||
videos=videos,
|
videos=videos,
|
||||||
audios=audios)
|
audios=audios)
|
||||||
|
|
||||||
all_logprobs: List[List[Dict[int, float]]] = []
|
all_logprobs: list[list[dict[int, float]]] = []
|
||||||
all_output_ids: List[List[int]] = []
|
all_output_ids: list[list[int]] = []
|
||||||
all_output_strs: List[str] = []
|
all_output_strs: list[str] = []
|
||||||
|
|
||||||
for inputs in all_inputs:
|
for inputs in all_inputs:
|
||||||
output = self.model.generate(
|
output = self.model.generate(
|
||||||
@ -577,23 +576,23 @@ class HfRunner:
|
|||||||
|
|
||||||
def generate_encoder_decoder_greedy_logprobs_limit(
|
def generate_encoder_decoder_greedy_logprobs_limit(
|
||||||
self,
|
self,
|
||||||
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[TokensTextLogprobs]:
|
) -> list[TokensTextLogprobs]:
|
||||||
'''
|
'''
|
||||||
Greedy logprobs generation for vLLM encoder/decoder models
|
Greedy logprobs generation for vLLM encoder/decoder models
|
||||||
'''
|
'''
|
||||||
|
|
||||||
all_logprobs: List[List[Dict[int, float]]] = []
|
all_logprobs: list[list[dict[int, float]]] = []
|
||||||
all_output_ids: List[List[int]] = []
|
all_output_ids: list[list[int]] = []
|
||||||
all_output_strs: List[str] = []
|
all_output_strs: list[str] = []
|
||||||
|
|
||||||
for i, (encoder_prompt, decoder_prompt) in enumerate(
|
for i, (encoder_prompt, decoder_prompt) in enumerate(
|
||||||
to_enc_dec_tuple_list(encoder_decoder_prompts)):
|
to_enc_dec_tuple_list(encoder_decoder_prompts)):
|
||||||
processor_kwargs: Dict[str, Any] = {
|
processor_kwargs: dict[str, Any] = {
|
||||||
"text": encoder_prompt,
|
"text": encoder_prompt,
|
||||||
"return_tensors": "pt",
|
"return_tensors": "pt",
|
||||||
}
|
}
|
||||||
@ -641,10 +640,10 @@ class HfRunner:
|
|||||||
return [(output_ids, output_str, output_logprobs)
|
return [(output_ids, output_str, output_logprobs)
|
||||||
for output_ids, output_str, output_logprobs in outputs]
|
for output_ids, output_str, output_logprobs in outputs]
|
||||||
|
|
||||||
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
|
def encode(self, prompts: list[str]) -> list[list[torch.Tensor]]:
|
||||||
return self.model.encode(prompts)
|
return self.model.encode(prompts)
|
||||||
|
|
||||||
def predict(self, prompts: List[List[str]]) -> torch.Tensor:
|
def predict(self, prompts: list[list[str]]) -> torch.Tensor:
|
||||||
return self.model.predict(prompts, convert_to_tensor=True)
|
return self.model.predict(prompts, convert_to_tensor=True)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@ -699,11 +698,11 @@ class VllmRunner:
|
|||||||
|
|
||||||
def get_inputs(
|
def get_inputs(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
) -> List[TextPrompt]:
|
) -> list[TextPrompt]:
|
||||||
if images is not None:
|
if images is not None:
|
||||||
assert len(prompts) == len(images)
|
assert len(prompts) == len(images)
|
||||||
|
|
||||||
@ -733,13 +732,13 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
inputs = self.get_inputs(prompts,
|
inputs = self.get_inputs(prompts,
|
||||||
images=images,
|
images=images,
|
||||||
videos=videos,
|
videos=videos,
|
||||||
@ -749,12 +748,12 @@ class VllmRunner:
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
outputs: List[Tuple[List[List[int]], List[str]]] = []
|
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||||
for req_output in req_outputs:
|
for req_output in req_outputs:
|
||||||
prompt_str = req_output.prompt
|
prompt_str = req_output.prompt
|
||||||
prompt_ids = req_output.prompt_token_ids
|
prompt_ids = req_output.prompt_token_ids
|
||||||
req_sample_output_ids: List[List[int]] = []
|
req_sample_output_ids: list[list[int]] = []
|
||||||
req_sample_output_strs: List[str] = []
|
req_sample_output_strs: list[str] = []
|
||||||
for sample in req_output.outputs:
|
for sample in req_output.outputs:
|
||||||
output_str = sample.text
|
output_str = sample.text
|
||||||
output_ids = list(sample.token_ids)
|
output_ids = list(sample.token_ids)
|
||||||
@ -765,9 +764,9 @@ class VllmRunner:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _final_steps_generate_w_logprobs(
|
def _final_steps_generate_w_logprobs(
|
||||||
req_outputs: List[RequestOutput],
|
req_outputs: list[RequestOutput],
|
||||||
) -> List[TokensTextLogprobsPromptLogprobs]:
|
) -> list[TokensTextLogprobsPromptLogprobs]:
|
||||||
outputs: List[TokensTextLogprobsPromptLogprobs] = []
|
outputs: list[TokensTextLogprobsPromptLogprobs] = []
|
||||||
for req_output in req_outputs:
|
for req_output in req_outputs:
|
||||||
assert len(req_output.outputs) > 0
|
assert len(req_output.outputs) > 0
|
||||||
for sample in req_output.outputs:
|
for sample in req_output.outputs:
|
||||||
@ -780,14 +779,14 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate_w_logprobs(
|
def generate_w_logprobs(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[List[TokensTextLogprobs],
|
) -> Union[list[TokensTextLogprobs],
|
||||||
List[TokensTextLogprobsPromptLogprobs]]:
|
list[TokensTextLogprobsPromptLogprobs]]:
|
||||||
inputs = self.get_inputs(prompts,
|
inputs = self.get_inputs(prompts,
|
||||||
images=images,
|
images=images,
|
||||||
videos=videos,
|
videos=videos,
|
||||||
@ -806,10 +805,10 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate_encoder_decoder_w_logprobs(
|
def generate_encoder_decoder_w_logprobs(
|
||||||
self,
|
self,
|
||||||
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
) -> Union[List[TokensTextLogprobs],
|
) -> Union[list[TokensTextLogprobs],
|
||||||
List[TokensTextLogprobsPromptLogprobs]]:
|
list[TokensTextLogprobsPromptLogprobs]]:
|
||||||
'''
|
'''
|
||||||
Logprobs generation for vLLM encoder/decoder models
|
Logprobs generation for vLLM encoder/decoder models
|
||||||
'''
|
'''
|
||||||
@ -826,13 +825,13 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate_greedy(
|
def generate_greedy(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> list[tuple[list[int], str]]:
|
||||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(prompts,
|
||||||
greedy_params,
|
greedy_params,
|
||||||
@ -845,18 +844,18 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate_greedy_logprobs(
|
def generate_greedy_logprobs(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
num_prompt_logprobs: Optional[int] = None,
|
num_prompt_logprobs: Optional[int] = None,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[list[int]] = None,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[List[TokensTextLogprobs],
|
) -> Union[list[TokensTextLogprobs],
|
||||||
List[TokensTextLogprobsPromptLogprobs]]:
|
list[TokensTextLogprobsPromptLogprobs]]:
|
||||||
greedy_logprobs_params = SamplingParams(
|
greedy_logprobs_params = SamplingParams(
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@ -874,12 +873,12 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate_encoder_decoder_greedy_logprobs(
|
def generate_encoder_decoder_greedy_logprobs(
|
||||||
self,
|
self,
|
||||||
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
num_prompt_logprobs: Optional[int] = None,
|
num_prompt_logprobs: Optional[int] = None,
|
||||||
) -> Union[List[TokensTextLogprobs],
|
) -> Union[list[TokensTextLogprobs],
|
||||||
List[TokensTextLogprobsPromptLogprobs]]:
|
list[TokensTextLogprobsPromptLogprobs]]:
|
||||||
greedy_logprobs_params = SamplingParams(
|
greedy_logprobs_params = SamplingParams(
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@ -895,10 +894,10 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate_beam_search(
|
def generate_beam_search(
|
||||||
self,
|
self,
|
||||||
prompts: Union[List[str], List[List[int]]],
|
prompts: Union[list[str], list[list[int]]],
|
||||||
beam_width: int,
|
beam_width: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
) -> list[tuple[list[list[int]], list[str]]]:
|
||||||
if is_list_of(prompts, str, check="all"):
|
if is_list_of(prompts, str, check="all"):
|
||||||
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
|
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
|
||||||
else:
|
else:
|
||||||
@ -915,17 +914,17 @@ class VllmRunner:
|
|||||||
returned_outputs.append((token_ids, texts))
|
returned_outputs.append((token_ids, texts))
|
||||||
return returned_outputs
|
return returned_outputs
|
||||||
|
|
||||||
def classify(self, prompts: List[str]) -> List[List[float]]:
|
def classify(self, prompts: list[str]) -> list[list[float]]:
|
||||||
req_outputs = self.model.classify(prompts)
|
req_outputs = self.model.classify(prompts)
|
||||||
return [req_output.outputs.probs for req_output in req_outputs]
|
return [req_output.outputs.probs for req_output in req_outputs]
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: list[str],
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
) -> List[List[float]]:
|
) -> list[list[float]]:
|
||||||
inputs = self.get_inputs(prompts,
|
inputs = self.get_inputs(prompts,
|
||||||
images=images,
|
images=images,
|
||||||
videos=videos,
|
videos=videos,
|
||||||
@ -936,9 +935,9 @@ class VllmRunner:
|
|||||||
|
|
||||||
def score(
|
def score(
|
||||||
self,
|
self,
|
||||||
text_1: Union[str, List[str]],
|
text_1: Union[str, list[str]],
|
||||||
text_2: Union[str, List[str]],
|
text_2: Union[str, list[str]],
|
||||||
) -> List[float]:
|
) -> list[float]:
|
||||||
req_outputs = self.model.score(text_1, text_2)
|
req_outputs = self.model.score(text_1, text_2)
|
||||||
return [req_output.outputs.score for req_output in req_outputs]
|
return [req_output.outputs.score for req_output in req_outputs]
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Callable, Iterable, Optional
|
from collections.abc import Iterable
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -137,9 +136,9 @@ def prep_prompts(batch_size: int):
|
|||||||
The prompt is just under 10k tokens; sliding window is 4k
|
The prompt is just under 10k tokens; sliding window is 4k
|
||||||
so the answer is outside sliding window, but should still be correct.
|
so the answer is outside sliding window, but should still be correct.
|
||||||
"""
|
"""
|
||||||
prompts: List[str] = []
|
prompts: list[str] = []
|
||||||
answer: List[int] = []
|
answer: list[int] = []
|
||||||
indices: List[int] = []
|
indices: list[int] = []
|
||||||
random.seed(1)
|
random.seed(1)
|
||||||
for _ in range(batch_size):
|
for _ in range(batch_size):
|
||||||
idx = random.randint(30, 90)
|
idx = random.randint(30, 90)
|
||||||
@ -158,7 +157,7 @@ def prep_prompts(batch_size: int):
|
|||||||
return prompts, answer, indices
|
return prompts, answer, indices
|
||||||
|
|
||||||
|
|
||||||
def check_answers(indices: List[int], answer: List[int], outputs: List[str]):
|
def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
|
||||||
answer2 = [int(text[0:2].strip()) for text in outputs]
|
answer2 = [int(text[0:2].strip()) for text in outputs]
|
||||||
print(list(zip(indices, zip(answer, answer2))))
|
print(list(zip(indices, zip(answer, answer2))))
|
||||||
numok = 0
|
numok = 0
|
||||||
@ -170,7 +169,7 @@ def check_answers(indices: List[int], answer: List[int], outputs: List[str]):
|
|||||||
assert frac_ok > 0.7
|
assert frac_ok > 0.7
|
||||||
|
|
||||||
|
|
||||||
def check_window(prompts: List[str]):
|
def check_window(prompts: list[str]):
|
||||||
|
|
||||||
def inner(llm: LLM):
|
def inner(llm: LLM):
|
||||||
sliding_window = llm.llm_engine.model_config.get_sliding_window()
|
sliding_window = llm.llm_engine.model_config.get_sliding_window()
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.core.block.block_table import BlockTable
|
from vllm.core.block.block_table import BlockTable
|
||||||
@ -32,7 +30,7 @@ def test_allocate_naive(block_size: int, sequence_len: int):
|
|||||||
token_ids = list(range(sequence_len))
|
token_ids = list(range(sequence_len))
|
||||||
num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
|
num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
|
||||||
|
|
||||||
block_tables: List[BlockTable] = []
|
block_tables: list[BlockTable] = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
assert allocator.get_num_free_blocks(
|
assert allocator.get_num_free_blocks(
|
||||||
device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc
|
device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc
|
||||||
@ -77,7 +75,7 @@ def test_allocate_prefix_caching(block_size: int, sequence_len: int):
|
|||||||
num_immutable_blocks_per_alloc = len(
|
num_immutable_blocks_per_alloc = len(
|
||||||
chunked_tokens) - num_mutable_blocks_per_alloc
|
chunked_tokens) - num_mutable_blocks_per_alloc
|
||||||
|
|
||||||
block_tables: List[BlockTable] = []
|
block_tables: list[BlockTable] = []
|
||||||
for alloc_i in range(1, 6):
|
for alloc_i in range(1, 6):
|
||||||
|
|
||||||
block_tables.append(
|
block_tables.append(
|
||||||
@ -272,7 +270,7 @@ def test_append_token_ids_correct_content(block_size: int, sequence_len: int,
|
|||||||
)
|
)
|
||||||
block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||||
|
|
||||||
appended_so_far: List[int] = []
|
appended_so_far: list[int] = []
|
||||||
for append in chunk_list(token_ids_to_append, append_size):
|
for append in chunk_list(token_ids_to_append, append_size):
|
||||||
block_table.append_token_ids(append)
|
block_table.append_token_ids(append)
|
||||||
appended_so_far.extend(append)
|
appended_so_far.extend(append)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ class TestNaiveBlockAllocator:
|
|||||||
def create_allocate_lambda(allocate_type: str,
|
def create_allocate_lambda(allocate_type: str,
|
||||||
allocator: NaiveBlockAllocator,
|
allocator: NaiveBlockAllocator,
|
||||||
prev_block: Optional[Block],
|
prev_block: Optional[Block],
|
||||||
token_ids: List[int]):
|
token_ids: list[int]):
|
||||||
if allocate_type == "immutable":
|
if allocate_type == "immutable":
|
||||||
allocate_block = lambda: allocator.allocate_immutable_block(
|
allocate_block = lambda: allocator.allocate_immutable_block(
|
||||||
prev_block=prev_block, token_ids=token_ids)
|
prev_block=prev_block, token_ids=token_ids)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -123,11 +123,11 @@ class TestPrefixCachingBlock:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_chain(block_size: int,
|
def create_chain(block_size: int,
|
||||||
token_ids: List[int],
|
token_ids: list[int],
|
||||||
num_empty_trailing_blocks=0) -> List[PrefixCachingBlock]:
|
num_empty_trailing_blocks=0) -> list[PrefixCachingBlock]:
|
||||||
"""Helper method which creates a chain of blocks.
|
"""Helper method which creates a chain of blocks.
|
||||||
"""
|
"""
|
||||||
blocks: List[PrefixCachingBlock] = []
|
blocks: list[PrefixCachingBlock] = []
|
||||||
num_blocks = math.ceil(
|
num_blocks = math.ceil(
|
||||||
len(token_ids) / block_size) + num_empty_trailing_blocks
|
len(token_ids) / block_size) + num_empty_trailing_blocks
|
||||||
|
|
||||||
@ -161,7 +161,7 @@ class TestPrefixCachingBlockAllocator:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator,
|
def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator,
|
||||||
prev_block: Optional[Block],
|
prev_block: Optional[Block],
|
||||||
token_ids: List[int]):
|
token_ids: list[int]):
|
||||||
if allocate_type == "immutable":
|
if allocate_type == "immutable":
|
||||||
allocate_block = lambda: allocator.allocate_immutable_block(
|
allocate_block = lambda: allocator.allocate_immutable_block(
|
||||||
prev_block=prev_block, token_ids=token_ids)
|
prev_block=prev_block, token_ids=token_ids)
|
||||||
@ -839,13 +839,13 @@ class TestPrefixCachingBlockAllocator:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def create_immutable_chain(
|
def create_immutable_chain(
|
||||||
block_size: int,
|
block_size: int,
|
||||||
token_ids: List[int],
|
token_ids: list[int],
|
||||||
allocator: PrefixCachingBlockAllocator,
|
allocator: PrefixCachingBlockAllocator,
|
||||||
extra_hash: Optional[int] = None,
|
extra_hash: Optional[int] = None,
|
||||||
) -> List[PrefixCachingBlock]:
|
) -> list[PrefixCachingBlock]:
|
||||||
"""Helper method which creates a chain of blocks.
|
"""Helper method which creates a chain of blocks.
|
||||||
"""
|
"""
|
||||||
blocks: List[Block] = []
|
blocks: list[Block] = []
|
||||||
num_blocks = math.ceil(len(token_ids) / block_size)
|
num_blocks = math.ceil(len(token_ids) / block_size)
|
||||||
|
|
||||||
if num_blocks == 0:
|
if num_blocks == 0:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest # noqa
|
import pytest # noqa
|
||||||
@ -46,7 +45,7 @@ def test_simple():
|
|||||||
cache_config.num_cpu_blocks = 8
|
cache_config.num_cpu_blocks = 8
|
||||||
cache_config.num_gpu_blocks = 8
|
cache_config.num_gpu_blocks = 8
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
running: List[SequenceGroup] = []
|
running: list[SequenceGroup] = []
|
||||||
|
|
||||||
# Add seq groups to scheduler.
|
# Add seq groups to scheduler.
|
||||||
for i in range(num_seq_group):
|
for i in range(num_seq_group):
|
||||||
@ -93,7 +92,7 @@ def test_chunk():
|
|||||||
cache_config.num_cpu_blocks = 32
|
cache_config.num_cpu_blocks = 32
|
||||||
cache_config.num_gpu_blocks = 32
|
cache_config.num_gpu_blocks = 32
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
running: List[SequenceGroup] = []
|
running: list[SequenceGroup] = []
|
||||||
|
|
||||||
# Add seq groups to scheduler.
|
# Add seq groups to scheduler.
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
@ -145,7 +144,7 @@ def test_concurrent_chunking():
|
|||||||
cache_config.num_cpu_blocks = 32
|
cache_config.num_cpu_blocks = 32
|
||||||
cache_config.num_gpu_blocks = 32
|
cache_config.num_gpu_blocks = 32
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
running: List[SequenceGroup] = []
|
running: list[SequenceGroup] = []
|
||||||
|
|
||||||
# Add seq groups to scheduler.
|
# Add seq groups to scheduler.
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
@ -226,8 +225,8 @@ def test_short_prompts_jump_long_prompts_in_queue():
|
|||||||
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
|
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
|
||||||
cache_config.num_gpu_blocks = 3200
|
cache_config.num_gpu_blocks = 3200
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
long_seqs: List[SequenceGroup] = []
|
long_seqs: list[SequenceGroup] = []
|
||||||
short_seqs: List[SequenceGroup] = []
|
short_seqs: list[SequenceGroup] = []
|
||||||
|
|
||||||
# Add 2 large seq groups to scheduler.
|
# Add 2 large seq groups to scheduler.
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
@ -368,7 +367,7 @@ def test_complex():
|
|||||||
cache_config.num_cpu_blocks = 64
|
cache_config.num_cpu_blocks = 64
|
||||||
cache_config.num_gpu_blocks = 64
|
cache_config.num_gpu_blocks = 64
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
running: List[SequenceGroup] = []
|
running: list[SequenceGroup] = []
|
||||||
|
|
||||||
# Add seq groups to scheduler.
|
# Add seq groups to scheduler.
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
@ -439,7 +438,7 @@ def test_maximal_decoding():
|
|||||||
cache_config.num_cpu_blocks = 8
|
cache_config.num_cpu_blocks = 8
|
||||||
cache_config.num_gpu_blocks = 8
|
cache_config.num_gpu_blocks = 8
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
running: List[SequenceGroup] = []
|
running: list[SequenceGroup] = []
|
||||||
|
|
||||||
# Add seq groups to scheduler.
|
# Add seq groups to scheduler.
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
@ -533,7 +532,7 @@ def test_prompt_limit():
|
|||||||
cache_config.num_cpu_blocks = 16
|
cache_config.num_cpu_blocks = 16
|
||||||
cache_config.num_gpu_blocks = 16
|
cache_config.num_gpu_blocks = 16
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
running: List[SequenceGroup] = []
|
running: list[SequenceGroup] = []
|
||||||
|
|
||||||
_, seq_group = create_dummy_prompt("1",
|
_, seq_group = create_dummy_prompt("1",
|
||||||
prompt_length=48,
|
prompt_length=48,
|
||||||
@ -565,7 +564,7 @@ def test_prompt_limit_exceed():
|
|||||||
cache_config.num_cpu_blocks = 16
|
cache_config.num_cpu_blocks = 16
|
||||||
cache_config.num_gpu_blocks = 16
|
cache_config.num_gpu_blocks = 16
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
running: List[SequenceGroup] = []
|
running: list[SequenceGroup] = []
|
||||||
_, seq_group = create_dummy_prompt("2",
|
_, seq_group = create_dummy_prompt("2",
|
||||||
prompt_length=48,
|
prompt_length=48,
|
||||||
block_size=block_size)
|
block_size=block_size)
|
||||||
@ -699,7 +698,7 @@ def test_chunked_prefill_max_seqs():
|
|||||||
cache_config.num_cpu_blocks = 128
|
cache_config.num_cpu_blocks = 128
|
||||||
cache_config.num_gpu_blocks = 128
|
cache_config.num_gpu_blocks = 128
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
running: List[SequenceGroup] = []
|
running: list[SequenceGroup] = []
|
||||||
|
|
||||||
_, seq_group = create_dummy_prompt("1",
|
_, seq_group = create_dummy_prompt("1",
|
||||||
prompt_length=65,
|
prompt_length=65,
|
||||||
@ -758,7 +757,7 @@ def test_prefix_caching():
|
|||||||
cache_config.num_cpu_blocks = 0
|
cache_config.num_cpu_blocks = 0
|
||||||
cache_config.num_gpu_blocks = 32
|
cache_config.num_gpu_blocks = 32
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
running: List[SequenceGroup] = []
|
running: list[SequenceGroup] = []
|
||||||
|
|
||||||
# Add seq groups to scheduler.
|
# Add seq groups to scheduler.
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
@ -800,7 +799,7 @@ def test_prefix_caching_with_concurrent_partial_prefills():
|
|||||||
cache_config.num_cpu_blocks = 0
|
cache_config.num_cpu_blocks = 0
|
||||||
cache_config.num_gpu_blocks = 32
|
cache_config.num_gpu_blocks = 32
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
running: List[SequenceGroup] = []
|
running: list[SequenceGroup] = []
|
||||||
|
|
||||||
# Add seq groups to scheduler.
|
# Add seq groups to scheduler.
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import List, Set, Tuple
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest # noqa
|
import pytest # noqa
|
||||||
@ -57,7 +56,7 @@ def test_scheduler_abort_seq_group():
|
|||||||
|
|
||||||
# Add multiple seq groups to scheduler.
|
# Add multiple seq groups to scheduler.
|
||||||
num_seq_group = 4
|
num_seq_group = 4
|
||||||
request_ids: Set[str] = set()
|
request_ids: set[str] = set()
|
||||||
for i in range(num_seq_group):
|
for i in range(num_seq_group):
|
||||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||||
scheduler.add_seq_group(seq_group)
|
scheduler.add_seq_group(seq_group)
|
||||||
@ -83,7 +82,7 @@ def test_scheduler_schedule_simple():
|
|||||||
cache_config.num_cpu_blocks = 8
|
cache_config.num_cpu_blocks = 8
|
||||||
cache_config.num_gpu_blocks = 8
|
cache_config.num_gpu_blocks = 8
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
running: List[SequenceGroup] = []
|
running: list[SequenceGroup] = []
|
||||||
|
|
||||||
# Add seq groups to scheduler.
|
# Add seq groups to scheduler.
|
||||||
for i in range(num_seq_group):
|
for i in range(num_seq_group):
|
||||||
@ -221,7 +220,7 @@ def test_scheduler_max_seqs():
|
|||||||
cache_config.num_gpu_blocks = 8
|
cache_config.num_gpu_blocks = 8
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
|
|
||||||
all_seq_groups: List[SequenceGroup] = []
|
all_seq_groups: list[SequenceGroup] = []
|
||||||
# Add seq groups to scheduler.
|
# Add seq groups to scheduler.
|
||||||
for i in range(num_seq_group):
|
for i in range(num_seq_group):
|
||||||
_, seq_group = create_dummy_prompt(str(i),
|
_, seq_group = create_dummy_prompt(str(i),
|
||||||
@ -480,7 +479,7 @@ def test_prefill_schedule_max_lora():
|
|||||||
num_cpu_blocks=64,
|
num_cpu_blocks=64,
|
||||||
num_gpu_blocks=64)
|
num_gpu_blocks=64)
|
||||||
budget = create_token_budget(token_budget=120)
|
budget = create_token_budget(token_budget=120)
|
||||||
curr_loras: Set[int] = set()
|
curr_loras: set[int] = set()
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
_, seq_group = create_dummy_prompt(str(i),
|
_, seq_group = create_dummy_prompt(str(i),
|
||||||
prompt_length=60,
|
prompt_length=60,
|
||||||
@ -651,8 +650,8 @@ def test_schedule_swapped_max_loras():
|
|||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_cpu_blocks=32,
|
num_cpu_blocks=32,
|
||||||
num_gpu_blocks=32)
|
num_gpu_blocks=32)
|
||||||
curr_loras: Set[int] = set()
|
curr_loras: set[int] = set()
|
||||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
blocks_to_swap_out: list[tuple[int, int]] = []
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
_, seq_group = create_dummy_prompt(str(i),
|
_, seq_group = create_dummy_prompt(str(i),
|
||||||
prompt_length=60,
|
prompt_length=60,
|
||||||
@ -683,7 +682,7 @@ def test_schedule_swapped_cannot_swap_in():
|
|||||||
num_cpu_blocks=32,
|
num_cpu_blocks=32,
|
||||||
num_gpu_blocks=32)
|
num_gpu_blocks=32)
|
||||||
curr_loras = None
|
curr_loras = None
|
||||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
blocks_to_swap_out: list[tuple[int, int]] = []
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
_, seq_group = create_dummy_prompt(str(i),
|
_, seq_group = create_dummy_prompt(str(i),
|
||||||
prompt_length=60,
|
prompt_length=60,
|
||||||
@ -714,7 +713,7 @@ def test_infeasible_swap():
|
|||||||
num_cpu_blocks=32,
|
num_cpu_blocks=32,
|
||||||
num_gpu_blocks=32)
|
num_gpu_blocks=32)
|
||||||
curr_loras = None
|
curr_loras = None
|
||||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
blocks_to_swap_out: list[tuple[int, int]] = []
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
_, seq_group = create_dummy_prompt(str(i),
|
_, seq_group = create_dummy_prompt(str(i),
|
||||||
prompt_length=60,
|
prompt_length=60,
|
||||||
@ -752,7 +751,7 @@ def test_schedule_swapped_blocks_to_copy():
|
|||||||
block_size=block_size)
|
block_size=block_size)
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
append_new_token_seq_group(60, seq_group, 1)
|
append_new_token_seq_group(60, seq_group, 1)
|
||||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
blocks_to_swap_out: list[tuple[int, int]] = []
|
||||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||||
scheduler._add_seq_group_to_swapped(seq_group)
|
scheduler._add_seq_group_to_swapped(seq_group)
|
||||||
|
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest # noqa
|
import pytest # noqa
|
||||||
|
|
||||||
from vllm.config import CacheConfig, SchedulerConfig
|
from vllm.config import CacheConfig, SchedulerConfig
|
||||||
@ -48,7 +46,7 @@ def test_scheduler_schedule_simple_encoder_decoder():
|
|||||||
cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group
|
cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group
|
||||||
cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group
|
cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group
|
||||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||||
running: List[SequenceGroup] = []
|
running: list[SequenceGroup] = []
|
||||||
|
|
||||||
# Add seq groups to scheduler.
|
# Add seq groups to scheduler.
|
||||||
req_id_list = []
|
req_id_list = []
|
||||||
|
@ -2,9 +2,8 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, List, Optional
|
from collections.abc import Sequence as GenericSequence
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Any, Optional
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||||
@ -20,10 +19,10 @@ def create_dummy_prompt(
|
|||||||
block_size: Optional[int] = None,
|
block_size: Optional[int] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
best_of: int = 1,
|
best_of: int = 1,
|
||||||
prompt_tokens: Optional[List[int]] = None,
|
prompt_tokens: Optional[list[int]] = None,
|
||||||
min_tokens: int = 0,
|
min_tokens: int = 0,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
) -> Tuple[Sequence, SequenceGroup]:
|
) -> tuple[Sequence, SequenceGroup]:
|
||||||
if not block_size:
|
if not block_size:
|
||||||
block_size = prompt_length
|
block_size = prompt_length
|
||||||
|
|
||||||
@ -48,7 +47,7 @@ def create_dummy_prompt(
|
|||||||
return prompt, seq_group
|
return prompt, seq_group
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_lora_sequence(request_id: int, token_ids: List[int],
|
def create_dummy_lora_sequence(request_id: int, token_ids: list[int],
|
||||||
block_size: int, lora_int_id: int) -> Sequence:
|
block_size: int, lora_int_id: int) -> Sequence:
|
||||||
return Sequence(seq_id=request_id,
|
return Sequence(seq_id=request_id,
|
||||||
inputs=token_inputs(token_ids),
|
inputs=token_inputs(token_ids),
|
||||||
@ -58,7 +57,7 @@ def create_dummy_lora_sequence(request_id: int, token_ids: List[int],
|
|||||||
lora_int_id=lora_int_id))
|
lora_int_id=lora_int_id))
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_sequence(request_id: int, token_ids: List[int],
|
def create_dummy_sequence(request_id: int, token_ids: list[int],
|
||||||
block_size: int) -> Sequence:
|
block_size: int) -> Sequence:
|
||||||
return Sequence(
|
return Sequence(
|
||||||
seq_id=request_id,
|
seq_id=request_id,
|
||||||
@ -74,7 +73,7 @@ def create_dummy_prompt_encoder_decoder(
|
|||||||
block_size: Optional[int] = None,
|
block_size: Optional[int] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
best_of: int = 1,
|
best_of: int = 1,
|
||||||
) -> Tuple[Sequence, Sequence, SequenceGroup]:
|
) -> tuple[Sequence, Sequence, SequenceGroup]:
|
||||||
if not block_size:
|
if not block_size:
|
||||||
block_size = decoder_prompt_length
|
block_size = decoder_prompt_length
|
||||||
|
|
||||||
@ -125,7 +124,7 @@ def create_seq_group(
|
|||||||
|
|
||||||
prompt_token_ids = [0] * seq_prompt_len
|
prompt_token_ids = [0] * seq_prompt_len
|
||||||
|
|
||||||
seqs: List[Sequence] = []
|
seqs: list[Sequence] = []
|
||||||
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
||||||
seq = Sequence(
|
seq = Sequence(
|
||||||
seq_id=seq_id_start + seq_id_offset,
|
seq_id=seq_id_start + seq_id_offset,
|
||||||
@ -241,7 +240,7 @@ class SchedulerProxy:
|
|||||||
|
|
||||||
def __init__(self, scheduler: Scheduler):
|
def __init__(self, scheduler: Scheduler):
|
||||||
self.scheduler_ = scheduler
|
self.scheduler_ = scheduler
|
||||||
self.call_history: Dict[str, List[Any]] = defaultdict(list)
|
self.call_history: dict[str, list[Any]] = defaultdict(list)
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> Any:
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
|
||||||
@ -253,6 +252,6 @@ class SchedulerProxy:
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def last_schedule_ret(
|
def last_schedule_ret(
|
||||||
self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Any]:
|
self, ) -> tuple[list[SequenceGroupMetadata], SchedulerOutputs, Any]:
|
||||||
_, _, ret = self.call_history["schedule"][-1]
|
_, _, ret = self.call_history["schedule"][-1]
|
||||||
return ret
|
return ret
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Literal, NamedTuple, Optional
|
from typing import Literal, NamedTuple, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -28,8 +28,8 @@ class EPTestOptions(NamedTuple):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EPTestSettings:
|
class EPTestSettings:
|
||||||
parallel_setups: List[ParallelSetup]
|
parallel_setups: list[ParallelSetup]
|
||||||
distributed_backends: List[str]
|
distributed_backends: list[str]
|
||||||
task: TaskOption
|
task: TaskOption
|
||||||
test_options: EPTestOptions
|
test_options: EPTestOptions
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Literal, NamedTuple, Optional
|
from typing import Literal, NamedTuple, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -38,14 +38,14 @@ class PPTestOptions(NamedTuple):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PPTestSettings:
|
class PPTestSettings:
|
||||||
parallel_setups: List[ParallelSetup]
|
parallel_setups: list[ParallelSetup]
|
||||||
# NOTE: the length of distributed_backends and
|
# NOTE: the length of distributed_backends and
|
||||||
# vllm_major_versions should be the same, and they
|
# vllm_major_versions should be the same, and they
|
||||||
# are first zipped together to iterate over all
|
# are first zipped together to iterate over all
|
||||||
# test settings.
|
# test settings.
|
||||||
distributed_backends: List[str]
|
distributed_backends: list[str]
|
||||||
# vllm major version: "0" for V0, "1" for V1
|
# vllm major version: "0" for V0, "1" for V1
|
||||||
vllm_major_versions: List[str]
|
vllm_major_versions: list[str]
|
||||||
task: TaskOption
|
task: TaskOption
|
||||||
test_options: PPTestOptions
|
test_options: PPTestOptions
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -20,9 +19,9 @@ from vllm.utils import update_environment_variables
|
|||||||
|
|
||||||
def distributed_run(fn, world_size):
|
def distributed_run(fn, world_size):
|
||||||
number_of_processes = world_size
|
number_of_processes = world_size
|
||||||
processes: List[multiprocessing.Process] = []
|
processes: list[multiprocessing.Process] = []
|
||||||
for i in range(number_of_processes):
|
for i in range(number_of_processes):
|
||||||
env: Dict[str, str] = {}
|
env: dict[str, str] = {}
|
||||||
env['RANK'] = str(i)
|
env['RANK'] = str(i)
|
||||||
env['LOCAL_RANK'] = str(i)
|
env['LOCAL_RANK'] = str(i)
|
||||||
env['WORLD_SIZE'] = str(number_of_processes)
|
env['WORLD_SIZE'] = str(number_of_processes)
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -13,7 +12,7 @@ from vllm.distributed.utils import StatelessProcessGroup
|
|||||||
from vllm.utils import get_ip, get_open_port, update_environment_variables
|
from vllm.utils import get_ip, get_open_port, update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
|
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
sizes = np.random.randint(1, 10_000, n)
|
sizes = np.random.randint(1, 10_000, n)
|
||||||
# on average, each array will have 5k elements
|
# on average, each array will have 5k elements
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
Run `pytest tests/encoder_decoder/test_e2e_correctness.py`.
|
Run `pytest tests/encoder_decoder/test_e2e_correctness.py`.
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoModelForSeq2SeqLM
|
from transformers import AutoModelForSeq2SeqLM
|
||||||
@ -22,7 +22,7 @@ LIST_ENC_DEC_SUPPORTED_BACKENDS = [
|
|||||||
|
|
||||||
|
|
||||||
def vllm_to_hf_output(
|
def vllm_to_hf_output(
|
||||||
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
|
vllm_output: tuple[list[int], str, Optional[SampleLogprobs]],
|
||||||
decoder_prompt_type: DecoderPromptType,
|
decoder_prompt_type: DecoderPromptType,
|
||||||
):
|
):
|
||||||
"""Sanitize vllm output to be comparable with hf output."""
|
"""Sanitize vllm output to be comparable with hf output."""
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -22,8 +22,8 @@ class CustomUniExecutor(UniProcExecutor):
|
|||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable],
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[dict] = None) -> list[Any]:
|
||||||
# Drop marker to show that this was ran
|
# Drop marker to show that this was ran
|
||||||
with open(".marker", "w"):
|
with open(".marker", "w"):
|
||||||
...
|
...
|
||||||
|
@ -4,7 +4,7 @@ import asyncio
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Any, List, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ from vllm.worker.worker_base import WorkerWrapperBase
|
|||||||
class DummyWorkerWrapper(WorkerWrapperBase):
|
class DummyWorkerWrapper(WorkerWrapperBase):
|
||||||
"""Dummy version of vllm.worker.worker.Worker"""
|
"""Dummy version of vllm.worker.worker.Worker"""
|
||||||
|
|
||||||
def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
|
def worker_method(self, worker_input: Any) -> tuple[int, Any]:
|
||||||
sleep(0.05)
|
sleep(0.05)
|
||||||
|
|
||||||
if isinstance(worker_input, Exception):
|
if isinstance(worker_input, Exception):
|
||||||
@ -27,7 +27,7 @@ class DummyWorkerWrapper(WorkerWrapperBase):
|
|||||||
return self.rpc_rank, input
|
return self.rpc_rank, input
|
||||||
|
|
||||||
|
|
||||||
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
|
def _start_workers() -> tuple[list[ProcessWorkerWrapper], WorkerMonitor]:
|
||||||
result_handler = ResultHandler()
|
result_handler = ResultHandler()
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
workers = [
|
workers = [
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -21,8 +21,8 @@ def vllm_model(vllm_runner):
|
|||||||
def _test_stopping(llm_engine: LLMEngine,
|
def _test_stopping(llm_engine: LLMEngine,
|
||||||
expected_output: str,
|
expected_output: str,
|
||||||
expected_reason: Any,
|
expected_reason: Any,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[list[int]] = None,
|
||||||
include_in_output: bool = False,
|
include_in_output: bool = False,
|
||||||
use_async_output_proc: bool = False) -> None:
|
use_async_output_proc: bool = False) -> None:
|
||||||
llm_engine.add_request(
|
llm_engine.add_request(
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
@ -63,7 +61,7 @@ def test_multi_chat():
|
|||||||
|
|
||||||
@pytest.mark.parametrize("image_urls",
|
@pytest.mark.parametrize("image_urls",
|
||||||
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
|
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
|
||||||
def test_chat_multi_image(image_urls: List[str]):
|
def test_chat_multi_image(image_urls: list[str]):
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="microsoft/Phi-3.5-vision-instruct",
|
model="microsoft/Phi-3.5-vision-instruct",
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import weakref
|
import weakref
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -45,8 +44,8 @@ def llm():
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
def assert_outputs_equal(o1: List[PoolingRequestOutput],
|
def assert_outputs_equal(o1: list[PoolingRequestOutput],
|
||||||
o2: List[PoolingRequestOutput]):
|
o2: list[PoolingRequestOutput]):
|
||||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import weakref
|
import weakref
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -43,7 +42,7 @@ def llm():
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
|
def assert_outputs_equal(o1: list[RequestOutput], o2: list[RequestOutput]):
|
||||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,7 +10,6 @@ import asyncio
|
|||||||
import io
|
import io
|
||||||
import time
|
import time
|
||||||
from statistics import mean, median
|
from statistics import mean, median
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import pytest
|
import pytest
|
||||||
@ -67,7 +66,7 @@ async def process_dataset(model, client, data, concurrent_request):
|
|||||||
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"]
|
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"]
|
||||||
_ = await bound_transcribe(model, sem, client, (audio, sr), "")
|
_ = await bound_transcribe(model, sem, client, (audio, sr), "")
|
||||||
|
|
||||||
tasks: List[asyncio.Task] = []
|
tasks: list[asyncio.Task] = []
|
||||||
for sample in data:
|
for sample in data:
|
||||||
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
@ -180,7 +178,7 @@ def test_reasoning(
|
|||||||
):
|
):
|
||||||
output = tokenizer.tokenize(param_dict["output"])
|
output = tokenizer.tokenize(param_dict["output"])
|
||||||
# decode everything to tokens
|
# decode everything to tokens
|
||||||
output_tokens: List[str] = [
|
output_tokens: list[str] = [
|
||||||
tokenizer.convert_tokens_to_string([token]) for token in output
|
tokenizer.convert_tokens_to_string([token]) for token in output
|
||||||
]
|
]
|
||||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaMessage)
|
DeltaMessage)
|
||||||
@ -33,10 +33,10 @@ class StreamingReasoningReconstructor:
|
|||||||
|
|
||||||
def run_reasoning_extraction(
|
def run_reasoning_extraction(
|
||||||
reasoning_parser: ReasoningParser,
|
reasoning_parser: ReasoningParser,
|
||||||
model_output: List[str],
|
model_output: list[str],
|
||||||
request: Union[ChatCompletionRequest, None] = None,
|
request: Union[ChatCompletionRequest, None] = None,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
) -> Tuple[Optional[str], Optional[str]]:
|
) -> tuple[Optional[str], Optional[str]]:
|
||||||
if streaming:
|
if streaming:
|
||||||
reconstructor = run_reasoning_extraction_streaming(
|
reconstructor = run_reasoning_extraction_streaming(
|
||||||
reasoning_parser,
|
reasoning_parser,
|
||||||
@ -55,9 +55,9 @@ def run_reasoning_extraction(
|
|||||||
|
|
||||||
def run_reasoning_extraction_nonstreaming(
|
def run_reasoning_extraction_nonstreaming(
|
||||||
reasoning_parser: ReasoningParser,
|
reasoning_parser: ReasoningParser,
|
||||||
model_output: List[str],
|
model_output: list[str],
|
||||||
request: Union[ChatCompletionRequest, None] = None,
|
request: Union[ChatCompletionRequest, None] = None,
|
||||||
) -> Tuple[Optional[str], Optional[str]]:
|
) -> tuple[Optional[str], Optional[str]]:
|
||||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||||
return reasoning_parser.extract_reasoning_content(
|
return reasoning_parser.extract_reasoning_content(
|
||||||
model_output=''.join(model_output), request=request)
|
model_output=''.join(model_output), request=request)
|
||||||
@ -65,13 +65,13 @@ def run_reasoning_extraction_nonstreaming(
|
|||||||
|
|
||||||
def run_reasoning_extraction_streaming(
|
def run_reasoning_extraction_streaming(
|
||||||
reasoning_parser: ReasoningParser,
|
reasoning_parser: ReasoningParser,
|
||||||
model_deltas: List[str],
|
model_deltas: list[str],
|
||||||
request: Union[ChatCompletionRequest, None] = None,
|
request: Union[ChatCompletionRequest, None] = None,
|
||||||
) -> StreamingReasoningReconstructor:
|
) -> StreamingReasoningReconstructor:
|
||||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||||
reconstructor = StreamingReasoningReconstructor()
|
reconstructor = StreamingReasoningReconstructor()
|
||||||
previous_text = ""
|
previous_text = ""
|
||||||
previous_tokens: List[int] = []
|
previous_tokens: list[int] = []
|
||||||
for delta in model_deltas:
|
for delta in model_deltas:
|
||||||
token_delta = [
|
token_delta = [
|
||||||
reasoning_parser.vocab.get(token)
|
reasoning_parser.vocab.get(token)
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
@ -41,7 +39,7 @@ async def client(server):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def base64_encoded_audio() -> Dict[str, str]:
|
def base64_encoded_audio() -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
audio_url: encode_audio_base64(*fetch_audio(audio_url))
|
audio_url: encode_audio_base64(*fetch_audio(audio_url))
|
||||||
for audio_url in TEST_AUDIO_URLS
|
for audio_url in TEST_AUDIO_URLS
|
||||||
@ -107,7 +105,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||||
async def test_single_chat_session_audio_base64encoded(
|
async def test_single_chat_session_audio_base64encoded(
|
||||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
||||||
base64_encoded_audio: Dict[str, str]):
|
base64_encoded_audio: dict[str, str]):
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
@ -165,7 +163,7 @@ async def test_single_chat_session_audio_base64encoded(
|
|||||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||||
async def test_single_chat_session_input_audio(
|
async def test_single_chat_session_input_audio(
|
||||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
||||||
base64_encoded_audio: Dict[str, str]):
|
base64_encoded_audio: dict[str, str]):
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
"user",
|
"user",
|
||||||
@ -255,7 +253,7 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
chunks: List[str] = []
|
chunks: list[str] = []
|
||||||
finish_reason_count = 0
|
finish_reason_count = 0
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
delta = chunk.choices[0].delta
|
delta = chunk.choices[0].delta
|
||||||
@ -277,7 +275,7 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||||
async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
||||||
model_name: str, audio_url: str,
|
model_name: str, audio_url: str,
|
||||||
base64_encoded_audio: Dict[str,
|
base64_encoded_audio: dict[str,
|
||||||
str]):
|
str]):
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
@ -315,7 +313,7 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
chunks: List[str] = []
|
chunks: list[str] = []
|
||||||
finish_reason_count = 0
|
finish_reason_count = 0
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
delta = chunk.choices[0].delta
|
delta = chunk.choices[0].delta
|
||||||
@ -337,7 +335,7 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||||
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
|
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
|
||||||
audio_url: str,
|
audio_url: str,
|
||||||
base64_encoded_audio: Dict[str, str]):
|
base64_encoded_audio: dict[str, str]):
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
@ -17,7 +16,7 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope='module')
|
||||||
def server_args(request: pytest.FixtureRequest) -> List[str]:
|
def server_args(request: pytest.FixtureRequest) -> list[str]:
|
||||||
""" Provide extra arguments to the server via indirect parametrization
|
""" Provide extra arguments to the server via indirect parametrization
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
# imports for guided decoding tests
|
# imports for guided decoding tests
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
@ -190,7 +190,7 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
|
|||||||
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
prompt_logprobs: Optional[int]):
|
prompt_logprobs: Optional[int]):
|
||||||
params: Dict = {
|
params: dict = {
|
||||||
"messages": [{
|
"messages": [{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "You are a helpful assistant."
|
"content": "You are a helpful assistant."
|
||||||
@ -232,7 +232,7 @@ async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
|||||||
)
|
)
|
||||||
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||||
model_name: str):
|
model_name: str):
|
||||||
params: Dict = {
|
params: dict = {
|
||||||
"messages": [{
|
"messages": [{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "You are a helpful assistant."
|
"content": "You are a helpful assistant."
|
||||||
@ -343,7 +343,7 @@ async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
chunks: List[str] = []
|
chunks: list[str] = []
|
||||||
finish_reason_count = 0
|
finish_reason_count = 0
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
delta = chunk.choices[0].delta
|
delta = chunk.choices[0].delta
|
||||||
|
@ -5,7 +5,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Dict, List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
@ -287,7 +287,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
|||||||
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
prompt_logprobs: Optional[int]):
|
prompt_logprobs: Optional[int]):
|
||||||
params: Dict = {
|
params: dict = {
|
||||||
"prompt": ["A robot may not injure another robot", "My name is"],
|
"prompt": ["A robot may not injure another robot", "My name is"],
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
}
|
}
|
||||||
@ -331,7 +331,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
|
|||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
stream=True)
|
stream=True)
|
||||||
chunks: List[str] = []
|
chunks: list[str] = []
|
||||||
finish_reason_count = 0
|
finish_reason_count = 0
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
chunks.append(chunk.choices[0].text)
|
chunks.append(chunk.choices[0].text)
|
||||||
@ -364,7 +364,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
n=n,
|
n=n,
|
||||||
stream=True)
|
stream=True)
|
||||||
chunks: List[List[str]] = [[] for i in range(n)]
|
chunks: list[list[str]] = [[] for i in range(n)]
|
||||||
finish_reason_count = 0
|
finish_reason_count = 0
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
index = chunk.choices[0].index
|
index = chunk.choices[0].index
|
||||||
|
@ -86,7 +86,7 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
|
async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
|
||||||
# test List[str]
|
# test list[str]
|
||||||
input_texts = [
|
input_texts = [
|
||||||
"The cat sat on the mat.", "A feline was resting on a rug.",
|
"The cat sat on the mat.", "A feline was resting on a rug.",
|
||||||
"Stars twinkle brightly in the night sky."
|
"Stars twinkle brightly in the night sky."
|
||||||
@ -106,7 +106,7 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
assert embeddings.usage.prompt_tokens == 33
|
assert embeddings.usage.prompt_tokens == 33
|
||||||
assert embeddings.usage.total_tokens == 33
|
assert embeddings.usage.total_tokens == 33
|
||||||
|
|
||||||
# test List[List[int]]
|
# test list[list[int]]
|
||||||
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||||
[25, 32, 64, 77]]
|
[25, 32, 64, 77]]
|
||||||
embedding_response = await client.embeddings.create(
|
embedding_response = await client.embeddings.create(
|
||||||
|
@ -84,7 +84,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||||
# test List[str]
|
# test list[str]
|
||||||
input_texts = [
|
input_texts = [
|
||||||
"The cat sat on the mat.", "A feline was resting on a rug.",
|
"The cat sat on the mat.", "A feline was resting on a rug.",
|
||||||
"Stars twinkle brightly in the night sky."
|
"Stars twinkle brightly in the night sky."
|
||||||
@ -107,7 +107,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
|||||||
assert poolings.usage.prompt_tokens == 25
|
assert poolings.usage.prompt_tokens == 25
|
||||||
assert poolings.usage.total_tokens == 25
|
assert poolings.usage.total_tokens == 25
|
||||||
|
|
||||||
# test List[List[int]]
|
# test list[list[int]]
|
||||||
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||||
[25, 32, 64, 77]]
|
[25, 32, 64, 77]]
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
from typing import Any, List, NamedTuple
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
@ -40,7 +40,7 @@ def server():
|
|||||||
|
|
||||||
class TestCase(NamedTuple):
|
class TestCase(NamedTuple):
|
||||||
model_name: str
|
model_name: str
|
||||||
base_url: List[str]
|
base_url: list[str]
|
||||||
api_key: str
|
api_key: str
|
||||||
expected_error: Any
|
expected_error: Any
|
||||||
|
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
@ -49,7 +47,7 @@ async def client(server):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def base64_encoded_video() -> Dict[str, str]:
|
def base64_encoded_video() -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
video_url: encode_video_base64(fetch_video(video_url))
|
video_url: encode_video_base64(fetch_video(video_url))
|
||||||
for video_url in TEST_VIDEO_URLS
|
for video_url in TEST_VIDEO_URLS
|
||||||
@ -151,7 +149,7 @@ async def test_single_chat_session_video_beamsearch(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||||
async def test_single_chat_session_video_base64encoded(
|
async def test_single_chat_session_video_base64encoded(
|
||||||
client: openai.AsyncOpenAI, model_name: str, video_url: str,
|
client: openai.AsyncOpenAI, model_name: str, video_url: str,
|
||||||
base64_encoded_video: Dict[str, str]):
|
base64_encoded_video: dict[str, str]):
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
@ -209,7 +207,7 @@ async def test_single_chat_session_video_base64encoded(
|
|||||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||||
async def test_single_chat_session_video_base64encoded_beamsearch(
|
async def test_single_chat_session_video_base64encoded_beamsearch(
|
||||||
client: openai.AsyncOpenAI, model_name: str, video_url: str,
|
client: openai.AsyncOpenAI, model_name: str, video_url: str,
|
||||||
base64_encoded_video: Dict[str, str]):
|
base64_encoded_video: dict[str, str]):
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
@ -279,7 +277,7 @@ async def test_chat_streaming_video(client: openai.AsyncOpenAI,
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
chunks: List[str] = []
|
chunks: list[str] = []
|
||||||
finish_reason_count = 0
|
finish_reason_count = 0
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
delta = chunk.choices[0].delta
|
delta = chunk.choices[0].delta
|
||||||
@ -302,7 +300,7 @@ async def test_chat_streaming_video(client: openai.AsyncOpenAI,
|
|||||||
"video_urls",
|
"video_urls",
|
||||||
[TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))])
|
[TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))])
|
||||||
async def test_multi_video_input(client: openai.AsyncOpenAI, model_name: str,
|
async def test_multi_video_input(client: openai.AsyncOpenAI, model_name: str,
|
||||||
video_urls: List[str]):
|
video_urls: list[str]):
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
@ -50,7 +48,7 @@ async def client(server):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def base64_encoded_image() -> Dict[str, str]:
|
def base64_encoded_image() -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
image_url: encode_image_base64(fetch_image(image_url))
|
image_url: encode_image_base64(fetch_image(image_url))
|
||||||
for image_url in TEST_IMAGE_URLS
|
for image_url in TEST_IMAGE_URLS
|
||||||
@ -152,7 +150,7 @@ async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI,
|
|||||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||||
async def test_single_chat_session_image_base64encoded(
|
async def test_single_chat_session_image_base64encoded(
|
||||||
client: openai.AsyncOpenAI, model_name: str, image_url: str,
|
client: openai.AsyncOpenAI, model_name: str, image_url: str,
|
||||||
base64_encoded_image: Dict[str, str]):
|
base64_encoded_image: dict[str, str]):
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
@ -210,7 +208,7 @@ async def test_single_chat_session_image_base64encoded(
|
|||||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||||
async def test_single_chat_session_image_base64encoded_beamsearch(
|
async def test_single_chat_session_image_base64encoded_beamsearch(
|
||||||
client: openai.AsyncOpenAI, model_name: str, image_url: str,
|
client: openai.AsyncOpenAI, model_name: str, image_url: str,
|
||||||
base64_encoded_image: Dict[str, str]):
|
base64_encoded_image: dict[str, str]):
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
@ -280,7 +278,7 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI,
|
|||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
chunks: List[str] = []
|
chunks: list[str] = []
|
||||||
finish_reason_count = 0
|
finish_reason_count = 0
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
delta = chunk.choices[0].delta
|
delta = chunk.choices[0].delta
|
||||||
@ -303,7 +301,7 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI,
|
|||||||
"image_urls",
|
"image_urls",
|
||||||
[TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))])
|
[TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))])
|
||||||
async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
|
async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
|
||||||
image_urls: List[str]):
|
image_urls: list[str]):
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role":
|
"role":
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@ -49,7 +47,7 @@ def server():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def base64_encoded_image() -> Dict[str, str]:
|
def base64_encoded_image() -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
image_url: encode_image_base64(fetch_image(image_url))
|
image_url: encode_image_base64(fetch_image(image_url))
|
||||||
for image_url in TEST_IMAGE_URLS
|
for image_url in TEST_IMAGE_URLS
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -125,7 +124,7 @@ TEST_CASES = [
|
|||||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls",
|
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls",
|
||||||
TEST_CASES)
|
TEST_CASES)
|
||||||
def test_tool_call(streaming: bool, model_output: str,
|
def test_tool_call(streaming: bool, model_output: str,
|
||||||
expected_tool_calls: List[FunctionCall]):
|
expected_tool_calls: list[FunctionCall]):
|
||||||
mock_tokenizer = MagicMock()
|
mock_tokenizer = MagicMock()
|
||||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||||
mock_tokenizer)
|
mock_tokenizer)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Iterable, List, Tuple, Union
|
from collections.abc import Iterable
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaMessage,
|
DeltaMessage,
|
||||||
@ -12,7 +13,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser
|
|||||||
class StreamingToolReconstructor:
|
class StreamingToolReconstructor:
|
||||||
|
|
||||||
def __init__(self, assert_one_tool_per_delta: bool = True):
|
def __init__(self, assert_one_tool_per_delta: bool = True):
|
||||||
self.tool_calls: List[ToolCall] = []
|
self.tool_calls: list[ToolCall] = []
|
||||||
self.other_content: str = ""
|
self.other_content: str = ""
|
||||||
self._assert_one_tool_per_delta = assert_one_tool_per_delta
|
self._assert_one_tool_per_delta = assert_one_tool_per_delta
|
||||||
|
|
||||||
@ -72,7 +73,7 @@ def run_tool_extraction(
|
|||||||
request: Union[ChatCompletionRequest, None] = None,
|
request: Union[ChatCompletionRequest, None] = None,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
assert_one_tool_per_delta: bool = True,
|
assert_one_tool_per_delta: bool = True,
|
||||||
) -> Tuple[Union[str, None], List[ToolCall]]:
|
) -> tuple[Union[str, None], list[ToolCall]]:
|
||||||
if streaming:
|
if streaming:
|
||||||
reconstructor = run_tool_extraction_streaming(
|
reconstructor = run_tool_extraction_streaming(
|
||||||
tool_parser,
|
tool_parser,
|
||||||
@ -106,7 +107,7 @@ def run_tool_extraction_streaming(
|
|||||||
reconstructor = StreamingToolReconstructor(
|
reconstructor = StreamingToolReconstructor(
|
||||||
assert_one_tool_per_delta=assert_one_tool_per_delta)
|
assert_one_tool_per_delta=assert_one_tool_per_delta)
|
||||||
previous_text = ""
|
previous_text = ""
|
||||||
previous_tokens: List[int] = []
|
previous_tokens: list[int] = []
|
||||||
for delta in model_deltas:
|
for delta in model_deltas:
|
||||||
token_delta = [
|
token_delta = [
|
||||||
tool_parser.vocab.get(token)
|
tool_parser.vocab.get(token)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -19,7 +19,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
|||||||
def ref_dynamic_per_token_quant(x: torch.tensor,
|
def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||||
quant_dtype: torch.dtype,
|
quant_dtype: torch.dtype,
|
||||||
scale_ub: Optional[torch.tensor] = None) \
|
scale_ub: Optional[torch.tensor] = None) \
|
||||||
-> Tuple[torch.tensor, torch.tensor]:
|
-> tuple[torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
||||||
if scale_ub is not None:
|
if scale_ub is not None:
|
||||||
@ -68,7 +68,7 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
|||||||
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
|
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
|
||||||
# kernel
|
# kernel
|
||||||
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||||
-> Tuple[torch.tensor, torch.tensor]:
|
-> tuple[torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
fp8_traits = torch.finfo(FP8_DTYPE)
|
fp8_traits = torch.finfo(FP8_DTYPE)
|
||||||
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
|
fp8_traits_max = ROCM_FP8_MAX if current_platform.is_rocm() \
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -86,7 +85,7 @@ def test_act_and_mul(
|
|||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_activation(
|
def test_activation(
|
||||||
activation: Type[torch.nn.Module],
|
activation: type[torch.nn.Module],
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
d: int,
|
d: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -85,8 +85,8 @@ def ref_single_query_cached_kv_attention(
|
|||||||
block_table = block_tables_lst[i]
|
block_table = block_tables_lst[i]
|
||||||
seq_len = int(seq_lens_lst[i])
|
seq_len = int(seq_lens_lst[i])
|
||||||
|
|
||||||
keys_lst: List[torch.Tensor] = []
|
keys_lst: list[torch.Tensor] = []
|
||||||
values_lst: List[torch.Tensor] = []
|
values_lst: list[torch.Tensor] = []
|
||||||
for j in range(seq_len):
|
for j in range(seq_len):
|
||||||
block_number = int(block_table[j // block_size])
|
block_number = int(block_table[j // block_size])
|
||||||
block_offset = j % block_size
|
block_offset = j % block_size
|
||||||
@ -133,7 +133,7 @@ def test_paged_attention(
|
|||||||
kv_cache_factory,
|
kv_cache_factory,
|
||||||
version: str,
|
version: str,
|
||||||
num_seqs: int,
|
num_seqs: int,
|
||||||
num_heads: Tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
use_alibi: bool,
|
use_alibi: bool,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
@ -166,7 +166,7 @@ def test_paged_attention(
|
|||||||
|
|
||||||
# Create the block tables.
|
# Create the block tables.
|
||||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||||
block_tables_lst: List[List[int]] = []
|
block_tables_lst: list[list[int]] = []
|
||||||
for _ in range(num_seqs):
|
for _ in range(num_seqs):
|
||||||
block_table = [
|
block_table = [
|
||||||
random.randint(0, NUM_BLOCKS - 1)
|
random.randint(0, NUM_BLOCKS - 1)
|
||||||
@ -334,7 +334,7 @@ def test_paged_attention(
|
|||||||
|
|
||||||
|
|
||||||
def ref_multi_query_kv_attention(
|
def ref_multi_query_kv_attention(
|
||||||
cu_seq_lens: List[int],
|
cu_seq_lens: list[int],
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
@ -342,7 +342,7 @@ def ref_multi_query_kv_attention(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_seqs = len(cu_seq_lens) - 1
|
num_seqs = len(cu_seq_lens) - 1
|
||||||
ref_outputs: List[torch.Tensor] = []
|
ref_outputs: list[torch.Tensor] = []
|
||||||
for i in range(num_seqs):
|
for i in range(num_seqs):
|
||||||
start_idx = cu_seq_lens[i]
|
start_idx = cu_seq_lens[i]
|
||||||
end_idx = cu_seq_lens[i + 1]
|
end_idx = cu_seq_lens[i + 1]
|
||||||
@ -378,7 +378,7 @@ def ref_multi_query_kv_attention(
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_multi_query_kv_attention(
|
def test_multi_query_kv_attention(
|
||||||
num_seqs: int,
|
num_seqs: int,
|
||||||
num_heads: Tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -87,8 +87,8 @@ def ref_single_query_cached_kv_attention(
|
|||||||
block_table = block_tables_lst[i]
|
block_table = block_tables_lst[i]
|
||||||
seq_len = int(seq_lens_lst[i])
|
seq_len = int(seq_lens_lst[i])
|
||||||
|
|
||||||
keys_lst: List[torch.Tensor] = []
|
keys_lst: list[torch.Tensor] = []
|
||||||
values_lst: List[torch.Tensor] = []
|
values_lst: list[torch.Tensor] = []
|
||||||
for j in range(seq_len):
|
for j in range(seq_len):
|
||||||
block_number = int(block_table[j // block_size])
|
block_number = int(block_table[j // block_size])
|
||||||
block_offset = j % block_size
|
block_offset = j % block_size
|
||||||
@ -162,7 +162,7 @@ def test_paged_attention(
|
|||||||
kv_cache_factory,
|
kv_cache_factory,
|
||||||
version: str,
|
version: str,
|
||||||
num_seqs: int,
|
num_seqs: int,
|
||||||
num_heads: Tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
use_alibi: bool,
|
use_alibi: bool,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
@ -331,7 +331,7 @@ def test_paged_attention(
|
|||||||
|
|
||||||
|
|
||||||
def ref_multi_query_kv_attention(
|
def ref_multi_query_kv_attention(
|
||||||
cu_seq_lens: List[int],
|
cu_seq_lens: list[int],
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
@ -376,7 +376,7 @@ def ref_multi_query_kv_attention(
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_varlen_blocksparse_attention_prefill(
|
def test_varlen_blocksparse_attention_prefill(
|
||||||
num_seqs: int,
|
num_seqs: int,
|
||||||
num_heads: Tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
blocksparse_local_blocks: int,
|
blocksparse_local_blocks: int,
|
||||||
blocksparse_vert_stride: int,
|
blocksparse_vert_stride: int,
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -74,7 +73,7 @@ def test_copy_blocks(
|
|||||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||||
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||||
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
||||||
block_mapping: List[Tuple[int, int]] = []
|
block_mapping: list[tuple[int, int]] = []
|
||||||
for i in range(num_mappings):
|
for i in range(num_mappings):
|
||||||
src = src_blocks[i]
|
src = src_blocks[i]
|
||||||
dst1 = dst_blocks[2 * i]
|
dst1 = dst_blocks[2 * i]
|
||||||
@ -342,7 +341,7 @@ def test_reshape_and_cache_flash(
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_swap_blocks(
|
def test_swap_blocks(
|
||||||
kv_cache_factory,
|
kv_cache_factory,
|
||||||
direction: Tuple[str, str],
|
direction: tuple[str, str],
|
||||||
num_mappings: int,
|
num_mappings: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -25,7 +25,7 @@ DTYPES = [torch.float16, torch.bfloat16]
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_merge_kernel(
|
def test_merge_kernel(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_heads: Tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
@ -85,8 +85,8 @@ CASES = [
|
|||||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_cascade(
|
def test_cascade(
|
||||||
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
|
seq_lens_and_common_prefix: tuple[list[tuple[int, int]], int],
|
||||||
num_heads: Tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
Run `pytest tests/kernels/test_cutlass.py`.
|
Run `pytest tests/kernels/test_cutlass.py`.
|
||||||
"""
|
"""
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -71,7 +70,7 @@ def cutlass_fp8_gemm_helper(m: int,
|
|||||||
a_scale_group_shape: tuple,
|
a_scale_group_shape: tuple,
|
||||||
b_scale_group_shape: tuple,
|
b_scale_group_shape: tuple,
|
||||||
use_bias: bool,
|
use_bias: bool,
|
||||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||||
device: str = "cuda"):
|
device: str = "cuda"):
|
||||||
# Test for a cutlass kernel with per-token activation quantization
|
# Test for a cutlass kernel with per-token activation quantization
|
||||||
# and per-output channel weight quantization.
|
# and per-output channel weight quantization.
|
||||||
@ -109,7 +108,7 @@ def cutlass_int8_gemm_helper(m: int,
|
|||||||
a_scale_group_shape: tuple,
|
a_scale_group_shape: tuple,
|
||||||
b_scale_group_shape: tuple,
|
b_scale_group_shape: tuple,
|
||||||
use_bias: bool,
|
use_bias: bool,
|
||||||
out_dtype: Type[torch.dtype] = torch.bfloat16,
|
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||||
device: str = "cuda"):
|
device: str = "cuda"):
|
||||||
# Test for a cutlass kernel with per-token activation quantization
|
# Test for a cutlass kernel with per-token activation quantization
|
||||||
# and per-output channel weight quantization.
|
# and per-output channel weight quantization.
|
||||||
@ -187,7 +186,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
|
|||||||
@pytest.mark.parametrize("use_bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
|
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
|
||||||
b_scale_group_shape,
|
b_scale_group_shape,
|
||||||
out_dtype: Type[torch.dtype],
|
out_dtype: type[torch.dtype],
|
||||||
use_bias: bool):
|
use_bias: bool):
|
||||||
cutlass_int8_gemm_helper(512,
|
cutlass_int8_gemm_helper(512,
|
||||||
512,
|
512,
|
||||||
@ -208,7 +207,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
|
|||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
|
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
|
||||||
b_scale_group_shape,
|
b_scale_group_shape,
|
||||||
out_dtype: Type[torch.dtype],
|
out_dtype: type[torch.dtype],
|
||||||
use_bias: bool):
|
use_bias: bool):
|
||||||
cutlass_fp8_gemm_helper(512,
|
cutlass_fp8_gemm_helper(512,
|
||||||
512,
|
512,
|
||||||
@ -227,7 +226,7 @@ def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
|
|||||||
reason="FP8 blockwise is not supported on this GPU type.")
|
reason="FP8 blockwise is not supported on this GPU type.")
|
||||||
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
|
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
|
||||||
b_scale_group_shape,
|
b_scale_group_shape,
|
||||||
out_dtype: Type[torch.dtype],
|
out_dtype: type[torch.dtype],
|
||||||
use_bias: bool):
|
use_bias: bool):
|
||||||
cutlass_fp8_gemm_helper(512,
|
cutlass_fp8_gemm_helper(512,
|
||||||
512,
|
512,
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
Run `pytest tests/kernels/test_semi_structured.py`.
|
Run `pytest tests/kernels/test_semi_structured.py`.
|
||||||
"""
|
"""
|
||||||
from typing import Tuple, Type
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
|
|||||||
|
|
||||||
def make_rand_sparse_tensors(
|
def make_rand_sparse_tensors(
|
||||||
dtype: torch.dtype, m: int, n: int, k: int
|
dtype: torch.dtype, m: int, n: int, k: int
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
a = torch.randn((m, k), device='cuda')
|
a = torch.randn((m, k), device='cuda')
|
||||||
b = torch.randn((n, k), device='cuda').t()
|
b = torch.randn((n, k), device='cuda').t()
|
||||||
|
|
||||||
@ -167,7 +166,7 @@ MNK_FACTORS = [
|
|||||||
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
|
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.parametrize("use_bias", [True, False])
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype],
|
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype],
|
||||||
use_bias: bool):
|
use_bias: bool):
|
||||||
|
|
||||||
# Create tensors
|
# Create tensors
|
||||||
|
@ -243,7 +243,7 @@ def _decoder_attn_setup(
|
|||||||
test_pt: TestPoint,
|
test_pt: TestPoint,
|
||||||
test_rsrcs: TestResources,
|
test_rsrcs: TestResources,
|
||||||
block_base_addr: int = 0,
|
block_base_addr: int = 0,
|
||||||
) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
|
) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
|
||||||
'''
|
'''
|
||||||
Set up test vectors & data structures for self-attention test.
|
Set up test vectors & data structures for self-attention test.
|
||||||
|
|
||||||
@ -421,7 +421,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
|
|||||||
test_pt: TestPoint,
|
test_pt: TestPoint,
|
||||||
test_rsrcs: TestResources,
|
test_rsrcs: TestResources,
|
||||||
block_base_addr: int = 0,
|
block_base_addr: int = 0,
|
||||||
) -> Tuple[PhaseTestParameters, PhaseTestParameters]:
|
) -> tuple[PhaseTestParameters, PhaseTestParameters]:
|
||||||
'''
|
'''
|
||||||
Set up test vectors & data structures for cross-attention test.
|
Set up test vectors & data structures for cross-attention test.
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -24,8 +24,8 @@ def ref_paged_attn(
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
query_lens: List[int],
|
query_lens: list[int],
|
||||||
kv_lens: List[int],
|
kv_lens: list[int],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
@ -35,7 +35,7 @@ def ref_paged_attn(
|
|||||||
block_tables = block_tables.cpu().numpy()
|
block_tables = block_tables.cpu().numpy()
|
||||||
_, block_size, num_kv_heads, head_size = key_cache.shape
|
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||||
|
|
||||||
outputs: List[torch.Tensor] = []
|
outputs: list[torch.Tensor] = []
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
for i in range(num_seqs):
|
for i in range(num_seqs):
|
||||||
query_len = query_lens[i]
|
query_len = query_lens[i]
|
||||||
@ -88,8 +88,8 @@ def ref_paged_attn(
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_flash_attn_with_paged_kv(
|
def test_flash_attn_with_paged_kv(
|
||||||
use_out: bool,
|
use_out: bool,
|
||||||
kv_lens: List[int],
|
kv_lens: list[int],
|
||||||
num_heads: Tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
@ -174,8 +174,8 @@ def test_flash_attn_with_paged_kv(
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_varlen_with_paged_kv(
|
def test_varlen_with_paged_kv(
|
||||||
use_out: bool,
|
use_out: bool,
|
||||||
seq_lens: List[Tuple[int, int]],
|
seq_lens: list[tuple[int, int]],
|
||||||
num_heads: Tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import flashinfer
|
import flashinfer
|
||||||
import pytest
|
import pytest
|
||||||
@ -19,8 +19,8 @@ def ref_paged_attn(
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
query_lens: List[int],
|
query_lens: list[int],
|
||||||
kv_lens: List[int],
|
kv_lens: list[int],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
@ -30,7 +30,7 @@ def ref_paged_attn(
|
|||||||
block_tables = block_tables.cpu().numpy()
|
block_tables = block_tables.cpu().numpy()
|
||||||
_, block_size, num_kv_heads, head_size = key_cache.shape
|
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||||
|
|
||||||
outputs: List[torch.Tensor] = []
|
outputs: list[torch.Tensor] = []
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
for i in range(num_seqs):
|
for i in range(num_seqs):
|
||||||
query_len = query_lens[i]
|
query_len = query_lens[i]
|
||||||
@ -78,8 +78,8 @@ def ref_paged_attn(
|
|||||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def test_flashinfer_decode_with_paged_kv(
|
def test_flashinfer_decode_with_paged_kv(
|
||||||
kv_lens: List[int],
|
kv_lens: list[int],
|
||||||
num_heads: Tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
@ -168,8 +168,8 @@ def test_flashinfer_decode_with_paged_kv(
|
|||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
|
||||||
num_heads: Tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
head_size: int, dtype: torch.dtype,
|
head_size: int, dtype: torch.dtype,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
soft_cap: Optional[float]) -> None:
|
soft_cap: Optional[float]) -> None:
|
||||||
@ -270,7 +270,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
|||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||||
def test_flashinfer_prefill_with_paged_fp8_kv(
|
def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||||
seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int],
|
seq_lens: list[tuple[int, int]], num_heads: tuple[int, int],
|
||||||
head_size: int, dtype: torch.dtype, block_size: int,
|
head_size: int, dtype: torch.dtype, block_size: int,
|
||||||
soft_cap: Optional[float]) -> None:
|
soft_cap: Optional[float]) -> None:
|
||||||
pytest.skip("TODO: fix the accuracy issue")
|
pytest.skip("TODO: fix the accuracy issue")
|
||||||
@ -378,8 +378,8 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
|||||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def test_flashinfer_decode_with_paged_fp8_kv(
|
def test_flashinfer_decode_with_paged_fp8_kv(
|
||||||
kv_lens: List[int],
|
kv_lens: list[int],
|
||||||
num_heads: Tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -39,7 +39,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
|||||||
def ref_rms_norm(rms_norm_layer: RMSNorm,
|
def ref_rms_norm(rms_norm_layer: RMSNorm,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor]) \
|
residual: Optional[torch.Tensor]) \
|
||||||
-> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
-> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
residual = residual.clone()
|
residual = residual.clone()
|
||||||
out, residual = rms_norm_layer.forward_native(x, residual)
|
out, residual = rms_norm_layer.forward_native(x, residual)
|
||||||
@ -54,7 +54,7 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
|
|||||||
quant_dtype: torch.dtype,
|
quant_dtype: torch.dtype,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
scale_ub: Optional[torch.Tensor]) \
|
scale_ub: Optional[torch.Tensor]) \
|
||||||
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
if scale_ub is not None:
|
if scale_ub is not None:
|
||||||
assert quant_dtype == torch.float8_e4m3fn
|
assert quant_dtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ def ref_impl(rms_norm_layer: RMSNorm,
|
|||||||
quant_dtype: torch.dtype,
|
quant_dtype: torch.dtype,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
scale_ub: Optional[torch.Tensor]) \
|
scale_ub: Optional[torch.Tensor]) \
|
||||||
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype,
|
return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype,
|
||||||
residual, scale_ub)
|
residual, scale_ub)
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor,
|
|||||||
quant_dtype: torch.dtype,
|
quant_dtype: torch.dtype,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
scale_ub: Optional[torch.Tensor]) \
|
scale_ub: Optional[torch.Tensor]) \
|
||||||
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
residual = residual.clone()
|
residual = residual.clone()
|
||||||
out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS,
|
out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS,
|
||||||
@ -102,7 +102,7 @@ def ops_impl(weight: torch.Tensor,
|
|||||||
quant_dtype: torch.dtype,
|
quant_dtype: torch.dtype,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
scale_ub: Optional[torch.Tensor]) \
|
scale_ub: Optional[torch.Tensor]) \
|
||||||
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
-> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual,
|
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual,
|
||||||
scale_ub)
|
scale_ub)
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -16,7 +15,7 @@ GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
|
|||||||
|
|
||||||
def get_gguf_sample_tensors(
|
def get_gguf_sample_tensors(
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
quant_type: GGMLQuantizationType) -> List[ReaderTensor]:
|
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
|
||||||
sample_dir = GGUF_SAMPLE
|
sample_dir = GGUF_SAMPLE
|
||||||
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
|
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
|
||||||
sample_file = Path(sample_dir) / filename
|
sample_file = Path(sample_dir) / filename
|
||||||
|
@ -6,7 +6,7 @@ Run `pytest tests/kernels/test_machete_mm.py`.
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -45,7 +45,7 @@ MNK_SHAPES = [
|
|||||||
(1024, 8192, 4096),
|
(1024, 8192, 4096),
|
||||||
]
|
]
|
||||||
|
|
||||||
GROUP_SIZES_TO_TEST: List[Optional[int]] = [128, -1]
|
GROUP_SIZES_TO_TEST: list[Optional[int]] = [128, -1]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -75,7 +75,7 @@ class Tensors:
|
|||||||
# Ch Scales Type, Tok Scales Type)
|
# Ch Scales Type, Tok Scales Type)
|
||||||
# NOTE: None "Scale Type" means the act type is floating point
|
# NOTE: None "Scale Type" means the act type is floating point
|
||||||
# None "Output Type" means the output type is the same as the act type
|
# None "Output Type" means the output type is the same as the act type
|
||||||
TestTypeTuple = Tuple[List[torch.dtype], ScalarType, Optional[torch.dtype],
|
TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
|
||||||
Optional[torch.dtype], bool]
|
Optional[torch.dtype], bool]
|
||||||
TEST_TYPES = [
|
TEST_TYPES = [
|
||||||
# GPTQ style
|
# GPTQ style
|
||||||
@ -136,7 +136,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
|
|||||||
return zps if zps is None else -1 * s * (zps.to(s.dtype))
|
return zps if zps is None else -1 * s * (zps.to(s.dtype))
|
||||||
|
|
||||||
|
|
||||||
def group_size_valid(shape: Tuple[int, int, int],
|
def group_size_valid(shape: tuple[int, int, int],
|
||||||
group_size: Optional[int]) -> bool:
|
group_size: Optional[int]) -> bool:
|
||||||
return group_size is None or group_size == -1 or group_size % shape[2] == 0
|
return group_size is None or group_size == -1 or group_size % shape[2] == 0
|
||||||
|
|
||||||
@ -166,7 +166,7 @@ def machete_quantize_and_pack(atype: torch.dtype,
|
|||||||
return w_ref, w_q_machete, w_s, w_zp
|
return w_ref, w_q_machete, w_s, w_zp
|
||||||
|
|
||||||
|
|
||||||
def create_test_tensors(shape: Tuple[int, int, int],
|
def create_test_tensors(shape: tuple[int, int, int],
|
||||||
types: TypeConfig,
|
types: TypeConfig,
|
||||||
group_size: Optional[int],
|
group_size: Optional[int],
|
||||||
subset_stride_factor: Optional[int] = None) -> Tensors:
|
subset_stride_factor: Optional[int] = None) -> Tensors:
|
||||||
@ -265,7 +265,7 @@ def machete_mm_test_helper(types: TypeConfig,
|
|||||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||||
def test_machete_all_schedules(shape, types: TypeConfig):
|
def test_machete_all_schedules(shape, types: TypeConfig):
|
||||||
|
|
||||||
group_sizes: List[Optional[int]] = []
|
group_sizes: list[Optional[int]] = []
|
||||||
if types.group_scale_type is None:
|
if types.group_scale_type is None:
|
||||||
group_sizes = [None]
|
group_sizes = [None]
|
||||||
else:
|
else:
|
||||||
@ -294,7 +294,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
|
|||||||
ids=lambda x: "x".join(str(v) for v in x))
|
ids=lambda x: "x".join(str(v) for v in x))
|
||||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||||
def test_machete_heuristic(shape, types: TypeConfig):
|
def test_machete_heuristic(shape, types: TypeConfig):
|
||||||
group_sizes: List[Optional[int]] = []
|
group_sizes: list[Optional[int]] = []
|
||||||
if types.group_scale_type is None:
|
if types.group_scale_type is None:
|
||||||
group_sizes = [None]
|
group_sizes = [None]
|
||||||
else:
|
else:
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -29,7 +28,7 @@ from vllm.utils import update_environment_variables
|
|||||||
def test_mixer2_gated_norm_multi_gpu(
|
def test_mixer2_gated_norm_multi_gpu(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
hidden_size_n_groups: Tuple[int, int],
|
hidden_size_n_groups: tuple[int, int],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: str = 'cuda',
|
device: str = 'cuda',
|
||||||
):
|
):
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Dict, Tuple
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -134,7 +132,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
|
|||||||
# given a tuple of lengths for each example in the batch
|
# given a tuple of lengths for each example in the batch
|
||||||
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
|
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
|
||||||
# 4 examples from second eg, etc
|
# 4 examples from second eg, etc
|
||||||
def get_continuous_batch(example_lens: Tuple[int, ...]):
|
def get_continuous_batch(example_lens: tuple[int, ...]):
|
||||||
|
|
||||||
indices = []
|
indices = []
|
||||||
for i, x in enumerate(example_lens):
|
for i, x in enumerate(example_lens):
|
||||||
@ -264,8 +262,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
|||||||
|
|
||||||
# hold state during the cutting process so we know if an
|
# hold state during the cutting process so we know if an
|
||||||
# example has been exhausted and needs to cycle
|
# example has been exhausted and needs to cycle
|
||||||
last_taken: Dict = {} # map: eg -> pointer to last taken sample
|
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
||||||
exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted
|
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||||
|
|
||||||
states = None
|
states = None
|
||||||
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
|
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from itertools import accumulate, product
|
from itertools import accumulate, product
|
||||||
from typing import Callable, Dict, List, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora(
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
if rotary_dim is None:
|
if rotary_dim is None:
|
||||||
rotary_dim = head_size
|
rotary_dim = head_size
|
||||||
scaling_factors: List[int] = [1, 2, 4]
|
scaling_factors: list[int] = [1, 2, 4]
|
||||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
||||||
"rope_type": "linear",
|
"rope_type": "linear",
|
||||||
"factor": tuple(scaling_factors)
|
"factor": tuple(scaling_factors)
|
||||||
@ -234,7 +234,7 @@ def test_rope_module_cache():
|
|||||||
})
|
})
|
||||||
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
|
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
|
||||||
ROPE_SCALINGS, DTYPES)
|
ROPE_SCALINGS, DTYPES)
|
||||||
rope_setting_id_map: Dict[str, int] = {}
|
rope_setting_id_map: dict[str, int] = {}
|
||||||
for setting in product(*settings):
|
for setting in product(*settings):
|
||||||
head_size, rotary_dim, max_position, base, \
|
head_size, rotary_dim, max_position, base, \
|
||||||
is_neox_stype, rope_scaling, dtype = setting
|
is_neox_stype, rope_scaling, dtype = setting
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
Run `pytest tests/kernels/test_triton_scaled_mm.py`.
|
Run `pytest tests/kernels/test_triton_scaled_mm.py`.
|
||||||
"""
|
"""
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Optional, Type
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -18,7 +18,7 @@ def scaled_mm_torch(a: torch.Tensor,
|
|||||||
b: torch.Tensor,
|
b: torch.Tensor,
|
||||||
scale_a: torch.Tensor,
|
scale_a: torch.Tensor,
|
||||||
scale_b: torch.Tensor,
|
scale_b: torch.Tensor,
|
||||||
out_dtype: Type[torch.dtype],
|
out_dtype: type[torch.dtype],
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
|
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
|
||||||
out = scale_a * out
|
out = scale_a * out
|
||||||
|
@ -4,9 +4,9 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import random
|
import random
|
||||||
import unittest
|
import unittest
|
||||||
|
from collections.abc import Sequence
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
|
from typing import Any, NamedTuple, Optional, Union
|
||||||
Type, Union)
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -20,13 +20,13 @@ from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
|||||||
|
|
||||||
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
||||||
# bugs related to this test in PyTorch 2.4.
|
# bugs related to this test in PyTorch 2.4.
|
||||||
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
DEFAULT_OPCHECK_TEST_UTILS: tuple[str, ...] = (
|
||||||
"test_schema",
|
"test_schema",
|
||||||
"test_autograd_registration",
|
"test_autograd_registration",
|
||||||
"test_faketensor",
|
"test_faketensor",
|
||||||
)
|
)
|
||||||
|
|
||||||
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
ALL_OPCHECK_TEST_UTILS: tuple[str, ...] = (
|
||||||
"test_schema",
|
"test_schema",
|
||||||
"test_autograd_registration",
|
"test_autograd_registration",
|
||||||
"test_faketensor",
|
"test_faketensor",
|
||||||
@ -50,8 +50,8 @@ class QKVInputs(NamedTuple):
|
|||||||
query: torch.Tensor
|
query: torch.Tensor
|
||||||
key: torch.Tensor
|
key: torch.Tensor
|
||||||
value: torch.Tensor
|
value: torch.Tensor
|
||||||
q_seq_lens: List[int]
|
q_seq_lens: list[int]
|
||||||
kv_seq_lens: List[int]
|
kv_seq_lens: list[int]
|
||||||
|
|
||||||
|
|
||||||
class QKVO(NamedTuple):
|
class QKVO(NamedTuple):
|
||||||
@ -89,10 +89,10 @@ class PackedQKVInputs(NamedTuple):
|
|||||||
query: torch.Tensor
|
query: torch.Tensor
|
||||||
key: torch.Tensor
|
key: torch.Tensor
|
||||||
value: torch.Tensor
|
value: torch.Tensor
|
||||||
q_start_loc_list: Optional[List[int]]
|
q_start_loc_list: Optional[list[int]]
|
||||||
kv_start_loc_list: Optional[List[int]]
|
kv_start_loc_list: Optional[list[int]]
|
||||||
q_seq_lens: Optional[List[int]]
|
q_seq_lens: Optional[list[int]]
|
||||||
kv_seq_lens: Optional[List[int]]
|
kv_seq_lens: Optional[list[int]]
|
||||||
|
|
||||||
|
|
||||||
class PackedQKVO(NamedTuple):
|
class PackedQKVO(NamedTuple):
|
||||||
@ -146,7 +146,7 @@ class PhaseTestParameters(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
def maybe_make_int_tensor(
|
def maybe_make_int_tensor(
|
||||||
_list: Optional[List[int]],
|
_list: Optional[list[int]],
|
||||||
device: Union[torch.device, str],
|
device: Union[torch.device, str],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
@ -162,7 +162,7 @@ def maybe_make_int_tensor(
|
|||||||
|
|
||||||
|
|
||||||
def maybe_make_long_tensor(
|
def maybe_make_long_tensor(
|
||||||
_list: Optional[List[int]],
|
_list: Optional[list[int]],
|
||||||
device: Union[torch.device, str],
|
device: Union[torch.device, str],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
@ -177,7 +177,7 @@ def maybe_make_long_tensor(
|
|||||||
_list, dtype=torch.long, device=device)
|
_list, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
|
||||||
def maybe_max(_list: Optional[List]) -> Optional[Number]:
|
def maybe_max(_list: Optional[list]) -> Optional[Number]:
|
||||||
'''
|
'''
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@ -232,8 +232,8 @@ def ref_masked_attention(query: torch.Tensor,
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
custom_mask: Optional[torch.Tensor] = None,
|
custom_mask: Optional[torch.Tensor] = None,
|
||||||
q_seq_lens: Optional[List] = None,
|
q_seq_lens: Optional[list] = None,
|
||||||
kv_seq_lens: Optional[List] = None) -> torch.Tensor:
|
kv_seq_lens: Optional[list] = None) -> torch.Tensor:
|
||||||
'''
|
'''
|
||||||
"Golden" masked attention reference. Supports two types of masking:
|
"Golden" masked attention reference. Supports two types of masking:
|
||||||
|
|
||||||
@ -295,10 +295,10 @@ def make_qkv(
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
device: Union[torch.device, str],
|
device: Union[torch.device, str],
|
||||||
force_kv_seq_lens: Optional[List[int]] = None,
|
force_kv_seq_lens: Optional[list[int]] = None,
|
||||||
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
|
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
|
||||||
force_max_len: bool = False,
|
force_max_len: bool = False,
|
||||||
) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
|
) -> tuple[QKVInputs, QKVInputs, QKVInputs]:
|
||||||
'''
|
'''
|
||||||
Construct QKV test tensors for self- and cross-attention.
|
Construct QKV test tensors for self- and cross-attention.
|
||||||
|
|
||||||
@ -429,8 +429,8 @@ def make_qkv(
|
|||||||
|
|
||||||
|
|
||||||
def pack_tensor(
|
def pack_tensor(
|
||||||
unpacked_tensor: torch.Tensor, seq_lens: List[int],
|
unpacked_tensor: torch.Tensor, seq_lens: list[int],
|
||||||
device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
|
device: Union[torch.device, str]) -> tuple[torch.Tensor, list[int]]:
|
||||||
'''
|
'''
|
||||||
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
|
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
|
||||||
unpadded number_of_tokens x num_heads x head_size tensor, where
|
unpadded number_of_tokens x num_heads x head_size tensor, where
|
||||||
@ -537,11 +537,11 @@ def make_backend(backend_name: str) -> AttentionBackend:
|
|||||||
|
|
||||||
|
|
||||||
def _make_metadata_tensors(
|
def _make_metadata_tensors(
|
||||||
seq_lens: Optional[List[int]],
|
seq_lens: Optional[list[int]],
|
||||||
context_lens: Optional[List[int]],
|
context_lens: Optional[list[int]],
|
||||||
encoder_seq_lens: Optional[List[int]],
|
encoder_seq_lens: Optional[list[int]],
|
||||||
device: Union[torch.device, str],
|
device: Union[torch.device, str],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
|
||||||
torch.Tensor, torch.Tensor, Optional[int]]:
|
torch.Tensor, torch.Tensor, Optional[int]]:
|
||||||
'''
|
'''
|
||||||
Build scalar & tensor values required to build attention metadata structure.
|
Build scalar & tensor values required to build attention metadata structure.
|
||||||
@ -654,7 +654,7 @@ def make_empty_block_tables_tensor(device: Union[torch.device, str]):
|
|||||||
return torch.tensor([], device=device)
|
return torch.tensor([], device=device)
|
||||||
|
|
||||||
|
|
||||||
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
|
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int],
|
||||||
device: Union[torch.device, str]):
|
device: Union[torch.device, str]):
|
||||||
'''
|
'''
|
||||||
Split a slot mapping into valid prefill- and decode-phase slot mappings.
|
Split a slot mapping into valid prefill- and decode-phase slot mappings.
|
||||||
@ -682,9 +682,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
* slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
|
* slot_mapping_list: Length-P 1D slot mapping (as list) reflecting all N
|
||||||
post-decode sequences
|
post-decode sequences
|
||||||
* seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
|
* seq_lens: list of N post-decode sequence lengths (K_i + 1 in the
|
||||||
description above)
|
description above)
|
||||||
* device: cuda, cpu, etc.
|
* device: cuda, cpu, etc.
|
||||||
|
|
||||||
@ -712,9 +712,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
|
|||||||
|
|
||||||
def make_block_tables_slot_mapping(
|
def make_block_tables_slot_mapping(
|
||||||
block_size: int,
|
block_size: int,
|
||||||
seq_lens: List[int],
|
seq_lens: list[int],
|
||||||
device: Union[torch.device, str],
|
device: Union[torch.device, str],
|
||||||
block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
|
block_base_addr: int = 0) -> tuple[torch.Tensor, list[int], int]:
|
||||||
'''
|
'''
|
||||||
Construct fake block tables & slot mappings.
|
Construct fake block tables & slot mappings.
|
||||||
|
|
||||||
@ -794,7 +794,7 @@ def make_block_tables_slot_mapping(
|
|||||||
def make_test_metadata(
|
def make_test_metadata(
|
||||||
attn_backend: _Backend,
|
attn_backend: _Backend,
|
||||||
is_prompt: bool,
|
is_prompt: bool,
|
||||||
seq_lens: Optional[List[int]],
|
seq_lens: Optional[list[int]],
|
||||||
decoder_test_params: Optional[PhaseTestParameters],
|
decoder_test_params: Optional[PhaseTestParameters],
|
||||||
device: Union[torch.device, str],
|
device: Union[torch.device, str],
|
||||||
encoder_test_params: Optional[PhaseTestParameters] = None,
|
encoder_test_params: Optional[PhaseTestParameters] = None,
|
||||||
@ -1043,7 +1043,7 @@ def fp8_allclose(
|
|||||||
# Marlin MoE test utils
|
# Marlin MoE test utils
|
||||||
|
|
||||||
|
|
||||||
def stack_and_dev(tensors: List[torch.Tensor]):
|
def stack_and_dev(tensors: list[torch.Tensor]):
|
||||||
dev = tensors[0].device
|
dev = tensors[0].device
|
||||||
return torch.stack(tensors, dim=0).to(dev)
|
return torch.stack(tensors, dim=0).to(dev)
|
||||||
|
|
||||||
@ -1090,12 +1090,12 @@ def torch_moe_single(a, w, score, topk):
|
|||||||
# and a patched version of allclose that supports fp8 types.
|
# and a patched version of allclose that supports fp8 types.
|
||||||
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
||||||
torch._library.custom_ops.CustomOpDef],
|
torch._library.custom_ops.CustomOpDef],
|
||||||
args: Tuple[Any, ...],
|
args: tuple[Any, ...],
|
||||||
kwargs: Optional[Dict[str, Any]] = None,
|
kwargs: Optional[dict[str, Any]] = None,
|
||||||
*,
|
*,
|
||||||
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
||||||
raise_exception: bool = True,
|
raise_exception: bool = True,
|
||||||
cond: bool = True) -> Dict[str, str]:
|
cond: bool = True) -> dict[str, str]:
|
||||||
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
|
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
|
||||||
return torch.library.opcheck(
|
return torch.library.opcheck(
|
||||||
op,
|
op,
|
||||||
@ -1120,7 +1120,7 @@ def baseline_scaled_mm(a: torch.Tensor,
|
|||||||
b: torch.Tensor,
|
b: torch.Tensor,
|
||||||
scale_a: torch.Tensor,
|
scale_a: torch.Tensor,
|
||||||
scale_b: torch.Tensor,
|
scale_b: torch.Tensor,
|
||||||
out_dtype: Type[torch.dtype],
|
out_dtype: type[torch.dtype],
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
# We treat N-dimensional group scaling as extended numpy-style broadcasting
|
# We treat N-dimensional group scaling as extended numpy-style broadcasting
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user