[CI/Build] drop support for Python 3.8 EOL (#8464)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2024-11-06 02:11:55 -05:00 committed by GitHub
parent 4be3a45158
commit 21063c11c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
115 changed files with 239 additions and 321 deletions

View File

@ -56,7 +56,7 @@ serving_column_mapping = {
def read_markdown(file): def read_markdown(file):
if os.path.exists(file): if os.path.exists(file):
with open(file, "r") as f: with open(file) as f:
return f.read() + "\n" return f.read() + "\n"
else: else:
return f"{file} not found.\n" return f"{file} not found.\n"
@ -75,14 +75,14 @@ if __name__ == "__main__":
# collect results # collect results
for test_file in results_folder.glob("*.json"): for test_file in results_folder.glob("*.json"):
with open(test_file, "r") as f: with open(test_file) as f:
raw_result = json.loads(f.read()) raw_result = json.loads(f.read())
if "serving" in str(test_file): if "serving" in str(test_file):
# this result is generated via `benchmark_serving.py` # this result is generated via `benchmark_serving.py`
# attach the benchmarking command to raw_result # attach the benchmarking command to raw_result
with open(test_file.with_suffix(".commands"), "r") as f: with open(test_file.with_suffix(".commands")) as f:
command = json.loads(f.read()) command = json.loads(f.read())
raw_result.update(command) raw_result.update(command)
@ -97,7 +97,7 @@ if __name__ == "__main__":
# this result is generated via `benchmark_latency.py` # this result is generated via `benchmark_latency.py`
# attach the benchmarking command to raw_result # attach the benchmarking command to raw_result
with open(test_file.with_suffix(".commands"), "r") as f: with open(test_file.with_suffix(".commands")) as f:
command = json.loads(f.read()) command = json.loads(f.read())
raw_result.update(command) raw_result.update(command)
@ -119,7 +119,7 @@ if __name__ == "__main__":
# this result is generated via `benchmark_throughput.py` # this result is generated via `benchmark_throughput.py`
# attach the benchmarking command to raw_result # attach the benchmarking command to raw_result
with open(test_file.with_suffix(".commands"), "r") as f: with open(test_file.with_suffix(".commands")) as f:
command = json.loads(f.read()) command = json.loads(f.read())
raw_result.update(command) raw_result.update(command)

View File

@ -72,7 +72,7 @@ def main(args):
# collect results # collect results
for test_file in results_folder.glob("*_nightly_results.json"): for test_file in results_folder.glob("*_nightly_results.json"):
with open(test_file, "r") as f: with open(test_file) as f:
results = results + json.loads(f.read()) results = results + json.loads(f.read())
# generate markdown table # generate markdown table
@ -80,7 +80,7 @@ def main(args):
md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False) md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False)
with open(args.description, "r") as f: with open(args.description) as f:
description = f.read() description = f.read()
description = description.format( description = description.format(

View File

@ -36,11 +36,11 @@ if __name__ == "__main__":
# collect results # collect results
for test_file in results_folder.glob("*.json"): for test_file in results_folder.glob("*.json"):
with open(test_file, "r") as f: with open(test_file) as f:
raw_result = json.loads(f.read()) raw_result = json.loads(f.read())
# attach the benchmarking command to raw_result # attach the benchmarking command to raw_result
with open(test_file.with_suffix(".commands"), "r") as f: with open(test_file.with_suffix(".commands")) as f:
command = json.loads(f.read()) command = json.loads(f.read())
raw_result.update(command) raw_result.update(command)

View File

@ -25,7 +25,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python-version: ["3.9", "3.10", "3.11", "3.12"]
steps: steps:
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}

View File

@ -48,7 +48,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: ['ubuntu-20.04'] os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] python-version: ['3.9', '3.10', '3.11', '3.12']
pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt. pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt.
cuda-version: ['11.8', '12.1'] cuda-version: ['11.8', '12.1']

View File

@ -6,7 +6,7 @@ version: 2
build: build:
os: ubuntu-22.04 os: ubuntu-22.04
tools: tools:
python: "3.8" python: '3.9'
sphinx: sphinx:
configuration: docs/source/conf.py configuration: docs/source/conf.py
@ -19,4 +19,3 @@ formats: []
python: python:
install: install:
- requirements: docs/requirements-docs.txt - requirements: docs/requirements-docs.txt

View File

