[Quality] Add code formatter and linter (#326)

This commit is contained in:
Zhuohan Li 2023-07-03 11:31:55 -07:00 committed by GitHub
parent 0ffded812a
commit d6fa1be3a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 1547 additions and 617 deletions

434
.pylintrc Normal file
View File

@ -0,0 +1,434 @@
# This Pylint rcfile contains a best-effort configuration to uphold the
# best-practices and style described in the Google Python style guide:
# https://google.github.io/styleguide/pyguide.html
#
# Its canonical open-source location is:
# https://google.github.io/styleguide/pylintrc
[MASTER]
# Files or directories to be skipped. They should be base names, not paths.
ignore=docs,parallel_utils
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=
# Pickle collected data for later comparisons.
persistent=no
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=4
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=abstract-method,
apply-builtin,
arguments-differ,
attribute-defined-outside-init,
backtick,
bad-option-value,
basestring-builtin,
buffer-builtin,
c-extension-no-member,
consider-using-enumerate,
cmp-builtin,
cmp-method,
coerce-builtin,
coerce-method,
delslice-method,
div-method,
duplicate-code,
eq-without-hash,
execfile-builtin,
file-builtin,
filter-builtin-not-iterating,
fixme,
getslice-method,
global-statement,
hex-method,
idiv-method,
implicit-str-concat-in-sequence,
import-error,
import-self,
import-star-module-level,
inconsistent-return-statements,
input-builtin,
intern-builtin,
invalid-str-codec,
locally-disabled,
logging-fstring-interpolation, # added by vLLM
logging-not-lazy, # added by vLLM
long-builtin,
long-suffix,
map-builtin-not-iterating,
misplaced-comparison-constant,
missing-class-docstring, # TODO (vLLM): enable
missing-function-docstring,
missing-module-docstring, # TODO (vLLM): enable
metaclass-assignment,
next-method-called,
next-method-defined,
no-absolute-import,
no-else-break,
no-else-continue,
no-else-raise,
no-else-return,
no-init, # added
no-member,
no-name-in-module,
no-self-use,
nonzero-method,
oct-method,
old-division,
old-ne-operator,
old-octal-literal,
old-raise-syntax,
parameter-unpacking,
print-statement,
raising-string,
range-builtin-not-iterating,
raw_input-builtin,
rdiv-method,
reduce-builtin,
relative-import,
reload-builtin,
round-builtin,
setslice-method,
signature-differs,
standarderror-builtin,
suppressed-message,
sys-max-int,
too-few-public-methods,
too-many-ancestors,
too-many-arguments,
too-many-boolean-expressions,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
unichr-builtin,
unicode-builtin,
unnecessary-pass,
unpacking-in-except,
unspecified-encoding,
useless-else-on-loop,
useless-object-inheritance,
useless-suppression,
using-cmp-argument,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[BASIC]
# Good variable names which should always be accepted, separated by a comma
good-names=main,_
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
# Regular expression matching correct function names
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
# Regular expression matching correct variable names
variable-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct constant names
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct attribute names
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
# Regular expression matching correct argument names
argument-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class attribute names
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct inline iteration names
inlinevar-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class names
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
# Regular expression matching correct module names
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
# Regular expression matching correct method names
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=80
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
# lines made too long by directives to pytype.
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=(?x)(
^\s*(\#\ )?<?https?://\S+>?$|
^\s*(from\s+\S+\s+)?import\s+.+$)
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
# Maximum number of lines in a module
max-module-lines=99999
# String used as indentation unit. The internal Google style guide mandates 2
# spaces. Google's externaly-published style guide says 4, consistent with
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
# projects (like TensorFlow).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=TODO
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=yes
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging,absl.logging,tensorflow.io.logging
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,
TERMIOS,
Bastion,
rexec,
sets
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant, absl
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls,
class_
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=StandardError,
Exception,
BaseException

View File

@ -49,12 +49,15 @@ If not, please file a new issue, providing as much relevant information as possi
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html). In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
We include a formatting script [`format.sh`](./format.sh) to format the code.
### Pull Requests ### Pull Requests
When submitting a pull request: When submitting a pull request:
1. Make sure your code has been rebased on top of the latest commit on the main branch. 1. Make sure your code has been rebased on top of the latest commit on the main branch.
2. Include a detailed description of the changes in the pull request. 2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
3. Include a detailed description of the changes in the pull request.
Explain why you made the changes you did. Explain why you made the changes you did.
If your pull request fixes an open issue, please include a reference to it in the description. If your pull request fixes an open issue, please include a reference to it in the description.

View File

@ -14,7 +14,9 @@ def clear_line(n: int = 1) -> None:
print(LINE_UP, end=LINE_CLEAR, flush=True) print(LINE_UP, end=LINE_CLEAR, flush=True)
def post_http_request(prompt: str, api_url: str, n: int = 1, def post_http_request(prompt: str,
api_url: str,
n: int = 1,
stream: bool = False) -> requests.Response: stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"} headers = {"User-Agent": "Test Client"}
pload = { pload = {
@ -30,7 +32,8 @@ def post_http_request(prompt: str, api_url: str, n: int = 1,
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]: def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"): delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))

View File

@ -12,9 +12,14 @@ def http_bot(prompt):
"stream": True, "stream": True,
"max_tokens": 128, "max_tokens": 128,
} }
response = requests.post(args.model_url, headers=headers, json=pload, stream=True) response = requests.post(args.model_url,
headers=headers,
json=pload,
stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))
output = data["text"][0] output = data["text"][0]
@ -23,11 +28,11 @@ def http_bot(prompt):
def build_demo(): def build_demo():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown( gr.Markdown("# vLLM text completion demo\n")
"# vLLM text completion demo\n" inputbox = gr.Textbox(label="Input",
) placeholder="Enter text and press ENTER")
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER") outputbox = gr.Textbox(label="Output",
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") placeholder="Generated result from the model")
inputbox.submit(http_bot, [inputbox], [outputbox]) inputbox.submit(http_bot, [inputbox], [outputbox])
return demo return demo
@ -36,7 +41,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate") parser.add_argument("--model-url",
type=str,
default="http://localhost:8000/generate")
args = parser.parse_args() args = parser.parse_args()
demo = build_demo() demo = build_demo()

View File

@ -14,9 +14,14 @@ def main(args: argparse.Namespace):
("To be or not to be,", ("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?", ("What is the meaning of life?",
SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)), SamplingParams(n=2,
best_of=5,
temperature=0.8,
top_p=0.95,
frequency_penalty=0.1)),
("It is only with the heart that one can see rightly", ("It is only with the heart that one can see rightly",
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), SamplingParams(n=3, best_of=3, use_beam_search=True,
temperature=0.0)),
] ]
# Run the engine by calling `engine.step()` manually. # Run the engine by calling `engine.step()` manually.

View File

@ -1,6 +1,5 @@
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",

View File

@ -12,8 +12,13 @@ print("Models:", models)
# Test completion API # Test completion API
stream = True stream = True
completion = openai.Completion.create( completion = openai.Completion.create(
model=model, prompt="A robot may not injure a human being", echo=False, n=2, model=model,
best_of=3, stream=stream, logprobs=3) prompt="A robot may not injure a human being",
echo=False,
n=2,
best_of=3,
stream=stream,
logprobs=3)
# print the completion # print the completion
if stream: if stream:

108
format.sh Executable file
View File

@ -0,0 +1,108 @@
#!/usr/bin/env bash
# YAPF formatter, adapted from ray and skypilot.
#
# Usage:
# # Do work and commit your work.
# # Format files that differ from origin/main.
# bash format.sh
# # Commit changed files with message 'Run yapf and pylint'
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
set -eo pipefail
# this stops git rev-parse from failing if we run this from the .git directory
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1
YAPF_VERSION=$(yapf --version | awk '{print $2}')
PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}')
# # params: tool name, tool version, required version
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2."
exit 1
fi
}
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
YAPF_FLAGS=(
'--recursive'
'--parallel'
)
YAPF_EXCLUDES=(
'--exclude' 'build/**'
'--exclude' 'vllm/model_executor/parallel_utils/**'
)
# Format specified files
format() {
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
}
# Format files that differ from main branch. Ignores dirs that are not slated
# for autoformat yet.
format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause yapf to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
fi
}
# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
}
## This flag formats individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
format "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is formatted.
elif [[ "$1" == '--all' ]]; then
format_all
else
# Format only the files that changed in last commit.
format_changed
fi
echo 'vLLM yapf: Done'
# Run mypy
# TODO(zhuohan): Enable mypy
# echo 'vLLM mypy:'
# mypy
# Run Pylint
echo 'vLLM Pylint:'
pylint vllm
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
echo
git --no-pager diff --name-only
exit 1
fi

View File

@ -1,2 +1,12 @@
mypy # formatting
yapf==0.32.0
pylint==2.8.2
# type checking
mypy==0.991
types-PyYAML
types-requests
types-setuptools
# testing
pytest pytest

View File

