[Quality] Add code formatter and linter (#326)

This commit is contained in:
Zhuohan Li 2023-07-03 11:31:55 -07:00 committed by GitHub
parent 0ffded812a
commit d6fa1be3a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 1547 additions and 617 deletions

434
.pylintrc Normal file
View File

@ -0,0 +1,434 @@
# This Pylint rcfile contains a best-effort configuration to uphold the
# best-practices and style described in the Google Python style guide:
# https://google.github.io/styleguide/pyguide.html
#
# Its canonical open-source location is:
# https://google.github.io/styleguide/pylintrc
[MASTER]
# Files or directories to be skipped. They should be base names, not paths.
ignore=docs,parallel_utils
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=
# Pickle collected data for later comparisons.
persistent=no
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=4
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=abstract-method,
apply-builtin,
arguments-differ,
attribute-defined-outside-init,
backtick,
bad-option-value,
basestring-builtin,
buffer-builtin,
c-extension-no-member,
consider-using-enumerate,
cmp-builtin,
cmp-method,
coerce-builtin,
coerce-method,
delslice-method,
div-method,
duplicate-code,
eq-without-hash,
execfile-builtin,
file-builtin,
filter-builtin-not-iterating,
fixme,
getslice-method,
global-statement,
hex-method,
idiv-method,
implicit-str-concat-in-sequence,
import-error,
import-self,
import-star-module-level,
inconsistent-return-statements,
input-builtin,
intern-builtin,
invalid-str-codec,
locally-disabled,
logging-fstring-interpolation, # added by vLLM
logging-not-lazy, # added by vLLM
long-builtin,
long-suffix,
map-builtin-not-iterating,
misplaced-comparison-constant,
missing-class-docstring, # TODO (vLLM): enable
missing-function-docstring,
missing-module-docstring, # TODO (vLLM): enable
metaclass-assignment,
next-method-called,
next-method-defined,
no-absolute-import,
no-else-break,
no-else-continue,
no-else-raise,
no-else-return,
no-init, # added
no-member,
no-name-in-module,
no-self-use,
nonzero-method,
oct-method,
old-division,
old-ne-operator,
old-octal-literal,
old-raise-syntax,
parameter-unpacking,
print-statement,
raising-string,
range-builtin-not-iterating,
raw_input-builtin,
rdiv-method,
reduce-builtin,
relative-import,
reload-builtin,
round-builtin,
setslice-method,
signature-differs,
standarderror-builtin,
suppressed-message,
sys-max-int,
too-few-public-methods,
too-many-ancestors,
too-many-arguments,
too-many-boolean-expressions,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
unichr-builtin,
unicode-builtin,
unnecessary-pass,
unpacking-in-except,
unspecified-encoding,
useless-else-on-loop,
useless-object-inheritance,
useless-suppression,
using-cmp-argument,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[BASIC]
# Good variable names which should always be accepted, separated by a comma
good-names=main,_
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
# Regular expression matching correct function names
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
# Regular expression matching correct variable names
variable-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct constant names
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct attribute names
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
# Regular expression matching correct argument names
argument-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class attribute names
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct inline iteration names
inlinevar-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class names
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
# Regular expression matching correct module names
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
# Regular expression matching correct method names
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=80
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
# lines made too long by directives to pytype.
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=(?x)(
^\s*(\#\ )?<?https?://\S+>?$|
^\s*(from\s+\S+\s+)?import\s+.+$)
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
# Maximum number of lines in a module
max-module-lines=99999
# String used as indentation unit. The internal Google style guide mandates 2
# spaces. Google's externaly-published style guide says 4, consistent with
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
# projects (like TensorFlow).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=TODO
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=yes
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging,absl.logging,tensorflow.io.logging
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,
TERMIOS,
Bastion,
rexec,
sets
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant, absl
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls,
class_
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=StandardError,
Exception,
BaseException

View File

@ -49,12 +49,15 @@ If not, please file a new issue, providing as much relevant information as possi
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
We include a formatting script [`format.sh`](./format.sh) to format the code.
### Pull Requests
When submitting a pull request:
1. Make sure your code has been rebased on top of the latest commit on the main branch.
2. Include a detailed description of the changes in the pull request.
2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
3. Include a detailed description of the changes in the pull request.
Explain why you made the changes you did.
If your pull request fixes an open issue, please include a reference to it in the description.

View File

@ -14,7 +14,9 @@ def clear_line(n: int = 1) -> None:
print(LINE_UP, end=LINE_CLEAR, flush=True)
def post_http_request(prompt: str, api_url: str, n: int = 1,
def post_http_request(prompt: str,
api_url: str,
n: int = 1,
stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
@ -30,7 +32,8 @@ def post_http_request(prompt: str, api_url: str, n: int = 1,
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False,
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))

View File

@ -12,9 +12,14 @@ def http_bot(prompt):
"stream": True,
"max_tokens": 128,
}
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
response = requests.post(args.model_url,
headers=headers,
json=pload,
stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"][0]
@ -23,11 +28,11 @@ def http_bot(prompt):
def build_demo():
with gr.Blocks() as demo:
gr.Markdown(
"# vLLM text completion demo\n"
)
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
gr.Markdown("# vLLM text completion demo\n")
inputbox = gr.Textbox(label="Input",
placeholder="Enter text and press ENTER")
outputbox = gr.Textbox(label="Output",
placeholder="Generated result from the model")
inputbox.submit(http_bot, [inputbox], [outputbox])
return demo
@ -36,7 +41,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate")
parser.add_argument("--model-url",
type=str,
default="http://localhost:8000/generate")
args = parser.parse_args()
demo = build_demo()

View File

@ -14,9 +14,14 @@ def main(args: argparse.Namespace):
("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?",
SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)),
SamplingParams(n=2,
best_of=5,
temperature=0.8,
top_p=0.95,
frequency_penalty=0.1)),
("It is only with the heart that one can see rightly",
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
SamplingParams(n=3, best_of=3, use_beam_search=True,
temperature=0.0)),
]
# Run the engine by calling `engine.step()` manually.

View File

@ -1,6 +1,5 @@
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",

View File

@ -12,8 +12,13 @@ print("Models:", models)
# Test completion API
stream = True
completion = openai.Completion.create(
model=model, prompt="A robot may not injure a human being", echo=False, n=2,
best_of=3, stream=stream, logprobs=3)
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
best_of=3,
stream=stream,
logprobs=3)
# print the completion
if stream:

108
format.sh Executable file
View File

@ -0,0 +1,108 @@
#!/usr/bin/env bash
# YAPF formatter, adapted from ray and skypilot.
#
# Usage:
# # Do work and commit your work.
# # Format files that differ from origin/main.
# bash format.sh
# # Commit changed files with message 'Run yapf and pylint'
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
set -eo pipefail
# this stops git rev-parse from failing if we run this from the .git directory
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1
YAPF_VERSION=$(yapf --version | awk '{print $2}')
PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}')
# # params: tool name, tool version, required version
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2."
exit 1
fi
}
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
YAPF_FLAGS=(
'--recursive'
'--parallel'
)
YAPF_EXCLUDES=(
'--exclude' 'build/**'
'--exclude' 'vllm/model_executor/parallel_utils/**'
)
# Format specified files
format() {
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
}
# Format files that differ from main branch. Ignores dirs that are not slated
# for autoformat yet.
format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause yapf to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
fi
}
# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
}
## This flag formats individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
format "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is formatted.
elif [[ "$1" == '--all' ]]; then
format_all
else
# Format only the files that changed in last commit.
format_changed
fi
echo 'vLLM yapf: Done'
# Run mypy
# TODO(zhuohan): Enable mypy
# echo 'vLLM mypy:'
# mypy
# Run Pylint
echo 'vLLM Pylint:'
pylint vllm
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
echo
git --no-pager diff --name-only
exit 1
fi