@ -79,7 +79,7 @@ async def async_request_tgi(
# any data, we should skip it. # any data, we should skip it.
if chunk_bytes.startswith(":"): if chunk_bytes.startswith(":"):
continue continue
chunk = remove_prefix(chunk_bytes, "data:") chunk = chunk_bytes.removeprefix("data:")
data = json.loads(chunk) data = json.loads(chunk)
timestamp = time.perf_counter() timestamp = time.perf_counter()
@ -144,7 +144,7 @@ async def async_request_trt_llm(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), chunk = chunk_bytes.decode("utf-8").removeprefix(
"data:") "data:")
data = json.loads(chunk) data = json.loads(chunk)
@ -261,7 +261,7 @@ async def async_request_openai_completions(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ") "data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st latency = time.perf_counter() - st
@ -349,7 +349,7 @@ async def async_request_openai_chat_completions(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ") "data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st latency = time.perf_counter() - st
@ -389,14 +389,6 @@ async def async_request_openai_chat_completions(
return output return output
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
# introduced in Python 3.9
def remove_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix):]
return text
def get_model(pretrained_model_name_or_path: str) -> str: def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
from modelscope import snapshot_download from modelscope import snapshot_download

View File

@ -269,10 +269,10 @@ def run_square_bench(args):
def run_range_bench(args): def run_range_bench(args):
m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")] m_start, k_start, n_start = (int(x) for x in args.dim_start.split(","))
m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")] m_end, k_end, n_end = (int(x) for x in args.dim_end.split(","))
m_increment, k_increment, n_increment = \ m_increment, k_increment, n_increment = \
[int(x) for x in args.dim_increment.split(",")] (int(x) for x in args.dim_increment.split(","))
Ms = list(range(m_start, m_end + 1, m_increment)) Ms = list(range(m_start, m_end + 1, m_increment))
Ks = list(range(k_start, k_end + 1, k_increment)) Ks = list(range(k_start, k_end + 1, k_increment))
Ns = list(range(n_start, n_end + 1, n_increment)) Ns = list(range(n_start, n_end + 1, n_increment))

View File

@ -468,7 +468,7 @@ def generate():
impl_configs = [] impl_configs = []
GPTQ_kernel_type_configs = list( GPTQ_kernel_type_configs = list(
(TypeConfig( TypeConfig(
element_a=element_a, element_a=element_a,
element_b=element_b, element_b=element_b,
element_b_scale=element_a, element_b_scale=element_a,
@ -476,7 +476,7 @@ def generate():
element_d=element_a, element_d=element_a,
accumulator=DataType.f32, accumulator=DataType.f32,
) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128) ) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for element_a in (DataType.f16, DataType.bf16))) for element_a in (DataType.f16, DataType.bf16))
GPTQ_kernel_specializations = [ GPTQ_kernel_specializations = [
Specialization(with_C=False, with_zeropoints=False, with_scales=True) Specialization(with_C=False, with_zeropoints=False, with_scales=True)
@ -490,7 +490,7 @@ def generate():
] ]
AWQ_kernel_type_configs = list( AWQ_kernel_type_configs = list(
(TypeConfig( TypeConfig(
element_a=element_a, element_a=element_a,
element_b=element_b, element_b=element_b,
element_b_scale=element_a, element_b_scale=element_a,
@ -498,7 +498,7 @@ def generate():
element_d=element_a, element_d=element_a,
accumulator=DataType.f32, accumulator=DataType.f32,
) for element_b in (DataType.u4, DataType.u8) ) for element_b in (DataType.u4, DataType.u8)
for element_a in (DataType.f16, DataType.bf16))) for element_a in (DataType.f16, DataType.bf16))
AWQ_kernel_specializations = [ AWQ_kernel_specializations = [
Specialization(with_C=False, with_zeropoints=True, with_scales=True) Specialization(with_C=False, with_zeropoints=True, with_scales=True)

View File

@ -10,7 +10,7 @@ Requirements
============ ============
* OS: Linux * OS: Linux
* Python: 3.8 - 3.12 * Python: 3.9 -- 3.12
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.) * GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.)
Install released versions Install released versions

View File

@ -34,7 +34,7 @@ select = [
# Pyflakes # Pyflakes
"F", "F",
# pyupgrade # pyupgrade
# "UP", "UP",
# flake8-bugbear # flake8-bugbear
"B", "B",
# flake8-simplify # flake8-simplify
@ -55,7 +55,7 @@ ignore = [
] ]
[tool.mypy] [tool.mypy]
python_version = "3.8" python_version = "3.9"
ignore_missing_imports = true ignore_missing_imports = true
check_untyped_defs = true check_untyped_defs = true

View File