@ -60,7 +60,7 @@ def ref_single_query_cached_kv_attention(
keys = torch.stack(keys, dim=0) keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0) values = torch.stack(values, dim=0)
scale = 1.0 / (head_size ** 0.5) scale = 1.0 / (head_size**0.5)
out = ref_masked_attention(q, keys, values, scale) out = ref_masked_attention(q, keys, values, scale)
out = out.view(num_heads, head_size) out = out.view(num_heads, head_size)
output[i].copy_(out, non_blocking=True) output[i].copy_(out, non_blocking=True)
@ -74,7 +74,7 @@ def ref_multi_query_kv_attention(
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
head_size = query.shape[-1] head_size = query.shape[-1]
scale = 1.0 / (head_size ** 0.5) scale = 1.0 / (head_size**0.5)
num_seqs = len(cu_seq_lens) - 1 num_seqs = len(cu_seq_lens) - 1
ref_outputs = [] ref_outputs = []
@ -84,8 +84,8 @@ def ref_multi_query_kv_attention(
seq_len = end_idx - start_idx seq_len = end_idx - start_idx
# Create attention mask. # Create attention mask.
attn_mask = torch.triu( attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device='cuda') attn_mask = attn_mask.to(dtype=dtype, device='cuda')
@ -113,7 +113,7 @@ def ref_multi_query_cached_kv_attention(
num_heads = value_cache.shape[1] num_heads = value_cache.shape[1]
head_size = value_cache.shape[2] head_size = value_cache.shape[2]
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
scale = 1.0 / (head_size ** 0.5) scale = 1.0 / (head_size**0.5)
num_queries = len(cu_query_lens) - 1 num_queries = len(cu_query_lens) - 1
ref_outputs = [] ref_outputs = []
@ -125,8 +125,8 @@ def ref_multi_query_cached_kv_attention(
block_table = block_tables[i] block_table = block_tables[i]
# Create attention mask # Create attention mask
attn_mask = torch.triu( attn_mask = torch.triu(torch.ones(query_len, context_len),
torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5 diagonal=context_len - query_len + 1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda') attn_mask = attn_mask.to(dtype=dtype, device='cuda')
keys = [] keys = []
@ -165,22 +165,28 @@ def run_single_query_cached_kv_attention(
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
) -> None: ) -> None:
qkv = torch.empty( qkv = torch.empty(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
qkv.uniform_(-1e-3, 1e-3) qkv.uniform_(-1e-3, 1e-3)
query, _, _ = qkv.unbind(dim=1) query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size() x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x) key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.empty( key_cache = torch.empty(size=(num_blocks, *key_block_shape),
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
key_cache.uniform_(-1e-3, 1e-3) key_cache.uniform_(-1e-3, 1e-3)
value_block_shape = (num_heads, head_size, block_size) value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.empty( value_cache = torch.empty(size=(num_blocks, *value_block_shape),
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
value_cache.uniform_(-1e-3, 1e-3) value_cache.uniform_(-1e-3, 1e-3)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
max_context_len = max(context_lens) max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
@ -194,9 +200,12 @@ def run_single_query_cached_kv_attention(
block_tables.append(block_table) block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5)) scale = float(1.0 / (head_size**0.5))
output = torch.empty( output = torch.empty(num_tokens,
num_tokens, num_heads, head_size, dtype=dtype, device='cuda') num_heads,
head_size,
dtype=dtype,
device='cuda')
attention_ops.single_query_cached_kv_attention( attention_ops.single_query_cached_kv_attention(
output, output,
query, query,
@ -235,9 +244,13 @@ def run_multi_query_kv_attention(
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
num_tokens = sum(seq_lens) num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size ** 0.5)) scale = float(1.0 / (head_size**0.5))
qkv = torch.empty( qkv = torch.empty(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
qkv.uniform_(-1e-3, 1e-3) qkv.uniform_(-1e-3, 1e-3)
query, key, value = qkv.unbind(dim=1) query, key, value = qkv.unbind(dim=1)

View File

@ -26,8 +26,9 @@ def run_copy_blocks(
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = [] key_caches = []
for _ in range(num_layers): for _ in range(num_layers):
key_cache = torch.randn( key_cache = torch.randn(size=key_cache_shape,
size=key_cache_shape, dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
key_caches.append(key_cache) key_caches.append(key_cache)
cloned_key_caches = [] cloned_key_caches = []
for key_cache in key_caches: for key_cache in key_caches:
@ -36,8 +37,9 @@ def run_copy_blocks(
value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = [] value_caches = []
for _ in range(num_layers): for _ in range(num_layers):
value_cache = torch.randn( value_cache = torch.randn(size=value_cache_shape,
size=value_cache_shape, dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
value_caches.append(value_cache) value_caches.append(value_cache)
cloned_value_caches = [] cloned_value_caches = []
for value_cache in value_caches: for value_cache in value_caches:
@ -49,15 +51,18 @@ def run_copy_blocks(
# Reference implementation. # Reference implementation.
for src, dsts in block_mapping.items(): for src, dsts in block_mapping.items():
for dst in dsts: for dst in dsts:
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): for key_cache, cloned_key_cache in zip(key_caches,
cloned_key_caches):
cloned_key_cache[dst] = cloned_key_cache[src] cloned_key_cache[dst] = cloned_key_cache[src]
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
cloned_value_cache[dst] = cloned_value_cache[src] cloned_value_cache[dst] = cloned_value_cache[src]
# Compare the results. # Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
assert torch.allclose(value_cache, cloned_value_cache) assert torch.allclose(value_cache, cloned_value_cache)
@ -74,8 +79,12 @@ def run_reshape_and_cache(
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn( qkv = torch.randn(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
_, key, value = qkv.unbind(dim=1) _, key, value = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size() x = 16 // torch.tensor([], dtype=dtype).element_size()
@ -84,15 +93,19 @@ def run_reshape_and_cache(
cloned_key_cache = key_cache.clone() cloned_key_cache = key_cache.clone()
value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn( value_cache = torch.randn(size=value_cache_shape,
size=value_cache_shape, dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
cloned_value_cache = value_cache.clone() cloned_value_cache = value_cache.clone()
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping)
for i in range(num_tokens): for i in range(num_tokens):
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x) reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor') block_idx = torch.div(slot_mapping[i],
block_size,
rounding_mode='floor')
block_offset = slot_mapping[i] % block_size block_offset = slot_mapping[i] % block_size
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i]
@ -114,8 +127,12 @@ def run_gather_cached_kv(
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn( qkv = torch.randn(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
_, key, value = qkv.unbind(dim=1) _, key, value = qkv.unbind(dim=1)
qkv_clone = qkv.clone() qkv_clone = qkv.clone()
@ -126,15 +143,20 @@ def run_gather_cached_kv(
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn( value_cache = torch.randn(size=value_cache_shape,
size=value_cache_shape, dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping) cache_ops.gather_cached_kv(key, value, key_cache, value_cache,
slot_mapping)
# Reference implementation. # Reference implementation.
for i in range(num_tokens): for i in range(num_tokens):
reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x) reshaped_key = cloned_key.reshape(num_tokens, num_heads,
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor') head_size // x, x)
block_idx = torch.div(slot_mapping[i],
block_size,
rounding_mode='floor')
block_offset = slot_mapping[i] % block_size block_offset = slot_mapping[i] % block_size
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :] reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
cloned_value[i] = value_cache[block_idx, :, :, block_offset] cloned_value[i] = value_cache[block_idx, :, :, block_offset]
@ -145,20 +167,30 @@ def run_gather_cached_kv(
def test_copy_blocks() -> None: def test_copy_blocks() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
run_copy_blocks( run_copy_blocks(num_mappings=23,
num_mappings=23, num_layers=7, num_heads=17, head_size=16, num_layers=7,
block_size=8, num_blocks=1024, dtype=dtype) num_heads=17,
head_size=16,
block_size=8,
num_blocks=1024,
dtype=dtype)
def test_reshape_and_cache() -> None: def test_reshape_and_cache() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
run_reshape_and_cache( run_reshape_and_cache(num_tokens=3,
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, num_heads=2,
dtype=dtype) head_size=16,
block_size=8,
num_blocks=2,
dtype=dtype)
def test_gather_cached_kv() -> None: def test_gather_cached_kv() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
run_gather_cached_kv( run_gather_cached_kv(num_tokens=3,
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, num_heads=2,
dtype=dtype) head_size=16,
block_size=8,
num_blocks=2,
dtype=dtype)

View File

@ -14,8 +14,10 @@ class RefRMSNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states): def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).pow(2).mean(-1,
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]: if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states return self.weight * hidden_states

View File

@ -8,8 +8,8 @@ from vllm import pos_encoding_ops
def rotate_half(x: torch.Tensor) -> torch.Tensor: def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2] x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :] x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
@ -38,7 +38,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings. # Create cos and sin embeddings.
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_embeddings).float() t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, inv_freq.float()) freqs = torch.einsum("i,j->ij", t, inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
@ -49,16 +49,15 @@ class RefRotaryEmbeddingNeox(nn.Module):
def forward( def forward(
self, self,
positions: torch.Tensor, # [num_tokens] positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size] query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size] key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
query_rot = query[..., : self.rotary_dim] query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim :] query_pass = query[..., self.rotary_dim:]
key_rot = key[..., : self.rotary_dim] key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim :] key_pass = key[..., self.rotary_dim:]
query_rot = query_rot.transpose(0, 1) query_rot = query_rot.transpose(0, 1)
key_rot = key_rot.transpose(0, 1) key_rot = key_rot.transpose(0, 1)
@ -85,12 +84,18 @@ def run_rotary_embedding_neox(
dtype: torch.dtype, dtype: torch.dtype,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
positions = torch.randint(0, max_position, (num_tokens,), device='cuda') positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda') query = torch.randn(num_tokens,
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda') num_heads * head_size,
dtype=dtype,
device='cuda')
key = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device='cuda')
# Create the rotary embedding. # Create the rotary embedding.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float() t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos() cos = freqs.cos()

View File

@ -1,3 +1,5 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine

View File

@ -35,7 +35,8 @@ class LogicalTokenBlock:
def append_tokens(self, token_ids: List[int]) -> None: def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots() assert len(token_ids) <= self.get_num_empty_slots()
self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids) self.num_tokens += len(token_ids)
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> List[int]:

View File

@ -8,7 +8,7 @@ from vllm.utils import get_cpu_memory
logger = init_logger(__name__) logger = init_logger(__name__)
_GiB = 1 << 30 _GB = 1 << 30
class ModelConfig: class ModelConfig:
@ -106,6 +106,7 @@ class CacheConfig:
vLLM execution. vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB). swap_space: Size of the CPU swap space per GPU (in GiB).
""" """
def __init__( def __init__(
self, self,
block_size: int, block_size: int,
@ -114,7 +115,7 @@ class CacheConfig:
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GiB self.swap_space_bytes = swap_space * _GB
self._verify_args() self._verify_args()
# Will be set after profiling. # Will be set after profiling.
@ -137,14 +138,13 @@ class CacheConfig:
num_gpus_per_node = parallel_config.tensor_parallel_size num_gpus_per_node = parallel_config.tensor_parallel_size
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
msg = ( msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
f"{cpu_memory_usage / _GiB:.2f} GiB out of " f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
f"the {total_cpu_memory / _GiB:.2f} GiB total CPU memory is " "allocated for the swap space.")
"allocated for the swap space.")
if cpu_memory_usage > 0.7 * total_cpu_memory: if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg) raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory: elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warn("Possibly too large swap space. " + msg) logger.warning("Possibly too large swap space. " + msg)
class ParallelConfig: class ParallelConfig:
@ -157,6 +157,7 @@ class ParallelConfig:
True if either pipeline_parallel_size or tensor_parallel_size is True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1. greater than 1.
""" """
def __init__( def __init__(
self, self,
pipeline_parallel_size: int, pipeline_parallel_size: int,
@ -189,12 +190,9 @@ class SchedulerConfig:
max_seq_len: Maximum length of a sequence (including prompt max_seq_len: Maximum length of a sequence (including prompt
and generated text). and generated text).
""" """
def __init__(
self, def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
max_num_batched_tokens: int, max_seq_len: int) -> None:
max_num_seqs: int,
max_seq_len: int
) -> None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
@ -241,7 +239,7 @@ def _get_and_verify_dtype(
pass pass
else: else:
# Casting between float16 and bfloat16 is allowed with a warning. # Casting between float16 and bfloat16 is allowed with a warning.
logger.warn(f"Casting {config_dtype} to {torch_dtype}.") logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: if torch_dtype == torch.bfloat16:

View File

@ -27,8 +27,9 @@ class BlockAllocator:
# Initialize the free blocks. # Initialize the free blocks.
self.free_blocks: List[PhysicalTokenBlock] = [] self.free_blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks): for i in range(num_blocks):
block = PhysicalTokenBlock( block = PhysicalTokenBlock(device=device,
device=device, block_number=i, block_size=block_size) block_number=i,
block_size=block_size)
self.free_blocks.append(block) self.free_blocks.append(block)
def allocate(self) -> PhysicalTokenBlock: def allocate(self) -> PhysicalTokenBlock:
@ -84,10 +85,12 @@ class BlockSpaceManager:
num_required_blocks = len(seq.logical_token_blocks) num_required_blocks = len(seq.logical_token_blocks)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction. # Use watermark to avoid frequent cache eviction.
return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks return (num_free_gpu_blocks - num_required_blocks >=
self.watermark_blocks)
def allocate(self, seq_group: SequenceGroup) -> None: def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same prompt. # NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = seq_group.get_seqs()[0] seq = seq_group.get_seqs()[0]
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
@ -143,7 +146,8 @@ class BlockSpaceManager:
for block in src_block_table: for block in src_block_table:
block.ref_count += 1 block.ref_count += 1
def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: def _get_physical_blocks(
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
# NOTE: Here, we assume that the physical blocks are only shared by # NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group. # the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set() blocks: Set[PhysicalTokenBlock] = set()

View File

@ -43,8 +43,7 @@ class SchedulerOutputs:
assert not (blocks_to_swap_in and blocks_to_swap_out) assert not (blocks_to_swap_in and blocks_to_swap_out)
def is_empty(self) -> bool: def is_empty(self) -> bool:
return (not self.blocks_to_swap_in return (not self.blocks_to_swap_in and not self.blocks_to_swap_out
and not self.blocks_to_swap_out
and not self.blocks_to_copy) and not self.blocks_to_copy)
@ -61,7 +60,7 @@ class Scheduler:
self.log_stats = log_stats self.log_stats = log_stats
# Instantiate the scheduling policy. # Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name='fcfs') self.policy = PolicyFactory.get_policy(policy_name="fcfs")
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManager( self.block_manager = BlockSpaceManager(
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
@ -102,7 +101,8 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]: def _schedule(
self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
@ -160,7 +160,8 @@ class Scheduler:
num_curr_seqs = sum( num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING) seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running) for seq_group in self.running)
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break break
seq_group = self.swapped.pop(0) seq_group = self.swapped.pop(0)
@ -170,8 +171,7 @@ class Scheduler:
num_batched_tokens = sum( num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING) seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running for seq_group in self.running)
)
# Join waiting sequences if possible. # Join waiting sequences if possible.
prompt_group_ids: List[str] = [] prompt_group_ids: List[str] = []
@ -191,7 +191,7 @@ class Scheduler:
num_prompt_tokens = seq_group.get_seqs()[0].get_len() num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens >= self.scheduler_config.max_seq_len: if num_prompt_tokens >= self.scheduler_config.max_seq_len:
logger.warn( logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long" f"Input prompt ({num_prompt_tokens} tokens) is too long"
" and exceeds limit of " " and exceeds limit of "
f"{self.scheduler_config.max_seq_len}") f"{self.scheduler_config.max_seq_len}")
@ -206,17 +206,19 @@ class Scheduler:
break break
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens if (num_batched_tokens + num_prompt_tokens >
> self.scheduler_config.max_num_batched_tokens): self.scheduler_config.max_num_batched_tokens):
break break
# The total number of sequences in the RUNNING state should not # The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences. # exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING) num_new_seqs = seq_group.num_seqs(
status=SequenceStatus.WAITING)
num_curr_seqs = sum( num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING) seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running) for seq_group in self.running)
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break break
seq_group = self.waiting.pop(0) seq_group = self.waiting.pop(0)
@ -240,12 +242,11 @@ class Scheduler:
elapsed_time = now - self.last_logging_time elapsed_time = now - self.last_logging_time
if elapsed_time > _LOGGING_INTERVAL_SEC: if elapsed_time > _LOGGING_INTERVAL_SEC:
self.last_logging_time = now self.last_logging_time = now
self.num_input_tokens = [ self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens
(t, n) for t, n in self.num_input_tokens if now - t < _LOGGING_INTERVAL_SEC]
if now - t < _LOGGING_INTERVAL_SEC
]
if len(self.num_input_tokens) > 1: if len(self.num_input_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_input_tokens[:-1]) total_num_tokens = sum(n
for _, n in self.num_input_tokens[:-1])
window = now - self.num_input_tokens[0][0] window = now - self.num_input_tokens[0][0]
avg_throughput = total_num_tokens / window avg_throughput = total_num_tokens / window
else: else:
@ -258,26 +259,30 @@ class Scheduler:
total_num_cpu_blocks = self.cache_config.num_cpu_blocks total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0: if total_num_cpu_blocks > 0:
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks() num_free_cpu_blocks = (
self.block_manager.get_num_free_cpu_blocks())
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else: else:
cpu_cache_usage = 0.0 cpu_cache_usage = 0.0
logger.info( logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, "
f"Throughput: {avg_throughput:.1f} tokens/s, " f"Running: {len(self.running)} reqs, "
f"Running: {len(self.running)} reqs, " f"Swapped: {len(self.swapped)} reqs, "
f"Swapped: {len(self.swapped)} reqs, " f"Pending: {len(self.waiting)} reqs, "
f"Pending: {len(self.waiting)} reqs, " f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
return scheduler_outputs, prompt_group_ids, ignored_seq_groups return scheduler_outputs, prompt_group_ids, ignored_seq_groups
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[SequenceGroup]]: def schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
List[SequenceGroup]]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting. # such as self.running, self.swapped, and self.waiting.
scheduler_outputs, prompt_group_ids, ignored_seq_groups = self._schedule() (scheduler_outputs, prompt_group_ids,
ignored_seq_groups) = self._schedule()
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
@ -311,8 +316,8 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id] output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id: if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search). # The sequence is a fork of the parent sequence (beam
# Free the current sequence. # search). Free the current sequence.
self.block_manager.free(seq) self.block_manager.free(seq)
# Fork the parent sequence. # Fork the parent sequence.
parent_seq = seq_group.find(output.parent_seq_id) parent_seq = seq_group.find(output.parent_seq_id)
@ -385,7 +390,7 @@ class Scheduler:
elif preemption_mode == PreemptionMode.SWAP: elif preemption_mode == PreemptionMode.SWAP:
self._preempt_by_swap(seq_group, blocks_to_swap_out) self._preempt_by_swap(seq_group, blocks_to_swap_out)
else: else:
assert False, 'Invalid preemption mode.' assert False, "Invalid preemption mode."
def _preempt_by_recompute( def _preempt_by_recompute(
self, self,

View File

@ -12,11 +12,11 @@ class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
model: str model: str
tokenizer: Optional[str] = None tokenizer: Optional[str] = None
tokenizer_mode: str = "auto" tokenizer_mode: str = 'auto'
download_dir: Optional[str] = None download_dir: Optional[str] = None
use_np_weights: bool = False use_np_weights: bool = False
use_dummy_weights: bool = False use_dummy_weights: bool = False
dtype: str = "auto" dtype: str = 'auto'
seed: int = 0 seed: int = 0
worker_use_ray: bool = False worker_use_ray: bool = False
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
@ -35,76 +35,101 @@ class EngineArgs:
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
# Model arguments # Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', parser.add_argument(
help='name or path of the huggingface model to use') '--model',
parser.add_argument('--tokenizer', type=str, default=EngineArgs.tokenizer, type=str,
help='name or path of the huggingface tokenizer to use') default='facebook/opt-125m',
parser.add_argument('--tokenizer-mode', type=str, help='name or path of the huggingface model to use')
parser.add_argument(
'--tokenizer',
type=str,
default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument('--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode, default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'], choices=['auto', 'slow'],
help='tokenizer mode. "auto" will use the fast ' help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will ' 'tokenizer if available, and "slow" will '
'always use the slow tokenizer.') 'always use the slow tokenizer.')
parser.add_argument('--download-dir', type=str, parser.add_argument('--download-dir',
type=str,
default=EngineArgs.download_dir, default=EngineArgs.download_dir,
help='directory to download and load the weights, ' help='directory to download and load the weights, '
'default to the default cache dir of ' 'default to the default cache dir of '
'huggingface') 'huggingface')
parser.add_argument('--use-np-weights', action='store_true', parser.add_argument('--use-np-weights',
action='store_true',
help='save a numpy copy of model weights for ' help='save a numpy copy of model weights for '
'faster loading. This can increase the disk ' 'faster loading. This can increase the disk '
'usage by up to 2x.') 'usage by up to 2x.')
parser.add_argument('--use-dummy-weights', action='store_true', parser.add_argument('--use-dummy-weights',
action='store_true',
help='use dummy values for model weights') help='use dummy values for model weights')
# TODO(woosuk): Support FP32. # TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=EngineArgs.dtype, parser.add_argument(
choices=['auto', 'half', 'bfloat16', 'float'], '--dtype',
help='data type for model weights and activations. ' type=str,
'The "auto" option will use FP16 precision ' default=EngineArgs.dtype,
'for FP32 and FP16 models, and BF16 precision ' choices=['auto', 'half', 'bfloat16', 'float'],
'for BF16 models.') help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
# Parallel arguments # Parallel arguments
parser.add_argument('--worker-use-ray', action='store_true', parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be ' help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU') 'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size, default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages') help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size, default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas') help='number of tensor parallel replicas')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', type=int, parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32], choices=[8, 16, 32],
help='token block size') help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request). # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=EngineArgs.seed, parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
help='random seed') help='random seed')
parser.add_argument('--swap-space', type=int, parser.add_argument('--swap-space',
type=int,
default=EngineArgs.swap_space, default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU') help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float, parser.add_argument('--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization, default=EngineArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for' help='the percentage of GPU memory to be used for'
'the model executor') 'the model executor')
parser.add_argument('--max-num-batched-tokens', type=int, parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens, default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per ' help='maximum number of batched tokens per '
'iteration') 'iteration')
parser.add_argument('--max-num-seqs', type=int, parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration') help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true', parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics') help='disable logging statistics')
return parser return parser
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
# Get the list of attributes of this dataclass. # Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)] attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments. # Set the attributes from the parsed arguments.
@ -115,18 +140,19 @@ class EngineArgs:
self, self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs. # Initialize the configs.
model_config = ModelConfig( model_config = ModelConfig(self.model, self.tokenizer,
self.model, self.tokenizer, self.tokenizer_mode, self.download_dir, self.tokenizer_mode, self.download_dir,
self.use_np_weights, self.use_dummy_weights, self.dtype, self.seed) self.use_np_weights, self.use_dummy_weights,
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.dtype, self.seed)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space) self.swap_space)
parallel_config = ParallelConfig(self.pipeline_parallel_size, parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size,
self.worker_use_ray) self.worker_use_ray)
max_seq_len = min( model_max_len = getattr(model_config.hf_config,
self.max_num_batched_tokens, 'max_position_embeddings', float('inf'))
getattr(model_config.hf_config, "max_position_embeddings", max_seq_len = min(self.max_num_batched_tokens, model_max_len)
float("inf")))
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, max_seq_len) self.max_num_seqs, max_seq_len)
return model_config, cache_config, parallel_config, scheduler_config return model_config, cache_config, parallel_config, scheduler_config
@ -140,12 +166,13 @@ class AsyncEngineArgs(EngineArgs):
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
) -> argparse.ArgumentParser:
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray', action='store_true', parser.add_argument('--engine-use-ray',
action='store_true',
help='use Ray to start the LLM engine in a ' help='use Ray to start the LLM engine in a '
'separate process as the server process.') 'separate process as the server process.')
parser.add_argument('--disable-log-requests', action='store_true', parser.add_argument('--disable-log-requests',
action='store_true',
help='disable logging requests') help='disable logging requests')
return parser return parser