View File

@ -1,2 +1,12 @@
mypy
# formatting
yapf==0.32.0
pylint==2.8.2
# type checking
mypy==0.991
types-PyYAML
types-requests
types-setuptools
# testing
pytest

View File

@ -60,7 +60,7 @@ def ref_single_query_cached_kv_attention(
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
scale = 1.0 / (head_size ** 0.5)
scale = 1.0 / (head_size**0.5)
out = ref_masked_attention(q, keys, values, scale)
out = out.view(num_heads, head_size)
output[i].copy_(out, non_blocking=True)
@ -74,7 +74,7 @@ def ref_multi_query_kv_attention(
dtype: torch.dtype,
) -> torch.Tensor:
head_size = query.shape[-1]
scale = 1.0 / (head_size ** 0.5)
scale = 1.0 / (head_size**0.5)
num_seqs = len(cu_seq_lens) - 1
ref_outputs = []
@ -84,8 +84,8 @@ def ref_multi_query_kv_attention(
seq_len = end_idx - start_idx
# Create attention mask.
attn_mask = torch.triu(
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
@ -113,7 +113,7 @@ def ref_multi_query_cached_kv_attention(
num_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
block_size = value_cache.shape[3]
scale = 1.0 / (head_size ** 0.5)
scale = 1.0 / (head_size**0.5)
num_queries = len(cu_query_lens) - 1
ref_outputs = []
@ -125,8 +125,8 @@ def ref_multi_query_cached_kv_attention(
block_table = block_tables[i]
# Create attention mask
attn_mask = torch.triu(
torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5
attn_mask = torch.triu(torch.ones(query_len, context_len),
diagonal=context_len - query_len + 1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
keys = []
@ -165,22 +165,28 @@ def run_single_query_cached_kv_attention(
num_blocks: int,
dtype: torch.dtype,
) -> None:
qkv = torch.empty(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
qkv = torch.empty(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
qkv.uniform_(-1e-3, 1e-3)
query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.empty(
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
key_cache = torch.empty(size=(num_blocks, *key_block_shape),
dtype=dtype,
device='cuda')
key_cache.uniform_(-1e-3, 1e-3)
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.empty(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
value_cache = torch.empty(size=(num_blocks, *value_block_shape),
dtype=dtype,
device='cuda')
value_cache.uniform_(-1e-3, 1e-3)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
@ -194,9 +200,12 @@ def run_single_query_cached_kv_attention(
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
output = torch.empty(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
scale = float(1.0 / (head_size**0.5))
output = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')
attention_ops.single_query_cached_kv_attention(
output,
query,
@ -235,9 +244,13 @@ def run_multi_query_kv_attention(
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size ** 0.5))
qkv = torch.empty(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
scale = float(1.0 / (head_size**0.5))
qkv = torch.empty(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
qkv.uniform_(-1e-3, 1e-3)
query, key, value = qkv.unbind(dim=1)

View File

@ -26,8 +26,9 @@ def run_copy_blocks(
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
for _ in range(num_layers):
key_cache = torch.randn(
size=key_cache_shape, dtype=dtype, device='cuda')
key_cache = torch.randn(size=key_cache_shape,
dtype=dtype,
device='cuda')
key_caches.append(key_cache)
cloned_key_caches = []
for key_cache in key_caches:
@ -36,8 +37,9 @@ def run_copy_blocks(
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
for _ in range(num_layers):
value_cache = torch.randn(
size=value_cache_shape, dtype=dtype, device='cuda')
value_cache = torch.randn(size=value_cache_shape,
dtype=dtype,
device='cuda')
value_caches.append(value_cache)
cloned_value_caches = []
for value_cache in value_caches:
@ -49,15 +51,18 @@ def run_copy_blocks(
# Reference implementation.
for src, dsts in block_mapping.items():
for dst in dsts:
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
for key_cache, cloned_key_cache in zip(key_caches,
cloned_key_caches):
cloned_key_cache[dst] = cloned_key_cache[src]
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
cloned_value_cache[dst] = cloned_value_cache[src]
# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
assert torch.allclose(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
assert torch.allclose(value_cache, cloned_value_cache)
@ -74,8 +79,12 @@ def run_reshape_and_cache(
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
_, key, value = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
@ -84,15 +93,19 @@ def run_reshape_and_cache(
cloned_key_cache = key_cache.clone()
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn(
size=value_cache_shape, dtype=dtype, device='cuda')
value_cache = torch.randn(size=value_cache_shape,
dtype=dtype,
device='cuda')
cloned_value_cache = value_cache.clone()
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping)
for i in range(num_tokens):
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
block_idx = torch.div(slot_mapping[i],
block_size,
rounding_mode='floor')
block_offset = slot_mapping[i] % block_size
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
@ -114,8 +127,12 @@ def run_gather_cached_kv(
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
_, key, value = qkv.unbind(dim=1)
qkv_clone = qkv.clone()
@ -126,15 +143,20 @@ def run_gather_cached_kv(
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn(
size=value_cache_shape, dtype=dtype, device='cuda')
value_cache = torch.randn(size=value_cache_shape,
dtype=dtype,
device='cuda')
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)
cache_ops.gather_cached_kv(key, value, key_cache, value_cache,
slot_mapping)
# Reference implementation.
for i in range(num_tokens):
reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x)
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
reshaped_key = cloned_key.reshape(num_tokens, num_heads,
head_size // x, x)
block_idx = torch.div(slot_mapping[i],
block_size,
rounding_mode='floor')
block_offset = slot_mapping[i] % block_size
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
cloned_value[i] = value_cache[block_idx, :, :, block_offset]
@ -145,20 +167,30 @@ def run_gather_cached_kv(
def test_copy_blocks() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_copy_blocks(
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
block_size=8, num_blocks=1024, dtype=dtype)
run_copy_blocks(num_mappings=23,
num_layers=7,
num_heads=17,
head_size=16,
block_size=8,
num_blocks=1024,
dtype=dtype)
def test_reshape_and_cache() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=dtype)
run_reshape_and_cache(num_tokens=3,
num_heads=2,
head_size=16,
block_size=8,
num_blocks=2,
dtype=dtype)
def test_gather_cached_kv() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_gather_cached_kv(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=dtype)
run_gather_cached_kv(num_tokens=3,
num_heads=2,
head_size=16,
block_size=8,
num_blocks=2,
dtype=dtype)

View File

@ -14,8 +14,10 @@ class RefRMSNorm(nn.Module):
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states

View File

@ -8,8 +8,8 @@ from vllm import pos_encoding_ops
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
@ -38,7 +38,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings.
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
emb = torch.cat((freqs, freqs), dim=-1)
@ -49,16 +49,15 @@ class RefRotaryEmbeddingNeox(nn.Module):
def forward(
self,
positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size]
positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
query_rot = query_rot.transpose(0, 1)
key_rot = key_rot.transpose(0, 1)
@ -85,12 +84,18 @@ def run_rotary_embedding_neox(
dtype: torch.dtype,
base: int = 10000,
) -> None:
positions = torch.randint(0, max_position, (num_tokens,), device='cuda')
query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
query = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device='cuda')
key = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device='cuda')
# Create the rotary embedding.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()

View File

@ -1,3 +1,5 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine

View File

@ -35,7 +35,8 @@ class LogicalTokenBlock:
def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots()
self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids
curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids)
def get_token_ids(self) -> List[int]:

View File

@ -8,7 +8,7 @@ from vllm.utils import get_cpu_memory
logger = init_logger(__name__)
_GiB = 1 << 30
_GB = 1 << 30
class ModelConfig:
@ -106,6 +106,7 @@ class CacheConfig:
vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
"""
def __init__(
self,
block_size: int,
@ -114,7 +115,7 @@ class CacheConfig:
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GiB
self.swap_space_bytes = swap_space * _GB
self._verify_args()
# Will be set after profiling.
@ -137,14 +138,13 @@ class CacheConfig:
num_gpus_per_node = parallel_config.tensor_parallel_size
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
msg = (
f"{cpu_memory_usage / _GiB:.2f} GiB out of "
f"the {total_cpu_memory / _GiB:.2f} GiB total CPU memory is "
"allocated for the swap space.")
msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
"allocated for the swap space.")
if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warn("Possibly too large swap space. " + msg)
logger.warning("Possibly too large swap space. " + msg)
class ParallelConfig:
@ -157,6 +157,7 @@ class ParallelConfig:
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
"""
def __init__(
self,
pipeline_parallel_size: int,
@ -189,12 +190,9 @@ class SchedulerConfig:
max_seq_len: Maximum length of a sequence (including prompt
and generated text).
"""
def __init__(
self,
max_num_batched_tokens: int,
max_num_seqs: int,
max_seq_len: int
) -> None:
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
max_seq_len: int) -> None:
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs
self.max_seq_len = max_seq_len
@ -241,7 +239,7 @@ def _get_and_verify_dtype(
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warn(f"Casting {config_dtype} to {torch_dtype}.")
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:

View File

@ -27,8 +27,9 @@ class BlockAllocator:
# Initialize the free blocks.
self.free_blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks):
block = PhysicalTokenBlock(
device=device, block_number=i, block_size=block_size)
block = PhysicalTokenBlock(device=device,
block_number=i,
block_size=block_size)
self.free_blocks.append(block)
def allocate(self) -> PhysicalTokenBlock:
@ -84,10 +85,12 @@ class BlockSpaceManager:
num_required_blocks = len(seq.logical_token_blocks)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction.
return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks
return (num_free_gpu_blocks - num_required_blocks >=
self.watermark_blocks)
def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same prompt.
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = seq_group.get_seqs()[0]
# Allocate new physical token blocks that will store the prompt tokens.
@ -143,7 +146,8 @@ class BlockSpaceManager:
for block in src_block_table:
block.ref_count += 1
def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
def _get_physical_blocks(
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
# NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set()

View File

@ -43,8 +43,7 @@ class SchedulerOutputs:
assert not (blocks_to_swap_in and blocks_to_swap_out)
def is_empty(self) -> bool:
return (not self.blocks_to_swap_in
and not self.blocks_to_swap_out
return (not self.blocks_to_swap_in and not self.blocks_to_swap_out
and not self.blocks_to_copy)
@ -61,7 +60,7 @@ class Scheduler:
self.log_stats = log_stats
# Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
# Create the block space manager.
self.block_manager = BlockSpaceManager(
block_size=self.cache_config.block_size,
@ -102,7 +101,8 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
def _schedule(
self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
@ -160,7 +160,8 @@ class Scheduler:
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.swapped.pop(0)
@ -170,8 +171,7 @@ class Scheduler:
num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running
)
for seq_group in self.running)
# Join waiting sequences if possible.
prompt_group_ids: List[str] = []
@ -191,7 +191,7 @@ class Scheduler:
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens >= self.scheduler_config.max_seq_len:
logger.warn(
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
" and exceeds limit of "
f"{self.scheduler_config.max_seq_len}")
@ -206,17 +206,19 @@ class Scheduler:
break
# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens
> self.scheduler_config.max_num_batched_tokens):
if (num_batched_tokens + num_prompt_tokens >
self.scheduler_config.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
num_new_seqs = seq_group.num_seqs(
status=SequenceStatus.WAITING)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.waiting.pop(0)
@ -240,12 +242,11 @@ class Scheduler:
elapsed_time = now - self.last_logging_time
if elapsed_time > _LOGGING_INTERVAL_SEC:
self.last_logging_time = now
self.num_input_tokens = [
(t, n) for t, n in self.num_input_tokens
if now - t < _LOGGING_INTERVAL_SEC
]
self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens
if now - t < _LOGGING_INTERVAL_SEC]
if len(self.num_input_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_input_tokens[:-1])
total_num_tokens = sum(n
for _, n in self.num_input_tokens[:-1])
window = now - self.num_input_tokens[0][0]
avg_throughput = total_num_tokens / window
else:
@ -258,26 +259,30 @@ class Scheduler:
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
num_free_cpu_blocks = (
self.block_manager.get_num_free_cpu_blocks())
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
logger.info(
f"Throughput: {avg_throughput:.1f} tokens/s, "
f"Running: {len(self.running)} reqs, "
f"Swapped: {len(self.swapped)} reqs, "
f"Pending: {len(self.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, "
f"Running: {len(self.running)} reqs, "
f"Swapped: {len(self.swapped)} reqs, "
f"Pending: {len(self.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[SequenceGroup]]:
def schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
List[SequenceGroup]]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_outputs, prompt_group_ids, ignored_seq_groups = self._schedule()
(scheduler_outputs, prompt_group_ids,
ignored_seq_groups) = self._schedule()
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
@ -311,8 +316,8 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search).
# Free the current sequence.
# The sequence is a fork of the parent sequence (beam
# search). Free the current sequence.
self.block_manager.free(seq)
# Fork the parent sequence.
parent_seq = seq_group.find(output.parent_seq_id)
@ -385,7 +390,7 @@ class Scheduler:
elif preemption_mode == PreemptionMode.SWAP:
self._preempt_by_swap(seq_group, blocks_to_swap_out)
else:
assert False, 'Invalid preemption mode.'
assert False, "Invalid preemption mode."
def _preempt_by_recompute(
self,

View File

@ -12,11 +12,11 @@ class EngineArgs:
"""Arguments for vLLM engine."""
model: str
tokenizer: Optional[str] = None
tokenizer_mode: str = "auto"
tokenizer_mode: str = 'auto'
download_dir: Optional[str] = None
use_np_weights: bool = False
use_dummy_weights: bool = False
dtype: str = "auto"
dtype: str = 'auto'
seed: int = 0
worker_use_ray: bool = False
pipeline_parallel_size: int = 1
@ -35,76 +35,101 @@ class EngineArgs:
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument('--tokenizer', type=str, default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument('--tokenizer-mode', type=str,
parser.add_argument(
'--model',
type=str,
default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument(
'--tokenizer',
type=str,
default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument('--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--download-dir', type=str,
'tokenizer if available, and "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--download-dir',
type=str,
default=EngineArgs.download_dir,
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
parser.add_argument('--use-np-weights', action='store_true',
'default to the default cache dir of '
'huggingface')
parser.add_argument('--use-np-weights',
action='store_true',
help='save a numpy copy of model weights for '
'faster loading. This can increase the disk '
'usage by up to 2x.')
parser.add_argument('--use-dummy-weights', action='store_true',
'faster loading. This can increase the disk '
'usage by up to 2x.')
parser.add_argument('--use-dummy-weights',
action='store_true',
help='use dummy values for model weights')
# TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=EngineArgs.dtype,
choices=['auto', 'half', 'bfloat16', 'float'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument(
'--dtype',
type=str,
default=EngineArgs.dtype,
choices=['auto', 'half', 'bfloat16', 'float'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
# Parallel arguments
parser.add_argument('--worker-use-ray', action='store_true',
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int,
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=EngineArgs.seed,
parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
help='random seed')
parser.add_argument('--swap-space', type=int,
parser.add_argument('--swap-space',
type=int,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float,
parser.add_argument('--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for'
'the model executor')
parser.add_argument('--max-num-batched-tokens', type=int,
'the model executor')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--max-num-seqs', type=int,
'iteration')
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true',
parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics')
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
@ -115,18 +140,19 @@ class EngineArgs:
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode, self.download_dir,
self.use_np_weights, self.use_dummy_weights, self.dtype, self.seed)
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.download_dir,
self.use_np_weights, self.use_dummy_weights,
self.dtype, self.seed)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space)
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray)
max_seq_len = min(
self.max_num_batched_tokens,
getattr(model_config.hf_config, "max_position_embeddings",
float("inf")))
model_max_len = getattr(model_config.hf_config,
'max_position_embeddings', float('inf'))
max_seq_len = min(self.max_num_batched_tokens, model_max_len)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, max_seq_len)
return model_config, cache_config, parallel_config, scheduler_config
@ -140,12 +166,13 @@ class AsyncEngineArgs(EngineArgs):
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray', action='store_true',
parser.add_argument('--engine-use-ray',
action='store_true',
help='use Ray to start the LLM engine in a '
'separate process as the server process.')
parser.add_argument('--disable-log-requests', action='store_true',
'separate process as the server process.')
parser.add_argument('--disable-log-requests',
action='store_true',
help='disable logging requests')
return parser

View File

@ -11,7 +11,7 @@ from vllm.sampling_params import SamplingParams
logger = init_logger(__name__)
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
class AsyncLLMEngine:
@ -35,8 +35,13 @@ class AsyncLLMEngine:
log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMEngine.
"""
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
log_requests: bool = True, *args, **kwargs) -> None:
def __init__(self,
worker_use_ray: bool,
engine_use_ray: bool,
*args,
log_requests: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray
self.log_requests = log_requests
@ -76,12 +81,11 @@ class AsyncLLMEngine:
self.request_events[request_id].set()
async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None
) -> RequestOutput:
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
@ -117,14 +121,17 @@ class AsyncLLMEngine:
# Add the request into the vLLM engine's waiting queue.
if self.engine_use_ray:
await self.engine.add_request.remote(
request_id, prompt, sampling_params,
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
else:
self.engine.add_request(
request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
self.engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
# The vLLM engine does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
@ -200,7 +207,8 @@ class AsyncLLMEngine:
self.kicking_request_id = None
@classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
def from_engine_args(cls,
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
@ -211,8 +219,9 @@ class AsyncLLMEngine:
# Create the async LLM engine.
engine = cls(engine_args.worker_use_ray,
engine_args.engine_use_ray,
not engine_args.disable_log_requests,
*engine_configs,
distributed_init_method, devices,
distributed_init_method,
devices,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats)
return engine

View File

@ -67,8 +67,7 @@ class LLMEngine:
f"download_dir={model_config.download_dir!r}, "
f"use_np_weights={model_config.use_np_weights}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"seed={model_config.seed})"
)
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
@ -78,8 +77,8 @@ class LLMEngine:
self.log_stats = log_stats
self._verify_args()
self.tokenizer = get_tokenizer(model_config.tokenizer,
model_config.tokenizer_mode)
self.tokenizer = get_tokenizer(
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
self.seq_counter = Counter()
# Create the parallel GPU workers.
@ -129,8 +128,8 @@ class LLMEngine:
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
# FIXME(woosuk): Change to debug log.
logger.info(f'# GPU blocks: {num_gpu_blocks}, '
f'# CPU blocks: {num_cpu_blocks}')
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
@ -152,7 +151,9 @@ class LLMEngine:
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM engine.
engine = cls(*engine_configs, distributed_init_method, devices,
engine = cls(*engine_configs,
distributed_init_method,
devices,
log_stats=not engine_args.disable_log_stats)
return engine
@ -226,8 +227,10 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups):
(seq_group_metadata_list, scheduler_outputs,
ignored_seq_groups) = self.scheduler.schedule()
if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
and (not ignored_seq_groups)):
# Nothing to do.
return []
@ -281,8 +284,8 @@ class LLMEngine:
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(seq,
SequenceStatus.FINISHED_STOPPED)
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
stopped = True
break
if stopped:
@ -290,7 +293,7 @@ class LLMEngine:
# Check if the sequence has reached max_seq_len.
if (seq.get_len() >=
self.scheduler.scheduler_config.max_seq_len):
self.scheduler.scheduler_config.max_seq_len):
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
@ -302,15 +305,15 @@ class LLMEngine:
# Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
self.scheduler.free_seq(seq,
SequenceStatus.FINISHED_STOPPED)
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
continue
def _run_workers(
self,
method: str,
get_all_outputs: bool = False,
*args,
get_all_outputs: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""

View File

@ -8,7 +8,8 @@ except ImportError:
from vllm.config import ParallelConfig
DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id
# rank, node resource (node IP), device id
DeviceID = Tuple[int, Optional[str], int]
def initialize_cluster(
@ -53,15 +54,15 @@ def initialize_cluster(
valid_node_resources = []
num_devices_per_node = None
for node in ray.nodes():
if (not node['Alive']) or node['Resources']['GPU'] <= 0:
if (not node["Alive"]) or node["Resources"]["GPU"] <= 0:
continue
if num_devices_per_node is None:
num_devices_per_node = node['Resources']['GPU']
num_devices_per_node = node["Resources"]["GPU"]
else:
assert num_devices_per_node == node['Resources']['GPU'], (
assert num_devices_per_node == node["Resources"]["GPU"], (
"The number of GPUs per node is not uniform.")
for key in node['Resources']:
if key.startswith('node:'):
for key in node["Resources"]:
if key.startswith("node:"):
valid_node_resources.append(key)
# Verify the parallel config.

View File

@ -11,8 +11,8 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI()
@ -37,8 +37,7 @@ async def generate(request: Request) -> Response:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [
prompt + output.text
for output in request_output.outputs
prompt + output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
@ -63,10 +62,7 @@ async def generate(request: Request) -> Response:
assert final_output is not None
prompt = final_output.prompt
text_outputs = [
prompt + output.text
for output in final_output.outputs
]
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return Response(content=json.dumps(ret))
@ -81,5 +77,8 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -63,8 +63,7 @@ class LLM:
self.request_counter = Counter()
def get_tokenizer(
self,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer
def set_tokenizer(

View File

@ -1,4 +1,5 @@
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import argparse
from http import HTTPStatus
@ -29,7 +30,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
served_model = None
@ -38,14 +39,13 @@ app = fastapi.FastAPI()
def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse:
return JSONResponse(
ErrorResponse(message=message, type="invalid_request_error").dict(),
status_code=status_code.value
)
return JSONResponse(ErrorResponse(message=message,
type="invalid_request_error").dict(),
status_code=status_code.value)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
@ -126,8 +126,11 @@ async def check_length(request, prompt, engine):
@app.get("/v1/models")
async def show_available_models():
"""Show available models. Right now we only have one model."""
model_cards = [ModelCard(id=served_model, root=served_model,
permission=[ModelPermission()])]
model_cards = [
ModelCard(id=served_model,
root=served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards)
@ -144,12 +147,14 @@ def create_logprobs(token_ids: List[int],
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
logprobs.top_logprobs.append(
{tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()})
logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()
})
return logprobs
@ -348,7 +353,7 @@ async def create_completion(raw_request: Request):
if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")
"suffix is not currently supported")
if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine.
@ -387,22 +392,23 @@ async def create_completion(raw_request: Request):
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params,
request_id)
result_generator = engine.generate(prompt, sampling_params, request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream and
(request.best_of is None or request.n == request.best_of) and
not request.use_beam_search)
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json(index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None) -> str:
def create_stream_response_json(
index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None,
) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
@ -443,7 +449,8 @@ async def create_completion(raw_request: Request):
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None
logprobs = (LogProbs()
if request.logprobs is not None else None)
response_json = create_stream_response_json(
index=i,
text="",
@ -487,8 +494,8 @@ async def create_completion(raw_request: Request):
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids)
for output in final_res.outputs)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
@ -506,9 +513,11 @@ async def create_completion(raw_request: Request):
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
@ -517,26 +526,34 @@ async def create_completion(raw_request: Request):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server."
)
parser.add_argument("--host", type=str, default="localhost", help="host name")
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host",
type=str,
default="localhost",
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials"
)
parser.add_argument(
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
)
parser.add_argument(
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
)
parser.add_argument(
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
)
parser.add_argument("--served-model-name", type=str, default=None,
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
"--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
@ -556,7 +573,11 @@ if __name__ == "__main__":
engine = AsyncLLMEngine.from_engine_args(engine_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer, engine_args.tokenizer_mode)
tokenizer = get_tokenizer(engine_args.tokenizer,
tokenizer_mode=engine_args.tokenizer_mode)
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -1,4 +1,5 @@
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
@ -98,7 +99,8 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
class CompletionResponseChoice(BaseModel):

View File

@ -1,9 +1,9 @@
# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
# Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
import logging
import sys
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"

View File

@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.utils import set_random_seed
__all__ = [
"InputMetadata",
"get_model",

View File

@ -8,11 +8,22 @@ from vllm.sequence import SequenceData
class InputMetadata:
"""Metadata for input sequences. Used for PagedAttention.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
block_tables: The block tables. (Seq id -> list of physical block)
"""
def __init__(
self,
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params).
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData.
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData],
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,

View File

@ -6,9 +6,10 @@ from vllm import activation_ops
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
# NOTE: The following GELU functions may introduce small rounding errors.
"gelu_new": nn.GELU(approximate="tanh"),
"gelu_fast": nn.GELU(approximate="tanh"),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
}
@ -25,15 +26,13 @@ class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
Shapes:
x: (num_tokens, 2 * d)
return: (num_tokens, d)
"""
def __init__(self):
super().__init__()
def forward(
self,
x: torch.Tensor, # (num_tokens, 2 * d)
) -> torch.Tensor: # (num_tokens, d)
def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)

View File

@ -14,6 +14,7 @@ _SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
class PagedAttention(nn.Module):
# pylint: disable=line-too-long
"""GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The
@ -54,12 +55,20 @@ class PagedAttention(nn.Module):
def multi_query_kv_attention(
self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: xops.AttentionBias,
) -> torch.Tensor:
"""Normal attention for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
"""
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
@ -76,12 +85,22 @@ class PagedAttention(nn.Module):
def single_query_cached_kv_attention(
self,
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
) -> None:
"""PagedAttention for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
@ -97,16 +116,32 @@ class PagedAttention(nn.Module):
def forward(
self,
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size]
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# NOTE: The query, key, and value tensors must be sliced from a qkv
# tensor of shape [num_tokens, 3 * num_heads * head_size].
) -> torch.Tensor:
"""PagedAttention forward pass.
NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [num_tokens, 3 * num_heads * head_size].
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
@ -136,7 +171,7 @@ class PagedAttention(nn.Module):
# and value vectors will not be cached.
num_valid_tokens = input_metadata.num_valid_tokens
if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None):
and value_cache is not None):
# The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache(
key[:num_valid_tokens],
@ -149,15 +184,12 @@ class PagedAttention(nn.Module):
if input_metadata.num_generation_tokens > 0:
assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when "
"generating tokens."
)
"generating tokens.")
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens],
key_cache,
value_cache,
input_metadata)
query[num_prompt_tokens:num_valid_tokens], key_cache,
value_cache, input_metadata)
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
@ -179,9 +211,9 @@ class PagedAttentionWithRoPE(PagedAttention):
super().__init__(num_heads, head_size, scale)
# Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
@ -195,15 +227,32 @@ class PagedAttentionWithRoPE(PagedAttention):
def forward(
self,
positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
) -> torch.Tensor:
""" PagedAttention forward pass with rotary embedding.
Args:
positions: shape = [num_tokens]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Apply rotary embedding to the query and key before passing them
# to the attention op.
pos_encoding_ops.rotary_embedding_neox(

View File

@ -13,6 +13,7 @@ from vllm.sequence import SequenceOutputs
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs.
@ -50,19 +51,20 @@ class Sampler(nn.Module):
# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties(input_metadata)
presence_penalties, frequency_penalties = _get_penalties(
input_metadata)
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties(
logits, output_tokens, presence_penalties, frequency_penalties,
self.vocab_size)
logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties, self.vocab_size)
# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures):
t = torch.tensor(
temperatures, dtype=logits.dtype, device=logits.device)
t = torch.tensor(temperatures,
dtype=logits.dtype,
device=logits.device)
# Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1))
@ -75,7 +77,9 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0]
if any(p < 1.0 - _SAMPLING_EPS for p in top_ps) or any(k != self.vocab_size for k in top_ks):
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
# Sample the next tokens.
@ -97,8 +101,7 @@ def _prune_hidden_states(
def _get_penalties(
input_metadata: InputMetadata,
) -> Tuple[List[float], List[float]]:
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
# Collect the presence and frequency penalties.
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
@ -117,9 +120,7 @@ def _get_penalties(
return presence_penalties, frequency_penalties
def _get_output_tokens(
input_metadata: InputMetadata,
) -> List[List[int]]:
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, _ = seq_group
@ -169,11 +170,13 @@ def _apply_penalties(
device=logits.device)
frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor(
frequency_penalties, dtype=logits.dtype, device=logits.device)
frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype,
device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor(
presence_penalties, dtype=logits.dtype, device=logits.device)
presence_penalties = torch.tensor(presence_penalties,
dtype=logits.dtype,
device=logits.device)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
@ -183,9 +186,7 @@ def _apply_penalties(
return logits
def _get_temperatures(
input_metadata: InputMetadata,
) -> List[float]:
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# Collect the temperatures for the logits.
temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
@ -252,8 +253,9 @@ def _apply_top_p_top_k(
probs_sort[top_k_mask] = 0.0
# Re-sort the probabilities.
probs = torch.gather(
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
probs = torch.gather(probs_sort,
dim=-1,
index=torch.argsort(probs_idx, dim=-1))
return probs
@ -296,8 +298,9 @@ def _sample_from_prompt(
# Random sampling.
# Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial(
prob, num_samples=num_seqs, replacement=True)
next_token_ids = torch.multinomial(prob,
num_samples=num_seqs,
replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids
@ -315,8 +318,9 @@ def _sample_from_generation_tokens(
if sampling_params.use_beam_search:
# Beam search.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs = torch.tensor(
seq_logprobs, dtype=torch.float, device=logprobs.device)
seq_logprobs = torch.tensor(seq_logprobs,
dtype=torch.float,
device=logprobs.device)
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
vocab_size = logprobs.size(-1)
@ -353,8 +357,9 @@ def _sample_from_generation_tokens(
else:
# Random sampling.
# Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial(
probs, num_samples=1, replacement=True)
next_token_ids = torch.multinomial(probs,
num_samples=1,
replacement=True)
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
parent_seq_ids = seq_ids
return parent_seq_ids, next_token_ids
@ -381,15 +386,16 @@ def _sample(
# Sample the next tokens.
next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(
logprob, sampling_params.logprobs)
next_logprobs = _get_topk_logprobs(logprob,
sampling_params.logprobs)
# Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id, seq_id, next_token_id, output_logprobs)
seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
next_token_id,
output_logprobs)
else:
# Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)]
@ -399,22 +405,24 @@ def _sample(
# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids]
for seq_id in seq_ids
]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs: Dict[int, Dict[int, float]] = {}
for i, seq_id in enumerate(seq_ids):
for j, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs(
logprob[i], sampling_params.logprobs)
logprob[j], sampling_params.logprobs)
# Build the output.
for seq_id, parent_seq_id, next_token_id in zip(
seq_ids, parent_seq_ids, next_token_ids):
i = seq_ids.index(parent_seq_id)
seq_ids, parent_seq_ids, next_token_ids):
j = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
output_logprobs[next_token_id] = logprob[j,
next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id,
parent_seq_id,

View File

@ -6,8 +6,9 @@ import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import ModelConfig
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM, GPTNeoXForCausalLM,
LlamaForCausalLM, OPTForCausalLM)
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM,
GPTNeoXForCausalLM, LlamaForCausalLM,
OPTForCausalLM)
from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes.
@ -28,8 +29,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
return _MODEL_REGISTRY[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
)
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
def get_model(model_config: ModelConfig) -> nn.Module:
@ -46,8 +46,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
model.load_weights(
model_config.model, model_config.download_dir,
model_config.use_np_weights)
model.load_weights(model_config.model, model_config.download_dir,
model_config.use_np_weights)
model = model.cuda()
return model.eval()

View File

@ -4,8 +4,6 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
__all__ = [
"GPT2LMHeadModel",
"GPTBigCodeForCausalLM",

View File

@ -1,5 +1,6 @@
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
@ -47,19 +48,25 @@ class GPT2Attention(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim ** -0.5
self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size,
bias=True, gather_output=False,
self.c_attn = ColumnParallelLinear(self.hidden_size,
3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size,
bias=True, input_is_parallel=True,
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim,
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale)
def forward(
@ -72,8 +79,8 @@ class GPT2Attention(nn.Module):
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output)
return attn_output
@ -87,11 +94,15 @@ class GPT2MLP(nn.Module):
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
bias=True, gather_output=False,
self.c_fc = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=True, input_is_parallel=True,
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.act = get_act_fn(config.activation_function)
@ -107,7 +118,8 @@ class GPT2Block(nn.Module):
def __init__(self, config: GPT2Config):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config)
@ -145,9 +157,9 @@ class GPT2Model(nn.Module):
def __init__(self, config: GPT2Config):
super().__init__()
self.config = config
assert config.add_cross_attention == False
assert config.scale_attn_by_inverse_layer_idx == False
assert config.reorder_and_upcast_attn == False
assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx
assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
@ -180,8 +192,8 @@ class GPT2Model(nn.Module):
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
hidden_states = self.ln_f(hidden_states)
return hidden_states
@ -206,24 +218,26 @@ class GPT2LMHeadModel(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head_weight, hidden_states, input_metadata)
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self, model_name_or_path: str,
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
@ -248,16 +262,20 @@ class GPT2LMHeadModel(nn.Module):
if name == "transformer.wte.weight":
# Consider padding in the vocab size.
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
padded_vocab_size = (param.shape[0] *
tensor_model_parallel_world_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along the head dimension.
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
@ -266,11 +284,13 @@ class GPT2LMHeadModel(nn.Module):
head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:

View File

@ -1,5 +1,6 @@
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
@ -49,19 +50,25 @@ class GPTBigCodeAttention(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim ** -0.5
self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size,
bias=True, gather_output=False,
self.c_attn = ColumnParallelLinear(self.hidden_size,
3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size,
bias=True, input_is_parallel=True,
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim,
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale)
def forward(
@ -74,8 +81,8 @@ class GPTBigCodeAttention(nn.Module):
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output)
return attn_output
@ -89,11 +96,15 @@ class GPTBigMLP(nn.Module):
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
bias=True, gather_output=False,
self.c_fc = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=True, input_is_parallel=True,
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.act = get_act_fn(config.activation_function)
@ -109,7 +120,8 @@ class GPTBigCodeBlock(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config)
@ -147,7 +159,7 @@ class GPTBigCodeModel(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
super().__init__()
self.config = config
assert config.add_cross_attention == False
assert not config.add_cross_attention
self.embed_dim = config.hidden_size
@ -181,8 +193,8 @@ class GPTBigCodeModel(nn.Module):
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
hidden_states = self.ln_f(hidden_states)
return hidden_states
@ -207,24 +219,26 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head_weight, hidden_states, input_metadata)
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self, model_name_or_path: str,
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
@ -241,9 +255,11 @@ class GPTBigCodeForCausalLM(nn.Module):
if name == "transformer.wte.weight":
# Consider padding in the vocab size.
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
padded_vocab_size = param.shape[
0] * tensor_model_parallel_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
@ -258,25 +274,31 @@ class GPTBigCodeForCausalLM(nn.Module):
qkv_array = qkv_array.numpy()
dims_q = n_head * head_dim
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0)
# q is fine, but k & v have not replicated shape along the first axis
# as long as MQA is not nativly supported, increase memory and replicated
# (head_dim, hidden_dim) to (n_heads * head_dim, hidden_dim)
# pylint: disable=unbalanced-tuple-unpacking
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
axis=0)
# q is fine, but k & v have not replicated shape along the first
# axis as long as MQA is not nativly supported, increase memory
# and replicated (head_dim, hidden_dim) to
# (n_heads * head_dim, hidden_dim)
if k.ndim == 2 and v.ndim == 2:
replication = (n_head, 1) # weights
else:
replication = n_head # biases
# replicate n_head times for q, v
k, v = np.tile(k, replication), np.tile(v, replication)
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim)
# concat q, k, v along the first axis
# (n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim)
qkv_array = np.concatenate((q, k, v), axis=0)
return torch.from_numpy(qkv_array)
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along the head dimension.
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
@ -285,13 +307,19 @@ class GPTBigCodeForCausalLM(nn.Module):
head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"):
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
loaded_weight = _expand_mqa_mha(loaded_weight,
n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
loaded_weight = _expand_mqa_mha(loaded_weight,
n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:

View File

@ -1,5 +1,6 @@
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
@ -48,19 +49,23 @@ class GPTNeoXAttention(nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.query_key_value = ColumnParallelLinear(config.hidden_size,
3 * config.hidden_size,
gather_output=False,
perform_initialization=False)
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size,
self.query_key_value = ColumnParallelLinear(
config.hidden_size,
3 * config.hidden_size,
gather_output=False,
perform_initialization=False)
self.dense = RowParallelLinear(config.hidden_size,
config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
scaling = self.head_size ** -0.5
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
@ -78,8 +83,8 @@ class GPTNeoXAttention(nn.Module):
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event)
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.dense(attn_output)
return output
@ -92,7 +97,8 @@ class GPTNeoXMLP(nn.Module):
config.intermediate_size,
gather_output=False,
perform_initialization=False)
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
self.act = get_act_fn(config.hidden_act)
@ -109,8 +115,10 @@ class GPTNeoXLayer(nn.Module):
def __init__(self, config: GPTNeoXConfig):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config)
self.mlp = GPTNeoXMLP(config)
@ -154,10 +162,13 @@ class GPTNeoXModel(nn.Module):
super().__init__()
self.config = config
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
self.embed_in = VocabParallelEmbedding(config.vocab_size,
config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layers = nn.ModuleList(
[GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
@ -191,8 +202,10 @@ class GPTNeoXForCausalLM(nn.Module):
super().__init__()
self.config = config
self.gpt_neox = GPTNeoXModel(config)
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size,
bias=False, gather_output=False,
self.embed_out = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
@ -204,24 +217,28 @@ class GPTNeoXForCausalLM(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.gpt_neox(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.embed_out.weight, hidden_states, input_metadata)
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
_column_parallel_weights = [
"embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, model_name_or_path: str,
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, use_np_cache):
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
or "rotary_emb.inv_freq" in name):
continue
param = state_dict[name]
if "query_key_value" in name:
@ -230,17 +247,19 @@ class GPTNeoXForCausalLM(nn.Module):
# required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion.
shard_size = param.shape[0]
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // num_heads
if 'query_key_value.weight' in name:
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size)
if "query_key_value.weight" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size,
hidden_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif 'query_key_value.bias' in name:
elif "query_key_value.bias" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)

View File

@ -1,5 +1,6 @@
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
@ -30,7 +31,6 @@ import torch
from torch import nn
from transformers import LlamaConfig
from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
@ -56,15 +56,19 @@ class LlamaMLP(nn.Module):
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
bias=False, gather_output=False,
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
if hidden_act != 'silu':
raise ValueError(f'Unsupported activation: {hidden_act}. '
'Only silu is supported for now.')
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
@ -83,12 +87,14 @@ class LlamaAttention(nn.Module):
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(
hidden_size,
@ -104,8 +110,10 @@ class LlamaAttention(nn.Module):
input_is_parallel=True,
perform_initialization=False,
)
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim,
self.scaling, rotary_dim=self.head_dim)
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
def forward(
self,
@ -118,8 +126,8 @@ class LlamaAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
@ -138,8 +146,10 @@ class LlamaDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
@ -177,9 +187,13 @@ class LlamaModel(nn.Module):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
@ -209,6 +223,7 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
@ -228,39 +243,42 @@ class LlamaForCausalLM(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head.weight, hidden_states, input_metadata)
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
"qkv_proj.weight", "gate_proj.weight",
"up_proj.weight"]
_column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
"gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, model_name_or_path: str,
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name:
continue
is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
@ -275,10 +293,10 @@ class LlamaForCausalLM(nn.Module):
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True

View File

@ -1,7 +1,9 @@
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Copyright 2023 The vLLM team.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
# OPT is set up so that if padding_idx is specified then offset the
# embedding ids by 2 and adjust num_embeddings appropriately. Other
# models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
@ -62,20 +65,26 @@ class OPTAttention(nn.Module):
) -> None:
super().__init__()
self.embed_dim = embed_dim
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
total_num_heads = num_heads
assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
self.qkv_proj = ColumnParallelLinear(embed_dim,
3 * embed_dim,
bias=bias,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
self.out_proj = RowParallelLinear(embed_dim,
embed_dim,
bias=bias,
input_is_parallel=True,
perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim,
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling)
def forward(
@ -88,8 +97,8 @@ class OPTAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
output, _ = self.out_proj(attn_output)
return output
@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module):
self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim,
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
gather_output=False,
perform_initialization=False)
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim,
self.fc2 = RowParallelLinear(config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
input_is_parallel=True,
perform_initialization=False)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
def forward(
self,
@ -133,11 +146,10 @@ class OPTDecoderLayer(nn.Module):
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event)
hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
@ -167,35 +179,42 @@ class OPTDecoder(nn.Module):
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.word_embed_proj_dim,
perform_initialization=False)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.word_embed_proj_dim,
perform_initialization=False)
# Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size)
# Project out & in will be replicated if they exist.
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
self.project_out = nn.Linear(config.hidden_size,
config.word_embed_proj_dim,
bias=False)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
self.project_in = nn.Linear(config.word_embed_proj_dim,
config.hidden_size,
bias=False)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# Note that the only purpose of `config._remove_final_layer_norm` is to
# keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
)
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList(
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
self,
@ -217,8 +236,8 @@ class OPTDecoder(nn.Module):
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states, kv_caches[i], input_metadata, cache_event)
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
@ -241,8 +260,8 @@ class OPTModel(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
return self.decoder(
input_ids, positions, kv_caches, input_metadata, cache_events)
return self.decoder(input_ids, positions, kv_caches, input_metadata,
cache_events)
class OPTForCausalLM(nn.Module):
@ -264,23 +283,26 @@ class OPTForCausalLM(nn.Module):
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head_weight, hidden_states, input_metadata)
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
_column_parallel_weights = [
"embed_tokens.weight", "fc1.weight", "fc1.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
def load_weights(self, model_name_or_path: str,
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name:
continue
@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module):
name = "model." + name
is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id
:shard_size * (stride_id + 1)]
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True

View File

@ -44,9 +44,9 @@ def hf_model_weights_iterator(
if use_np_cache:
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, 'np')
np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, 'weight_names.json')
weight_names_file = os.path.join(np_folder, "weight_names.json")
with lock:
if not os.path.exists(weight_names_file):
weight_names = []
@ -57,10 +57,10 @@ def hf_model_weights_iterator(
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, 'w') as f:
with open(weight_names_file, "w") as f:
json.dump(weight_names, f)
with open(weight_names_file, 'r') as f:
with open(weight_names_file, "r") as f:
weight_names = json.load(f)
for name in weight_names:
@ -86,17 +86,16 @@ def load_tensor_parallel_weights(
for p in column_parallel_weight_names:
if p in param_name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[start_idx:end_idx]
break
for p in row_parallel_weight_names:
if p in param_name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
loaded_weight = loaded_weight[:, start_idx:end_idx]
break
assert param.shape == loaded_weight.shape, (
f"{param_name} shape mismatch between model and checkpoint: "

View File

@ -55,6 +55,7 @@ class RequestOutput:
outputs: The output sequences of the request.
finished: Whether the whole request is finished.
"""
def __init__(
self,
request_id: str,
@ -75,8 +76,9 @@ class RequestOutput:
n = seq_group.sampling_params.n
seqs = seq_group.get_seqs()
assert n <= len(seqs)
sorted_seqs = sorted(
seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True)
sorted_seqs = sorted(seqs,
key=lambda seq: seq.get_cumulative_logprob(),
reverse=True)
top_n_seqs = sorted_seqs[:n]
# Create the outputs.

View File

@ -3,6 +3,7 @@ from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5
class SamplingParams:
"""Sampling parameters for text generation.
@ -51,7 +52,7 @@ class SamplingParams:
top_p: float = 1.0,
top_k: int = -1,
use_beam_search: bool = False,
stop: Union[str, List[str]] = [],
stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
@ -64,7 +65,12 @@ class SamplingParams:
self.top_p = top_p
self.top_k = top_k
self.use_beam_search = use_beam_search
self.stop = [stop] if isinstance(stop, str) else list(stop)
if stop is None:
self.stop = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = list(stop)
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.logprobs = logprobs

View File

@ -1,3 +1,4 @@
"""Sequence and its related classes."""
import copy
import enum
from typing import Dict, List, Optional, Union
@ -7,6 +8,7 @@ from vllm.sampling_params import SamplingParams
class SequenceStatus(enum.Enum):
"""Status of a sequence."""
WAITING = enum.auto()
RUNNING = enum.auto()
SWAPPED = enum.auto()
@ -21,7 +23,7 @@ class SequenceStatus(enum.Enum):
SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED,
SequenceStatus.FINISHED_IGNORED
SequenceStatus.FINISHED_IGNORED,
]
@staticmethod
@ -40,6 +42,17 @@ class SequenceStatus(enum.Enum):
class SequenceData:
"""Data associated with a sequence.
Args:
prompt_token_ids: The token IDs of the prompt.
Attributes:
prompt_token_ids: The token IDs of the prompt.
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
def __init__(
self,
@ -75,6 +88,15 @@ class SequenceData:
class Sequence:
"""Stores the data, status, and block information of a sequence.
Args:
seq_id: The ID of the sequence.
prompt: The prompt of the sequence.
prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
"""
def __init__(
self,
@ -149,19 +171,27 @@ class Sequence:
def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status)
def fork(self, child_seq: 'Sequence') -> None:
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
def fork(self, child_seq: "Sequence") -> None:
child_seq.logical_token_blocks = copy.deepcopy(
self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.data = copy.deepcopy(self.data)
return None
def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, '
f'status={self.status.name}, '
f'num_blocks={len(self.logical_token_blocks)})')
return (f"Sequence(seq_id={self.seq_id}, "
f"status={self.status.name}, "
f"num_blocks={len(self.logical_token_blocks)})")
class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
Args:
request_id: The ID of the request.
seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
"""
def __init__(
self,
@ -191,7 +221,7 @@ class SequenceGroup:
for seq in self.seqs:
if seq.seq_id == seq_id:
return seq
raise ValueError(f'Sequence {seq_id} not found.')
raise ValueError(f"Sequence {seq_id} not found.")
def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.seqs)
@ -203,14 +233,25 @@ class SequenceGroup:
class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`.
Args:
request_id: The ID of the request.
is_prompt: Whether the request is at prompt stage.
seq_data: The sequence data. (Seq id -> sequence data)
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
"""
def __init__(
self,
request_id: str,
is_prompt: bool,
seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
block_tables: Dict[int, List[int]],
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
@ -220,13 +261,23 @@ class SequenceGroupMetadata:
class SequenceOutputs:
"""The model output associated with a sequence.
Args:
seq_id: The ID of the sequence.
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
logprobs: The logprobs of the output token.
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""
def __init__(
self,
seq_id: int,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i).
logprobs: Dict[int, float],
) -> None:
self.seq_id = seq_id
self.parent_seq_id = parent_seq_id
@ -234,15 +285,15 @@ class SequenceOutputs:
self.logprobs = logprobs
def __repr__(self) -> str:
return (f'SequenceOutputs(seq_id={self.seq_id}, '
f'parent_seq_id={self.parent_seq_id}, '
f'output_token={self.output_token}), '
f'logprobs={self.logprobs}')
return (f"SequenceOutputs(seq_id={self.seq_id}, "
f"parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}), "
f"logprobs={self.logprobs}")
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs):
return NotImplemented
return (self.seq_id == other.seq_id and
self.parent_seq_id == other.parent_seq_id and
self.output_token == other.output_token and
self.logprobs == other.logprobs)
return (self.seq_id == other.seq_id
and self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token
and self.logprobs == other.logprobs)

View File

@ -13,8 +13,8 @@ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def get_tokenizer(
tokenizer_name: str,
tokenizer_mode: str = "auto",
*args,
tokenizer_mode: str = "auto",
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
@ -73,7 +73,8 @@ def detokenize_incrementally(
output_text = tokenizer.convert_tokens_to_string(output_tokens)
return new_token, output_text
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.

View File

@ -17,9 +17,9 @@ class Counter:
self.counter = start
def __next__(self) -> int:
id = self.counter
i = self.counter
self.counter += 1
return id
return i
def reset(self) -> None:
self.counter = 0
@ -38,6 +38,7 @@ def get_cpu_memory() -> int:
def random_uuid() -> str:
return str(uuid.uuid4().hex)
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower()

View File

@ -93,8 +93,8 @@ class CacheEngine:
if not pin_memory:
# Pinning memory in WSL is not supported.
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
logger.warn("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.")
logger.warning("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.")
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape),
@ -120,11 +120,10 @@ class CacheEngine:
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.swap_blocks(
src_key_cache, dst_key_cache, src_to_dst)
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
cache_ops.swap_blocks(
src_value_cache, dst_value_cache, src_to_dst)
cache_ops.swap_blocks(src_value_cache, dst_value_cache,
src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)

View File

@ -73,8 +73,8 @@ class Worker:
# number of tokens equal to max_num_batched_tokens.
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99,
top_k=self.model.config.vocab_size - 1)
vocab_size = self.model.config.vocab_size
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
seqs = []
@ -91,7 +91,8 @@ class Worker:
)
seqs.append(seq)
input_tokens, input_positions, input_metadata = self._prepare_inputs(seqs)
input_tokens, input_positions, input_metadata = self._prepare_inputs(
seqs)
# Execute the model.
num_layers = self.model_config.get_num_layers(self.parallel_config)
@ -110,8 +111,9 @@ class Worker:
total_gpu_memory = get_gpu_memory()
cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config)
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
- peak_memory) // cache_block_size)
num_gpu_blocks = int(
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
@ -125,8 +127,8 @@ class Worker:
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.block_size = cache_config.block_size
self.cache_engine = CacheEngine(
self.cache_config, self.model_config, self.parallel_config)
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache
@ -202,8 +204,8 @@ class Worker:
generation_block_tables.append(block_table)
max_context_len = max(max_context_len, context_len)
max_num_blocks_per_seq = max(
max_num_blocks_per_seq, len(block_table))
max_num_blocks_per_seq = max(max_num_blocks_per_seq,
len(block_table))
context_lens.append(context_len)
block_number = block_table[position // self.block_size]
@ -223,7 +225,8 @@ class Worker:
context_lens_tensor = torch.cuda.IntTensor(context_lens)
padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables]
for block_table in generation_block_tables
]
block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
seq_data: Dict[int, SequenceData] = {}