@ -1,5 +1,4 @@
import importlib.util import importlib.util
import io
import logging import logging
import os import os
import re import re
@ -327,7 +326,7 @@ def get_neuronxcc_version():
"__init__.py") "__init__.py")
# Check if the command was executed successfully # Check if the command was executed successfully
with open(version_file, "rt") as fp: with open(version_file) as fp:
content = fp.read() content = fp.read()
# Extract the version using a regular expression # Extract the version using a regular expression
@ -404,7 +403,8 @@ def read_readme() -> str:
"""Read the README file if present.""" """Read the README file if present."""
p = get_path("README.md") p = get_path("README.md")
if os.path.isfile(p): if os.path.isfile(p):
return io.open(get_path("README.md"), "r", encoding="utf-8").read() with open(get_path("README.md"), encoding="utf-8") as f:
return f.read()
else: else:
return "" return ""
@ -498,7 +498,6 @@ setup(
"Documentation": "https://vllm.readthedocs.io/en/latest/", "Documentation": "https://vllm.readthedocs.io/en/latest/",
}, },
classifiers=[ classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
@ -512,7 +511,7 @@ setup(
], ],
packages=find_packages(exclude=("benchmarks", "csrc", "docs", "examples", packages=find_packages(exclude=("benchmarks", "csrc", "docs", "examples",
"tests*")), "tests*")),
python_requires=">=3.8", python_requires=">=3.9",
install_requires=get_requirements(), install_requires=get_requirements(),
ext_modules=ext_modules, ext_modules=ext_modules,
extras_require={ extras_require={

View File

@ -429,8 +429,8 @@ def benchmark():
# print in tabular format # print in tabular format
print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph") print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph")
for b in cudagraph_sizes: for b in cudagraph_sizes:
print((f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}" print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}"
f"\t{piecewise_cudagraph_time[b]:.3f}")) f"\t{piecewise_cudagraph_time[b]:.3f}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,6 +1,5 @@
import json import json
import os import os
import sys
import tempfile import tempfile
from collections import UserList from collections import UserList
from enum import Enum from enum import Enum
@ -52,7 +51,7 @@ PromptVideoInput = _PromptMultiModalInput[np.ndarray]
def _read_prompts(filename: str) -> List[str]: def _read_prompts(filename: str) -> List[str]:
with open(filename, "r") as f: with open(filename) as f:
prompts = f.readlines() prompts = f.readlines()
return prompts return prompts
@ -62,13 +61,7 @@ class _ImageAssetPrompts(TypedDict):
cherry_blossom: str cherry_blossom: str
if sys.version_info < (3, 9): class _ImageAssetsBase(UserList[ImageAsset]):
# UserList cannot be subscripted
class _ImageAssetsBase(UserList):
pass
else:
class _ImageAssetsBase(UserList[ImageAsset]):
pass pass
@ -94,13 +87,7 @@ class _VideoAssetPrompts(TypedDict):
sample_demo_1: str sample_demo_1: str
if sys.version_info < (3, 9): class _VideoAssetsBase(UserList[VideoAsset]):
# UserList cannot be subscripted
class _VideoAssetsBase(UserList):
pass
else:
class _VideoAssetsBase(UserList[VideoAsset]):
pass pass
@ -958,7 +945,7 @@ def dummy_opt_path():
"*.msgpack" "*.msgpack"
]) ])
assert os.path.exists(json_path) assert os.path.exists(json_path)
with open(json_path, "r") as f: with open(json_path) as f:
config = json.load(f) config = json.load(f)
config["architectures"] = ["MyOPTForCausalLM"] config["architectures"] = ["MyOPTForCausalLM"]
with open(json_path, "w") as f: with open(json_path, "w") as f:
@ -977,7 +964,7 @@ def dummy_llava_path():
"*.msgpack" "*.msgpack"
]) ])
assert os.path.exists(json_path) assert os.path.exists(json_path)
with open(json_path, "r") as f: with open(json_path) as f:
config = json.load(f) config = json.load(f)
config["architectures"] = ["MyLlava"] config["architectures"] = ["MyLlava"]
with open(json_path, "w") as f: with open(json_path, "w") as f:
@ -996,7 +983,7 @@ def dummy_gemma2_embedding_path():
"*.msgpack" "*.msgpack"
]) ])
assert os.path.exists(json_path) assert os.path.exists(json_path)
with open(json_path, "r") as f: with open(json_path) as f:
config = json.load(f) config = json.load(f)
config["architectures"] = ["MyGemma2Embedding"] config["architectures"] = ["MyGemma2Embedding"]
with open(json_path, "w") as f: with open(json_path, "w") as f:

View File

@ -99,13 +99,11 @@ class TestPrefixCachingBlock:
token_ids = [random.randint(0, 50_000) for _ in range(num_tokens)] token_ids = [random.randint(0, 50_000) for _ in range(num_tokens)]
first_chain, second_chain = [ first_chain, second_chain = (TestPrefixCachingBlock.create_chain(
TestPrefixCachingBlock.create_chain(
block_size=block_size, block_size=block_size,
token_ids=token_ids, token_ids=token_ids,
num_empty_trailing_blocks=num_empty_trailing_blocks) num_empty_trailing_blocks=num_empty_trailing_blocks)
for _ in range(2) for _ in range(2))
]
for first_chain_block, second_chain_block in zip( for first_chain_block, second_chain_block in zip(
first_chain, second_chain): first_chain, second_chain):

View File