View File

@ -11,7 +11,7 @@ from vllm.sampling_params import SamplingParams
logger = init_logger(__name__) logger = init_logger(__name__)
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
class AsyncLLMEngine: class AsyncLLMEngine:
@ -35,8 +35,13 @@ class AsyncLLMEngine:
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMEngine. *args, *kwargs: Arguments for LLMEngine.
""" """
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
log_requests: bool = True, *args, **kwargs) -> None: def __init__(self,
worker_use_ray: bool,
engine_use_ray: bool,
*args,
log_requests: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray self.engine_use_ray = engine_use_ray
self.log_requests = log_requests self.log_requests = log_requests
@ -76,12 +81,11 @@ class AsyncLLMEngine:
self.request_events[request_id].set() self.request_events[request_id].set()
async def generate( async def generate(
self, self,
prompt: Optional[str], prompt: Optional[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
prompt_token_ids: Optional[List[int]] = None prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
) -> RequestOutput:
"""Generate outputs for a request. """Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the Generate outputs for a request. This method is a coroutine. It adds the
@ -117,14 +121,17 @@ class AsyncLLMEngine:
# Add the request into the vLLM engine's waiting queue. # Add the request into the vLLM engine's waiting queue.
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.add_request.remote( await self.engine.add_request.remote(
request_id, prompt, sampling_params, request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time) arrival_time=arrival_time)
else: else:
self.engine.add_request( self.engine.add_request(request_id,
request_id, prompt, sampling_params, prompt,
prompt_token_ids=prompt_token_ids, sampling_params,
arrival_time=arrival_time) prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
# The vLLM engine does not have a background loop that keeps # The vLLM engine does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking # processing incoming requests. Therefore, we need to keep kicking
@ -200,7 +207,8 @@ class AsyncLLMEngine:
self.kicking_request_id = None self.kicking_request_id = None
@classmethod @classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine": def from_engine_args(cls,
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
@ -211,8 +219,9 @@ class AsyncLLMEngine:
# Create the async LLM engine. # Create the async LLM engine.
engine = cls(engine_args.worker_use_ray, engine = cls(engine_args.worker_use_ray,
engine_args.engine_use_ray, engine_args.engine_use_ray,
not engine_args.disable_log_requests,
*engine_configs, *engine_configs,
distributed_init_method, devices, distributed_init_method,
devices,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine

View File

@ -67,8 +67,7 @@ class LLMEngine:
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
f"use_np_weights={model_config.use_np_weights}, " f"use_np_weights={model_config.use_np_weights}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"seed={model_config.seed})" f"seed={model_config.seed})")
)
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config self.model_config = model_config
@ -78,8 +77,8 @@ class LLMEngine:
self.log_stats = log_stats self.log_stats = log_stats
self._verify_args() self._verify_args()
self.tokenizer = get_tokenizer(model_config.tokenizer, self.tokenizer = get_tokenizer(
model_config.tokenizer_mode) model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. # Create the parallel GPU workers.
@ -129,8 +128,8 @@ class LLMEngine:
num_gpu_blocks = min(b[0] for b in num_blocks) num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks)
# FIXME(woosuk): Change to debug log. # FIXME(woosuk): Change to debug log.
logger.info(f'# GPU blocks: {num_gpu_blocks}, ' logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f'# CPU blocks: {num_cpu_blocks}') f"# CPU blocks: {num_cpu_blocks}")
if num_gpu_blocks <= 0: if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. " raise ValueError("No available memory for the cache blocks. "
@ -152,7 +151,9 @@ class LLMEngine:
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config) distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM engine. # Create the LLM engine.
engine = cls(*engine_configs, distributed_init_method, devices, engine = cls(*engine_configs,
distributed_init_method,
devices,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine
@ -226,8 +227,10 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule() (seq_group_metadata_list, scheduler_outputs,
if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups): ignored_seq_groups) = self.scheduler.schedule()
if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
and (not ignored_seq_groups)):
# Nothing to do. # Nothing to do.
return [] return []
@ -281,8 +284,8 @@ class LLMEngine:
# Truncate the output text so that the stop string is # Truncate the output text so that the stop string is
# not included in the output. # not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)] seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(seq, self.scheduler.free_seq(
SequenceStatus.FINISHED_STOPPED) seq, SequenceStatus.FINISHED_STOPPED)
stopped = True stopped = True
break break
if stopped: if stopped:
@ -290,7 +293,7 @@ class LLMEngine:
# Check if the sequence has reached max_seq_len. # Check if the sequence has reached max_seq_len.
if (seq.get_len() >= if (seq.get_len() >=
self.scheduler.scheduler_config.max_seq_len): self.scheduler.scheduler_config.max_seq_len):
self.scheduler.free_seq( self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED) seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue continue
@ -302,15 +305,15 @@ class LLMEngine:
# Check if the sequence has generated the EOS token. # Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos: if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id: if seq.get_last_token_id() == self.tokenizer.eos_token_id:
self.scheduler.free_seq(seq, self.scheduler.free_seq(
SequenceStatus.FINISHED_STOPPED) seq, SequenceStatus.FINISHED_STOPPED)
continue continue
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
get_all_outputs: bool = False,
*args, *args,
get_all_outputs: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers."""

View File

@ -8,7 +8,8 @@ except ImportError:
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id # rank, node resource (node IP), device id
DeviceID = Tuple[int, Optional[str], int]
def initialize_cluster( def initialize_cluster(
@ -53,15 +54,15 @@ def initialize_cluster(
valid_node_resources = [] valid_node_resources = []
num_devices_per_node = None num_devices_per_node = None
for node in ray.nodes(): for node in ray.nodes():
if (not node['Alive']) or node['Resources']['GPU'] <= 0: if (not node["Alive"]) or node["Resources"]["GPU"] <= 0:
continue continue
if num_devices_per_node is None: if num_devices_per_node is None:
num_devices_per_node = node['Resources']['GPU'] num_devices_per_node = node["Resources"]["GPU"]
else: else:
assert num_devices_per_node == node['Resources']['GPU'], ( assert num_devices_per_node == node["Resources"]["GPU"], (
"The number of GPUs per node is not uniform.") "The number of GPUs per node is not uniform.")
for key in node['Resources']: for key in node["Resources"]:
if key.startswith('node:'): if key.startswith("node:"):
valid_node_resources.append(key) valid_node_resources.append(key)
# Verify the parallel config. # Verify the parallel config.

View File

@ -11,8 +11,8 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI() app = FastAPI()
@ -37,8 +37,7 @@ async def generate(request: Request) -> Response:
async for request_output in results_generator: async for request_output in results_generator:
prompt = request_output.prompt prompt = request_output.prompt
text_outputs = [ text_outputs = [
prompt + output.text prompt + output.text for output in request_output.outputs
for output in request_output.outputs
] ]
ret = {"text": text_outputs} ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
@ -63,10 +62,7 @@ async def generate(request: Request) -> Response:
assert final_output is not None assert final_output is not None
prompt = final_output.prompt prompt = final_output.prompt
text_outputs = [ text_outputs = [prompt + output.text for output in final_output.outputs]
prompt + output.text
for output in final_output.outputs
]
ret = {"text": text_outputs} ret = {"text": text_outputs}
return Response(content=json.dumps(ret)) return Response(content=json.dumps(ret))
@ -81,5 +77,8 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="debug", uvicorn.run(app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE) timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -63,8 +63,7 @@ class LLM:
self.request_counter = Counter() self.request_counter = Counter()
def get_tokenizer( def get_tokenizer(
self, self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer return self.llm_engine.tokenizer
def set_tokenizer( def set_tokenizer(

View File

@ -1,4 +1,5 @@
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import argparse import argparse
from http import HTTPStatus from http import HTTPStatus
@ -29,7 +30,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__) logger = init_logger(__name__)
served_model = None served_model = None
@ -38,14 +39,13 @@ app = fastapi.FastAPI()
def create_error_response(status_code: HTTPStatus, def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse: message: str) -> JSONResponse:
return JSONResponse( return JSONResponse(ErrorResponse(message=message,
ErrorResponse(message=message, type="invalid_request_error").dict(), type="invalid_request_error").dict(),
status_code=status_code.value status_code=status_code.value)
)
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
@ -126,8 +126,11 @@ async def check_length(request, prompt, engine):
@app.get("/v1/models") @app.get("/v1/models")
async def show_available_models(): async def show_available_models():
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ModelCard(id=served_model, root=served_model, model_cards = [
permission=[ModelPermission()])] ModelCard(id=served_model,
root=served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards) return ModelList(data=model_cards)
@ -144,12 +147,14 @@ def create_logprobs(token_ids: List[int],
if len(logprobs.text_offset) == 0: if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset) logprobs.text_offset.append(initial_text_offset)
else: else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token) last_token_len = len(token)
logprobs.top_logprobs.append( logprobs.top_logprobs.append({
{tokenizer.convert_ids_to_tokens(i): p tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()}) for i, p in id_logprob.items()
})
return logprobs return logprobs
@ -348,7 +353,7 @@ async def create_completion(raw_request: Request):
if request.suffix is not None: if request.suffix is not None:
# The language models we currently support do not support suffix. # The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported") "suffix is not currently supported")
if request.logit_bias is not None: if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine. # TODO: support logit_bias in vLLM engine.
@ -387,22 +392,23 @@ async def create_completion(raw_request: Request):
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, result_generator = engine.generate(prompt, sampling_params, request_id)
request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search. # results. In addition, we do not stream the results when use beam search.
stream = (request.stream and stream = (request.stream
(request.best_of is None or request.n == request.best_of) and and (request.best_of is None or request.n == request.best_of)
not request.use_beam_search) and not request.use_beam_search)
async def abort_request() -> None: async def abort_request() -> None:
await engine.abort(request_id) await engine.abort(request_id)
def create_stream_response_json(index: int, def create_stream_response_json(
text: str, index: int,
logprobs: Optional[LogProbs] = None, text: str,
finish_reason: Optional[str] = None) -> str: logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None,
) -> str:
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=index, index=index,
text=text, text=text,
@ -443,7 +449,8 @@ async def create_completion(raw_request: Request):
) )
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
if output.finish_reason is not None: if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None logprobs = (LogProbs()
if request.logprobs is not None else None)
response_json = create_stream_response_json( response_json = create_stream_response_json(
index=i, index=i,
text="", text="",
@ -487,8 +494,8 @@ async def create_completion(raw_request: Request):
choices.append(choice_data) choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids) num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) num_generated_tokens = sum(
for output in final_res.outputs) len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens, completion_tokens=num_generated_tokens,
@ -506,9 +513,11 @@ async def create_completion(raw_request: Request):
# When user requests streaming but we don't stream, we still need to # When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event. # return a streaming response with a single event.
response_json = response.json(ensure_ascii=False) response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]: async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(), return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream") media_type="text/event-stream")
@ -517,26 +526,34 @@ async def create_completion(raw_request: Request):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server." description="vLLM OpenAI-Compatible RESTful API server.")
) parser.add_argument("--host",
parser.add_argument("--host", type=str, default="localhost", help="host name") type=str,
default="localhost",
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument( parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials" "--served-model-name",
) type=str,
parser.add_argument( default=None,
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins" help="The model name used in the API. If not specified, "
) "the model name will be the same as the "
parser.add_argument( "huggingface name.")
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
)
parser.add_argument(
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
)
parser.add_argument("--served-model-name", type=str, default=None,
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
@ -556,7 +573,11 @@ if __name__ == "__main__":
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
# A separate tokenizer to map token IDs to strings. # A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer, engine_args.tokenizer_mode) tokenizer = get_tokenizer(engine_args.tokenizer,
tokenizer_mode=engine_args.tokenizer_mode)
uvicorn.run(app, host=args.host, port=args.port, log_level="info", uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE) timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -1,4 +1,5 @@
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
@ -98,7 +99,8 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):

View File

@ -1,9 +1,9 @@
# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py # Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
import logging import logging
import sys import sys
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S" _DATE_FORMAT = "%m-%d %H:%M:%S"

View File

@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
__all__ = [ __all__ = [
"InputMetadata", "InputMetadata",
"get_model", "get_model",

View File

@ -8,11 +8,22 @@ from vllm.sequence import SequenceData
class InputMetadata: class InputMetadata:
"""Metadata for input sequences. Used for PagedAttention.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
block_tables: The block tables. (Seq id -> list of physical block)
"""
def __init__( def __init__(
self, self,
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params). seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData. seq_data: Dict[int, SequenceData],
prompt_lens: List[int], prompt_lens: List[int],
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
context_lens: torch.Tensor, context_lens: torch.Tensor,

View File

@ -6,9 +6,10 @@ from vllm import activation_ops
_ACTIVATION_REGISTRY = { _ACTIVATION_REGISTRY = {
"gelu": nn.GELU(), "gelu": nn.GELU(),
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. # NOTE: The following GELU functions may introduce small rounding errors.
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. "gelu_new": nn.GELU(approximate="tanh"),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. "gelu_fast": nn.GELU(approximate="tanh"),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(), "relu": nn.ReLU(),
} }
@ -25,15 +26,13 @@ class SiluAndMul(nn.Module):
"""An activation function for SwiGLU. """An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
Shapes:
x: (num_tokens, 2 * d)
return: (num_tokens, d)
""" """
def __init__(self): def forward(self, x: torch.Tensor) -> torch.Tensor:
super().__init__()
def forward(
self,
x: torch.Tensor, # (num_tokens, 2 * d)
) -> torch.Tensor: # (num_tokens, d)
num_tokens = x.shape[0] num_tokens = x.shape[0]
d = x.shape[1] // 2 d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)

View File

@ -14,6 +14,7 @@ _SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
class PagedAttention(nn.Module): class PagedAttention(nn.Module):
# pylint: disable=line-too-long
"""GPT-style multi-head PagedAttention. """GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The This class takes flattened 1D query, key, and value tensors as input. The
@ -54,12 +55,20 @@ class PagedAttention(nn.Module):
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] output: torch.Tensor,
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] query: torch.Tensor,
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] key: torch.Tensor,
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] value: torch.Tensor,
attn_bias: xops.AttentionBias, attn_bias: xops.AttentionBias,
) -> torch.Tensor: ) -> torch.Tensor:
"""Normal attention for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
"""
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
query.unsqueeze(0), query.unsqueeze(0),
@ -76,12 +85,22 @@ class PagedAttention(nn.Module):
def single_query_cached_kv_attention( def single_query_cached_kv_attention(
self, self,
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size] output: torch.Tensor,
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size] query: torch.Tensor,
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] key_cache: torch.Tensor,
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] value_cache: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> None: ) -> None:
"""PagedAttention for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention( attention_ops.single_query_cached_kv_attention(
output, output,
@ -97,16 +116,32 @@ class PagedAttention(nn.Module):
def forward( def forward(
self, self,
query: torch.Tensor, # [num_tokens, num_heads * head_size] query: torch.Tensor,
key: torch.Tensor, # [num_tokens, num_heads * head_size] key: torch.Tensor,
value: torch.Tensor, # [num_tokens, num_heads * head_size] value: torch.Tensor,
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x] key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size] value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size] ) -> torch.Tensor:
# NOTE: The query, key, and value tensors must be sliced from a qkv """PagedAttention forward pass.
# tensor of shape [num_tokens, 3 * num_heads * head_size].
NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [num_tokens, 3 * num_heads * head_size].
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
@ -136,7 +171,7 @@ class PagedAttention(nn.Module):
# and value vectors will not be cached. # and value vectors will not be cached.
num_valid_tokens = input_metadata.num_valid_tokens num_valid_tokens = input_metadata.num_valid_tokens
if (num_valid_tokens > 0 and key_cache is not None if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None): and value_cache is not None):
# The stride is 3 because the key and value are sliced from qkv. # The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key[:num_valid_tokens], key[:num_valid_tokens],
@ -149,15 +184,12 @@ class PagedAttention(nn.Module):
if input_metadata.num_generation_tokens > 0: if input_metadata.num_generation_tokens > 0:
assert key_cache is not None and value_cache is not None, ( assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when " "key_cache and value_cache must be provided when "
"generating tokens." "generating tokens.")
)
# Compute the attention op for generation tokens. # Compute the attention op for generation tokens.
self.single_query_cached_kv_attention( self.single_query_cached_kv_attention(
output[num_prompt_tokens:num_valid_tokens], output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens], query[num_prompt_tokens:num_valid_tokens], key_cache,
key_cache, value_cache, input_metadata)
value_cache,
input_metadata)
# Reshape the output tensor. # Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings. # NOTE(woosuk): The output tensor may include paddings.
@ -179,9 +211,9 @@ class PagedAttentionWithRoPE(PagedAttention):
super().__init__(num_heads, head_size, scale) super().__init__(num_heads, head_size, scale)
# Create the cos and sin cache. # Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float() t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
@ -195,15 +227,32 @@ class PagedAttentionWithRoPE(PagedAttention):
def forward( def forward(
self, self,
positions: torch.Tensor, # [num_tokens] positions: torch.Tensor,
query: torch.Tensor, # [num_tokens, num_heads * head_size] query: torch.Tensor,
key: torch.Tensor, # [num_tokens, num_heads * head_size] key: torch.Tensor,
value: torch.Tensor, # [num_tokens, num_heads * head_size] value: torch.Tensor,
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] key_cache: torch.Tensor,
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] value_cache: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size] ) -> torch.Tensor:
""" PagedAttention forward pass with rotary embedding.
Args:
positions: shape = [num_tokens]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Apply rotary embedding to the query and key before passing them # Apply rotary embedding to the query and key before passing them
# to the attention op. # to the attention op.
pos_encoding_ops.rotary_embedding_neox( pos_encoding_ops.rotary_embedding_neox(

View File

@ -13,6 +13,7 @@ from vllm.sequence import SequenceOutputs
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
class Sampler(nn.Module): class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs. """Samples the next tokens from the model's outputs.
@ -50,19 +51,20 @@ class Sampler(nn.Module):
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata) output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0] assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties(input_metadata) presence_penalties, frequency_penalties = _get_penalties(
input_metadata)
assert len(presence_penalties) == logits.shape[0] assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties( logits = _apply_penalties(logits, output_tokens, presence_penalties,
logits, output_tokens, presence_penalties, frequency_penalties, frequency_penalties, self.vocab_size)
self.vocab_size)
# Apply temperature scaling. # Apply temperature scaling.
temperatures = _get_temperatures(input_metadata) temperatures = _get_temperatures(input_metadata)
assert len(temperatures) == logits.shape[0] assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures): if any(t != 1.0 for t in temperatures):
t = torch.tensor( t = torch.tensor(temperatures,
temperatures, dtype=logits.dtype, device=logits.device) dtype=logits.dtype,
device=logits.device)
# Use in-place division to avoid creating a new tensor. # Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1)) logits.div_(t.unsqueeze(dim=1))
@ -75,7 +77,9 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation. # Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0] assert len(top_ps) == len(top_ks) == probs.shape[0]
if any(p < 1.0 - _SAMPLING_EPS for p in top_ps) or any(k != self.vocab_size for k in top_ks): do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
probs = _apply_top_p_top_k(probs, top_ps, top_ks) probs = _apply_top_p_top_k(probs, top_ps, top_ks)
# Sample the next tokens. # Sample the next tokens.
@ -97,8 +101,7 @@ def _prune_hidden_states(
def _get_penalties( def _get_penalties(
input_metadata: InputMetadata, input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
) -> Tuple[List[float], List[float]]:
# Collect the presence and frequency penalties. # Collect the presence and frequency penalties.
presence_penalties: List[float] = [] presence_penalties: List[float] = []
frequency_penalties: List[float] = [] frequency_penalties: List[float] = []
@ -117,9 +120,7 @@ def _get_penalties(
return presence_penalties, frequency_penalties return presence_penalties, frequency_penalties
def _get_output_tokens( def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
input_metadata: InputMetadata,
) -> List[List[int]]:
output_tokens: List[List[int]] = [] output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, _ = seq_group seq_ids, _ = seq_group
@ -169,11 +170,13 @@ def _apply_penalties(
device=logits.device) device=logits.device)
frequency_penalties = [frequency_penalties[i] for i in indices] frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor( frequency_penalties = torch.tensor(frequency_penalties,
frequency_penalties, dtype=logits.dtype, device=logits.device) dtype=logits.dtype,
device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices] presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor( presence_penalties = torch.tensor(presence_penalties,
presence_penalties, dtype=logits.dtype, device=logits.device) dtype=logits.dtype,
device=logits.device)
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
@ -183,9 +186,7 @@ def _apply_penalties(
return logits return logits
def _get_temperatures( def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
input_metadata: InputMetadata,
) -> List[float]:
# Collect the temperatures for the logits. # Collect the temperatures for the logits.
temperatures: List[float] = [] temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
@ -252,8 +253,9 @@ def _apply_top_p_top_k(
probs_sort[top_k_mask] = 0.0 probs_sort[top_k_mask] = 0.0
# Re-sort the probabilities. # Re-sort the probabilities.
probs = torch.gather( probs = torch.gather(probs_sort,
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1)) dim=-1,
index=torch.argsort(probs_idx, dim=-1))
return probs return probs
@ -296,8 +298,9 @@ def _sample_from_prompt(
# Random sampling. # Random sampling.
# Sample `best_of` tokens for the prompt. # Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial( next_token_ids = torch.multinomial(prob,
prob, num_samples=num_seqs, replacement=True) num_samples=num_seqs,
replacement=True)
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
return next_token_ids return next_token_ids
@ -315,8 +318,9 @@ def _sample_from_generation_tokens(
if sampling_params.use_beam_search: if sampling_params.use_beam_search:
# Beam search. # Beam search.
# Add cumulative logprobs for the sequences in the group. # Add cumulative logprobs for the sequences in the group.
seq_logprobs = torch.tensor( seq_logprobs = torch.tensor(seq_logprobs,
seq_logprobs, dtype=torch.float, device=logprobs.device) dtype=torch.float,
device=logprobs.device)
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1) logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
vocab_size = logprobs.size(-1) vocab_size = logprobs.size(-1)
@ -353,8 +357,9 @@ def _sample_from_generation_tokens(
else: else:
# Random sampling. # Random sampling.
# Sample 1 token for each sequence in the group. # Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial( next_token_ids = torch.multinomial(probs,
probs, num_samples=1, replacement=True) num_samples=1,
replacement=True)
next_token_ids = next_token_ids.squeeze(dim=-1).tolist() next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
parent_seq_ids = seq_ids parent_seq_ids = seq_ids
return parent_seq_ids, next_token_ids return parent_seq_ids, next_token_ids
@ -381,15 +386,16 @@ def _sample(
# Sample the next tokens. # Sample the next tokens.
next_token_ids = _sample_from_prompt(prob, sampling_params) next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens. # Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs( next_logprobs = _get_topk_logprobs(logprob,
logprob, sampling_params.logprobs) sampling_params.logprobs)
# Build the output. # Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids): for seq_id, next_token_id in zip(seq_ids, next_token_ids):
output_logprobs = next_logprobs.copy() output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item() output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs( seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
seq_id, seq_id, next_token_id, output_logprobs) next_token_id,
output_logprobs)
else: else:
# Generate the next tokens for generation tokens. # Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)] prob = probs[idx:idx + len(seq_ids)]
@ -399,22 +405,24 @@ def _sample(
# Sample the next tokens. # Sample the next tokens.
seq_logprobs = [ seq_logprobs = [
input_metadata.seq_data[seq_id].cumulative_logprob input_metadata.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids] for seq_id in seq_ids
]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens( parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params) seq_ids, prob, logprob, seq_logprobs, sampling_params)
# Get top-k log probabilities for the next tokens. # Get top-k log probabilities for the next tokens.
next_logprobs: Dict[int, Dict[int, float]] = {} next_logprobs: Dict[int, Dict[int, float]] = {}
for i, seq_id in enumerate(seq_ids): for j, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs( next_logprobs[seq_id] = _get_topk_logprobs(
logprob[i], sampling_params.logprobs) logprob[j], sampling_params.logprobs)
# Build the output. # Build the output.
for seq_id, parent_seq_id, next_token_id in zip( for seq_id, parent_seq_id, next_token_id in zip(
seq_ids, parent_seq_ids, next_token_ids): seq_ids, parent_seq_ids, next_token_ids):
i = seq_ids.index(parent_seq_id) j = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy() output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[i, next_token_id].item() output_logprobs[next_token_id] = logprob[j,
next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs( seq_outputs[seq_id] = SequenceOutputs(
seq_id, seq_id,
parent_seq_id, parent_seq_id,

View File

@ -6,8 +6,9 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM, GPTNeoXForCausalLM, from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM,
LlamaForCausalLM, OPTForCausalLM) GPTNeoXForCausalLM, LlamaForCausalLM,
OPTForCausalLM)
from vllm.model_executor.weight_utils import initialize_dummy_weights from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
@ -28,8 +29,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
return _MODEL_REGISTRY[arch] return _MODEL_REGISTRY[arch]
raise ValueError( raise ValueError(
f"Model architectures {architectures} are not supported for now. " f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}" f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
)
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:
@ -46,8 +46,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
initialize_dummy_weights(model) initialize_dummy_weights(model)
else: else:
# Load the weights from the cached or downloaded files. # Load the weights from the cached or downloaded files.
model.load_weights( model.load_weights(model_config.model, model_config.download_dir,
model_config.model, model_config.download_dir, model_config.use_np_weights)
model_config.use_np_weights)
model = model.cuda() model = model.cuda()
return model.eval() return model.eval()

View File

@ -4,8 +4,6 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM
__all__ = [ __all__ = [
"GPT2LMHeadModel", "GPT2LMHeadModel",
"GPTBigCodeForCausalLM", "GPTBigCodeForCausalLM",

View File

@ -1,5 +1,6 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
@ -47,19 +48,25 @@ class GPT2Attention(nn.Module):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0 assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, self.c_attn = ColumnParallelLinear(self.hidden_size,
bias=True, gather_output=False, 3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, self.c_proj = RowParallelLinear(self.hidden_size,
bias=True, input_is_parallel=True, self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale) scale=self.scale)
def forward( def forward(
@ -72,8 +79,8 @@ class GPT2Attention(nn.Module):
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(q, k, v, key_cache, value_cache,
q, k, v, key_cache, value_cache, input_metadata, cache_event) input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
@ -87,11 +94,15 @@ class GPT2MLP(nn.Module):
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size, self.c_fc = ColumnParallelLinear(hidden_size,
bias=True, gather_output=False, intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size, hidden_size, self.c_proj = RowParallelLinear(intermediate_size,
bias=True, input_is_parallel=True, hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.act = get_act_fn(config.activation_function) self.act = get_act_fn(config.activation_function)
@ -107,7 +118,8 @@ class GPT2Block(nn.Module):
def __init__(self, config: GPT2Config): def __init__(self, config: GPT2Config):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config) self.attn = GPT2Attention(config)
@ -145,9 +157,9 @@ class GPT2Model(nn.Module):
def __init__(self, config: GPT2Config): def __init__(self, config: GPT2Config):
super().__init__() super().__init__()
self.config = config self.config = config
assert config.add_cross_attention == False assert not config.add_cross_attention
assert config.scale_attn_by_inverse_layer_idx == False assert not config.scale_attn_by_inverse_layer_idx
assert config.reorder_and_upcast_attn == False assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it # Optimization: While the vocab size of GPT-2 is 50257, we extend it
@ -180,8 +192,8 @@ class GPT2Model(nn.Module):
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
hidden_states, kv_caches[i], input_metadata, cache_event) cache_event)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
@ -206,24 +218,26 @@ class GPT2LMHeadModel(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer( hidden_states = self.transformer(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head_weight, hidden_states,
self.lm_head_weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"] _row_parallel_weights = ["c_proj.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.
@ -248,16 +262,20 @@ class GPT2LMHeadModel(nn.Module):
if name == "transformer.wte.weight": if name == "transformer.wte.weight":
# Consider padding in the vocab size. # Consider padding in the vocab size.
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size padded_vocab_size = (param.shape[0] *
tensor_model_parallel_world_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1]) extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight) extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
# For the fused QKV linear layer, manually shard the weights. # For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name: if "c_attn" in name:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size]. # GPT-2's fused QKV has the shape of
# When tensor parallelism is used, we shard the weights along the head dimension. # [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads head_size = hidden_size // total_num_heads
@ -266,11 +284,13 @@ class GPT2LMHeadModel(nn.Module):
head_end = (tensor_model_parallel_rank + 1) * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"): if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :] loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size) loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"): elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size) loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :] loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1) loaded_weight = loaded_weight.reshape(-1)
else: else:

View File

@ -1,5 +1,6 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil # Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
@ -49,19 +50,25 @@ class GPTBigCodeAttention(nn.Module):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0 assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, self.c_attn = ColumnParallelLinear(self.hidden_size,
bias=True, gather_output=False, 3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, self.c_proj = RowParallelLinear(self.hidden_size,
bias=True, input_is_parallel=True, self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale) scale=self.scale)
def forward( def forward(
@ -74,8 +81,8 @@ class GPTBigCodeAttention(nn.Module):
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(q, k, v, key_cache, value_cache,
q, k, v, key_cache, value_cache, input_metadata, cache_event) input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
@ -89,11 +96,15 @@ class GPTBigMLP(nn.Module):
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size, self.c_fc = ColumnParallelLinear(hidden_size,
bias=True, gather_output=False, intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size, hidden_size, self.c_proj = RowParallelLinear(intermediate_size,
bias=True, input_is_parallel=True, hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.act = get_act_fn(config.activation_function) self.act = get_act_fn(config.activation_function)
@ -109,7 +120,8 @@ class GPTBigCodeBlock(nn.Module):
def __init__(self, config: GPTBigCodeConfig): def __init__(self, config: GPTBigCodeConfig):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config) self.attn = GPTBigCodeAttention(config)
@ -147,7 +159,7 @@ class GPTBigCodeModel(nn.Module):
def __init__(self, config: GPTBigCodeConfig): def __init__(self, config: GPTBigCodeConfig):
super().__init__() super().__init__()
self.config = config self.config = config
assert config.add_cross_attention == False assert not config.add_cross_attention
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -181,8 +193,8 @@ class GPTBigCodeModel(nn.Module):
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
hidden_states, kv_caches[i], input_metadata, cache_event) cache_event)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
@ -207,24 +219,26 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer( hidden_states = self.transformer(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head_weight, hidden_states,
self.lm_head_weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"] _row_parallel_weights = ["c_proj.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.
@ -241,9 +255,11 @@ class GPTBigCodeForCausalLM(nn.Module):
if name == "transformer.wte.weight": if name == "transformer.wte.weight":
# Consider padding in the vocab size. # Consider padding in the vocab size.
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size padded_vocab_size = param.shape[
0] * tensor_model_parallel_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1]) extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight) extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
@ -258,25 +274,31 @@ class GPTBigCodeForCausalLM(nn.Module):
qkv_array = qkv_array.numpy() qkv_array = qkv_array.numpy()
dims_q = n_head * head_dim dims_q = n_head * head_dim
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0) # pylint: disable=unbalanced-tuple-unpacking
# q is fine, but k & v have not replicated shape along the first axis q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
# as long as MQA is not nativly supported, increase memory and replicated axis=0)
# (head_dim, hidden_dim) to (n_heads * head_dim, hidden_dim) # q is fine, but k & v have not replicated shape along the first
# axis as long as MQA is not nativly supported, increase memory
# and replicated (head_dim, hidden_dim) to
# (n_heads * head_dim, hidden_dim)
if k.ndim == 2 and v.ndim == 2: if k.ndim == 2 and v.ndim == 2:
replication = (n_head, 1) # weights replication = (n_head, 1) # weights
else: else:
replication = n_head # biases replication = n_head # biases
# replicate n_head times for q, v # replicate n_head times for q, v
k, v = np.tile(k, replication), np.tile(v, replication) k, v = np.tile(k, replication), np.tile(v, replication)
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) # concat q, k, v along the first axis
# (n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim) # to (3 * n_heads * head_dim, hidden_dim)
qkv_array = np.concatenate((q, k, v), axis=0) qkv_array = np.concatenate((q, k, v), axis=0)
return torch.from_numpy(qkv_array) return torch.from_numpy(qkv_array)
# For the fused QKV linear layer, manually shard the weights. # For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name: if "c_attn" in name:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size]. # GPT-2's fused QKV has the shape of
# When tensor parallelism is used, we shard the weights along the head dimension. # [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads head_size = hidden_size // total_num_heads
@ -285,13 +307,19 @@ class GPTBigCodeForCausalLM(nn.Module):
head_end = (tensor_model_parallel_rank + 1) * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"): if name.endswith(".weight"):
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size) loaded_weight = _expand_mqa_mha(loaded_weight,
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :] loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size) loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"): elif name.endswith(".bias"):
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size) loaded_weight = _expand_mqa_mha(loaded_weight,
loaded_weight = loaded_weight.view(3, total_num_heads, head_size) n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :] loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1) loaded_weight = loaded_weight.reshape(-1)
else: else:

View File

@ -1,5 +1,6 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
# #
@ -48,19 +49,23 @@ class GPTNeoXAttention(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads self.head_size = self.hidden_size // self.total_num_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.query_key_value = ColumnParallelLinear(config.hidden_size, self.query_key_value = ColumnParallelLinear(
3 * config.hidden_size, config.hidden_size,
gather_output=False, 3 * config.hidden_size,
perform_initialization=False) gather_output=False,
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size, perform_initialization=False)
self.dense = RowParallelLinear(config.hidden_size,
config.hidden_size,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
scaling = self.head_size ** -0.5 scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct) rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0 assert rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size, self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
@ -78,8 +83,8 @@ class GPTNeoXAttention(nn.Module):
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
@ -92,7 +97,8 @@ class GPTNeoXMLP(nn.Module):
config.intermediate_size, config.intermediate_size,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size, self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.act = get_act_fn(config.hidden_act) self.act = get_act_fn(config.hidden_act)
@ -109,8 +115,10 @@ class GPTNeoXLayer(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(self, config: GPTNeoXConfig):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size,
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config) self.attention = GPTNeoXAttention(config)
self.mlp = GPTNeoXMLP(config) self.mlp = GPTNeoXMLP(config)
@ -154,10 +162,13 @@ class GPTNeoXModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size, self.embed_in = VocabParallelEmbedding(config.vocab_size,
config.hidden_size,
perform_initialization=False) perform_initialization=False)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList(
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) [GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward( def forward(
self, self,
@ -191,8 +202,10 @@ class GPTNeoXForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.gpt_neox = GPTNeoXModel(config) self.gpt_neox = GPTNeoXModel(config)
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size, self.embed_out = ColumnParallelLinear(config.hidden_size,
bias=False, gather_output=False, config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
@ -204,24 +217,28 @@ class GPTNeoXForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.gpt_neox( hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.embed_out.weight, hidden_states,
self.embed_out.weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"] _column_parallel_weights = [
"embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, use_np_cache):
if ("attention.bias" in name or "attention.masked_bias" in name if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name): or "rotary_emb.inv_freq" in name):
continue continue
param = state_dict[name] param = state_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
@ -230,17 +247,19 @@ class GPTNeoXForCausalLM(nn.Module):
# required shape is [3 * num_heads * head_size, hidden_size]. # required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion. # Thus, we need weight conversion.
shard_size = param.shape[0] shard_size = param.shape[0]
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank loaded_weight = loaded_weight[
:shard_size * (tensor_model_parallel_rank + 1)] shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
num_heads = self.config.num_attention_heads num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
head_size = hidden_size // num_heads head_size = hidden_size // num_heads
if 'query_key_value.weight' in name: if "query_key_value.weight" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size) loaded_weight = loaded_weight.view(-1, 3, head_size,
hidden_size)
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, hidden_size) loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif 'query_key_value.bias' in name: elif "query_key_value.bias" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size) loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1) loaded_weight = loaded_weight.reshape(-1)

