Re-enable the 80 char line width limit (#3305)

This commit is contained in:
Zhuohan Li 2024-03-10 19:49:14 -07:00 committed by GitHub
parent 4b59f00e91
commit 2f8844ba08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
67 changed files with 557 additions and 528 deletions

View File

@ -9,6 +9,10 @@ requires = [
]
build-backend = "setuptools.build_meta"
[tool.ruff]
# Allow lines to be as long as 80.
line-length = 80
[tool.ruff.lint]
select = [
# pycodestyle
@ -29,8 +33,6 @@ ignore = [
"F405", "F403",
# lambda expression assignment
"E731",
# line too long, handled by black formatting
"E501",
# .strip() with multi-character strings
"B005",
# Loop control variable not used within loop body

View File

@ -142,8 +142,8 @@ def get_pytorch_rocm_arch() -> Set[str]:
# If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator
if env_arch_list is None:
command = "rocm_agent_enumerator"
env_arch_list = subprocess.check_output([command]).decode('utf-8')\
.strip().replace("\n", ";")
env_arch_list = (subprocess.check_output(
[command]).decode('utf-8').strip().replace("\n", ";"))
arch_source_str = "rocm_agent_enumerator"
else:
arch_source_str = "PYTORCH_ROCM_ARCH env variable"

View File

@ -73,7 +73,7 @@ def test_load_chat_template():
assert template_content is not None
# Hard coded value for template_chatml.jinja
assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
def test_no_load_chat_template():
@ -117,4 +117,6 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt=mock_request.add_generation_prompt)
# Test assertion
assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
assert result == expected_output, (
f"The generated prompt does not match the expected output for "
f"model {model} and template {template}")

View File

@ -4,7 +4,8 @@ from typing import List
from vllm import SamplingParams
from vllm.block import PhysicalTokenBlock
from vllm.core.block_manager import BlockAllocator, BlockSpaceManager, AllocStatus
from vllm.core.block_manager import (BlockAllocator, BlockSpaceManager,
AllocStatus)
from vllm.utils import Device
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus, Logprob

View File

@ -46,8 +46,8 @@ TEST_SCHEMA = {
"required": ["name", "age", "skills", "work history"]
}
TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
def test_guided_logits_processors():

View File

@ -5,9 +5,12 @@ import time
import sys
import pytest
import requests
import ray # using Ray for overall ease of process management, parallel requests, and debugging.
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
import openai # use the official client for correctness check
from huggingface_hub import snapshot_download # downloading lora to test lora requests
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
# imports for guided decoding tests
import json
@ -17,8 +20,11 @@ import re
from vllm.transformers_utils.tokenizer import get_tokenizer
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
TEST_SCHEMA = {
"type": "object",
@ -59,8 +65,8 @@ TEST_SCHEMA = {
"required": ["name", "age", "skills", "work history"]
}
TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
TEST_CHOICE = [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
@ -120,8 +126,9 @@ def server(zephyr_lora_files):
server_runner = ServerRunner.remote([
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16", # use half precision for speed and memory savings in CI environment
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
@ -392,7 +399,8 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client.
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
@ -469,8 +477,8 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=
f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}",
prompt=f"Give an example JSON for an employee profile "
f"that fits this schema: {TEST_SCHEMA}",
n=3,
temperature=1.0,
max_tokens=500,
@ -489,9 +497,11 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "Give an example JSON for an employee profile that " + \
f"fits this schema: {TEST_SCHEMA}"
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,

View File

@ -57,7 +57,8 @@ def test_fused_moe(
[torch.float32, torch.float16, torch.bfloat16])
@torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype):
"Make sure our Mixtral MoE implementation agrees with the one from huggingface."
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""
# Instantiate our and huggingface's MoE blocks
config = MixtralConfig()

View File

@ -114,7 +114,8 @@ def test_contexted_kv_attention(
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()
# Warm up the Triton kernel by calling it once before actually measuring generation time
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
torch.cuda.synchronize()

View File

@ -11,9 +11,9 @@ from .conftest import cleanup
MODEL_PATH = "Felladrin/Llama-68M-Chat-v1"
PROMPTS = [
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]",
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]",
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]",
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", # noqa: E501
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]", # noqa: E501
]

View File

@ -17,14 +17,16 @@ from vllm.lora.layers import (
LoRAMapping,
BaseLayerWithLoRA,
)
from vllm.lora.models import LoRALayerWeights, convert_mapping, PackedLoRALayerWeights
from vllm.lora.models import (LoRALayerWeights, convert_mapping,
PackedLoRALayerWeights)
from vllm.config import LoRAConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
QKVParallelLinear)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.utils import set_random_seed
from .utils import DummyLoRAManager
@ -258,7 +260,8 @@ def test_embeddings(dist_init, num_loras, device) -> None:
@torch.inference_mode()
# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.")
# @pytest.mark.skip(
# reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
@ -674,9 +677,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
result = linear(input_)[0]
subloras = sublora_dict[lora_id]
for i, sublora in enumerate(subloras):
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * (
i + 1
)] += input_ @ sublora.lora_a @ sublora.lora_b * sublora.scaling
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
(i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
sublora.scaling)
expected_results.append(result)
expected_result = torch.cat(expected_results)

View File

@ -10,12 +10,12 @@ MODEL_PATH = "meta-llama/Llama-2-7b-hf"
def do_sample(llm, lora_path: str, lora_id: int):
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]"
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
@ -48,20 +48,20 @@ def test_llama_lora(sql_lora_files, tp_size):
tensor_parallel_size=tp_size)
expected_no_lora_output = [
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]",
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ",
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m",
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ",
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ",
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE",
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501
]
expected_lora_output = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ",
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ",
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ",
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ",
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ",
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' "
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501
]
print("lora adapter created")
@ -121,7 +121,8 @@ def test_llama_tensor_parallel_equality(sql_lora_files):
def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and is more conservative"""
"""Test that the LLM initialization works with a warmup LORA path and
is more conservative"""
@ray.remote(num_gpus=1)
def get_num_gpu_blocks_lora():
@ -132,13 +133,15 @@ def test_llama_lora_warmup(sql_lora_files):
@ray.remote(num_gpus=1)
def get_num_gpu_blocks_no_lora():
llm = vllm.LLM(MODEL_PATH, max_num_seqs=16)
num_gpu_blocks_no_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks
num_gpu_blocks_no_lora_warmup = (
llm.llm_engine.cache_config.num_gpu_blocks)
return num_gpu_blocks_no_lora_warmup
num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote())
num_gpu_blocks_no_lora_warmup = ray.get(
get_num_gpu_blocks_no_lora.remote())
assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, (
"The warmup with lora should be more"
" conservative than without lora, therefore the number of memory blocks for the KV cache should be "
"The warmup with lora should be more "
"conservative than without lora, therefore the number of "
"memory blocks for the KV cache should be "
"less when using lora than when not using lora")

View File

@ -9,9 +9,9 @@ MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
def do_sample(llm, lora_path: str, lora_id: int):
prompts = [
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]",
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]",
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]",
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", # noqa: E501
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]", # noqa: E501
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
outputs = llm.generate(
@ -42,9 +42,9 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
worker_use_ray=True)
expected_lora_output = [
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])",
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])",
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])",
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", # noqa: E501
]
assert do_sample(llm, mixtral_lora_files,

View File

@ -21,7 +21,8 @@ def test_metric_counter_prompt_tokens(
gpu_memory_utilization=0.4)
tokenizer = vllm_model.model.get_tokenizer()
prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts]
# This test needs at least 2 prompts in a batch of different lengths to verify their token count is correct despite padding.
# This test needs at least 2 prompts in a batch of different lengths to
# verify their token count is correct despite padding.
assert len(example_prompts) > 1, "at least 2 prompts are required"
assert prompt_token_counts[0] != prompt_token_counts[1], (
"prompts of different lengths are required")
@ -33,8 +34,8 @@ def test_metric_counter_prompt_tokens(
**stat_logger.labels)._value.get()
assert vllm_prompt_token_count == metric_count, (
f"prompt token count: {vllm_prompt_token_count!r}\nmetric: {metric_count!r}"
)
f"prompt token count: {vllm_prompt_token_count!r}\n"
f"metric: {metric_count!r}")
@pytest.mark.parametrize("model", MODELS)
@ -60,9 +61,10 @@ def test_metric_counter_generation_tokens(
for i in range(len(example_prompts)):
vllm_output_ids, vllm_output_str = vllm_outputs[i]
prompt_ids = tokenizer.encode(example_prompts[i])
# vllm_output_ids contains both prompt tokens and generation tokens. We're interested only in the count of the generation tokens.
# vllm_output_ids contains both prompt tokens and generation tokens.
# We're interested only in the count of the generation tokens.
vllm_generation_count += len(vllm_output_ids) - len(prompt_ids)
assert vllm_generation_count == metric_count, (
f"generation token count: {vllm_generation_count!r}\nmetric: {metric_count!r}"
)
f"generation token count: {vllm_generation_count!r}\n"
f"metric: {metric_count!r}")

View File

@ -1,7 +1,7 @@
"""Compare the outputs of a GPTQ model to a Marlin model.
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can
@ -14,7 +14,8 @@ Run `pytest tests/models/test_marlin.py --forked`.
import pytest
import torch
from dataclasses import dataclass
from vllm.model_executor.layers.quantization import _QUANTIZATION_CONFIG_REGISTRY
from vllm.model_executor.layers.quantization import (
_QUANTIZATION_CONFIG_REGISTRY)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
@ -87,11 +88,11 @@ def test_models(
if marlin_output_id != gptq_output_id:
# Each predicted token must be in top 5 of the other's
assert gptq_output_id in marlin_logprobs[idx], (
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}"
)
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
f"Marlin:\t{marlin_output_str!r}")
assert marlin_output_id in gptq_logprobs[idx], (
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}"
)
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
f"Marlin:\t{marlin_output_str!r}")
# Break out since sequences will now diverge.
break

View File

@ -20,20 +20,23 @@ def test_block_allocator(
num_blocks,
enable_caching=True)
# Allocate two PysicalTokenBlocks with the same hash and check that they are the same PhysicalTokenBlock
# Allocate two PysicalTokenBlocks with the same hash and check
# that they are the same PhysicalTokenBlock
first_block = block_allocator.allocate(block_hash, 0)
second_block = block_allocator.allocate(block_hash, 0)
assert (first_block == second_block)
assert (second_block.ref_count == 2)
# Free the first_block and confirm that the ref_count is correctly decremented on the second block
# Free the first_block and confirm that the ref_count is correctly
# decremented on the second block
block_allocator.free(first_block)
assert (second_block.ref_count == 1)
# Free the second block
block_allocator.free(second_block)
# Reallocate the first block and confirm that, even after the block had its ref_count go to 0, we still get the same block back
# Reallocate the first block and confirm that, even after the block
# had its ref_count go to 0, we still get the same block back
first_block = block_allocator.allocate(block_hash, 0)
assert (first_block == second_block)
assert (first_block.block_hash == block_hash)
@ -56,7 +59,8 @@ def test_eviction(num_blocks: int, ):
for block in blocks:
block_allocator.free(block)
# Allocate a new block and confirm that it's the first block freed. I.E The Least Recently Used block
# Allocate a new block and confirm that it's the first block freed.
# I.E The Least Recently Used block
new_block_hash = block_size
new_block = block_allocator.allocate(new_block_hash, 0)
assert (new_block == blocks[0])
@ -68,7 +72,8 @@ def test_eviction(num_blocks: int, ):
assert (realloc_block == blocks[realloc_block_hash])
assert (realloc_block.block_hash == realloc_block_hash)
# Allocate a new block and confirm that it's not the realloc_block, since the realloc_block shouldn't be in the free list
# Allocate a new block and confirm that it's not the realloc_block,
# since the realloc_block shouldn't be in the free list
new_block_hash = block_size + 1
new_block = block_allocator.allocate(new_block_hash, 0)
assert (realloc_block != new_block)

View File

@ -70,8 +70,8 @@ def test_get_prompt_logprobs(
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)
assert isinstance(sample_logprob.decoded_token, str), \
("The token should be decoded by the time it is returned "
assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is returned "
" to the user.")

View File

@ -255,9 +255,10 @@ def test_sampler_mixed(seed: int, device: str):
if metadata.sampling_params.use_beam_search:
continue
if metadata.sampling_params.seed is not None \
and expected_tokens[i] is None:
# Record seeded random result to compare with results of second invocation
if (metadata.sampling_params.seed is not None
and expected_tokens[i] is None):
# Record seeded random result to compare with results of
# second invocation
expected_tokens[i] = [
nth_output.output_token
for nth_output in sequence_output.samples
@ -265,11 +266,13 @@ def test_sampler_mixed(seed: int, device: str):
continue
for n, nth_output in enumerate(sequence_output.samples):
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
if (metadata.sampling_params.temperature == 0
or metadata.sampling_params.seed is not None):
# Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens[i][n]
else:
# For non-seeded random check that one of the high-logit tokens were chosen
# For non-seeded random check that one of the high-logit
# tokens were chosen
assert nth_output.output_token in expected_tokens[i]
# Test batch
@ -284,8 +287,8 @@ def test_sampler_mixed(seed: int, device: str):
input_tensor.data = input_tensor.index_select(0, target_index)
fake_logits.data = fake_logits.index_select(0, target_index)
# This time, results of seeded random samples will be compared with the corresponding
# sample in the pre-shuffled batch
# This time, results of seeded random samples will be compared with
# the corresponding sample in the pre-shuffled batch
test_sampling(model_runner)
del model_runner

View File

@ -150,8 +150,10 @@ def test_initial_metrics_has_correct_values(has_data: bool):
assert metrics.emitted_tokens == num_emitted_tokens
if has_data:
assert metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens
assert metrics.system_efficiency == num_emitted_tokens / num_possible_tokens
assert (metrics.draft_acceptance_rate == num_accepted_tokens /
num_draft_tokens)
assert (metrics.system_efficiency == num_emitted_tokens /
num_possible_tokens)
else:
assert math.isnan(metrics.draft_acceptance_rate)
assert math.isnan(metrics.system_efficiency)

View File

@ -3,7 +3,8 @@ import random
import pytest
from unittest.mock import MagicMock
from vllm.spec_decode.multi_step_worker import MultiStepWorker, DraftModelTop1Proposer
from vllm.spec_decode.multi_step_worker import (MultiStepWorker,
DraftModelTop1Proposer)
from vllm.worker.worker import Worker
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput

View File

@ -4,12 +4,15 @@ import pytest
from unittest.mock import MagicMock
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, split_num_cache_blocks_evenly
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly)
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from .utils import mock_worker, create_batch, ExecuteModelData, create_sampler_output_list
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics, AsyncMetricsCollector
from .utils import (mock_worker, create_batch, ExecuteModelData,
create_sampler_output_list)
from vllm.spec_decode.metrics import (SpecDecodeWorkerMetrics,
AsyncMetricsCollector)
@pytest.mark.parametrize('k', [1, 2, 6])
@ -391,13 +394,15 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
mock_rejsample_metrics = MagicMock(
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
metrics_collector.maybe_collect_rejsample_metrics.return_value = mock_rejsample_metrics
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
mock_rejsample_metrics)
output = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = metrics_collector.maybe_collect_rejsample_metrics.call_args_list
call_args_list = (
metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
assert len(call_args_list) == 1
args, kwargs = call_args_list[0]
assert args[0] == k or kwargs.get('k', -1) == k
@ -547,7 +552,8 @@ def test_profile_num_available_blocks(available_gpu_blocks: int,
target_worker.profile_num_available_blocks.return_value = (
available_gpu_blocks, available_cpu_blocks)
target_worker.get_cache_block_size_bytes.return_value = target_cache_block_size_bytes
target_worker.get_cache_block_size_bytes.return_value = (
target_cache_block_size_bytes)
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,

View File

@ -45,7 +45,7 @@ class ModelConfig:
a tag name, or a commit id. If unspecified, will use the default
version.
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
@ -189,8 +189,8 @@ class ModelConfig:
if is_hip(
) and self.quantization in rocm_not_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not supported "
f"in ROCm.")
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if self.quantization != "marlin":
logger.warning(
f"{self.quantization} quantization is not fully "
@ -321,7 +321,8 @@ class CacheConfig:
self.num_cpu_blocks = None
def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus metrics info
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self) -> None:
@ -399,8 +400,9 @@ class ParallelConfig:
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
if is_neuron():
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
# For Neuron device support, here we assign TP=1 to avoid sharding
# within vLLM directly. Transformer-neuronx would take
# neuron_tp_degree attribute, and distribute the workload
# to multiple NeuronCores.
self.tensor_parallel_size = 1
self.neuron_tp_degree = tensor_parallel_size

View File

@ -95,13 +95,15 @@ class BlockAllocator:
del self.cached_blocks[block.block_hash]
def get_num_free_blocks(self) -> int:
return self.num_blocks - self.current_num_blocks + self.evictor.num_blocks
return (self.num_blocks - self.current_num_blocks +
self.evictor.num_blocks)
def contains_block(self, block_hash: int) -> bool:
return block_hash in self.cached_blocks or block_hash in self.evictor
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
# If caching is enabled, update the hash of block and the cached_blocks dictionary.
# If caching is enabled, update the hash of block and the
# cached_blocks dictionary.
if self.enable_caching:
assert not self.contains_block(block_hash)
old_hash = block.block_hash
@ -218,10 +220,12 @@ class BlockSpaceManager:
seq: Sequence,
last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock:
# Compute a new hash for the block so that it can be shared by other Sequences
# Compute a new hash for the block so that it can be shared by
# other Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
# if new_hash is already in the cached table, then free last_block and return the cached version
# if new_hash is already in the cached table, then free last_block
# and return the cached version
if self.gpu_allocator.contains_block(new_hash):
self.gpu_allocator.free(last_block)
return self.gpu_allocator.allocate(new_hash)
@ -289,7 +293,8 @@ class BlockSpaceManager:
assert last_block.device == Device.GPU
if last_block.ref_count == 1:
# Not shared with other sequences. Appendable.
# If the last block is now complete, promote it to a full block so that it can be shared
# If the last block is now complete, promote it to a full block so
# that it can be shared
new_block = self._maybe_promote_last_block(seq, last_block)
block_table[-1] = new_block
return None

View File

@ -39,9 +39,9 @@ class Evictor(ABC):
@abstractmethod
def remove(self, block_hash: int) -> PhysicalTokenBlock:
"""Simply removes the block with the hash value block_hash from the
evictor. Caller is responsible for making sure that block_hash is contained
in the evictor before calling remove. Should be used to "bring back" blocks
that have been freed but not evicted yet.
evictor. Caller is responsible for making sure that block_hash is
contained in the evictor before calling remove. Should be used to
"bring back" blocks that have been freed but not evicted yet.
"""
pass

View File

@ -214,8 +214,8 @@ class Scheduler:
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if lora_int_id > 0 and lora_int_id not in curr_loras and len(
curr_loras) >= self.lora_config.max_loras:
if (lora_int_id > 0 and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_waiting_sequences.appendleft(seq_group)
@ -309,8 +309,8 @@ class Scheduler:
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if lora_int_id > 0 and lora_int_id not in curr_loras and len(
curr_loras) >= self.lora_config.max_loras:
if (lora_int_id > 0 and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_swapped.appendleft(seq_group)

View File

@ -100,7 +100,8 @@ class LLMEngine:
f"download_dir={model_config.download_dir!r}, "
f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, "
f"disable_custom_all_reduce="
f"{parallel_config.disable_custom_all_reduce}, "
f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, "
f"kv_cache_dtype={cache_config.cache_dtype}, "
@ -929,7 +930,8 @@ class LLMEngine:
# Latency Timings.
time_last_iters = []
for seq_group in scheduler_outputs.scheduled_seq_groups:
# Time since last token. (n.b. updates seq_group.metrics.last_token_time)
# Time since last token.
# (n.b. updates seq_group.metrics.last_token_time)
time_last_iters.append(seq_group.get_last_latency(now))
# Time since arrival for all finished requests.
if seq_group.is_finished():
@ -961,16 +963,17 @@ class LLMEngine:
for token_id, sample_logprob in logprobs.items():
if (sample_logprob.decoded_token is None and token_id != -1):
all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
_, new_text, prefix_offset, read_offset = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
(_, new_text, prefix_offset,
read_offset) = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:

View File

@ -1,5 +1,6 @@
from vllm.logger import init_logger
from prometheus_client import Counter, Gauge, Histogram, Info, REGISTRY, disable_created_metrics
from prometheus_client import (Counter, Gauge, Histogram, Info, REGISTRY,
disable_created_metrics)
import time
import numpy as np
@ -177,10 +178,12 @@ class StatLogger:
def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
# Logs metrics to prometheus that are computed every logging_interval.
# Support legacy gauge metrics that make throughput calculations on the vLLM side.
# Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens
# Which log raw data and calculate summaries using rate() on the grafana/prometheus side.
# See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
# Support legacy gauge metrics that make throughput calculations on
# the vLLM side. Moving forward, we should use counters like
# counter_prompt_tokens, counter_generation_tokens
# Which log raw data and calculate summaries using rate() on the
# grafana/prometheus side. See
# https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
self.metrics.gauge_avg_prompt_throughput.labels(
**self.labels).set(prompt_throughput)
self.metrics.gauge_avg_generation_throughput.labels(
@ -188,7 +191,7 @@ class StatLogger:
def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
Logs to prometheus and tracked stats every iteration.
Logs to prometheus and tracked stats every iteration.
Logs to Stdout every self.local_interval seconds."""
# Log to prometheus.
@ -200,8 +203,8 @@ class StatLogger:
# Log locally every local_interval seconds.
if self._local_interval_elapsed(stats.now):
# Compute summary metrics for tracked stats (and log them to promethus if applicable).
# Compute summary metrics for tracked stats (and log them
# to promethus if applicable).
prompt_throughput = self._get_throughput(self.num_prompt_tokens,
now=stats.now)
generation_throughput = self._get_throughput(
@ -213,7 +216,8 @@ class StatLogger:
# Log to stdout.
logger.info(
f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, "
f"Avg generation throughput: {generation_throughput:.1f} tokens/s, "
f"Avg generation throughput: "
f"{generation_throughput:.1f} tokens/s, "
f"Running: {stats.num_running} reqs, "
f"Swapped: {stats.num_swapped} reqs, "
f"Pending: {stats.num_waiting} reqs, "

View File

@ -1,7 +1,9 @@
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine and simple performance benchmarks.
It is not intended for production use. For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead.
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""
import argparse

View File

@ -18,7 +18,9 @@ from fastapi.responses import JSONResponse, StreamingResponse, Response
import vllm
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
from vllm.entrypoints.openai.protocol import (CompletionRequest,
ChatCompletionRequest,
ErrorResponse)
from vllm.logger import init_logger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
@ -84,13 +86,11 @@ def parse_args():
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument(
"--api-key",
type=str,
default=None,
help=
"If provided, the server will require this key to be presented in the header."
)
parser.add_argument("--api-key",
type=str,
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument("--served-model-name",
type=str,
default=None,
@ -103,9 +103,8 @@ def parse_args():
default=None,
nargs='+',
action=LoRAParserAction,
help=
"LoRA module configurations in the format name=path. Multiple modules can be specified."
)
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.")
parser.add_argument("--chat-template",
type=str,
default=None,
@ -138,9 +137,10 @@ def parse_args():
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server using app.add_middleware(). "
)
"If a function is provided, vLLM will add it to the server "
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args()
@ -235,9 +235,8 @@ if __name__ == "__main__":
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
raise ValueError(
f"Invalid middleware {middleware}. Must be a function or a class."
)
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")
logger.info(f"vLLM API server version {vllm.__version__}")
logger.info(f"args: {args}")

View File

@ -12,7 +12,8 @@ from vllm.entrypoints.openai.protocol import (
UsageInfo)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
logger = init_logger(__name__)
@ -37,8 +38,9 @@ class OpenAIServingChat(OpenAIServing):
ChatCompletionResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI
ChatCompletion API.
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
@ -116,7 +118,8 @@ class OpenAIServingChat(OpenAIServing):
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
if first_iteration:
# Send first response for each request.n (index) with the role
# Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request)
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
@ -133,7 +136,8 @@ class OpenAIServingChat(OpenAIServing):
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the last message
# Send response to echo the input portion of the
# last message
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
@ -145,11 +149,12 @@ class OpenAIServingChat(OpenAIServing):
if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
content=last_msg_content),
finish_reason=None)
choice_data = (
ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
content=last_msg_content),
finish_reason=None))
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,

View File

@ -1,7 +1,8 @@
import asyncio
import time
from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict, Tuple
from typing import (AsyncGenerator, AsyncIterator, Callable, List, Optional,
Dict, Tuple)
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -16,7 +17,8 @@ from vllm.entrypoints.openai.protocol import (
)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
logger = init_logger(__name__)
@ -44,9 +46,8 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays
else:
raise ValueError(
"prompt must be a string, array of strings, array of tokens, or array of token arrays"
)
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
return prompt_is_tokens, prompts
@ -156,7 +157,8 @@ class OpenAIServingCompletion(OpenAIServing):
int, RequestOutput]] = merge_async_iterators(*generators)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
# results. In addition, we do not stream the results when use
# beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
@ -223,7 +225,8 @@ class OpenAIServingCompletion(OpenAIServing):
for output in res.outputs:
i = output.index + prompt_idx * request.n
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
if request.echo and request.max_tokens == 0:
# only return the prompt
@ -231,11 +234,12 @@ class OpenAIServingCompletion(OpenAIServing):
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
elif request.echo and request.max_tokens > 0 and not has_echoed[
i]:
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = res.prompt_token_ids + output.token_ids
delta_token_ids = (res.prompt_token_ids +
output.token_ids)
top_logprobs = res.prompt_logprobs + (output.logprobs
or [])
has_echoed[i] = True
@ -248,7 +252,9 @@ class OpenAIServingCompletion(OpenAIServing):
i]:] if output.logprobs else None
if request.logprobs is not None:
assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested"
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,

View File

@ -50,10 +50,12 @@ class OpenAIServing:
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running(
): # If the current is instanced by Ray Serve, there is already a running event loop
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop.create_task(self._post_init())
else: # When using single vLLM without engine_use_ray
else:
# When using single vLLM without engine_use_ray
asyncio.run(self._post_init())
async def _post_init(self):
@ -178,8 +180,9 @@ class OpenAIServing:
if token_num + request.max_tokens > self.max_model_len:
raise ValueError(
f"This model's maximum context length is {self.max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )

View File

@ -20,10 +20,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear,
QKVParallelLinear,
MergedColumnParallelLinear)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim
from vllm.model_executor.parallel_utils.utils import (
split_tensor_along_last_dim)
if TYPE_CHECKING:
pass
@ -84,7 +86,8 @@ def _apply_lora_packed_nslice(
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...), where n is number of slices
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
x = x.view(-1, x.shape[-1])
@ -819,9 +822,8 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if 32000 < self.base_layer.vocab_size > 33024:
raise ValueError(
"When using LoRA, vocab size must be 32000 >= vocab_size <= 33024"
)
raise ValueError("When using LoRA, vocab size must be "
"32000 >= vocab_size <= 33024")
self.lora_a_stacked = torch.zeros(
(
max_loras,

View File

@ -13,7 +13,8 @@ from torch import nn
from vllm.config import LoRAConfig
from vllm.utils import LRUCache, in_wsl
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
from_layer_sampler)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule

View File

@ -154,10 +154,9 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
f"LoRA rank {lora.rank} is greater than max_lora_rank "
f"{self.lora_config.max_lora_rank}.")
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError(
f"LoRA added vocab size {lora.extra_vocab_size} is greater than "
f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}."
)
raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
f"is greater than lora_extra_vocab_size "
f"{self.lora_config.lora_extra_vocab_size}.")
return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:

View File

@ -8,8 +8,10 @@ from re import escape as regex_escape
from typing import Union, Tuple
from pydantic import BaseModel
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest
from vllm.model_executor.guided_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor
from vllm.entrypoints.openai.protocol import (CompletionRequest,
ChatCompletionRequest)
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
RegexLogitsProcessor)
class GuidedDecodingMode(Enum):

View File

@ -107,12 +107,15 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
Parameters
----------
schema
A JSON schema that encodes the structure we want the model to generate
A JSON schema that encodes the structure we want the model to
generate
tokenizer
The model's tokenizer
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
Pattern to use for JSON syntactic whitespace (doesn't impact
string literals)
Example: allow only a single space or newline with
`whitespace_pattern=r"[\n ]?"`
"""
if isinstance(schema, type(BaseModel)):
schema_str = json.dumps(schema.model_json_schema())
@ -122,8 +125,8 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
schema_str = schema
else:
raise ValueError(
f"Cannot parse schema {schema}. The schema must be either " +
"a Pydantic object, a dictionary or a string that contains the JSON "
+ "Schema specification")
f"Cannot parse schema {schema}. The schema must be either "
f"a Pydantic object, a dictionary or a string that contains "
f"the JSON Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer)

View File

@ -35,12 +35,12 @@ class Attention(nn.Module):
) -> None:
super().__init__()
if _use_flash_attn():
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
else:
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501
self.backend = XFormersBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)

View File

@ -30,9 +30,10 @@ def fused_moe_kernel(
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
@ -50,17 +51,30 @@ def fused_moe_kernel(
compute_type: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices.
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated,
and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to.
- expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A.
This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids`
by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert.
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
@ -105,7 +119,8 @@ def fused_moe_kernel(
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
@ -139,30 +154,41 @@ def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token.
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert.
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions align correctly.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens.
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1), ),
@ -224,13 +250,14 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of batch sizes
to configurations of the fused_moe kernel. To evaluate the kernel on a given batch
size bs, the closest batch size in the grid should be picked and the associated
configuration chosen to invoke the kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the fused_moe kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs directory
# First look up if an optimized configuration is available in the configs
# directory
device_name = torch.cuda.get_device_name().replace(" ", "_")
config_file_path = os.path.join(
@ -243,7 +270,8 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default configuration
# If no optimized configuration is available, we will use the default
# configuration
return None
@ -258,18 +286,22 @@ def fused_moe(
override_config: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation (before softmax).
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
@ -325,7 +357,8 @@ def fused_moe(
configs = get_moe_configs(E, w2.shape[2])
if configs:
# If an optimal configuration map has been found, look up the optimal config
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config

View File

@ -285,7 +285,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
# If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
@ -307,7 +308,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
# If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
@ -413,7 +415,8 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
# If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
@ -442,7 +445,8 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
# If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

View File

@ -1,6 +1,7 @@
from typing import Type
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig

View File

@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class AWQConfig(QuantizationConfig):
@ -50,7 +51,8 @@ class AWQConfig(QuantizationConfig):
def get_config_filenames() -> List[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
"quantize_config.json",
]
@classmethod

View File

@ -31,8 +31,8 @@ class GPTQConfig(QuantizationConfig):
self.pack_factor = Fraction(32, self.weight_bits)
if self.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is supported for "
f"GPTQ, but got {self.weight_bits} bits.")
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {self.weight_bits} bits.")
def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
@ -101,7 +101,8 @@ class GPTQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
@ -114,7 +115,8 @@ class GPTQLinearMethod(LinearMethodBase):
exllama_state = ExllamaState.UNINITIALIZED
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if input_size != input_size_per_partition and self.quant_config.group_size != -1:
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
# For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED

View File

@ -5,7 +5,8 @@ from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class MarlinConfig(QuantizationConfig):
@ -22,8 +23,9 @@ class MarlinConfig(QuantizationConfig):
self.group_size = group_size
if self.group_size != 128 and self.group_size != -1:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) is supported for "
f"Marlin, but got group_size of {self.group_size}")
"Currently, only group size 128 and -1 (channelwise) "
"is supported for Marlin, but got group_size of "
f"{self.group_size}")
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // 4
@ -37,7 +39,8 @@ class MarlinConfig(QuantizationConfig):
# Min in_features dim
self.min_k_threads = 128
# Max parallel problems to solve at once (improves large batch performance)
# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = 16
# Permutation length used by the marlin kernels.
@ -102,22 +105,26 @@ class MarlinLinearMethod(LinearMethodBase):
# Validate output_size_per_partition
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}."
)
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}.")
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}."
)
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}."
)
if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}."
)
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"min_k_threads = {self.quant_config.min_k_threads}.")
if (self.quant_config.group_size != -1 and
input_size_per_partition % self.quant_config.group_size != 0):
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}.")
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
@ -149,7 +156,9 @@ class MarlinLinearMethod(LinearMethodBase):
)
# Determine if channelwise or not
input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size
input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)
scales = Parameter(
torch.empty(

View File

@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import is_hip

View File

@ -6,7 +6,8 @@ import torch.nn as nn
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceData, SequenceGroupOutput,

View File

@ -333,7 +333,8 @@ class BaiChuanBaseForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
if name == "lm_head.weight":
# Unlike Baichuan, Baichuan2 normalizes the head weights. Refer to:
# Unlike Baichuan, Baichuan2 normalizes the head weights.
# Refer to:
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
# Distinguish between Baichuan and Baichuan2 by checking the
# vocab size. This is suggested by

View File

@ -119,7 +119,8 @@ class DeepseekMoE(nn.Module):
linear_method=None)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
@ -273,8 +274,9 @@ class DeepseekDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
)
if (config.n_routed_experts is not None and \
layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0):
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
else:
self.mlp = DeepseekMLP(

View File

@ -143,7 +143,8 @@ class GPTJBlock(nn.Module):
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner
inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config, linear_method)
self.mlp = GPTJMLP(inner_dim, config, linear_method)

View File

@ -305,7 +305,8 @@ class InternLM2ForCausalLM(nn.Module):
param = params_dict[name]
if "wqkv" in name:
config = self.config
kv_groups = config.num_attention_heads // config.num_key_value_heads
kv_groups = (config.num_attention_heads //
config.num_key_value_heads)
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
head_dim,

View File

@ -52,7 +52,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size, )
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -81,7 +82,8 @@ class SwiGLU(nn.Module):
class OlmoAttention(nn.Module):
"""
This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
This is the attention block where the output is computed as
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
@ -94,11 +96,12 @@ class OlmoAttention(nn.Module):
self.config = config
self.hidden_size = config.d_model
assert config.d_model % config.n_heads == 0
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
self.total_num_heads = self.config.n_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // self.total_num_heads
# Layer norms.
@ -158,7 +161,8 @@ class OlmoAttention(nn.Module):
class OlmoMLP(nn.Module):
"""
This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
This is the MLP block where the output is computed as
``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
@ -217,7 +221,8 @@ class OlmoMLP(nn.Module):
class OlmoBlock(nn.Module):
"""
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""

View File

@ -170,7 +170,8 @@ class Qwen2DecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
use_sliding_window = config.use_sliding_window and layer_idx < config.max_window_layers
use_sliding_window = (config.use_sliding_window
and layer_idx < config.max_window_layers)
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,

View File

@ -1,5 +1,6 @@
# coding=utf-8
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,7 +17,8 @@
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights."""
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
import torch
@ -102,9 +104,9 @@ class StablelmAttention(nn.Module):
self.kv_size = self.num_key_value_heads * self.head_dim
self.qkv_bias = getattr(config, "use_qkv_bias", False)
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
raise ValueError(f"hidden_size must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
self.qkv_proj = QKVParallelLinear(self.hidden_size,
self.head_dim,
@ -192,7 +194,6 @@ class StableLMEpochModel(nn.Module):
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
super().__init__()
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,

View File

@ -35,7 +35,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_world_size
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput

View File

@ -34,7 +34,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
from transformers_neuronx.config import NeuronConfig, ContinuousBatchingConfig
from transformers_neuronx.config import (NeuronConfig,
ContinuousBatchingConfig)
parallel_config = kwargs.get("parallel_config")
scheduler_config = kwargs.get("scheduler_config")

View File

@ -11,7 +11,8 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_group,
is_cupy_nccl_enabled_for_all_reduce,
)
from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce
from vllm.model_executor.parallel_utils.custom_all_reduce import (
custom_all_reduce)
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
@ -24,7 +25,7 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
and GPU topology.
TLDR: always assume this function modifies its input, but use the return
value as the output.
value as the output.
"""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:

View File

@ -114,7 +114,8 @@ class SamplingTensors:
do_penalties = True
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get their logprobs
# For tokens in the prompt that we only need to get
# their logprobs
prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
top_ps += [top_p] * (prompt_len - 1)

View File

@ -74,8 +74,8 @@ class SamplingParams:
stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens.
include_stop_str_in_output: Whether to include the stop strings in output
text. Defaults to False.
include_stop_str_in_output: Whether to include the stop strings in
output text. Defaults to False.
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.

View File

@ -351,7 +351,8 @@ class SequenceGroup:
self.metrics.first_token_time = time
def maybe_set_first_scheduled_time(self, time: float) -> None:
"""Sets the first scheduled time and time in queue for Request level timings."""
"""Sets the first scheduled time and time in queue for Request
level timings."""
if self.metrics.first_scheduled_time is None:
self.metrics.first_scheduled_time = time
self.metrics.time_in_queue = time - self.metrics.arrival_time

View File

@ -5,8 +5,12 @@ import torch
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceData)
from vllm.worker.worker import Worker
from vllm.spec_decode.util import nvtx_range, sampler_output_to_torch, get_all_seq_ids, split_batch_by_proposal_len
from vllm.spec_decode.interfaces import SpeculativeScorer, SpeculativeProposals, SpeculativeScores
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
get_all_seq_ids,
split_batch_by_proposal_len)
from vllm.spec_decode.interfaces import (SpeculativeScorer,
SpeculativeProposals,
SpeculativeScores)
SeqId = int
TargetSeqId = int
@ -68,11 +72,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list,
proposal_lens_list=proposal_lens_list,
)
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list,
proposal_lens_list=proposal_lens_list,
)
target_sampler_output = self._scorer_worker.execute_model(
seq_group_metadata_list=target_seq_group_metadata_list,
@ -125,7 +130,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
num_scoring_tokens = len(target_seq_group_metadata_list)
target_seq_group_metadata_list.extend(non_spec_seqs)
return spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens)
def _contract_batch(self, original_bs: int,
target_sampler_output: List[SamplerOutput],
@ -306,10 +312,11 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
non_spec_target_token_ids, non_spec_target_probs = sampler_output_to_torch(
[sampler_output])
non_spec_target_token_ids, non_spec_target_probs = (
sampler_output_to_torch([sampler_output]))
return target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs
return (target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs)
def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:

View File

@ -5,7 +5,8 @@ import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.worker import Worker
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeProposer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch
@ -247,8 +248,9 @@ class DraftModelTop1Proposer(SpeculativeProposer):
"""
# Split speculative- and non-speculative- sequences.
proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices = self._split_by_max_model_len(
seq_group_metadata_list, max_proposal_len)
(proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices) = self._split_by_max_model_len(
seq_group_metadata_list, max_proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
@ -306,7 +308,8 @@ class DraftModelTop1Proposer(SpeculativeProposer):
else:
proposal_lens.append(0)
return proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices
return (proposal_lens, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices)
def _merge_outputs(
self,
@ -356,7 +359,8 @@ class DraftModelTop1Proposer(SpeculativeProposer):
device=self._device)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = entire_proposal_tokens, entire_proposal_probs
proposal_tokens, proposal_probs = (entire_proposal_tokens,
entire_proposal_probs)
proposal_lens = torch.zeros(batch_size,
dtype=torch.long,

View File

@ -10,7 +10,8 @@ from vllm.worker.worker import Worker
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.config import CacheConfig
from vllm.spec_decode.util import nvtx_range, get_all_seq_ids, split_batch_by_proposal_len
from vllm.spec_decode.util import (nvtx_range, get_all_seq_ids,
split_batch_by_proposal_len)
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeScorer
@ -25,7 +26,7 @@ class SpecDecodeWorker:
LLM, after which some verification routine determines which (if any) of the
speculative tokens are accepted by the larger LLM.
See https://github.com/vllm-project/vllm/pull/2188 and
See https://github.com/vllm-project/vllm/pull/2188 and
https://github.com/vllm-project/vllm/pull/3103 for more info.
The current implementation has the following limitations:
@ -109,10 +110,12 @@ class SpecDecodeWorker:
block_size, gpu_memory_utilization, cpu_swap_space,
cache_dtype))
scorer_cache_block_size_bytes = self.scorer_worker.get_cache_block_size_bytes(
block_size, cache_dtype)
proposer_cache_block_size_bytes = self.proposer_worker.get_cache_block_size_bytes(
block_size, cache_dtype)
scorer_cache_block_size_bytes = (
self.scorer_worker.get_cache_block_size_bytes(
block_size, cache_dtype))
proposer_cache_block_size_bytes = (
self.proposer_worker.get_cache_block_size_bytes(
block_size, cache_dtype))
new_num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
@ -320,8 +323,8 @@ class SpecDecodeWorker:
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
maybe_rejsample_metrics = self._metrics.maybe_collect_rejsample_metrics(
k)
maybe_rejsample_metrics = (
self._metrics.maybe_collect_rejsample_metrics(k))
if maybe_rejsample_metrics is not None:
sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics

View File

@ -62,62 +62,6 @@ class MPTConfig(PretrainedConfig):
fc_type: str = 'torch',
verbose: Optional[int] = None,
**kwargs: Any):
"""The MPT configuration class.
Args:
d_model (int): The size of the embedding dimension of the model.
n_heads (int): The number of attention heads.
n_layers (int): The number of layers in the model.
expansion_ratio (int): The ratio of the up/down scale in the ffn.
max_seq_len (int): The maximum sequence length of the model.
vocab_size (int): The size of the vocabulary.
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
emb_pdrop (float): The dropout probability for the embedding layer.
learned_pos_emb (bool): Whether to use learned positional embeddings
attn_config (Dict): A dictionary used to configure the model's attention module:
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention
attn_pdrop (float): The dropout probability for the attention layers.
attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
this value.
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
use the default scale of ``1/sqrt(d_keys)``.
prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
which sub-sequence each token belongs to.
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
alibi (bool): Whether to use the alibi bias instead of position embeddings.
alibi_bias_max (int): The maximum value of the alibi bias.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
ffn_config (Dict): A dictionary used to configure the model's ffn module:
ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp
init_device (str): The device to use for parameter initialization.
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
no_bias (bool): Whether to use bias in all layers.
verbose (int): The verbosity level. 0 is silent.
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
norm_type (str): choose type of norm to use
use_cache (bool): Whether or not the model should return the last key/values attentions
init_config (Dict): A dictionary used to configure the model initialization:
init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
init_std (float): The standard deviation of the normal distribution used to initialize the model,
if using the baseline_ parameter initialization scheme.
init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
---
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
"""
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
@ -139,8 +83,8 @@ class MPTConfig(PretrainedConfig):
self.fc_type = fc_type
if verbose is not None:
warnings.warn(DeprecationWarning(
'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.'
),
'verbose argument for MPTConfig is now ignored and '
'will be removed. Use python_log_level instead.'),
stacklevel=2)
if 'name' in kwargs:
del kwargs['name']
@ -149,7 +93,8 @@ class MPTConfig(PretrainedConfig):
if self.attn_config.get('alibi', False):
self.learned_pos_emb = False
warnings.warn(
f'alibi is turned on, setting `learned_pos_emb` to {self.learned_pos_emb}`',
f'alibi is turned on, setting `learned_pos_emb` '
f'to {self.learned_pos_emb}`',
stacklevel=2)
super().__init__(**kwargs)
self._validate_config()
@ -176,8 +121,8 @@ class MPTConfig(PretrainedConfig):
[self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop]
)):
raise ValueError(
"self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1" # pylint: disable=line-too-long
)
"self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are "
"probabilities and must be between 0 and 1")
if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
raise ValueError(
f"Unknown attn_impl={self.attn_config['attn_impl']}")
@ -193,17 +138,17 @@ class MPTConfig(PretrainedConfig):
if self.attn_config['attn_uses_sequence_id'] and self.attn_config[
'attn_impl'] not in ['torch', 'triton']:
raise NotImplementedError(
'attn_uses_sequence_id only implemented with torch and triton attention.' # pylint: disable=line-too-long
)
'attn_uses_sequence_id only implemented with torch '
'and triton attention.')
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' # pylint: disable=line-too-long
)
'model.embedding_fraction must be between 0 (exclusive) '
'and 1 (inclusive)!')
if isinstance(self.logit_scale,
str) and self.logit_scale != 'inv_sqrt_d_model':
raise ValueError(
f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." # pylint: disable=line-too-long
)
f"self.logit_scale={self.logit_scale!r} is not recognized as "
"an option; use numeric value or 'inv_sqrt_d_model'.")
if self.init_config.get('name', None) is None:
raise ValueError(
f"self.init_config={self.init_config!r} 'name' needs to be set."
@ -219,11 +164,11 @@ class MPTConfig(PretrainedConfig):
del te
except Exception as exc:
raise ImportError(
# pylint: disable=line-too-long
'TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. '
+
'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n'
+ 'pip install flash-attn==1.0.6 --no-build-isolation \n' +
'TransformerEngine import fail. `fc_type: te` requires '
'TransformerEngine be installed. '
'The required version of transformer_engine also requires '
'FlashAttention v1.0.6 is installed:\n'
'pip install flash-attn==1.0.6 --no-build-isolation \n'
'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156'
) from exc
if self.ffn_config['ffn_type'] == 'mptmlp':

View File

@ -2,78 +2,6 @@ from transformers import PretrainedConfig
class Starcoder2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Starcoder2Model`]. It is used to instantiate a
Starcoder2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the [bigcode/starcoder2-7b_16k](https://huggingface.co/bigcode/starcoder2-7b_16k) model.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 49152):
Vocabulary size of the Starcoder2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Starcoder2Model`]
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 12288):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 30):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 24):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 2):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with. Starcoder2's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
norm_epsilon (`float`, *optional*, defaults to 1e-05):
Epsilon value for the layer norm
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
bos_token_id (`int`, *optional*, defaults to 50256):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 50256):
The id of the "end-of-sequence" token.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `None` (no sliding window).
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
residual_dropout (`float`, *optional*, defaults to 0.0):
Residual connection dropout value.
embedding_dropout (`float`, *optional*, defaults to 0.0):
Embedding dropout.
use_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias term on linear layers of the model.
```python
>>> from transformers import Starcoder2Model, Starcoder2Config
>>> # Initializing a Starcoder2 7B style configuration
>>> configuration = Starcoder2Config()
>>> # Initializing a model from the Starcoder2 7B style configuration
>>> model = Starcoder2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "starcoder2"
keys_to_ignore_at_inference = ["past_key_values"]

View File

@ -1,4 +1,3 @@
# yapf: disable
# Adapted from
# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/8f6e343d545c503b91429582231d1d354dac2740/tokenization_baichuan.py
# This includes a fix suggested in
@ -13,7 +12,6 @@ import sentencepiece as spm
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from transformers.utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
@ -52,27 +50,16 @@ class BaichuanTokenizer(PreTrainedTokenizer):
clean_up_tokenization_spaces=False,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = (
AddedToken(bos_token, lstrip=False, rstrip=False)
if isinstance(bos_token, str)
else bos_token
)
eos_token = (
AddedToken(eos_token, lstrip=False, rstrip=False)
if isinstance(eos_token, str)
else eos_token
)
unk_token = (
AddedToken(unk_token, lstrip=False, rstrip=False)
if isinstance(unk_token, str)
else unk_token
)
pad_token = (
AddedToken(pad_token, lstrip=False, rstrip=False)
if isinstance(pad_token, str)
else pad_token
)
self.sp_model_kwargs = ({} if sp_model_kwargs is None else
sp_model_kwargs)
bos_token = (AddedToken(bos_token, lstrip=False, rstrip=False)
if isinstance(bos_token, str) else bos_token)
eos_token = (AddedToken(eos_token, lstrip=False, rstrip=False)
if isinstance(eos_token, str) else eos_token)
unk_token = (AddedToken(unk_token, lstrip=False, rstrip=False)
if isinstance(unk_token, str) else unk_token)
pad_token = (AddedToken(pad_token, lstrip=False, rstrip=False)
if isinstance(pad_token, str) else pad_token)
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
@ -107,7 +94,10 @@ class BaichuanTokenizer(PreTrainedTokenizer):
def get_vocab(self):
"""Returns vocab as a dict"""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab = {
self.convert_ids_to_tokens(i): i
for i in range(self.vocab_size)
}
vocab.update(self.added_tokens_encoder)
return vocab
@ -130,7 +120,8 @@ class BaichuanTokenizer(PreTrainedTokenizer):
out_string = ""
prev_is_special = False
for i, token in enumerate(tokens):
# make sure that special tokens are not decoded using sentencepiece model
# make sure that special tokens are not decoded using
# sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special and i != 0:
out_string += " "
@ -143,9 +134,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
out_string += self.sp_model.decode(current_sub_tokens)
return out_string
def save_vocabulary(
self, save_directory, filename_prefix: Optional[str] = None
) -> Tuple[str]:
def save_vocabulary(self,
save_directory,
filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
@ -157,24 +148,24 @@ class BaichuanTokenizer(PreTrainedTokenizer):
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
logger.error(f"Vocabulary path ({save_directory}) "
"should be a directory")
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ VOCAB_FILES_NAMES["vocab_file"],
(filename_prefix + "-" if filename_prefix else "") +
VOCAB_FILES_NAMES["vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(
out_vocab_file
) and os.path.isfile(self.vocab_file):
out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
return (out_vocab_file, )
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
@ -194,7 +185,8 @@ class BaichuanTokenizer(PreTrainedTokenizer):
already_has_special_tokens: bool = False,
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
Retrieve sequence ids from a token list that has no special tokens
added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
@ -202,11 +194,14 @@ class BaichuanTokenizer(PreTrainedTokenizer):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
already_has_special_tokens (`bool`, *optional*, defaults to
`False`):
Whether or not the token list is already formatted with
special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
`List[int]`: A list of integers in the range [0, 1]:
1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
@ -220,20 +215,16 @@ class BaichuanTokenizer(PreTrainedTokenizer):
if token_ids_1 is None:
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
return (
bos_token_id
+ ([0] * len(token_ids_0))
+ eos_token_id
+ bos_token_id
+ ([0] * len(token_ids_1))
+ eos_token_id
)
return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id +
bos_token_id + ([0] * len(token_ids_1)) + eos_token_id)
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
Creates a mask from the two sequences passed to be used in a
sequence-pair classification task. An ALBERT
sequence pair mask has the following format:
```
@ -250,7 +241,8 @@ class BaichuanTokenizer(PreTrainedTokenizer):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
`List[int]`: List of [token type IDs](../glossary#token-type-ids)
according to the given sequence(s).
"""
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []

View File

@ -133,9 +133,10 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
# the Neuron-X backend does not have the `cuda_utils` module.
from vllm._C import cuda_utils
max_shared_mem = cuda_utils.get_max_shared_memory_per_block_device_attribute(
gpu)
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py will fail
max_shared_mem = (
cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu))
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero"
return int(max_shared_mem)
@ -209,9 +210,8 @@ def get_nvcc_cuda_version() -> Optional[Version]:
if not cuda_home:
cuda_home = '/usr/local/cuda'
if os.path.isfile(cuda_home + '/bin/nvcc'):
logger.info(
f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.'
)
logger.info(f'CUDA_HOME is not found in the environment. '
f'Using {cuda_home} as CUDA_HOME.')
else:
logger.warning(
f'Not found nvcc in {cuda_home}. Skip cuda version check!')

View File

@ -93,14 +93,13 @@ class ModelRunner:
scheduler_config=self.scheduler_config)
self.model_memory_usage = m.consumed_memory
logger.info(
f"Loading model weights took {self.model_memory_usage / float(2**30):.4f} GB"
)
logger.info(f"Loading model weights took "
f"{self.model_memory_usage / float(2**30):.4f} GB")
if self.lora_config:
assert hasattr(
self.model, "supported_lora_modules"
) and self.model.supported_lora_modules, "Model does not support LoRA"
assert hasattr(self.model, "supported_lora_modules"
) and self.model.supported_lora_modules, (
"Model does not support LoRA")
assert hasattr(
self.model,
"embedding_modules"), "Model does not have embedding_modules"

View File

@ -79,7 +79,8 @@ class Worker:
cpu_swap_space: int = 0,
cache_dtype: str = "float16",
) -> Tuple[int, int]:
"""Simply returns max_num_seqs as num_gpu_blocks, 0 as num_cpu_blocks."""
"""Simply returns max_num_seqs as num_gpu_blocks, 0 as
num_cpu_blocks."""
num_gpu_blocks = self.scheduler_config.max_num_seqs
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
@ -177,7 +178,8 @@ def _init_distributed_environment(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
distributed_backend = distributed_backend if distributed_backend else "nccl"
distributed_backend = (distributed_backend
if distributed_backend else "nccl")
torch.distributed.init_process_group(
backend=distributed_backend,
world_size=parallel_config.world_size,