@ -510,7 +510,7 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
for var in (u_ref, delta_ref, B_ref, C_ref, z_ref) for var in (u_ref, delta_ref, B_ref, C_ref, z_ref)
] ]
for i in range(len(seqlens[0])): for i in range(len(seqlens[0])):
u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits] u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits)
if padded_state_indices[i] == PAD_SLOT_ID: if padded_state_indices[i] == PAD_SLOT_ID:
continue continue
out_ref_s, _ = selective_scan_ref( out_ref_s, _ = selective_scan_ref(

View File

@ -104,7 +104,7 @@ def test_input_mapper_valid_mm_data(input_mapper_for_qwen,
# Sad path tests for the multimodal input processor and mapper, respectively # Sad path tests for the multimodal input processor and mapper, respectively
@pytest.mark.parametrize("mm_data", [ @pytest.mark.parametrize("mm_data", [
{ {
"image": torch.rand((5)) "image": torch.rand(5)
}, },
{ {
"image": torch.rand((5, 5, 5, 5, 5)) "image": torch.rand((5, 5, 5, 5, 5))

View File

@ -413,12 +413,10 @@ class _CorrectnessTestHelper:
def generate_probs_for_test( def generate_probs_for_test(
self, draft_and_target_probs_equal: bool self, draft_and_target_probs_equal: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
draft_probs, target_probs = [ draft_probs, target_probs = (F.softmax(
F.softmax(
torch.rand(self.vocab_size, dtype=torch.float32), torch.rand(self.vocab_size, dtype=torch.float32),
dim=-1, dim=-1,
) for _ in range(2) ) for _ in range(2))
]
num_reference_probs = 100 num_reference_probs = 100
reference_probs = F.softmax( reference_probs = F.softmax(

View File

@ -29,7 +29,7 @@ def test_trace_function_call():
cur_dir = os.path.dirname(__file__) cur_dir = os.path.dirname(__file__)
enable_trace_function_call(path, cur_dir) enable_trace_function_call(path, cur_dir)
f1(1) f1(1)
with open(path, 'r') as f: with open(path) as f:
content = f.read() content = f.read()
assert "f1" in content assert "f1" in content

View File

@ -93,10 +93,10 @@ def test_mistral_edge_case(tokenizer, truth):
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]: def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
if "mistral" in tokenizer_name: if "mistral" in tokenizer_name:
yield ( yield (
bool(True) if request.param else True if request.param else
pytest.skip("mistral doesn't support skip_special_tokens=False")) pytest.skip("mistral doesn't support skip_special_tokens=False"))
else: else:
yield bool(True) if request.param else bool(False) yield bool(request.param)
@pytest.mark.parametrize("truth", TRUTH) @pytest.mark.parametrize("truth", TRUTH)

View File

@ -46,7 +46,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
with open(args.json_trace, "r") as f: with open(args.json_trace) as f:
profile_data = json.load(f) profile_data = json.load(f)
if args.table == "summary": if args.table == "summary":

View File

@ -434,7 +434,7 @@ def main(
f"{', Sparsity ' + sparsity if sparsity else ''}") f"{', Sparsity ' + sparsity if sparsity else ''}")
profile_json = None profile_json = None
with open(json_trace, "r") as f: with open(json_trace) as f:
profile_json = json.load(f) profile_json = json.load(f)
assert profile_json is not None assert profile_json is not None

View File

@ -81,7 +81,7 @@ class Target:
# Allow for modest floating-point errors # Allow for modest floating-point errors
epsilon = 0.000002 epsilon = 0.000002
if (self.weighted_duration > self.Duration() + epsilon): if (self.weighted_duration > self.Duration() + epsilon):
print('%s > %s?' % (self.weighted_duration, self.Duration())) print('{} > {}?'.format(self.weighted_duration, self.Duration()))
assert (self.weighted_duration <= self.Duration() + epsilon) assert (self.weighted_duration <= self.Duration() + epsilon)
return self.weighted_duration return self.weighted_duration
@ -104,7 +104,7 @@ def ReadTargets(log, show_all):
The result is a list of Target objects.""" The result is a list of Target objects."""
header = log.readline() header = log.readline()
assert header == '# ninja log v5\n', \ assert header == '# ninja log v5\n', \
'unrecognized ninja log version %r' % header 'unrecognized ninja log version {!r}'.format(header)
targets_dict = {} targets_dict = {}
last_end_seen = 0.0 last_end_seen = 0.0
for line in log: for line in log:
@ -254,8 +254,8 @@ def SummarizeEntries(entries, extra_step_types):
# Warn if the sum of weighted times is off by more than half a second. # Warn if the sum of weighted times is off by more than half a second.
if abs(length - weighted_total) > 500: if abs(length - weighted_total) > 500:
print('Warning: Possible corrupt ninja log, results may be ' print('Warning: Possible corrupt ninja log, results may be '
'untrustworthy. Length = %.3f, weighted total = %.3f' % 'untrustworthy. Length = {:.3f}, weighted total = {:.3f}'.format(
(length, weighted_total)) length, weighted_total))
entries_by_ext = defaultdict(list) entries_by_ext = defaultdict(list)
for target in entries: for target in entries:
@ -263,16 +263,17 @@ def SummarizeEntries(entries, extra_step_types):
entries_by_ext[extension].append(target) entries_by_ext[extension].append(target)
for key, values in entries_by_ext.items(): for key, values in entries_by_ext.items():
print(' Longest build steps for %s:' % key) print(' Longest build steps for {}:'.format(key))
values.sort(key=lambda x: x.WeightedDuration()) values.sort(key=lambda x: x.WeightedDuration())
for target in values[-long_count:]: for target in values[-long_count:]:
print(' %8.1f weighted s to build %s (%.1f s elapsed time)' % print(
(target.WeightedDuration(), target.DescribeTargets(), ' {:8.1f} weighted s to build {} ({:.1f} s elapsed time)'.
format(target.WeightedDuration(), target.DescribeTargets(),
target.Duration())) target.Duration()))
print(' %.1f s weighted time (%.1f s elapsed time sum, %1.1fx ' print(' {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x '
'parallelism)' % 'parallelism)'.format(length, total_cpu_time,
(length, total_cpu_time, total_cpu_time * 1.0 / length)) total_cpu_time * 1.0 / length))
print(' %d build steps completed, average of %1.2f/s' % print(' %d build steps completed, average of %1.2f/s' %
(len(entries), len(entries) / (length))) (len(entries), len(entries) / (length)))
@ -298,11 +299,12 @@ def main():
long_ext_count += len(args.step_types.split(';')) long_ext_count += len(args.step_types.split(';'))
try: try:
with open(log_file, 'r') as log: with open(log_file) as log:
entries = ReadTargets(log, False) entries = ReadTargets(log, False)
SummarizeEntries(entries, args.step_types) SummarizeEntries(entries, args.step_types)
except IOError: except OSError:
print('Log file %r not found, no build summary created.' % log_file) print('Log file {!r} not found, no build summary created.'.format(
log_file))
return errno.ENOENT return errno.ENOENT

View File

@ -4,7 +4,7 @@ requires_files = glob.glob('requirements*.txt')
requires_files += ["pyproject.toml"] requires_files += ["pyproject.toml"]
for file in requires_files: for file in requires_files:
print(f">>> cleaning {file}") print(f">>> cleaning {file}")
with open(file, 'r') as f: with open(file) as f:
lines = f.readlines() lines = f.readlines()
if "torch" in "".join(lines).lower(): if "torch" in "".join(lines).lower():
print("removed:") print("removed:")

View File

@ -192,10 +192,8 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen] attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1) q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
k2, v2 = [ k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio) for x in [k, v])
for x in [k, v]
]
spda_output = torch.nn.functional.scaled_dot_product_attention( spda_output = torch.nn.functional.scaled_dot_product_attention(
q2, k2, v2, attn_mask=attn_mask, scale=sm_scale) q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
return self.transpose_and_unpad(spda_output, cu_seqlens) return self.transpose_and_unpad(spda_output, cu_seqlens)

View File

@ -668,9 +668,10 @@ class ModelConfig:
@property @property
def is_encoder_decoder_model(self) -> bool: def is_encoder_decoder_model(self) -> bool:
"""Extract the HF encoder/decoder model flag.""" """Extract the HF encoder/decoder model flag."""
return getattr(self.hf_config, "is_encoder_decoder", False) or ( return getattr(
(hasattr(self.hf_config, "text_config") and getattr( self.hf_config, "is_encoder_decoder",
self.hf_config.text_config, "is_encoder_decoder", False))) False) or (hasattr(self.hf_config, "text_config") and getattr(
self.hf_config.text_config, "is_encoder_decoder", False))
@property @property
def is_multimodal_model(self) -> bool: def is_multimodal_model(self) -> bool:

View File

@ -52,7 +52,7 @@ class Evictor(ABC):
pass pass
class BlockMetaData(): class BlockMetaData:
"""Data structure for storing key data describe cached block, so that """Data structure for storing key data describe cached block, so that
evitor could use to make its decision which one to choose for eviction evitor could use to make its decision which one to choose for eviction

View File

@ -240,7 +240,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
if is_distributed: if is_distributed:
get_world_group().barrier() get_world_group().barrier()
logger.info("reading GPU P2P access cache from %s", path) logger.info("reading GPU P2P access cache from %s", path)
with open(path, "r") as f: with open(path) as f:
cache = json.load(f) cache = json.load(f)
_gpu_p2p_access_cache = cache _gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{src}->{tgt}"] return _gpu_p2p_access_cache[f"{src}->{tgt}"]

View File

@ -812,7 +812,7 @@ class AsyncLLMEngine(EngineClient):
async def run_engine_loop(engine_ref: ReferenceType): async def run_engine_loop(engine_ref: ReferenceType):
"""We use a weakref to the engine so that the running loop """We use a weakref to the engine so that the running loop
doesn't prevent the engine being garbage collected.""" doesn't prevent the engine being garbage collected."""
engine: Optional["AsyncLLMEngine"] = engine_ref() engine: Optional[AsyncLLMEngine] = engine_ref()
if not engine: if not engine:
return return

View File

@ -1541,8 +1541,8 @@ class LLMEngine:
seq_group.state.remaining_steps != ref_remaining_steps seq_group.state.remaining_steps != ref_remaining_steps
for seq_group in seq_group_metadata_list[1:] for seq_group in seq_group_metadata_list[1:]
]): ]):
raise AssertionError(("All running sequence groups should " raise AssertionError("All running sequence groups should "
"have the same remaining steps.")) "have the same remaining steps.")
return ref_remaining_steps > 0 return ref_remaining_steps > 0

View File

@ -77,7 +77,7 @@ class StatLoggerBase(ABC):
self.num_generation_tokens: List[int] = [] self.num_generation_tokens: List[int] = []
self.last_local_log = time.time() self.last_local_log = time.time()
self.local_interval = local_interval self.local_interval = local_interval
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None self.spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None
@abstractmethod @abstractmethod
def log(self, stats: Stats) -> None: def log(self, stats: Stats) -> None:

View File

@ -63,7 +63,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
single_step_process_prompt_logprob(self, seq_group, output) single_step_process_prompt_logprob(self, seq_group, output)
@staticmethod @staticmethod
@functools.lru_cache() @functools.lru_cache
def _log_prompt_logprob_unsupported_warning_once(): def _log_prompt_logprob_unsupported_warning_once():
# Reminder: Please update docs/source/serving/compatibility_matrix.rst # Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid # If the feature combo become valid

View File

@ -362,7 +362,7 @@ def load_chat_template(
if chat_template is None: if chat_template is None:
return None return None
try: try:
with open(chat_template, "r") as f: with open(chat_template) as f:
resolved_chat_template = f.read() resolved_chat_template = f.read()
except OSError as e: except OSError as e:
if isinstance(chat_template, Path): if isinstance(chat_template, Path):

View File

@ -120,7 +120,7 @@ async def read_file(path_or_url: str) -> str:
session.get(path_or_url) as resp: session.get(path_or_url) as resp:
return await resp.text() return await resp.text()
else: else:
with open(path_or_url, "r", encoding="utf-8") as f: with open(path_or_url, encoding="utf-8") as f:
return f.read() return f.read()

View File

@ -32,7 +32,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
uses_ray: bool = True uses_ray: bool = True
def _init_executor(self) -> None: def _init_executor(self) -> None:
self.forward_dag: Optional["ray.dag.CompiledDAG"] = None self.forward_dag: Optional[ray.dag.CompiledDAG] = None
# If the env var is set, it uses the Ray's compiled DAG API # If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead. # which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.

View File

@ -67,8 +67,7 @@ def _configure_vllm_root_logger() -> None:
raise RuntimeError( raise RuntimeError(
"Could not load logging config. File does not exist: %s", "Could not load logging config. File does not exist: %s",
VLLM_LOGGING_CONFIG_PATH) VLLM_LOGGING_CONFIG_PATH)
with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8", with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
mode="r") as file:
custom_config = json.loads(file.read()) custom_config = json.loads(file.read())
if not isinstance(custom_config, dict): if not isinstance(custom_config, dict):

View File

@ -343,7 +343,7 @@ class LoRAModelManager(AdapterModelManager):
# text modules (e.g. ChatGLM) # text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping")) and hasattr(self.model, "get_mm_mapping"))
self.packed_modules: Dict[str, List[str]] = {} self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {} self.modules: Dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a Set for compatibility with LRUCache. # Dict instead of a Set for compatibility with LRUCache.
self._last_mapping: Optional[LoRAMapping] = None self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules() self._create_lora_modules()
@ -548,7 +548,7 @@ class LoRAModelManager(AdapterModelManager):
else: else:
parts = module_name.split(".") parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]] replacements = self.packed_modules_mapping[parts[-1]]
subloras: List[Optional["LoRALayerWeights"]] = [] subloras: List[Optional[LoRALayerWeights]] = []
for i, r in enumerate(replacements): for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights( lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r, module_name + "." + r,

View File

@ -103,7 +103,7 @@ class CustomOp(nn.Module):
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence. # Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
@staticmethod @staticmethod
@lru_cache() @lru_cache
def default_on() -> bool: def default_on() -> bool:
count_none = envs.VLLM_CUSTOM_OPS.count("none") count_none = envs.VLLM_CUSTOM_OPS.count("none")
count_all = envs.VLLM_CUSTOM_OPS.count("all") count_all = envs.VLLM_CUSTOM_OPS.count("all")

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -746,7 +746,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
config_file_path = self._get_config_file(qlora_adapter) config_file_path = self._get_config_file(qlora_adapter)
with open(config_file_path, "r") as f: with open(config_file_path) as f:
config = json.load(f) config = json.load(f)
self.target_modules = config["target_modules"] self.target_modules = config["target_modules"]

View File

@ -190,7 +190,7 @@ def get_model(
kv_cache_dtype: ov.Type, kv_cache_dtype: ov.Type,
**kwargs, **kwargs,
) -> torch.nn.Module: ) -> torch.nn.Module:
lora_config = kwargs.get("lora_config", None) lora_config = kwargs.get("lora_config")
ov_core = kwargs.get("ov_core") ov_core = kwargs.get("ov_core")
if lora_config: if lora_config:
raise ValueError( raise ValueError(

View File

@ -280,7 +280,7 @@ class TensorizerAgent:
self.tensorizer_args = ( self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args()) self.tensorizer_config._construct_tensorizer_args())
self.extra_kwargs = extra_kwargs self.extra_kwargs = extra_kwargs
if extra_kwargs.get("quant_config", None) is not None: if extra_kwargs.get("quant_config") is not None:
self.quant_config = extra_kwargs["quant_config"] self.quant_config = extra_kwargs["quant_config"]
else: else:
self.quant_config = quant_config self.quant_config = quant_config
@ -380,8 +380,7 @@ def tensorizer_weights_iterator(
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
with TensorDeserializer(stream, **deserializer_args, with TensorDeserializer(stream, **deserializer_args,
device="cpu") as state: device="cpu") as state:
for name, param in state.items(): yield from state.items()
yield name, param
del state del state

View File

@ -188,7 +188,7 @@ def get_quant_config(model_config: ModelConfig,
f"{quant_config_files}") f"{quant_config_files}")
quant_config_file = quant_config_files[0] quant_config_file = quant_config_files[0]
with open(quant_config_file, "r") as f: with open(quant_config_file) as f:
config = json.load(f) config = json.load(f)
if model_config.quantization == "bitsandbytes": if model_config.quantization == "bitsandbytes":
@ -306,7 +306,7 @@ def filter_duplicate_safetensors_files(hf_weights_files: List[str],
# Iterate through the weight_map (weight_name: safetensors files) # Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use. # to identify weights that we should use.
with open(index_file_name, "r") as f: with open(index_file_name) as f:
weight_map = json.load(f)["weight_map"] weight_map = json.load(f)["weight_map"]
weight_files_in_index = set() weight_files_in_index = set()
for weight_name in weight_map: for weight_name in weight_map:
@ -382,7 +382,7 @@ def np_cache_weights_iterator(
with open(weight_names_file, "w") as f: with open(weight_names_file, "w") as f:
json.dump(weight_names, f) json.dump(weight_names, f)
with open(weight_names_file, "r") as f: with open(weight_names_file) as f:
weight_names = json.load(f) weight_names = json.load(f)
for name in weight_names: for name in weight_names:
@ -423,8 +423,7 @@ def pt_weights_iterator(
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file, map_location="cpu")
for name, param in state.items(): yield from state.items()
yield name, param
del state del state
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -48,7 +48,7 @@ class ArcticMLP(nn.Module):
is_residual_mlp: bool = False, is_residual_mlp: bool = False,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True): reduce_results: bool = True):
super(ArcticMLP, self).__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.expert_id = expert_id self.expert_id = expert_id
self.layer_id = layer_id self.layer_id = layer_id
@ -89,7 +89,7 @@ class ArcticMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True): reduce_results: bool = True):
super(ArcticMoE, self).__init__() super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size() self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# #
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/THUDM/GLM-4 # https://github.com/THUDM/GLM-4
"""Inference-only ChatGLM model compatible with THUDM weights.""" """Inference-only ChatGLM model compatible with THUDM weights."""

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved. # Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
# #
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX

View File

@ -1,4 +1,3 @@
# coding=utf-8
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Tuple, Union
import torch import torch

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 DeciAI Research Team. All rights reserved. # Copyright 2023 DeciAI Research Team. All rights reserved.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py # https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py
# Copyright 2024 The LG U+ CTO AI Tech Lab. # Copyright 2024 The LG U+ CTO AI Tech Lab.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py # https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py # adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved. # Copyright 2023 HuggingFace Inc. team. All rights reserved.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright (c) Google Inc. # Copyright (c) Google Inc.
# #

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 The vLLM team. # Copyright 2024 The vLLM team.
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
# #

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/THUDM/GLM-4 # https://github.com/THUDM/GLM-4
"""Inference-only GLM-4v model visual encoder compatible with THUDM weights.""" """Inference-only GLM-4v model visual encoder compatible with THUDM weights."""

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,5 +1,3 @@
# coding=utf-8
# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py # adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py
# Copyright 2024 The vLLM team. # Copyright 2024 The vLLM team.
# Copyright 2024 the HuggingFace Inc. team. All rights reserved. # Copyright 2024 the HuggingFace Inc. team. All rights reserved.

View File

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

View File

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py # https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
"""Inference-only Jamba model.""" """Inference-only Jamba model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
"""PyTorch MAMBA model.""" """PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2024 The ModelBest team. # Copyright 2024 The ModelBest team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved. # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -37,7 +37,7 @@ class MLPSpeculatorLayerNorm(nn.Module):
eps=1e-06, eps=1e-06,
elementwise_scale_and_shift=True, elementwise_scale_and_shift=True,
): ):
super(MLPSpeculatorLayerNorm, self).__init__() super().__init__()
self.elementwise_scale_and_shift = elementwise_scale_and_shift self.elementwise_scale_and_shift = elementwise_scale_and_shift
if self.elementwise_scale_and_shift: if self.elementwise_scale_and_shift:
self.weight = nn.Parameter(torch.empty(normalized_shape)) self.weight = nn.Parameter(torch.empty(normalized_shape))

View File

@ -1121,9 +1121,9 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
batch_size * num_image * num_patch, -1).contiguous() batch_size * num_image * num_patch, -1).contiguous()
image_input_idx = image_input_idx * valid.to(image_input_idx.dtype) image_input_idx = image_input_idx * valid.to(image_input_idx.dtype)
offset = torch.cat( offset = torch.cat([seq_len.new_zeros(1),
[seq_len.new_zeros( seq_len.cumsum(dim=0)[:-1]],
(1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None] dim=0)[:, None]
image_input_idx = image_input_idx + offset.to(image_input_idx.dtype) image_input_idx = image_input_idx + offset.to(image_input_idx.dtype)
image_input_idx = image_input_idx.flatten()[:, None] image_input_idx = image_input_idx.flatten()[:, None]
mat = image_input_idx == torch.arange( mat = image_input_idx == torch.arange(

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math import math
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Tuple, Union

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
# Copyright 2024 The vLLM team. # Copyright 2024 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py # https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
# Copyright (c) OrionStar Inc. # Copyright (c) OrionStar Inc.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py # adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py # https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from llama.py # Adapted from llama.py
"""Inference-only Phi3 model code inherit from Llama.py""" """Inference-only Phi3 model code inherit from Llama.py"""

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 The vLLM team. # Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
# #

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -136,11 +136,11 @@ def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
if image_token_id not in inputs['prompt_token_ids']: if image_token_id not in inputs['prompt_token_ids']:
raise ValueError( raise ValueError(
(f"You've passed {inputs=} without {image_token_id=}" f"You've passed {inputs=} without {image_token_id=}"
" Make sure to process your input via mistral_common's" " Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more" " tokenizer or pass a chat completion request. For more"
" For more info, see: " " For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")) "https://github.com/vllm-project/vllm/issues/8411.")
return inputs return inputs

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py # https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud. # Copyright (c) Alibaba Cloud.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
# Copyright 2024 The Qwen team. # Copyright 2024 The Qwen team.
@ -417,9 +416,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
and hasattr(config, "max_window_layers")): and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not " raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window " "supported. This model uses sliding window "
"but `max_window_layers` = %s is less than " "but `max_window_layers` = {} is less than "
"`num_hidden_layers` = %s. Please open an issue " "`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature." % ( "to discuss this feature.".format(
config.max_window_layers, config.max_window_layers,
config.num_hidden_layers, config.num_hidden_layers,
)) ))

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 The Qwen team. # Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py # https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
# Copyright 2024 Kakao Corp. (Kanana-X Team) # Copyright 2024 Kakao Corp. (Kanana-X Team)
@ -60,9 +59,9 @@ class Qwen2ForSequenceClassification(nn.Module):
and hasattr(config, "max_window_layers")): and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not " raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window " "supported. This model uses sliding window "
"but `max_window_layers` = %s is less than " "but `max_window_layers` = {} is less than "
"`num_hidden_layers` = %s. Please open an issue " "`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature." % ( "to discuss this feature.".format(
config.max_window_layers, config.max_window_layers,
config.num_hidden_layers, config.num_hidden_layers,
)) ))

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team. # Copyright 2024 The Qwen team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py # https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
# Copyright 2024 The Qwen team. # Copyright 2024 The Qwen team.
@ -71,9 +70,9 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
and hasattr(config, "max_window_layers")): and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not " raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window " "supported. This model uses sliding window "
"but `max_window_layers` = %s is less than " "but `max_window_layers` = {} is less than "
"`num_hidden_layers` = %s. Please open an issue " "`num_hidden_layers` = {}. Please open an issue "
"to discuss this feature." % ( "to discuss this feature.".format(
config.max_window_layers, config.max_window_layers,
config.num_hidden_layers, config.num_hidden_layers,
)) ))

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py # https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team. # Copyright 2024 The Qwen team.
@ -246,9 +245,8 @@ class Qwen2VisionAttention(nn.Module):
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
batch_size = q.shape[1] batch_size = q.shape[1]
q, k, v = [ q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) for x in (q, k, v))
]
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
@ -258,7 +256,7 @@ class Qwen2VisionAttention(nn.Module):
# flash_attn_varlen_func) # flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
output = flash_attn_varlen_func(q, output = flash_attn_varlen_func(q,
@ -276,7 +274,7 @@ class Qwen2VisionAttention(nn.Module):
b=batch_size) b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == _Backend.TORCH_SDPA:
seq_length = q.size(1) seq_length = q.size(1)
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
attention_mask = torch.zeros([1, seq_length, seq_length], attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device, device=q.device,
dtype=torch.bool) dtype=torch.bool)

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.

View File

@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. # Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
# All rights reserved. # All rights reserved.
# #

Some files were not shown because too many files have changed in this diff Show More