View File

@ -1,5 +1,6 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# #
@ -30,7 +31,6 @@ import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -56,15 +56,19 @@ class LlamaMLP(nn.Module):
hidden_act: str, hidden_act: str,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size, self.gate_up_proj = ColumnParallelLinear(hidden_size,
bias=False, gather_output=False, 2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size, self.down_proj = RowParallelLinear(intermediate_size,
bias=False, input_is_parallel=True, hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
if hidden_act != 'silu': if hidden_act != "silu":
raise ValueError(f'Unsupported activation: {hidden_act}. ' raise ValueError(f"Unsupported activation: {hidden_act}. "
'Only silu is supported for now.') "Only silu is supported for now.")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x):
@ -83,12 +87,14 @@ class LlamaAttention(nn.Module):
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
self.total_num_heads = num_heads self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ColumnParallelLinear(
hidden_size, hidden_size,
@ -104,8 +110,10 @@ class LlamaAttention(nn.Module):
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, self.attn = PagedAttentionWithRoPE(self.num_heads,
self.scaling, rotary_dim=self.head_dim) self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
def forward( def forward(
self, self,
@ -118,8 +126,8 @@ class LlamaAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -138,8 +146,10 @@ class LlamaDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size,
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward( def forward(
self, self,
@ -177,9 +187,13 @@ class LlamaModel(nn.Module):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, self.embed_tokens = VocabParallelEmbedding(
perform_initialization=False) config.vocab_size,
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward( def forward(
@ -209,6 +223,7 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
@ -228,39 +243,42 @@ class LlamaForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.model( hidden_states = self.model(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head.weight, hidden_states,
self.lm_head.weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight", _column_parallel_weights = [
"qkv_proj.weight", "gate_proj.weight", "embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
"up_proj.weight"] "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
is_attention_weight = False is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name: if att_weight_name not in name:
continue continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")] param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3 shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank shard_size * tensor_model_parallel_rank:shard_size *
:shard_size * (tensor_model_parallel_rank + 1)] (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id param_slice = param.data[shard_size * stride_id:shard_size *
:shard_size * (stride_id + 1)] (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight) param_slice.copy_(loaded_weight)
is_attention_weight = True is_attention_weight = True
@ -275,10 +293,10 @@ class LlamaForCausalLM(nn.Module):
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2 shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank shard_size * tensor_model_parallel_rank:shard_size *
:shard_size * (tensor_model_parallel_rank + 1)] (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id param_slice = param.data[shard_size * stride_id:shard_size *
:shard_size * (stride_id + 1)] (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight) param_slice.copy_(loaded_weight)
is_gate_up_weight = True is_gate_up_weight = True

View File

@ -1,7 +1,9 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTLearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int): def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 # OPT is set up so that if padding_idx is specified then offset the
# and adjust num_embeddings appropriately. Other models don't have this hack # embedding ids by 2 and adjust num_embeddings appropriately. Other
# models don't have this hack
self.offset = 2 self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim) super().__init__(num_embeddings + self.offset, embedding_dim)
@ -62,20 +65,26 @@ class OPTAttention(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
total_num_heads = num_heads total_num_heads = num_heads
assert num_heads % tensor_model_parallel_world_size == 0 assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias, self.qkv_proj = ColumnParallelLinear(embed_dim,
3 * embed_dim,
bias=bias,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, self.out_proj = RowParallelLinear(embed_dim,
embed_dim,
bias=bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling) scale=self.scaling)
def forward( def forward(
@ -88,8 +97,8 @@ class OPTAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(q, k, v, key_cache, value_cache,
q, k, v, key_cache, value_cache, input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module):
self.activation_fn = get_act_fn(config.activation_function) self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm( self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) self.embed_dim,
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim, elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim,
config.ffn_dim,
bias=config.enable_bias, bias=config.enable_bias,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim, self.fc2 = RowParallelLinear(config.ffn_dim,
self.embed_dim,
bias=config.enable_bias, bias=config.enable_bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.final_layer_norm = nn.LayerNorm( self.final_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
def forward( def forward(
self, self,
@ -133,11 +146,10 @@ class OPTDecoderLayer(nn.Module):
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before: if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn( hidden_states = self.self_attn(hidden_states=hidden_states,
hidden_states=hidden_states, kv_cache=kv_cache,
kv_cache=kv_cache, input_metadata=input_metadata,
input_metadata=input_metadata, cache_event=cache_event)
cache_event=cache_event)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention # 350m applies layer norm AFTER attention
if not self.do_layer_norm_before: if not self.do_layer_norm_before:
@ -167,35 +179,42 @@ class OPTDecoder(nn.Module):
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(
config.word_embed_proj_dim, config.vocab_size,
perform_initialization=False) config.word_embed_proj_dim,
perform_initialization=False)
# Positional embeddings are replicated (not sharded). # Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding( self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size) config.max_position_embeddings, config.hidden_size)
# Project out & in will be replicated if they exist. # Project out & in will be replicated if they exist.
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) self.project_out = nn.Linear(config.hidden_size,
config.word_embed_proj_dim,
bias=False)
else: else:
self.project_out = None self.project_out = None
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) self.project_in = nn.Linear(config.word_embed_proj_dim,
config.hidden_size,
bias=False)
else: else:
self.project_in = None self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility # Note that the only purpose of `config._remove_final_layer_norm` is to
# with checkpoints that have been fine-tuned before transformers v4.20.1 # keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164 # see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm: if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm( self.final_layer_norm = nn.LayerNorm(
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine config.hidden_size,
) elementwise_affine=config.layer_norm_elementwise_affine)
else: else:
self.final_layer_norm = None self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList(
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward( def forward(
self, self,
@ -217,8 +236,8 @@ class OPTDecoder(nn.Module):
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
hidden_states, kv_caches[i], input_metadata, cache_event) cache_event)
if self.final_layer_norm is not None: if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
@ -241,8 +260,8 @@ class OPTModel(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
return self.decoder( return self.decoder(input_ids, positions, kv_caches, input_metadata,
input_ids, positions, kv_caches, input_metadata, cache_events) cache_events)
class OPTForCausalLM(nn.Module): class OPTForCausalLM(nn.Module):
@ -264,23 +283,26 @@ class OPTForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.model( hidden_states = self.model(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head_weight, hidden_states,
self.lm_head_weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"] _column_parallel_weights = [
"embed_tokens.weight", "fc1.weight", "fc1.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"] _row_parallel_weights = ["out_proj.weight", "fc2.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache): model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name: if "lm_head.weight" in name:
continue continue
@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module):
name = "model." + name name = "model." + name
is_attention_weight = False is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name: if att_weight_name not in name:
continue continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")] param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3 shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank shard_size * tensor_model_parallel_rank:shard_size *
:shard_size * (tensor_model_parallel_rank + 1)] (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id param_slice = param.data[shard_size * stride_id:shard_size *
:shard_size * (stride_id + 1)] (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight) param_slice.copy_(loaded_weight)
is_attention_weight = True is_attention_weight = True

View File

@ -44,9 +44,9 @@ def hf_model_weights_iterator(
if use_np_cache: if use_np_cache:
# Convert the model weights from torch tensors to numpy arrays for # Convert the model weights from torch tensors to numpy arrays for
# faster loading. # faster loading.
np_folder = os.path.join(hf_folder, 'np') np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True) os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, 'weight_names.json') weight_names_file = os.path.join(np_folder, "weight_names.json")
with lock: with lock:
if not os.path.exists(weight_names_file): if not os.path.exists(weight_names_file):
weight_names = [] weight_names = []
@ -57,10 +57,10 @@ def hf_model_weights_iterator(
with open(param_path, "wb") as f: with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy()) np.save(f, param.cpu().detach().numpy())
weight_names.append(name) weight_names.append(name)
with open(weight_names_file, 'w') as f: with open(weight_names_file, "w") as f:
json.dump(weight_names, f) json.dump(weight_names, f)
with open(weight_names_file, 'r') as f: with open(weight_names_file, "r") as f:
weight_names = json.load(f) weight_names = json.load(f)
for name in weight_names: for name in weight_names:
@ -86,17 +86,16 @@ def load_tensor_parallel_weights(
for p in column_parallel_weight_names: for p in column_parallel_weight_names:
if p in param_name: if p in param_name:
shard_size = param.shape[0] shard_size = param.shape[0]
loaded_weight = loaded_weight[ start_idx = tensor_model_parallel_rank * shard_size
shard_size * tensor_model_parallel_rank end_idx = (tensor_model_parallel_rank + 1) * shard_size
:shard_size * (tensor_model_parallel_rank + 1)] loaded_weight = loaded_weight[start_idx:end_idx]
break break
for p in row_parallel_weight_names: for p in row_parallel_weight_names:
if p in param_name: if p in param_name:
shard_size = param.shape[1] shard_size = param.shape[1]
loaded_weight = loaded_weight[ start_idx = tensor_model_parallel_rank * shard_size
:, end_idx = (tensor_model_parallel_rank + 1) * shard_size
shard_size * tensor_model_parallel_rank loaded_weight = loaded_weight[:, start_idx:end_idx]
:shard_size * (tensor_model_parallel_rank + 1)]
break break
assert param.shape == loaded_weight.shape, ( assert param.shape == loaded_weight.shape, (
f"{param_name} shape mismatch between model and checkpoint: " f"{param_name} shape mismatch between model and checkpoint: "

View File

@ -55,6 +55,7 @@ class RequestOutput:
outputs: The output sequences of the request. outputs: The output sequences of the request.
finished: Whether the whole request is finished. finished: Whether the whole request is finished.
""" """
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
@ -75,8 +76,9 @@ class RequestOutput:
n = seq_group.sampling_params.n n = seq_group.sampling_params.n
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
assert n <= len(seqs) assert n <= len(seqs)
sorted_seqs = sorted( sorted_seqs = sorted(seqs,
seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True) key=lambda seq: seq.get_cumulative_logprob(),
reverse=True)
top_n_seqs = sorted_seqs[:n] top_n_seqs = sorted_seqs[:n]
# Create the outputs. # Create the outputs.

View File

@ -3,6 +3,7 @@ from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
class SamplingParams: class SamplingParams:
"""Sampling parameters for text generation. """Sampling parameters for text generation.
@ -51,7 +52,7 @@ class SamplingParams:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
use_beam_search: bool = False, use_beam_search: bool = False,
stop: Union[str, List[str]] = [], stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: int = 16,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
@ -64,7 +65,12 @@ class SamplingParams:
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.stop = [stop] if isinstance(stop, str) else list(stop) if stop is None:
self.stop = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = list(stop)
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.logprobs = logprobs self.logprobs = logprobs

View File

@ -1,3 +1,4 @@
"""Sequence and its related classes."""
import copy import copy
import enum import enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
@ -7,6 +8,7 @@ from vllm.sampling_params import SamplingParams
class SequenceStatus(enum.Enum): class SequenceStatus(enum.Enum):
"""Status of a sequence."""
WAITING = enum.auto() WAITING = enum.auto()
RUNNING = enum.auto() RUNNING = enum.auto()
SWAPPED = enum.auto() SWAPPED = enum.auto()
@ -21,7 +23,7 @@ class SequenceStatus(enum.Enum):
SequenceStatus.FINISHED_STOPPED, SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED, SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED, SequenceStatus.FINISHED_ABORTED,
SequenceStatus.FINISHED_IGNORED SequenceStatus.FINISHED_IGNORED,
] ]
@staticmethod @staticmethod
@ -40,6 +42,17 @@ class SequenceStatus(enum.Enum):
class SequenceData: class SequenceData:
"""Data associated with a sequence.
Args:
prompt_token_ids: The token IDs of the prompt.
Attributes:
prompt_token_ids: The token IDs of the prompt.
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
def __init__( def __init__(
self, self,
@ -75,6 +88,15 @@ class SequenceData:
class Sequence: class Sequence:
"""Stores the data, status, and block information of a sequence.
Args:
seq_id: The ID of the sequence.
prompt: The prompt of the sequence.
prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
"""
def __init__( def __init__(
self, self,
@ -149,19 +171,27 @@ class Sequence:
def is_finished(self) -> bool: def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status) return SequenceStatus.is_finished(self.status)
def fork(self, child_seq: 'Sequence') -> None: def fork(self, child_seq: "Sequence") -> None:
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks) child_seq.logical_token_blocks = copy.deepcopy(
self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs) child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.data = copy.deepcopy(self.data) child_seq.data = copy.deepcopy(self.data)
return None
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, ' return (f"Sequence(seq_id={self.seq_id}, "
f'status={self.status.name}, ' f"status={self.status.name}, "
f'num_blocks={len(self.logical_token_blocks)})') f"num_blocks={len(self.logical_token_blocks)})")
class SequenceGroup: class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
Args:
request_id: The ID of the request.
seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
"""
def __init__( def __init__(
self, self,
@ -191,7 +221,7 @@ class SequenceGroup:
for seq in self.seqs: for seq in self.seqs:
if seq.seq_id == seq_id: if seq.seq_id == seq_id:
return seq return seq
raise ValueError(f'Sequence {seq_id} not found.') raise ValueError(f"Sequence {seq_id} not found.")
def is_finished(self) -> bool: def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.seqs) return all(seq.is_finished() for seq in self.seqs)
@ -203,14 +233,25 @@ class SequenceGroup:
class SequenceGroupMetadata: class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`.
Args:
request_id: The ID of the request.
is_prompt: Whether the request is at prompt stage.
seq_data: The sequence data. (Seq id -> sequence data)
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
"""
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
is_prompt: bool, is_prompt: bool,
seq_data: Dict[int, SequenceData], # Seq id -> sequence data. seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams, sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers. block_tables: Dict[int, List[int]],
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.is_prompt = is_prompt self.is_prompt = is_prompt
@ -220,13 +261,23 @@ class SequenceGroupMetadata:
class SequenceOutputs: class SequenceOutputs:
"""The model output associated with a sequence.
Args:
seq_id: The ID of the sequence.
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
logprobs: The logprobs of the output token.
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""
def __init__( def __init__(
self, self,
seq_id: int, seq_id: int,
parent_seq_id: int, parent_seq_id: int,
output_token: int, output_token: int,
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i). logprobs: Dict[int, float],
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.parent_seq_id = parent_seq_id self.parent_seq_id = parent_seq_id
@ -234,15 +285,15 @@ class SequenceOutputs:
self.logprobs = logprobs self.logprobs = logprobs
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'SequenceOutputs(seq_id={self.seq_id}, ' return (f"SequenceOutputs(seq_id={self.seq_id}, "
f'parent_seq_id={self.parent_seq_id}, ' f"parent_seq_id={self.parent_seq_id}, "
f'output_token={self.output_token}), ' f"output_token={self.output_token}), "
f'logprobs={self.logprobs}') f"logprobs={self.logprobs}")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs): if not isinstance(other, SequenceOutputs):
return NotImplemented return NotImplemented
return (self.seq_id == other.seq_id and return (self.seq_id == other.seq_id
self.parent_seq_id == other.parent_seq_id and and self.parent_seq_id == other.parent_seq_id
self.output_token == other.output_token and and self.output_token == other.output_token
self.logprobs == other.logprobs) and self.logprobs == other.logprobs)

View File

@ -13,8 +13,8 @@ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def get_tokenizer( def get_tokenizer(
tokenizer_name: str, tokenizer_name: str,
tokenizer_mode: str = "auto",
*args, *args,
tokenizer_mode: str = "auto",
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface.""" """Gets a tokenizer for the given model name via Huggingface."""
@ -73,7 +73,8 @@ def detokenize_incrementally(
output_text = tokenizer.convert_tokens_to_string(output_tokens) output_text = tokenizer.convert_tokens_to_string(output_tokens)
return new_token, output_text return new_token, output_text
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over # NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow # the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple. # even when the loop body is very simple.

View File

@ -17,9 +17,9 @@ class Counter:
self.counter = start self.counter = start
def __next__(self) -> int: def __next__(self) -> int:
id = self.counter i = self.counter
self.counter += 1 self.counter += 1
return id return i
def reset(self) -> None: def reset(self) -> None:
self.counter = 0 self.counter = 0
@ -38,6 +38,7 @@ def get_cpu_memory() -> int:
def random_uuid() -> str: def random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
def in_wsl() -> bool: def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071 # Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()

View File

@ -93,8 +93,8 @@ class CacheEngine:
if not pin_memory: if not pin_memory:
# Pinning memory in WSL is not supported. # Pinning memory in WSL is not supported.
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
logger.warn("Using 'pin_memory=False' as WSL is detected. " logger.warning("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.") "This may slow down the performance.")
for _ in range(self.num_layers): for _ in range(self.num_layers):
key_blocks = torch.empty( key_blocks = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape), size=(self.num_cpu_blocks, *key_block_shape),
@ -120,11 +120,10 @@ class CacheEngine:
src_key_cache, src_value_cache = src[i] src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i] dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks. # Copy the key blocks.
cache_ops.swap_blocks( cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks. # Copy the value blocks.
cache_ops.swap_blocks( cache_ops.swap_blocks(src_value_cache, dst_value_cache,
src_value_cache, dst_value_cache, src_to_dst) src_to_dst)
event = self.events[i] event = self.events[i]
event.record(stream=self.cache_stream) event.record(stream=self.cache_stream)

View File

@ -73,8 +73,8 @@ class Worker:
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.
# Enable top-k sampling to reflect the accurate memory usage. # Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, vocab_size = self.model.config.vocab_size
top_k=self.model.config.vocab_size - 1) sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs max_num_seqs = self.scheduler_config.max_num_seqs
seqs = [] seqs = []
@ -91,7 +91,8 @@ class Worker:
) )
seqs.append(seq) seqs.append(seq)
input_tokens, input_positions, input_metadata = self._prepare_inputs(seqs) input_tokens, input_positions, input_metadata = self._prepare_inputs(
seqs)
# Execute the model. # Execute the model.
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
@ -110,8 +111,9 @@ class Worker:
total_gpu_memory = get_gpu_memory() total_gpu_memory = get_gpu_memory()
cache_block_size = CacheEngine.get_cache_block_size( cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config) block_size, self.model_config, self.parallel_config)
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization num_gpu_blocks = int(
- peak_memory) // cache_block_size) (total_gpu_memory * gpu_memory_utilization - peak_memory) //
cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_cpu_blocks = int(cpu_swap_space // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0) num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)
@ -125,8 +127,8 @@ class Worker:
def init_cache_engine(self, cache_config: CacheConfig) -> None: def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config self.cache_config = cache_config
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.cache_engine = CacheEngine( self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.cache_config, self.model_config, self.parallel_config) self.parallel_config)
self.cache_events = self.cache_engine.events self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache self.gpu_cache = self.cache_engine.gpu_cache
@ -202,8 +204,8 @@ class Worker:
generation_block_tables.append(block_table) generation_block_tables.append(block_table)
max_context_len = max(max_context_len, context_len) max_context_len = max(max_context_len, context_len)
max_num_blocks_per_seq = max( max_num_blocks_per_seq = max(max_num_blocks_per_seq,
max_num_blocks_per_seq, len(block_table)) len(block_table))
context_lens.append(context_len) context_lens.append(context_len)
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
@ -223,7 +225,8 @@ class Worker:
context_lens_tensor = torch.cuda.IntTensor(context_lens) context_lens_tensor = torch.cuda.IntTensor(context_lens)
padded_block_tables = [ padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq) _pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables] for block_table in generation_block_tables
]
block_tables_tensor = torch.cuda.IntTensor(padded_block_tables) block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = {}