[Quality] Add code formatter and linter (#326)
This commit is contained in:
parent
0ffded812a
commit
d6fa1be3a8
434
.pylintrc
Normal file
434
.pylintrc
Normal file
@ -0,0 +1,434 @@
|
||||
# This Pylint rcfile contains a best-effort configuration to uphold the
|
||||
# best-practices and style described in the Google Python style guide:
|
||||
# https://google.github.io/styleguide/pyguide.html
|
||||
#
|
||||
# Its canonical open-source location is:
|
||||
# https://google.github.io/styleguide/pylintrc
|
||||
|
||||
[MASTER]
|
||||
|
||||
# Files or directories to be skipped. They should be base names, not paths.
|
||||
ignore=docs,parallel_utils
|
||||
|
||||
# Files or directories matching the regex patterns are skipped. The regex
|
||||
# matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=no
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=4
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||
confidence=
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
#enable=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=abstract-method,
|
||||
apply-builtin,
|
||||
arguments-differ,
|
||||
attribute-defined-outside-init,
|
||||
backtick,
|
||||
bad-option-value,
|
||||
basestring-builtin,
|
||||
buffer-builtin,
|
||||
c-extension-no-member,
|
||||
consider-using-enumerate,
|
||||
cmp-builtin,
|
||||
cmp-method,
|
||||
coerce-builtin,
|
||||
coerce-method,
|
||||
delslice-method,
|
||||
div-method,
|
||||
duplicate-code,
|
||||
eq-without-hash,
|
||||
execfile-builtin,
|
||||
file-builtin,
|
||||
filter-builtin-not-iterating,
|
||||
fixme,
|
||||
getslice-method,
|
||||
global-statement,
|
||||
hex-method,
|
||||
idiv-method,
|
||||
implicit-str-concat-in-sequence,
|
||||
import-error,
|
||||
import-self,
|
||||
import-star-module-level,
|
||||
inconsistent-return-statements,
|
||||
input-builtin,
|
||||
intern-builtin,
|
||||
invalid-str-codec,
|
||||
locally-disabled,
|
||||
logging-fstring-interpolation, # added by vLLM
|
||||
logging-not-lazy, # added by vLLM
|
||||
long-builtin,
|
||||
long-suffix,
|
||||
map-builtin-not-iterating,
|
||||
misplaced-comparison-constant,
|
||||
missing-class-docstring, # TODO (vLLM): enable
|
||||
missing-function-docstring,
|
||||
missing-module-docstring, # TODO (vLLM): enable
|
||||
metaclass-assignment,
|
||||
next-method-called,
|
||||
next-method-defined,
|
||||
no-absolute-import,
|
||||
no-else-break,
|
||||
no-else-continue,
|
||||
no-else-raise,
|
||||
no-else-return,
|
||||
no-init, # added
|
||||
no-member,
|
||||
no-name-in-module,
|
||||
no-self-use,
|
||||
nonzero-method,
|
||||
oct-method,
|
||||
old-division,
|
||||
old-ne-operator,
|
||||
old-octal-literal,
|
||||
old-raise-syntax,
|
||||
parameter-unpacking,
|
||||
print-statement,
|
||||
raising-string,
|
||||
range-builtin-not-iterating,
|
||||
raw_input-builtin,
|
||||
rdiv-method,
|
||||
reduce-builtin,
|
||||
relative-import,
|
||||
reload-builtin,
|
||||
round-builtin,
|
||||
setslice-method,
|
||||
signature-differs,
|
||||
standarderror-builtin,
|
||||
suppressed-message,
|
||||
sys-max-int,
|
||||
too-few-public-methods,
|
||||
too-many-ancestors,
|
||||
too-many-arguments,
|
||||
too-many-boolean-expressions,
|
||||
too-many-branches,
|
||||
too-many-instance-attributes,
|
||||
too-many-locals,
|
||||
too-many-nested-blocks,
|
||||
too-many-public-methods,
|
||||
too-many-return-statements,
|
||||
too-many-statements,
|
||||
trailing-newlines,
|
||||
unichr-builtin,
|
||||
unicode-builtin,
|
||||
unnecessary-pass,
|
||||
unpacking-in-except,
|
||||
unspecified-encoding,
|
||||
useless-else-on-loop,
|
||||
useless-object-inheritance,
|
||||
useless-suppression,
|
||||
using-cmp-argument,
|
||||
wrong-import-order,
|
||||
xrange-builtin,
|
||||
zip-builtin-not-iterating,
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, msvs
|
||||
# (visual studio) and html. You can also give a reporter class, eg
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Tells whether to display a full report or only the messages
|
||||
reports=no
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details
|
||||
#msg-template=
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma
|
||||
good-names=main,_
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma
|
||||
bad-names=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name
|
||||
include-naming-hint=no
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
|
||||
|
||||
# Regular expression matching correct function names
|
||||
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression matching correct variable names
|
||||
variable-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct constant names
|
||||
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct attribute names
|
||||
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct argument names
|
||||
argument-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class attribute names
|
||||
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct inline iteration names
|
||||
inlinevar-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class names
|
||||
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
|
||||
|
||||
# Regular expression matching correct module names
|
||||
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
|
||||
|
||||
# Regular expression matching correct method names
|
||||
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=10
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=80
|
||||
|
||||
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
|
||||
# lines made too long by directives to pytype.
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=(?x)(
|
||||
^\s*(\#\ )?<?https?://\S+>?$|
|
||||
^\s*(from\s+\S+\s+)?import\s+.+$)
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=yes
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=99999
|
||||
|
||||
# String used as indentation unit. The internal Google style guide mandates 2
|
||||
# spaces. Google's externaly-published style guide says 4, consistent with
|
||||
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
|
||||
# projects (like TensorFlow).
|
||||
indent-string=' '
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=TODO
|
||||
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether inconsistent-quotes generates a warning when the
|
||||
# character used as a quote delimiter is used inconsistently within a module.
|
||||
check-quote-consistency=yes
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||
# not used).
|
||||
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid to define new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,_cb
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format
|
||||
logging-modules=logging,absl.logging,tensorflow.io.logging
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma
|
||||
deprecated-modules=regsub,
|
||||
TERMIOS,
|
||||
Bastion,
|
||||
rexec,
|
||||
sets
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled)
|
||||
import-graph=
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant, absl
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls,
|
||||
class_
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "Exception"
|
||||
overgeneral-exceptions=StandardError,
|
||||
Exception,
|
||||
BaseException
|
@ -49,12 +49,15 @@ If not, please file a new issue, providing as much relevant information as possi
|
||||
|
||||
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
||||
|
||||
We include a formatting script [`format.sh`](./format.sh) to format the code.
|
||||
|
||||
### Pull Requests
|
||||
|
||||
When submitting a pull request:
|
||||
|
||||
1. Make sure your code has been rebased on top of the latest commit on the main branch.
|
||||
2. Include a detailed description of the changes in the pull request.
|
||||
2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
|
||||
3. Include a detailed description of the changes in the pull request.
|
||||
Explain why you made the changes you did.
|
||||
If your pull request fixes an open issue, please include a reference to it in the description.
|
||||
|
||||
|
@ -14,7 +14,9 @@ def clear_line(n: int = 1) -> None:
|
||||
print(LINE_UP, end=LINE_CLEAR, flush=True)
|
||||
|
||||
|
||||
def post_http_request(prompt: str, api_url: str, n: int = 1,
|
||||
def post_http_request(prompt: str,
|
||||
api_url: str,
|
||||
n: int = 1,
|
||||
stream: bool = False) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
pload = {
|
||||
@ -30,7 +32,8 @@ def post_http_request(prompt: str, api_url: str, n: int = 1,
|
||||
|
||||
|
||||
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
|
||||
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False,
|
||||
for chunk in response.iter_lines(chunk_size=8192,
|
||||
decode_unicode=False,
|
||||
delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
|
@ -12,9 +12,14 @@ def http_bot(prompt):
|
||||
"stream": True,
|
||||
"max_tokens": 128,
|
||||
}
|
||||
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
|
||||
response = requests.post(args.model_url,
|
||||
headers=headers,
|
||||
json=pload,
|
||||
stream=True)
|
||||
|
||||
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
||||
for chunk in response.iter_lines(chunk_size=8192,
|
||||
decode_unicode=False,
|
||||
delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
output = data["text"][0]
|
||||
@ -23,11 +28,11 @@ def http_bot(prompt):
|
||||
|
||||
def build_demo():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(
|
||||
"# vLLM text completion demo\n"
|
||||
)
|
||||
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
|
||||
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
|
||||
gr.Markdown("# vLLM text completion demo\n")
|
||||
inputbox = gr.Textbox(label="Input",
|
||||
placeholder="Enter text and press ENTER")
|
||||
outputbox = gr.Textbox(label="Output",
|
||||
placeholder="Generated result from the model")
|
||||
inputbox.submit(http_bot, [inputbox], [outputbox])
|
||||
return demo
|
||||
|
||||
@ -36,7 +41,9 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate")
|
||||
parser.add_argument("--model-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/generate")
|
||||
args = parser.parse_args()
|
||||
|
||||
demo = build_demo()
|
||||
|
@ -14,9 +14,14 @@ def main(args: argparse.Namespace):
|
||||
("To be or not to be,",
|
||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||
("What is the meaning of life?",
|
||||
SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)),
|
||||
SamplingParams(n=2,
|
||||
best_of=5,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
frequency_penalty=0.1)),
|
||||
("It is only with the heart that one can see rightly",
|
||||
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
|
||||
SamplingParams(n=3, best_of=3, use_beam_search=True,
|
||||
temperature=0.0)),
|
||||
]
|
||||
|
||||
# Run the engine by calling `engine.step()` manually.
|
||||
|
@ -1,6 +1,5 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
|
@ -12,8 +12,13 @@ print("Models:", models)
|
||||
# Test completion API
|
||||
stream = True
|
||||
completion = openai.Completion.create(
|
||||
model=model, prompt="A robot may not injure a human being", echo=False, n=2,
|
||||
best_of=3, stream=stream, logprobs=3)
|
||||
model=model,
|
||||
prompt="A robot may not injure a human being",
|
||||
echo=False,
|
||||
n=2,
|
||||
best_of=3,
|
||||
stream=stream,
|
||||
logprobs=3)
|
||||
|
||||
# print the completion
|
||||
if stream:
|
||||
|
108
format.sh
Executable file
108
format.sh
Executable file
@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env bash
|
||||
# YAPF formatter, adapted from ray and skypilot.
|
||||
#
|
||||
# Usage:
|
||||
# # Do work and commit your work.
|
||||
|
||||
# # Format files that differ from origin/main.
|
||||
# bash format.sh
|
||||
|
||||
# # Commit changed files with message 'Run yapf and pylint'
|
||||
#
|
||||
#
|
||||
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
|
||||
# You are encouraged to run this locally before pushing changes for review.
|
||||
|
||||
# Cause the script to exit if a single command fails
|
||||
set -eo pipefail
|
||||
|
||||
# this stops git rev-parse from failing if we run this from the .git directory
|
||||
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
|
||||
ROOT="$(git rev-parse --show-toplevel)"
|
||||
builtin cd "$ROOT" || exit 1
|
||||
|
||||
YAPF_VERSION=$(yapf --version | awk '{print $2}')
|
||||
PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
|
||||
MYPY_VERSION=$(mypy --version | awk '{print $2}')
|
||||
|
||||
# # params: tool name, tool version, required version
|
||||
tool_version_check() {
|
||||
if [[ $2 != $3 ]]; then
|
||||
echo "Wrong $1 version installed: $3 is required, not $2."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
|
||||
tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
|
||||
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
|
||||
|
||||
YAPF_FLAGS=(
|
||||
'--recursive'
|
||||
'--parallel'
|
||||
)
|
||||
|
||||
YAPF_EXCLUDES=(
|
||||
'--exclude' 'build/**'
|
||||
'--exclude' 'vllm/model_executor/parallel_utils/**'
|
||||
)
|
||||
|
||||
# Format specified files
|
||||
format() {
|
||||
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
|
||||
}
|
||||
|
||||
# Format files that differ from main branch. Ignores dirs that are not slated
|
||||
# for autoformat yet.
|
||||
format_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause yapf to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
|
||||
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
|
||||
fi
|
||||
|
||||
}
|
||||
|
||||
# Format all files
|
||||
format_all() {
|
||||
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
|
||||
}
|
||||
|
||||
## This flag formats individual files. --files *must* be the first command line
|
||||
## arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
format "${@:2}"
|
||||
# If `--all` is passed, then any further arguments are ignored and the
|
||||
# entire python directory is formatted.
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
format_all
|
||||
else
|
||||
# Format only the files that changed in last commit.
|
||||
format_changed
|
||||
fi
|
||||
echo 'vLLM yapf: Done'
|
||||
|
||||
# Run mypy
|
||||
# TODO(zhuohan): Enable mypy
|
||||
# echo 'vLLM mypy:'
|
||||
# mypy
|
||||
|
||||
# Run Pylint
|
||||
echo 'vLLM Pylint:'
|
||||
pylint vllm
|
||||
|
||||
if ! git diff --quiet &>/dev/null; then
|
||||
echo 'Reformatted files. Please review and stage the changes.'
|
||||
echo 'Changes not staged for commit:'
|
||||
echo
|
||||
git --no-pager diff --name-only
|
||||
|
||||
exit 1
|
||||
fi
|
@ -1,2 +1,12 @@
|
||||
mypy
|
||||
# formatting
|
||||
yapf==0.32.0
|
||||
pylint==2.8.2
|
||||
|
||||
# type checking
|
||||
mypy==0.991
|
||||
types-PyYAML
|
||||
types-requests
|
||||
types-setuptools
|
||||
|
||||
# testing
|
||||
pytest
|
||||
|
@ -60,7 +60,7 @@ def ref_single_query_cached_kv_attention(
|
||||
keys = torch.stack(keys, dim=0)
|
||||
values = torch.stack(values, dim=0)
|
||||
|
||||
scale = 1.0 / (head_size ** 0.5)
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
out = ref_masked_attention(q, keys, values, scale)
|
||||
out = out.view(num_heads, head_size)
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
@ -74,7 +74,7 @@ def ref_multi_query_kv_attention(
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
head_size = query.shape[-1]
|
||||
scale = 1.0 / (head_size ** 0.5)
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
ref_outputs = []
|
||||
@ -84,8 +84,8 @@ def ref_multi_query_kv_attention(
|
||||
seq_len = end_idx - start_idx
|
||||
|
||||
# Create attention mask.
|
||||
attn_mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
|
||||
@ -113,7 +113,7 @@ def ref_multi_query_cached_kv_attention(
|
||||
num_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
block_size = value_cache.shape[3]
|
||||
scale = 1.0 / (head_size ** 0.5)
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
|
||||
num_queries = len(cu_query_lens) - 1
|
||||
ref_outputs = []
|
||||
@ -125,8 +125,8 @@ def ref_multi_query_cached_kv_attention(
|
||||
block_table = block_tables[i]
|
||||
|
||||
# Create attention mask
|
||||
attn_mask = torch.triu(
|
||||
torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5
|
||||
attn_mask = torch.triu(torch.ones(query_len, context_len),
|
||||
diagonal=context_len - query_len + 1) * -1e5
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
|
||||
keys = []
|
||||
@ -165,22 +165,28 @@ def run_single_query_cached_kv_attention(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
qkv = torch.empty(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv = torch.empty(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, _, _ = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.empty(
|
||||
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
|
||||
key_cache = torch.empty(size=(num_blocks, *key_block_shape),
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
key_cache.uniform_(-1e-3, 1e-3)
|
||||
value_block_shape = (num_heads, head_size, block_size)
|
||||
value_cache = torch.empty(
|
||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||
value_cache = torch.empty(size=(num_blocks, *value_block_shape),
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
value_cache.uniform_(-1e-3, 1e-3)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
@ -194,9 +200,12 @@ def run_single_query_cached_kv_attention(
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
output = torch.empty(num_tokens,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
@ -235,9 +244,13 @@ def run_multi_query_kv_attention(
|
||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
qkv = torch.empty(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
qkv = torch.empty(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, key, value = qkv.unbind(dim=1)
|
||||
|
||||
|
@ -26,8 +26,9 @@ def run_copy_blocks(
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_caches = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.randn(
|
||||
size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
key_cache = torch.randn(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
key_caches.append(key_cache)
|
||||
cloned_key_caches = []
|
||||
for key_cache in key_caches:
|
||||
@ -36,8 +37,9 @@ def run_copy_blocks(
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_caches = []
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
value_cache = torch.randn(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
value_caches.append(value_cache)
|
||||
cloned_value_caches = []
|
||||
for value_cache in value_caches:
|
||||
@ -49,15 +51,18 @@ def run_copy_blocks(
|
||||
# Reference implementation.
|
||||
for src, dsts in block_mapping.items():
|
||||
for dst in dsts:
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
for key_cache, cloned_key_cache in zip(key_caches,
|
||||
cloned_key_caches):
|
||||
cloned_key_cache[dst] = cloned_key_cache[src]
|
||||
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||
for value_cache, cloned_value_cache in zip(value_caches,
|
||||
cloned_value_caches):
|
||||
cloned_value_cache[dst] = cloned_value_cache[src]
|
||||
|
||||
# Compare the results.
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||
for value_cache, cloned_value_cache in zip(value_caches,
|
||||
cloned_value_caches):
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@ -74,8 +79,12 @@ def run_reshape_and_cache(
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
@ -84,15 +93,19 @@ def run_reshape_and_cache(
|
||||
cloned_key_cache = key_cache.clone()
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
value_cache = torch.randn(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping)
|
||||
|
||||
for i in range(num_tokens):
|
||||
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
|
||||
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||
block_idx = torch.div(slot_mapping[i],
|
||||
block_size,
|
||||
rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||
@ -114,8 +127,12 @@ def run_gather_cached_kv(
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
qkv_clone = qkv.clone()
|
||||
@ -126,15 +143,20 @@ def run_gather_cached_kv(
|
||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
value_cache = torch.randn(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
|
||||
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)
|
||||
cache_ops.gather_cached_kv(key, value, key_cache, value_cache,
|
||||
slot_mapping)
|
||||
|
||||
# Reference implementation.
|
||||
for i in range(num_tokens):
|
||||
reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x)
|
||||
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||
reshaped_key = cloned_key.reshape(num_tokens, num_heads,
|
||||
head_size // x, x)
|
||||
block_idx = torch.div(slot_mapping[i],
|
||||
block_size,
|
||||
rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
|
||||
cloned_value[i] = value_cache[block_idx, :, :, block_offset]
|
||||
@ -145,20 +167,30 @@ def run_gather_cached_kv(
|
||||
|
||||
def test_copy_blocks() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_copy_blocks(
|
||||
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
|
||||
block_size=8, num_blocks=1024, dtype=dtype)
|
||||
run_copy_blocks(num_mappings=23,
|
||||
num_layers=7,
|
||||
num_heads=17,
|
||||
head_size=16,
|
||||
block_size=8,
|
||||
num_blocks=1024,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def test_reshape_and_cache() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_reshape_and_cache(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=dtype)
|
||||
run_reshape_and_cache(num_tokens=3,
|
||||
num_heads=2,
|
||||
head_size=16,
|
||||
block_size=8,
|
||||
num_blocks=2,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def test_gather_cached_kv() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_gather_cached_kv(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=dtype)
|
||||
run_gather_cached_kv(num_tokens=3,
|
||||
num_heads=2,
|
||||
head_size=16,
|
||||
block_size=8,
|
||||
num_blocks=2,
|
||||
dtype=dtype)
|
||||
|
@ -14,8 +14,10 @@ class RefRMSNorm(nn.Module):
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
|
||||
keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance +
|
||||
self.variance_epsilon)
|
||||
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
return self.weight * hidden_states
|
||||
|
@ -8,8 +8,8 @@ from vllm import pos_encoding_ops
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# Create cos and sin embeddings.
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
|
||||
t = torch.arange(max_position_embeddings).float()
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
@ -49,16 +49,15 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor, # [num_tokens]
|
||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
positions: torch.Tensor, # [num_tokens]
|
||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
query_rot = query_rot.transpose(0, 1)
|
||||
key_rot = key_rot.transpose(0, 1)
|
||||
@ -85,12 +84,18 @@ def run_rotary_embedding_neox(
|
||||
dtype: torch.dtype,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
positions = torch.randint(0, max_position, (num_tokens,), device='cuda')
|
||||
query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
|
||||
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
|
||||
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
|
||||
query = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
key = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
|
||||
# Create the rotary embedding.
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
|
@ -1,3 +1,5 @@
|
||||
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
|
@ -35,7 +35,8 @@ class LogicalTokenBlock:
|
||||
|
||||
def append_tokens(self, token_ids: List[int]) -> None:
|
||||
assert len(token_ids) <= self.get_num_empty_slots()
|
||||
self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids
|
||||
curr_idx = self.num_tokens
|
||||
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
|
||||
self.num_tokens += len(token_ids)
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
|
@ -8,7 +8,7 @@ from vllm.utils import get_cpu_memory
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GiB = 1 << 30
|
||||
_GB = 1 << 30
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@ -106,6 +106,7 @@ class CacheConfig:
|
||||
vLLM execution.
|
||||
swap_space: Size of the CPU swap space per GPU (in GiB).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
@ -114,7 +115,7 @@ class CacheConfig:
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space_bytes = swap_space * _GiB
|
||||
self.swap_space_bytes = swap_space * _GB
|
||||
self._verify_args()
|
||||
|
||||
# Will be set after profiling.
|
||||
@ -137,14 +138,13 @@ class CacheConfig:
|
||||
num_gpus_per_node = parallel_config.tensor_parallel_size
|
||||
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
|
||||
|
||||
msg = (
|
||||
f"{cpu_memory_usage / _GiB:.2f} GiB out of "
|
||||
f"the {total_cpu_memory / _GiB:.2f} GiB total CPU memory is "
|
||||
"allocated for the swap space.")
|
||||
msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
|
||||
f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
|
||||
"allocated for the swap space.")
|
||||
if cpu_memory_usage > 0.7 * total_cpu_memory:
|
||||
raise ValueError("Too large swap space. " + msg)
|
||||
elif cpu_memory_usage > 0.4 * total_cpu_memory:
|
||||
logger.warn("Possibly too large swap space. " + msg)
|
||||
logger.warning("Possibly too large swap space. " + msg)
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
@ -157,6 +157,7 @@ class ParallelConfig:
|
||||
True if either pipeline_parallel_size or tensor_parallel_size is
|
||||
greater than 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_parallel_size: int,
|
||||
@ -189,12 +190,9 @@ class SchedulerConfig:
|
||||
max_seq_len: Maximum length of a sequence (including prompt
|
||||
and generated text).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
max_num_seqs: int,
|
||||
max_seq_len: int
|
||||
) -> None:
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
|
||||
max_seq_len: int) -> None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_seq_len = max_seq_len
|
||||
@ -241,7 +239,7 @@ def _get_and_verify_dtype(
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warn(f"Casting {config_dtype} to {torch_dtype}.")
|
||||
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
|
||||
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
|
@ -27,8 +27,9 @@ class BlockAllocator:
|
||||
# Initialize the free blocks.
|
||||
self.free_blocks: List[PhysicalTokenBlock] = []
|
||||
for i in range(num_blocks):
|
||||
block = PhysicalTokenBlock(
|
||||
device=device, block_number=i, block_size=block_size)
|
||||
block = PhysicalTokenBlock(device=device,
|
||||
block_number=i,
|
||||
block_size=block_size)
|
||||
self.free_blocks.append(block)
|
||||
|
||||
def allocate(self) -> PhysicalTokenBlock:
|
||||
@ -84,10 +85,12 @@ class BlockSpaceManager:
|
||||
num_required_blocks = len(seq.logical_token_blocks)
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks
|
||||
return (num_free_gpu_blocks - num_required_blocks >=
|
||||
self.watermark_blocks)
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
# NOTE: Here we assume that all sequences in the group have the same prompt.
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# prompt.
|
||||
seq = seq_group.get_seqs()[0]
|
||||
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
@ -143,7 +146,8 @@ class BlockSpaceManager:
|
||||
for block in src_block_table:
|
||||
block.ref_count += 1
|
||||
|
||||
def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
|
||||
def _get_physical_blocks(
|
||||
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
|
||||
# NOTE: Here, we assume that the physical blocks are only shared by
|
||||
# the sequences in the same group.
|
||||
blocks: Set[PhysicalTokenBlock] = set()
|
||||
|
@ -43,8 +43,7 @@ class SchedulerOutputs:
|
||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return (not self.blocks_to_swap_in
|
||||
and not self.blocks_to_swap_out
|
||||
return (not self.blocks_to_swap_in and not self.blocks_to_swap_out
|
||||
and not self.blocks_to_copy)
|
||||
|
||||
|
||||
@ -61,7 +60,7 @@ class Scheduler:
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Instantiate the scheduling policy.
|
||||
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
|
||||
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
# Create the block space manager.
|
||||
self.block_manager = BlockSpaceManager(
|
||||
block_size=self.cache_config.block_size,
|
||||
@ -102,7 +101,8 @@ class Scheduler:
|
||||
def get_num_unfinished_seq_groups(self) -> int:
|
||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||
|
||||
def _schedule(self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
|
||||
def _schedule(
|
||||
self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
|
||||
# Blocks that need to be swaped or copied before model execution.
|
||||
blocks_to_swap_in: Dict[int, int] = {}
|
||||
blocks_to_swap_out: Dict[int, int] = {}
|
||||
@ -160,7 +160,8 @@ class Scheduler:
|
||||
num_curr_seqs = sum(
|
||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq_group in self.running)
|
||||
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
|
||||
if (num_curr_seqs + num_new_seqs >
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
|
||||
seq_group = self.swapped.pop(0)
|
||||
@ -170,8 +171,7 @@ class Scheduler:
|
||||
|
||||
num_batched_tokens = sum(
|
||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq_group in self.running
|
||||
)
|
||||
for seq_group in self.running)
|
||||
|
||||
# Join waiting sequences if possible.
|
||||
prompt_group_ids: List[str] = []
|
||||
@ -191,7 +191,7 @@ class Scheduler:
|
||||
|
||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||
if num_prompt_tokens >= self.scheduler_config.max_seq_len:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||
" and exceeds limit of "
|
||||
f"{self.scheduler_config.max_seq_len}")
|
||||
@ -206,17 +206,19 @@ class Scheduler:
|
||||
break
|
||||
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
if (num_batched_tokens + num_prompt_tokens
|
||||
> self.scheduler_config.max_num_batched_tokens):
|
||||
if (num_batched_tokens + num_prompt_tokens >
|
||||
self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
|
||||
num_new_seqs = seq_group.num_seqs(
|
||||
status=SequenceStatus.WAITING)
|
||||
num_curr_seqs = sum(
|
||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq_group in self.running)
|
||||
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
|
||||
if (num_curr_seqs + num_new_seqs >
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
|
||||
seq_group = self.waiting.pop(0)
|
||||
@ -240,12 +242,11 @@ class Scheduler:
|
||||
elapsed_time = now - self.last_logging_time
|
||||
if elapsed_time > _LOGGING_INTERVAL_SEC:
|
||||
self.last_logging_time = now
|
||||
self.num_input_tokens = [
|
||||
(t, n) for t, n in self.num_input_tokens
|
||||
if now - t < _LOGGING_INTERVAL_SEC
|
||||
]
|
||||
self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens
|
||||
if now - t < _LOGGING_INTERVAL_SEC]
|
||||
if len(self.num_input_tokens) > 1:
|
||||
total_num_tokens = sum(n for _, n in self.num_input_tokens[:-1])
|
||||
total_num_tokens = sum(n
|
||||
for _, n in self.num_input_tokens[:-1])
|
||||
window = now - self.num_input_tokens[0][0]
|
||||
avg_throughput = total_num_tokens / window
|
||||
else:
|
||||
@ -258,26 +259,30 @@ class Scheduler:
|
||||
|
||||
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
|
||||
if total_num_cpu_blocks > 0:
|
||||
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
|
||||
num_free_cpu_blocks = (
|
||||
self.block_manager.get_num_free_cpu_blocks())
|
||||
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
|
||||
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
|
||||
else:
|
||||
cpu_cache_usage = 0.0
|
||||
|
||||
logger.info(
|
||||
f"Throughput: {avg_throughput:.1f} tokens/s, "
|
||||
f"Running: {len(self.running)} reqs, "
|
||||
f"Swapped: {len(self.swapped)} reqs, "
|
||||
f"Pending: {len(self.waiting)} reqs, "
|
||||
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, "
|
||||
f"Running: {len(self.running)} reqs, "
|
||||
f"Swapped: {len(self.swapped)} reqs, "
|
||||
f"Pending: {len(self.waiting)} reqs, "
|
||||
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
|
||||
|
||||
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[SequenceGroup]]:
|
||||
def schedule(
|
||||
self
|
||||
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
||||
List[SequenceGroup]]:
|
||||
# Schedule sequence groups.
|
||||
# This function call changes the internal states of the scheduler
|
||||
# such as self.running, self.swapped, and self.waiting.
|
||||
scheduler_outputs, prompt_group_ids, ignored_seq_groups = self._schedule()
|
||||
(scheduler_outputs, prompt_group_ids,
|
||||
ignored_seq_groups) = self._schedule()
|
||||
|
||||
# Create input data structures.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
@ -311,8 +316,8 @@ class Scheduler:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
output = seq_outputs[seq.seq_id]
|
||||
if seq.seq_id != output.parent_seq_id:
|
||||
# The sequence is a fork of the parent sequence (beam search).
|
||||
# Free the current sequence.
|
||||
# The sequence is a fork of the parent sequence (beam
|
||||
# search). Free the current sequence.
|
||||
self.block_manager.free(seq)
|
||||
# Fork the parent sequence.
|
||||
parent_seq = seq_group.find(output.parent_seq_id)
|
||||
@ -385,7 +390,7 @@ class Scheduler:
|
||||
elif preemption_mode == PreemptionMode.SWAP:
|
||||
self._preempt_by_swap(seq_group, blocks_to_swap_out)
|
||||
else:
|
||||
assert False, 'Invalid preemption mode.'
|
||||
assert False, "Invalid preemption mode."
|
||||
|
||||
def _preempt_by_recompute(
|
||||
self,
|
||||
|
@ -12,11 +12,11 @@ class EngineArgs:
|
||||
"""Arguments for vLLM engine."""
|
||||
model: str
|
||||
tokenizer: Optional[str] = None
|
||||
tokenizer_mode: str = "auto"
|
||||
tokenizer_mode: str = 'auto'
|
||||
download_dir: Optional[str] = None
|
||||
use_np_weights: bool = False
|
||||
use_dummy_weights: bool = False
|
||||
dtype: str = "auto"
|
||||
dtype: str = 'auto'
|
||||
seed: int = 0
|
||||
worker_use_ray: bool = False
|
||||
pipeline_parallel_size: int = 1
|
||||
@ -35,76 +35,101 @@ class EngineArgs:
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
) -> argparse.ArgumentParser:
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
# Model arguments
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m',
|
||||
help='name or path of the huggingface model to use')
|
||||
parser.add_argument('--tokenizer', type=str, default=EngineArgs.tokenizer,
|
||||
help='name or path of the huggingface tokenizer to use')
|
||||
parser.add_argument('--tokenizer-mode', type=str,
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default='facebook/opt-125m',
|
||||
help='name or path of the huggingface model to use')
|
||||
parser.add_argument(
|
||||
'--tokenizer',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer,
|
||||
help='name or path of the huggingface tokenizer to use')
|
||||
parser.add_argument('--tokenizer-mode',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer_mode,
|
||||
choices=['auto', 'slow'],
|
||||
help='tokenizer mode. "auto" will use the fast '
|
||||
'tokenizer if available, and "slow" will '
|
||||
'always use the slow tokenizer.')
|
||||
parser.add_argument('--download-dir', type=str,
|
||||
'tokenizer if available, and "slow" will '
|
||||
'always use the slow tokenizer.')
|
||||
parser.add_argument('--download-dir',
|
||||
type=str,
|
||||
default=EngineArgs.download_dir,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of '
|
||||
'huggingface')
|
||||
parser.add_argument('--use-np-weights', action='store_true',
|
||||
'default to the default cache dir of '
|
||||
'huggingface')
|
||||
parser.add_argument('--use-np-weights',
|
||||
action='store_true',
|
||||
help='save a numpy copy of model weights for '
|
||||
'faster loading. This can increase the disk '
|
||||
'usage by up to 2x.')
|
||||
parser.add_argument('--use-dummy-weights', action='store_true',
|
||||
'faster loading. This can increase the disk '
|
||||
'usage by up to 2x.')
|
||||
parser.add_argument('--use-dummy-weights',
|
||||
action='store_true',
|
||||
help='use dummy values for model weights')
|
||||
# TODO(woosuk): Support FP32.
|
||||
parser.add_argument('--dtype', type=str, default=EngineArgs.dtype,
|
||||
choices=['auto', 'half', 'bfloat16', 'float'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default=EngineArgs.dtype,
|
||||
choices=['auto', 'half', 'bfloat16', 'float'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--worker-use-ray', action='store_true',
|
||||
parser.add_argument('--worker-use-ray',
|
||||
action='store_true',
|
||||
help='use Ray for distributed serving, will be '
|
||||
'automatically set when using more than 1 GPU')
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
|
||||
'automatically set when using more than 1 GPU')
|
||||
parser.add_argument('--pipeline-parallel-size',
|
||||
'-pp',
|
||||
type=int,
|
||||
default=EngineArgs.pipeline_parallel_size,
|
||||
help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
|
||||
parser.add_argument('--tensor-parallel-size',
|
||||
'-tp',
|
||||
type=int,
|
||||
default=EngineArgs.tensor_parallel_size,
|
||||
help='number of tensor parallel replicas')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size', type=int,
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
default=EngineArgs.block_size,
|
||||
choices=[8, 16, 32],
|
||||
help='token block size')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=EngineArgs.seed,
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
default=EngineArgs.seed,
|
||||
help='random seed')
|
||||
parser.add_argument('--swap-space', type=int,
|
||||
parser.add_argument('--swap-space',
|
||||
type=int,
|
||||
default=EngineArgs.swap_space,
|
||||
help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--gpu-memory-utilization', type=float,
|
||||
parser.add_argument('--gpu-memory-utilization',
|
||||
type=float,
|
||||
default=EngineArgs.gpu_memory_utilization,
|
||||
help='the percentage of GPU memory to be used for'
|
||||
'the model executor')
|
||||
parser.add_argument('--max-num-batched-tokens', type=int,
|
||||
'the model executor')
|
||||
parser.add_argument('--max-num-batched-tokens',
|
||||
type=int,
|
||||
default=EngineArgs.max_num_batched_tokens,
|
||||
help='maximum number of batched tokens per '
|
||||
'iteration')
|
||||
parser.add_argument('--max-num-seqs', type=int,
|
||||
'iteration')
|
||||
parser.add_argument('--max-num-seqs',
|
||||
type=int,
|
||||
default=EngineArgs.max_num_seqs,
|
||||
help='maximum number of sequences per iteration')
|
||||
parser.add_argument('--disable-log-stats', action='store_true',
|
||||
parser.add_argument('--disable-log-stats',
|
||||
action='store_true',
|
||||
help='disable logging statistics')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
|
||||
# Get the list of attributes of this dataclass.
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
# Set the attributes from the parsed arguments.
|
||||
@ -115,18 +140,19 @@ class EngineArgs:
|
||||
self,
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||
# Initialize the configs.
|
||||
model_config = ModelConfig(
|
||||
self.model, self.tokenizer, self.tokenizer_mode, self.download_dir,
|
||||
self.use_np_weights, self.use_dummy_weights, self.dtype, self.seed)
|
||||
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
|
||||
model_config = ModelConfig(self.model, self.tokenizer,
|
||||
self.tokenizer_mode, self.download_dir,
|
||||
self.use_np_weights, self.use_dummy_weights,
|
||||
self.dtype, self.seed)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space)
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.worker_use_ray)
|
||||
max_seq_len = min(
|
||||
self.max_num_batched_tokens,
|
||||
getattr(model_config.hf_config, "max_position_embeddings",
|
||||
float("inf")))
|
||||
model_max_len = getattr(model_config.hf_config,
|
||||
'max_position_embeddings', float('inf'))
|
||||
max_seq_len = min(self.max_num_batched_tokens, model_max_len)
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs, max_seq_len)
|
||||
return model_config, cache_config, parallel_config, scheduler_config
|
||||
@ -140,12 +166,13 @@ class AsyncEngineArgs(EngineArgs):
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
) -> argparse.ArgumentParser:
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument('--engine-use-ray', action='store_true',
|
||||
parser.add_argument('--engine-use-ray',
|
||||
action='store_true',
|
||||
help='use Ray to start the LLM engine in a '
|
||||
'separate process as the server process.')
|
||||
parser.add_argument('--disable-log-requests', action='store_true',
|
||||
'separate process as the server process.')
|
||||
parser.add_argument('--disable-log-requests',
|
||||
action='store_true',
|
||||
help='disable logging requests')
|
||||
return parser
|
||||
|
@ -11,7 +11,7 @@ from vllm.sampling_params import SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
@ -35,8 +35,13 @@ class AsyncLLMEngine:
|
||||
log_requests: Whether to log the requests.
|
||||
*args, *kwargs: Arguments for LLMEngine.
|
||||
"""
|
||||
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
|
||||
log_requests: bool = True, *args, **kwargs) -> None:
|
||||
|
||||
def __init__(self,
|
||||
worker_use_ray: bool,
|
||||
engine_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.engine_use_ray = engine_use_ray
|
||||
self.log_requests = log_requests
|
||||
@ -76,12 +81,11 @@ class AsyncLLMEngine:
|
||||
self.request_events[request_id].set()
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
) -> RequestOutput:
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
@ -117,14 +121,17 @@ class AsyncLLMEngine:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(
|
||||
request_id, prompt, sampling_params,
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
else:
|
||||
self.engine.add_request(
|
||||
request_id, prompt, sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
self.engine.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
# The vLLM engine does not have a background loop that keeps
|
||||
# processing incoming requests. Therefore, we need to keep kicking
|
||||
@ -200,7 +207,8 @@ class AsyncLLMEngine:
|
||||
self.kicking_request_id = None
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
@ -211,8 +219,9 @@ class AsyncLLMEngine:
|
||||
# Create the async LLM engine.
|
||||
engine = cls(engine_args.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
not engine_args.disable_log_requests,
|
||||
*engine_configs,
|
||||
distributed_init_method, devices,
|
||||
distributed_init_method,
|
||||
devices,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
||||
|
@ -67,8 +67,7 @@ class LLMEngine:
|
||||
f"download_dir={model_config.download_dir!r}, "
|
||||
f"use_np_weights={model_config.use_np_weights}, "
|
||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||
f"seed={model_config.seed})"
|
||||
)
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
self.model_config = model_config
|
||||
@ -78,8 +77,8 @@ class LLMEngine:
|
||||
self.log_stats = log_stats
|
||||
self._verify_args()
|
||||
|
||||
self.tokenizer = get_tokenizer(model_config.tokenizer,
|
||||
model_config.tokenizer_mode)
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
|
||||
self.seq_counter = Counter()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
@ -129,8 +128,8 @@ class LLMEngine:
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
# FIXME(woosuk): Change to debug log.
|
||||
logger.info(f'# GPU blocks: {num_gpu_blocks}, '
|
||||
f'# CPU blocks: {num_cpu_blocks}')
|
||||
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
||||
f"# CPU blocks: {num_cpu_blocks}")
|
||||
|
||||
if num_gpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
@ -152,7 +151,9 @@ class LLMEngine:
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, devices = initialize_cluster(parallel_config)
|
||||
# Create the LLM engine.
|
||||
engine = cls(*engine_configs, distributed_init_method, devices,
|
||||
engine = cls(*engine_configs,
|
||||
distributed_init_method,
|
||||
devices,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
||||
|
||||
@ -226,8 +227,10 @@ class LLMEngine:
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule()
|
||||
if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups):
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
ignored_seq_groups) = self.scheduler.schedule()
|
||||
if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
|
||||
and (not ignored_seq_groups)):
|
||||
# Nothing to do.
|
||||
return []
|
||||
|
||||
@ -281,8 +284,8 @@ class LLMEngine:
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
self.scheduler.free_seq(seq,
|
||||
SequenceStatus.FINISHED_STOPPED)
|
||||
self.scheduler.free_seq(
|
||||
seq, SequenceStatus.FINISHED_STOPPED)
|
||||
stopped = True
|
||||
break
|
||||
if stopped:
|
||||
@ -290,7 +293,7 @@ class LLMEngine:
|
||||
|
||||
# Check if the sequence has reached max_seq_len.
|
||||
if (seq.get_len() >=
|
||||
self.scheduler.scheduler_config.max_seq_len):
|
||||
self.scheduler.scheduler_config.max_seq_len):
|
||||
self.scheduler.free_seq(
|
||||
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
||||
continue
|
||||
@ -302,15 +305,15 @@ class LLMEngine:
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if not sampling_params.ignore_eos:
|
||||
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
|
||||
self.scheduler.free_seq(seq,
|
||||
SequenceStatus.FINISHED_STOPPED)
|
||||
self.scheduler.free_seq(
|
||||
seq, SequenceStatus.FINISHED_STOPPED)
|
||||
continue
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
get_all_outputs: bool = False,
|
||||
*args,
|
||||
get_all_outputs: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
|
@ -8,7 +8,8 @@ except ImportError:
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id
|
||||
# rank, node resource (node IP), device id
|
||||
DeviceID = Tuple[int, Optional[str], int]
|
||||
|
||||
|
||||
def initialize_cluster(
|
||||
@ -53,15 +54,15 @@ def initialize_cluster(
|
||||
valid_node_resources = []
|
||||
num_devices_per_node = None
|
||||
for node in ray.nodes():
|
||||
if (not node['Alive']) or node['Resources']['GPU'] <= 0:
|
||||
if (not node["Alive"]) or node["Resources"]["GPU"] <= 0:
|
||||
continue
|
||||
if num_devices_per_node is None:
|
||||
num_devices_per_node = node['Resources']['GPU']
|
||||
num_devices_per_node = node["Resources"]["GPU"]
|
||||
else:
|
||||
assert num_devices_per_node == node['Resources']['GPU'], (
|
||||
assert num_devices_per_node == node["Resources"]["GPU"], (
|
||||
"The number of GPUs per node is not uniform.")
|
||||
for key in node['Resources']:
|
||||
if key.startswith('node:'):
|
||||
for key in node["Resources"]:
|
||||
if key.startswith("node:"):
|
||||
valid_node_resources.append(key)
|
||||
|
||||
# Verify the parallel config.
|
||||
|
@ -11,8 +11,8 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@ -37,8 +37,7 @@ async def generate(request: Request) -> Response:
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
text_outputs = [
|
||||
prompt + output.text
|
||||
for output in request_output.outputs
|
||||
prompt + output.text for output in request_output.outputs
|
||||
]
|
||||
ret = {"text": text_outputs}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
@ -63,10 +62,7 @@ async def generate(request: Request) -> Response:
|
||||
|
||||
assert final_output is not None
|
||||
prompt = final_output.prompt
|
||||
text_outputs = [
|
||||
prompt + output.text
|
||||
for output in final_output.outputs
|
||||
]
|
||||
text_outputs = [prompt + output.text for output in final_output.outputs]
|
||||
ret = {"text": text_outputs}
|
||||
return Response(content=json.dumps(ret))
|
||||
|
||||
@ -81,5 +77,8 @@ if __name__ == "__main__":
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
|
@ -63,8 +63,7 @@ class LLM:
|
||||
self.request_counter = Counter()
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
return self.llm_engine.tokenizer
|
||||
|
||||
def set_tokenizer(
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
||||
|
||||
import argparse
|
||||
from http import HTTPStatus
|
||||
@ -29,7 +30,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
logger = init_logger(__name__)
|
||||
served_model = None
|
||||
@ -38,14 +39,13 @@ app = fastapi.FastAPI()
|
||||
|
||||
def create_error_response(status_code: HTTPStatus,
|
||||
message: str) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
ErrorResponse(message=message, type="invalid_request_error").dict(),
|
||||
status_code=status_code.value
|
||||
)
|
||||
return JSONResponse(ErrorResponse(message=message,
|
||||
type="invalid_request_error").dict(),
|
||||
status_code=status_code.value)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc):
|
||||
async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
|
||||
|
||||
|
||||
@ -126,8 +126,11 @@ async def check_length(request, prompt, engine):
|
||||
@app.get("/v1/models")
|
||||
async def show_available_models():
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [ModelCard(id=served_model, root=served_model,
|
||||
permission=[ModelPermission()])]
|
||||
model_cards = [
|
||||
ModelCard(id=served_model,
|
||||
root=served_model,
|
||||
permission=[ModelPermission()])
|
||||
]
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
|
||||
@ -144,12 +147,14 @@ def create_logprobs(token_ids: List[int],
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
||||
last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
logprobs.top_logprobs.append(
|
||||
{tokenizer.convert_ids_to_tokens(i): p
|
||||
for i, p in id_logprob.items()})
|
||||
logprobs.top_logprobs.append({
|
||||
tokenizer.convert_ids_to_tokens(i): p
|
||||
for i, p in id_logprob.items()
|
||||
})
|
||||
return logprobs
|
||||
|
||||
|
||||
@ -348,7 +353,7 @@ async def create_completion(raw_request: Request):
|
||||
if request.suffix is not None:
|
||||
# The language models we currently support do not support suffix.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"suffix is not currently supported")
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.logit_bias is not None:
|
||||
# TODO: support logit_bias in vLLM engine.
|
||||
@ -387,22 +392,23 @@ async def create_completion(raw_request: Request):
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = engine.generate(prompt, sampling_params,
|
||||
request_id)
|
||||
result_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use beam search.
|
||||
stream = (request.stream and
|
||||
(request.best_of is None or request.n == request.best_of) and
|
||||
not request.use_beam_search)
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
def create_stream_response_json(index: int,
|
||||
text: str,
|
||||
logprobs: Optional[LogProbs] = None,
|
||||
finish_reason: Optional[str] = None) -> str:
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
logprobs: Optional[LogProbs] = None,
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> str:
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
text=text,
|
||||
@ -443,7 +449,8 @@ async def create_completion(raw_request: Request):
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
if output.finish_reason is not None:
|
||||
logprobs = LogProbs() if request.logprobs is not None else None
|
||||
logprobs = (LogProbs()
|
||||
if request.logprobs is not None else None)
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text="",
|
||||
@ -487,8 +494,8 @@ async def create_completion(raw_request: Request):
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids)
|
||||
for output in final_res.outputs)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
@ -506,9 +513,11 @@ async def create_completion(raw_request: Request):
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(fake_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
|
||||
@ -517,26 +526,34 @@ async def create_completion(raw_request: Request):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server."
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="localhost", help="host name")
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||
parser.add_argument("--allow-credentials",
|
||||
action="store_true",
|
||||
help="allow credentials")
|
||||
parser.add_argument("--allowed-origins",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed origins")
|
||||
parser.add_argument("--allowed-methods",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed methods")
|
||||
parser.add_argument("--allowed-headers",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed headers")
|
||||
parser.add_argument(
|
||||
"--allow-credentials", action="store_true", help="allow credentials"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
|
||||
)
|
||||
parser.add_argument("--served-model-name", type=str, default=None,
|
||||
help="The model name used in the API. If not specified, "
|
||||
"the model name will be the same as the "
|
||||
"huggingface name.")
|
||||
"--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. If not specified, "
|
||||
"the model name will be the same as the "
|
||||
"huggingface name.")
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -556,7 +573,11 @@ if __name__ == "__main__":
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(engine_args.tokenizer, engine_args.tokenizer_mode)
|
||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||
tokenizer_mode=engine_args.tokenizer_mode)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="info",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import time
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
@ -98,7 +99,8 @@ class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str,
|
||||
float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
|
@ -1,9 +1,9 @@
|
||||
# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
|
||||
"""Logging configuration for vLLM."""
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
|
||||
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
||||
|
||||
|
@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InputMetadata",
|
||||
"get_model",
|
||||
|
@ -8,11 +8,22 @@ from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
class InputMetadata:
|
||||
"""Metadata for input sequences. Used for PagedAttention.
|
||||
|
||||
Args:
|
||||
seq_groups: List of (seq_ids, sampling_params).
|
||||
seq_data: Seq_id -> SequenceData.
|
||||
prompt_lens: Lengths of prompts.
|
||||
slot_mapping: The address to write the new KV to of each token.
|
||||
context_lens: the length of attention context for each generation token.
|
||||
max_context_len: The maximum context length.
|
||||
block_tables: The block tables. (Seq id -> list of physical block)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params).
|
||||
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData.
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
seq_data: Dict[int, SequenceData],
|
||||
prompt_lens: List[int],
|
||||
slot_mapping: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
|
@ -6,9 +6,10 @@ from vllm import activation_ops
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
||||
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
|
||||
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
|
||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
|
||||
# NOTE: The following GELU functions may introduce small rounding errors.
|
||||
"gelu_new": nn.GELU(approximate="tanh"),
|
||||
"gelu_fast": nn.GELU(approximate="tanh"),
|
||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||
"relu": nn.ReLU(),
|
||||
}
|
||||
|
||||
@ -25,15 +26,13 @@ class SiluAndMul(nn.Module):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d)
|
||||
return: (num_tokens, d)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (num_tokens, 2 * d)
|
||||
) -> torch.Tensor: # (num_tokens, d)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = x.shape[0]
|
||||
d = x.shape[1] // 2
|
||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||
|
@ -14,6 +14,7 @@ _SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
|
||||
|
||||
|
||||
class PagedAttention(nn.Module):
|
||||
# pylint: disable=line-too-long
|
||||
"""GPT-style multi-head PagedAttention.
|
||||
|
||||
This class takes flattened 1D query, key, and value tensors as input. The
|
||||
@ -54,12 +55,20 @@ class PagedAttention(nn.Module):
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: xops.AttentionBias,
|
||||
) -> torch.Tensor:
|
||||
"""Normal attention for the prompt tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
key: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
value: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
"""
|
||||
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
@ -76,12 +85,22 @@ class PagedAttention(nn.Module):
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
"""PagedAttention for the generation tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_generation_tokens, num_heads, head_size]
|
||||
query: shape = [num_generation_tokens, num_heads, head_size]
|
||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
"""
|
||||
block_size = value_cache.shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
@ -97,16 +116,32 @@ class PagedAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size]
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
# NOTE: The query, key, and value tensors must be sliced from a qkv
|
||||
# tensor of shape [num_tokens, 3 * num_heads * head_size].
|
||||
) -> torch.Tensor:
|
||||
"""PagedAttention forward pass.
|
||||
|
||||
NOTE: The query, key, and value tensors must be sliced from a qkv
|
||||
tensor of shape [num_tokens, 3 * num_heads * head_size].
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_heads * head_size]
|
||||
value: shape = [num_tokens, num_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
@ -136,7 +171,7 @@ class PagedAttention(nn.Module):
|
||||
# and value vectors will not be cached.
|
||||
num_valid_tokens = input_metadata.num_valid_tokens
|
||||
if (num_valid_tokens > 0 and key_cache is not None
|
||||
and value_cache is not None):
|
||||
and value_cache is not None):
|
||||
# The stride is 3 because the key and value are sliced from qkv.
|
||||
cache_ops.reshape_and_cache(
|
||||
key[:num_valid_tokens],
|
||||
@ -149,15 +184,12 @@ class PagedAttention(nn.Module):
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
assert key_cache is not None and value_cache is not None, (
|
||||
"key_cache and value_cache must be provided when "
|
||||
"generating tokens."
|
||||
)
|
||||
"generating tokens.")
|
||||
# Compute the attention op for generation tokens.
|
||||
self.single_query_cached_kv_attention(
|
||||
output[num_prompt_tokens:num_valid_tokens],
|
||||
query[num_prompt_tokens:num_valid_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata)
|
||||
query[num_prompt_tokens:num_valid_tokens], key_cache,
|
||||
value_cache, input_metadata)
|
||||
|
||||
# Reshape the output tensor.
|
||||
# NOTE(woosuk): The output tensor may include paddings.
|
||||
@ -179,9 +211,9 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
super().__init__(num_heads, head_size, scale)
|
||||
|
||||
# Create the cos and sin cache.
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
@ -195,15 +227,32 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor, # [num_tokens]
|
||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
) -> torch.Tensor:
|
||||
""" PagedAttention forward pass with rotary embedding.
|
||||
|
||||
Args:
|
||||
positions: shape = [num_tokens]
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_heads * head_size]
|
||||
value: shape = [num_tokens, num_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
|
||||
# Apply rotary embedding to the query and key before passing them
|
||||
# to the attention op.
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
|
@ -13,6 +13,7 @@ from vllm.sequence import SequenceOutputs
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
"""Samples the next tokens from the model's outputs.
|
||||
|
||||
@ -50,19 +51,20 @@ class Sampler(nn.Module):
|
||||
# Apply presence and frequency penalties.
|
||||
output_tokens = _get_output_tokens(input_metadata)
|
||||
assert len(output_tokens) == logits.shape[0]
|
||||
presence_penalties, frequency_penalties = _get_penalties(input_metadata)
|
||||
presence_penalties, frequency_penalties = _get_penalties(
|
||||
input_metadata)
|
||||
assert len(presence_penalties) == logits.shape[0]
|
||||
assert len(frequency_penalties) == logits.shape[0]
|
||||
logits = _apply_penalties(
|
||||
logits, output_tokens, presence_penalties, frequency_penalties,
|
||||
self.vocab_size)
|
||||
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
||||
frequency_penalties, self.vocab_size)
|
||||
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
assert len(temperatures) == logits.shape[0]
|
||||
if any(t != 1.0 for t in temperatures):
|
||||
t = torch.tensor(
|
||||
temperatures, dtype=logits.dtype, device=logits.device)
|
||||
t = torch.tensor(temperatures,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits.div_(t.unsqueeze(dim=1))
|
||||
|
||||
@ -75,7 +77,9 @@ class Sampler(nn.Module):
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
||||
if any(p < 1.0 - _SAMPLING_EPS for p in top_ps) or any(k != self.vocab_size for k in top_ks):
|
||||
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
||||
do_top_k = any(k != self.vocab_size for k in top_ks)
|
||||
if do_top_p or do_top_k:
|
||||
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
||||
|
||||
# Sample the next tokens.
|
||||
@ -97,8 +101,7 @@ def _prune_hidden_states(
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
input_metadata: InputMetadata,
|
||||
) -> Tuple[List[float], List[float]]:
|
||||
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
|
||||
# Collect the presence and frequency penalties.
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
@ -117,9 +120,7 @@ def _get_penalties(
|
||||
return presence_penalties, frequency_penalties
|
||||
|
||||
|
||||
def _get_output_tokens(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[List[int]]:
|
||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||
output_tokens: List[List[int]] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, _ = seq_group
|
||||
@ -169,11 +170,13 @@ def _apply_penalties(
|
||||
device=logits.device)
|
||||
|
||||
frequency_penalties = [frequency_penalties[i] for i in indices]
|
||||
frequency_penalties = torch.tensor(
|
||||
frequency_penalties, dtype=logits.dtype, device=logits.device)
|
||||
frequency_penalties = torch.tensor(frequency_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
presence_penalties = [presence_penalties[i] for i in indices]
|
||||
presence_penalties = torch.tensor(
|
||||
presence_penalties, dtype=logits.dtype, device=logits.device)
|
||||
presence_penalties = torch.tensor(presence_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
@ -183,9 +186,7 @@ def _apply_penalties(
|
||||
return logits
|
||||
|
||||
|
||||
def _get_temperatures(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[float]:
|
||||
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||
# Collect the temperatures for the logits.
|
||||
temperatures: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
@ -252,8 +253,9 @@ def _apply_top_p_top_k(
|
||||
probs_sort[top_k_mask] = 0.0
|
||||
|
||||
# Re-sort the probabilities.
|
||||
probs = torch.gather(
|
||||
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
|
||||
probs = torch.gather(probs_sort,
|
||||
dim=-1,
|
||||
index=torch.argsort(probs_idx, dim=-1))
|
||||
return probs
|
||||
|
||||
|
||||
@ -296,8 +298,9 @@ def _sample_from_prompt(
|
||||
# Random sampling.
|
||||
# Sample `best_of` tokens for the prompt.
|
||||
num_seqs = sampling_params.best_of
|
||||
next_token_ids = torch.multinomial(
|
||||
prob, num_samples=num_seqs, replacement=True)
|
||||
next_token_ids = torch.multinomial(prob,
|
||||
num_samples=num_seqs,
|
||||
replacement=True)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return next_token_ids
|
||||
|
||||
@ -315,8 +318,9 @@ def _sample_from_generation_tokens(
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
# Add cumulative logprobs for the sequences in the group.
|
||||
seq_logprobs = torch.tensor(
|
||||
seq_logprobs, dtype=torch.float, device=logprobs.device)
|
||||
seq_logprobs = torch.tensor(seq_logprobs,
|
||||
dtype=torch.float,
|
||||
device=logprobs.device)
|
||||
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
|
||||
|
||||
vocab_size = logprobs.size(-1)
|
||||
@ -353,8 +357,9 @@ def _sample_from_generation_tokens(
|
||||
else:
|
||||
# Random sampling.
|
||||
# Sample 1 token for each sequence in the group.
|
||||
next_token_ids = torch.multinomial(
|
||||
probs, num_samples=1, replacement=True)
|
||||
next_token_ids = torch.multinomial(probs,
|
||||
num_samples=1,
|
||||
replacement=True)
|
||||
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
|
||||
parent_seq_ids = seq_ids
|
||||
return parent_seq_ids, next_token_ids
|
||||
@ -381,15 +386,16 @@ def _sample(
|
||||
# Sample the next tokens.
|
||||
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs = _get_topk_logprobs(
|
||||
logprob, sampling_params.logprobs)
|
||||
next_logprobs = _get_topk_logprobs(logprob,
|
||||
sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
|
||||
output_logprobs = next_logprobs.copy()
|
||||
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
||||
seq_outputs[seq_id] = SequenceOutputs(
|
||||
seq_id, seq_id, next_token_id, output_logprobs)
|
||||
seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
|
||||
next_token_id,
|
||||
output_logprobs)
|
||||
else:
|
||||
# Generate the next tokens for generation tokens.
|
||||
prob = probs[idx:idx + len(seq_ids)]
|
||||
@ -399,22 +405,24 @@ def _sample(
|
||||
# Sample the next tokens.
|
||||
seq_logprobs = [
|
||||
input_metadata.seq_data[seq_id].cumulative_logprob
|
||||
for seq_id in seq_ids]
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
||||
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs: Dict[int, Dict[int, float]] = {}
|
||||
for i, seq_id in enumerate(seq_ids):
|
||||
for j, seq_id in enumerate(seq_ids):
|
||||
next_logprobs[seq_id] = _get_topk_logprobs(
|
||||
logprob[i], sampling_params.logprobs)
|
||||
logprob[j], sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for seq_id, parent_seq_id, next_token_id in zip(
|
||||
seq_ids, parent_seq_ids, next_token_ids):
|
||||
i = seq_ids.index(parent_seq_id)
|
||||
seq_ids, parent_seq_ids, next_token_ids):
|
||||
j = seq_ids.index(parent_seq_id)
|
||||
output_logprobs = next_logprobs[parent_seq_id].copy()
|
||||
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
|
||||
output_logprobs[next_token_id] = logprob[j,
|
||||
next_token_id].item()
|
||||
seq_outputs[seq_id] = SequenceOutputs(
|
||||
seq_id,
|
||||
parent_seq_id,
|
||||
|
@ -6,8 +6,9 @@ import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM, GPTNeoXForCausalLM,
|
||||
LlamaForCausalLM, OPTForCausalLM)
|
||||
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM,
|
||||
GPTNeoXForCausalLM, LlamaForCausalLM,
|
||||
OPTForCausalLM)
|
||||
from vllm.model_executor.weight_utils import initialize_dummy_weights
|
||||
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
@ -28,8 +29,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
return _MODEL_REGISTRY[arch]
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
|
||||
)
|
||||
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
|
||||
|
||||
|
||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
@ -46,8 +46,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
initialize_dummy_weights(model)
|
||||
else:
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(
|
||||
model_config.model, model_config.download_dir,
|
||||
model_config.use_np_weights)
|
||||
model.load_weights(model_config.model, model_config.download_dir,
|
||||
model_config.use_np_weights)
|
||||
model = model.cuda()
|
||||
return model.eval()
|
||||
|
@ -4,8 +4,6 @@ from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GPT2LMHeadModel",
|
||||
"GPTBigCodeForCausalLM",
|
||||
|
@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
@ -47,19 +48,25 @@ class GPT2Attention(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size,
|
||||
bias=True, gather_output=False,
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
self.c_proj = RowParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim,
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale)
|
||||
|
||||
def forward(
|
||||
@ -72,8 +79,8 @@ class GPT2Attention(nn.Module):
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -87,11 +94,15 @@ class GPT2MLP(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
|
||||
bias=True, gather_output=False,
|
||||
self.c_fc = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
self.c_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
@ -107,7 +118,8 @@ class GPT2Block(nn.Module):
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(config)
|
||||
@ -145,9 +157,9 @@ class GPT2Model(nn.Module):
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert config.add_cross_attention == False
|
||||
assert config.scale_attn_by_inverse_layer_idx == False
|
||||
assert config.reorder_and_upcast_attn == False
|
||||
assert not config.add_cross_attention
|
||||
assert not config.scale_attn_by_inverse_layer_idx
|
||||
assert not config.reorder_and_upcast_attn
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
|
||||
@ -180,8 +192,8 @@ class GPT2Model(nn.Module):
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
|
||||
cache_event)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -206,24 +218,26 @@ class GPT2LMHeadModel(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.transformer(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
@ -248,16 +262,20 @@ class GPT2LMHeadModel(nn.Module):
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
|
||||
padded_vocab_size = (param.shape[0] *
|
||||
tensor_model_parallel_world_size)
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along the head dimension.
|
||||
# GPT-2's fused QKV has the shape of
|
||||
# [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along
|
||||
# the head dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
@ -266,11 +284,13 @@ class GPT2LMHeadModel(nn.Module):
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
if name.endswith(".weight"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif name.endswith(".bias"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
|
@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 CTranslate2, and Michael Feil
|
||||
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
||||
@ -49,19 +50,25 @@ class GPTBigCodeAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size,
|
||||
bias=True, gather_output=False,
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
self.c_proj = RowParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim,
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale)
|
||||
|
||||
def forward(
|
||||
@ -74,8 +81,8 @@ class GPTBigCodeAttention(nn.Module):
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -89,11 +96,15 @@ class GPTBigMLP(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
|
||||
bias=True, gather_output=False,
|
||||
self.c_fc = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
self.c_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
@ -109,7 +120,8 @@ class GPTBigCodeBlock(nn.Module):
|
||||
def __init__(self, config: GPTBigCodeConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTBigCodeAttention(config)
|
||||
@ -147,7 +159,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
def __init__(self, config: GPTBigCodeConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert config.add_cross_attention == False
|
||||
assert not config.add_cross_attention
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
@ -181,8 +193,8 @@ class GPTBigCodeModel(nn.Module):
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
|
||||
cache_event)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -207,24 +219,26 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.transformer(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
@ -241,9 +255,11 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
|
||||
padded_vocab_size = param.shape[
|
||||
0] * tensor_model_parallel_world_size
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
@ -258,25 +274,31 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
qkv_array = qkv_array.numpy()
|
||||
|
||||
dims_q = n_head * head_dim
|
||||
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim), axis=0)
|
||||
# q is fine, but k & v have not replicated shape along the first axis
|
||||
# as long as MQA is not nativly supported, increase memory and replicated
|
||||
# (head_dim, hidden_dim) to (n_heads * head_dim, hidden_dim)
|
||||
# pylint: disable=unbalanced-tuple-unpacking
|
||||
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
|
||||
axis=0)
|
||||
# q is fine, but k & v have not replicated shape along the first
|
||||
# axis as long as MQA is not nativly supported, increase memory
|
||||
# and replicated (head_dim, hidden_dim) to
|
||||
# (n_heads * head_dim, hidden_dim)
|
||||
if k.ndim == 2 and v.ndim == 2:
|
||||
replication = (n_head, 1) # weights
|
||||
else:
|
||||
replication = n_head # biases
|
||||
# replicate n_head times for q, v
|
||||
k, v = np.tile(k, replication), np.tile(v, replication)
|
||||
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim)
|
||||
# concat q, k, v along the first axis
|
||||
# (n_heads * head_dim, hidden_dim)
|
||||
# to (3 * n_heads * head_dim, hidden_dim)
|
||||
qkv_array = np.concatenate((q, k, v), axis=0)
|
||||
return torch.from_numpy(qkv_array)
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along the head dimension.
|
||||
# GPT-2's fused QKV has the shape of
|
||||
# [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along
|
||||
# the head dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
@ -285,13 +307,19 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
if name.endswith(".weight"):
|
||||
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
|
||||
loaded_weight = _expand_mqa_mha(loaded_weight,
|
||||
n_head=total_num_heads,
|
||||
head_dim=head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif name.endswith(".bias"):
|
||||
loaded_weight = _expand_mqa_mha(loaded_weight, n_head=total_num_heads, head_dim=head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
|
||||
loaded_weight = _expand_mqa_mha(loaded_weight,
|
||||
n_head=total_num_heads,
|
||||
head_dim=head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
|
@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@ -48,19 +49,23 @@ class GPTNeoXAttention(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.total_num_heads
|
||||
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
|
||||
self.query_key_value = ColumnParallelLinear(config.hidden_size,
|
||||
3 * config.hidden_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size,
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
3 * config.hidden_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense = RowParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
scaling = self.head_size ** -0.5
|
||||
scaling = self.head_size**-0.5
|
||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||
assert rotary_dim % 2 == 0
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
|
||||
@ -78,8 +83,8 @@ class GPTNeoXAttention(nn.Module):
|
||||
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@ -92,7 +97,8 @@ class GPTNeoXMLP(nn.Module):
|
||||
config.intermediate_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
|
||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
@ -109,8 +115,10 @@ class GPTNeoXLayer(nn.Module):
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.attention = GPTNeoXAttention(config)
|
||||
self.mlp = GPTNeoXMLP(config)
|
||||
|
||||
@ -154,10 +162,13 @@ class GPTNeoXModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||
self.embed_in = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.layers = nn.ModuleList(
|
||||
[GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -191,8 +202,10 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.gpt_neox = GPTNeoXModel(config)
|
||||
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size,
|
||||
bias=False, gather_output=False,
|
||||
self.embed_out = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@ -204,24 +217,28 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.gpt_neox(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.embed_out.weight, hidden_states, input_metadata)
|
||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
|
||||
_column_parallel_weights = [
|
||||
"embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
|
||||
"dense_h_to_4h.bias"
|
||||
]
|
||||
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
continue
|
||||
param = state_dict[name]
|
||||
if "query_key_value" in name:
|
||||
@ -230,17 +247,19 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
# required shape is [3 * num_heads * head_size, hidden_size].
|
||||
# Thus, we need weight conversion.
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
|
||||
num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // num_heads
|
||||
if 'query_key_value.weight' in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size)
|
||||
if "query_key_value.weight" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size,
|
||||
hidden_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif 'query_key_value.bias' in name:
|
||||
elif "query_key_value.bias" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
|
@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@ -30,7 +31,6 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -56,15 +56,19 @@ class LlamaMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
|
||||
bias=False, gather_output=False,
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=False, input_is_parallel=True,
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
if hidden_act != 'silu':
|
||||
raise ValueError(f'Unsupported activation: {hidden_act}. '
|
||||
'Only silu is supported for now.')
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
@ -83,12 +87,14 @@ class LlamaAttention(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
@ -104,8 +110,10 @@ class LlamaAttention(nn.Module):
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim,
|
||||
self.scaling, rotary_dim=self.head_dim)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -118,8 +126,8 @@ class LlamaAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -138,8 +146,10 @@ class LlamaDecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -177,9 +187,13 @@ class LlamaModel(nn.Module):
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
@ -209,6 +223,7 @@ class LlamaModel(nn.Module):
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -228,39 +243,42 @@ class LlamaForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head.weight, hidden_states, input_metadata)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
|
||||
"qkv_proj.weight", "gate_proj.weight",
|
||||
"up_proj.weight"]
|
||||
_column_parallel_weights = [
|
||||
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
|
||||
"gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
@ -275,10 +293,10 @@ class LlamaForCausalLM(nn.Module):
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
|
@ -1,7 +1,9 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately. Other models don't have this hack
|
||||
# OPT is set up so that if padding_idx is specified then offset the
|
||||
# embedding ids by 2 and adjust num_embeddings appropriately. Other
|
||||
# models don't have this hack
|
||||
self.offset = 2
|
||||
super().__init__(num_embeddings + self.offset, embedding_dim)
|
||||
|
||||
@ -62,20 +65,26 @@ class OPTAttention(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
total_num_heads = num_heads
|
||||
assert num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = embed_dim // total_num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
|
||||
self.qkv_proj = ColumnParallelLinear(embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
self.out_proj = RowParallelLinear(embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim,
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scaling)
|
||||
|
||||
def forward(
|
||||
@ -88,8 +97,8 @@ class OPTAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module):
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim,
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.fc1 = ColumnParallelLinear(self.embed_dim,
|
||||
config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim,
|
||||
self.fc2 = RowParallelLinear(config.ffn_dim,
|
||||
self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -133,11 +146,10 @@ class OPTDecoderLayer(nn.Module):
|
||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||
if self.do_layer_norm_before:
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event)
|
||||
hidden_states = residual + hidden_states
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
@ -167,35 +179,42 @@ class OPTDecoder(nn.Module):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.word_embed_proj_dim,
|
||||
perform_initialization=False)
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.word_embed_proj_dim,
|
||||
perform_initialization=False)
|
||||
# Positional embeddings are replicated (not sharded).
|
||||
self.embed_positions = OPTLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# Project out & in will be replicated if they exist.
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
|
||||
self.project_out = nn.Linear(config.hidden_size,
|
||||
config.word_embed_proj_dim,
|
||||
bias=False)
|
||||
else:
|
||||
self.project_out = None
|
||||
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
|
||||
self.project_in = nn.Linear(config.word_embed_proj_dim,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
else:
|
||||
self.project_in = None
|
||||
|
||||
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||
# Note that the only purpose of `config._remove_final_layer_norm` is to
|
||||
# keep backward compatibility with checkpoints that have been fine-tuned
|
||||
# before transformers v4.20.1
|
||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
|
||||
)
|
||||
config.hidden_size,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layers = nn.ModuleList(
|
||||
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -217,8 +236,8 @@ class OPTDecoder(nn.Module):
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
|
||||
cache_event)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
@ -241,8 +260,8 @@ class OPTModel(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
return self.decoder(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
return self.decoder(input_ids, positions, kv_caches, input_metadata,
|
||||
cache_events)
|
||||
|
||||
|
||||
class OPTForCausalLM(nn.Module):
|
||||
@ -264,23 +283,26 @@ class OPTForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
|
||||
_column_parallel_weights = [
|
||||
"embed_tokens.weight", "fc1.weight", "fc1.bias"
|
||||
]
|
||||
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
|
||||
@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module):
|
||||
name = "model." + name
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
|
@ -44,9 +44,9 @@ def hf_model_weights_iterator(
|
||||
if use_np_cache:
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, 'np')
|
||||
np_folder = os.path.join(hf_folder, "np")
|
||||
os.makedirs(np_folder, exist_ok=True)
|
||||
weight_names_file = os.path.join(np_folder, 'weight_names.json')
|
||||
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||
with lock:
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names = []
|
||||
@ -57,10 +57,10 @@ def hf_model_weights_iterator(
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
weight_names.append(name)
|
||||
with open(weight_names_file, 'w') as f:
|
||||
with open(weight_names_file, "w") as f:
|
||||
json.dump(weight_names, f)
|
||||
|
||||
with open(weight_names_file, 'r') as f:
|
||||
with open(weight_names_file, "r") as f:
|
||||
weight_names = json.load(f)
|
||||
|
||||
for name in weight_names:
|
||||
@ -86,17 +86,16 @@ def load_tensor_parallel_weights(
|
||||
for p in column_parallel_weight_names:
|
||||
if p in param_name:
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
loaded_weight = loaded_weight[start_idx:end_idx]
|
||||
break
|
||||
for p in row_parallel_weight_names:
|
||||
if p in param_name:
|
||||
shard_size = param.shape[1]
|
||||
loaded_weight = loaded_weight[
|
||||
:,
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
loaded_weight = loaded_weight[:, start_idx:end_idx]
|
||||
break
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f"{param_name} shape mismatch between model and checkpoint: "
|
||||
|
@ -55,6 +55,7 @@ class RequestOutput:
|
||||
outputs: The output sequences of the request.
|
||||
finished: Whether the whole request is finished.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -75,8 +76,9 @@ class RequestOutput:
|
||||
n = seq_group.sampling_params.n
|
||||
seqs = seq_group.get_seqs()
|
||||
assert n <= len(seqs)
|
||||
sorted_seqs = sorted(
|
||||
seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True)
|
||||
sorted_seqs = sorted(seqs,
|
||||
key=lambda seq: seq.get_cumulative_logprob(),
|
||||
reverse=True)
|
||||
top_n_seqs = sorted_seqs[:n]
|
||||
|
||||
# Create the outputs.
|
||||
|
@ -3,6 +3,7 @@ from typing import List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
"""Sampling parameters for text generation.
|
||||
|
||||
@ -51,7 +52,7 @@ class SamplingParams:
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
use_beam_search: bool = False,
|
||||
stop: Union[str, List[str]] = [],
|
||||
stop: Union[None, str, List[str]] = None,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: int = 16,
|
||||
logprobs: Optional[int] = None,
|
||||
@ -64,7 +65,12 @@ class SamplingParams:
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.use_beam_search = use_beam_search
|
||||
self.stop = [stop] if isinstance(stop, str) else list(stop)
|
||||
if stop is None:
|
||||
self.stop = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = list(stop)
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.logprobs = logprobs
|
||||
|
@ -1,3 +1,4 @@
|
||||
"""Sequence and its related classes."""
|
||||
import copy
|
||||
import enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
@ -7,6 +8,7 @@ from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class SequenceStatus(enum.Enum):
|
||||
"""Status of a sequence."""
|
||||
WAITING = enum.auto()
|
||||
RUNNING = enum.auto()
|
||||
SWAPPED = enum.auto()
|
||||
@ -21,7 +23,7 @@ class SequenceStatus(enum.Enum):
|
||||
SequenceStatus.FINISHED_STOPPED,
|
||||
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
||||
SequenceStatus.FINISHED_ABORTED,
|
||||
SequenceStatus.FINISHED_IGNORED
|
||||
SequenceStatus.FINISHED_IGNORED,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@ -40,6 +42,17 @@ class SequenceStatus(enum.Enum):
|
||||
|
||||
|
||||
class SequenceData:
|
||||
"""Data associated with a sequence.
|
||||
|
||||
|
||||
Args:
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
|
||||
Attributes:
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
output_token_ids: The token IDs of the output.
|
||||
cumulative_logprob: The cumulative log probability of the output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -75,6 +88,15 @@ class SequenceData:
|
||||
|
||||
|
||||
class Sequence:
|
||||
"""Stores the data, status, and block information of a sequence.
|
||||
|
||||
Args:
|
||||
seq_id: The ID of the sequence.
|
||||
prompt: The prompt of the sequence.
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
block_size: The block size of the sequence. Should be the same as the
|
||||
block size used by the block manager and cache engine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -149,19 +171,27 @@ class Sequence:
|
||||
def is_finished(self) -> bool:
|
||||
return SequenceStatus.is_finished(self.status)
|
||||
|
||||
def fork(self, child_seq: 'Sequence') -> None:
|
||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||
def fork(self, child_seq: "Sequence") -> None:
|
||||
child_seq.logical_token_blocks = copy.deepcopy(
|
||||
self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
child_seq.data = copy.deepcopy(self.data)
|
||||
return None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'Sequence(seq_id={self.seq_id}, '
|
||||
f'status={self.status.name}, '
|
||||
f'num_blocks={len(self.logical_token_blocks)})')
|
||||
return (f"Sequence(seq_id={self.seq_id}, "
|
||||
f"status={self.status.name}, "
|
||||
f"num_blocks={len(self.logical_token_blocks)})")
|
||||
|
||||
|
||||
class SequenceGroup:
|
||||
"""A group of sequences that are generated from the same prompt.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
seqs: The list of sequences.
|
||||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
arrival_time: The arrival time of the request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -191,7 +221,7 @@ class SequenceGroup:
|
||||
for seq in self.seqs:
|
||||
if seq.seq_id == seq_id:
|
||||
return seq
|
||||
raise ValueError(f'Sequence {seq_id} not found.')
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return all(seq.is_finished() for seq in self.seqs)
|
||||
@ -203,14 +233,25 @@ class SequenceGroup:
|
||||
|
||||
|
||||
class SequenceGroupMetadata:
|
||||
"""Metadata for a sequence group. Used to create `InputMetadata`.
|
||||
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
is_prompt: Whether the request is at prompt stage.
|
||||
seq_data: The sequence data. (Seq id -> sequence data)
|
||||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
block_tables: The block tables. (Seq id -> list of physical block
|
||||
numbers)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
is_prompt: bool,
|
||||
seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
|
||||
seq_data: Dict[int, SequenceData],
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
|
||||
block_tables: Dict[int, List[int]],
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.is_prompt = is_prompt
|
||||
@ -220,13 +261,23 @@ class SequenceGroupMetadata:
|
||||
|
||||
|
||||
class SequenceOutputs:
|
||||
"""The model output associated with a sequence.
|
||||
|
||||
Args:
|
||||
seq_id: The ID of the sequence.
|
||||
parent_seq_id: The ID of the parent sequence (for forking in beam
|
||||
search).
|
||||
output_token: The output token ID.
|
||||
logprobs: The logprobs of the output token.
|
||||
(Token id -> logP(x_i+1 | x_0, ..., x_i))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
parent_seq_id: int,
|
||||
output_token: int,
|
||||
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i).
|
||||
logprobs: Dict[int, float],
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.parent_seq_id = parent_seq_id
|
||||
@ -234,15 +285,15 @@ class SequenceOutputs:
|
||||
self.logprobs = logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'SequenceOutputs(seq_id={self.seq_id}, '
|
||||
f'parent_seq_id={self.parent_seq_id}, '
|
||||
f'output_token={self.output_token}), '
|
||||
f'logprobs={self.logprobs}')
|
||||
return (f"SequenceOutputs(seq_id={self.seq_id}, "
|
||||
f"parent_seq_id={self.parent_seq_id}, "
|
||||
f"output_token={self.output_token}), "
|
||||
f"logprobs={self.logprobs}")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceOutputs):
|
||||
return NotImplemented
|
||||
return (self.seq_id == other.seq_id and
|
||||
self.parent_seq_id == other.parent_seq_id and
|
||||
self.output_token == other.output_token and
|
||||
self.logprobs == other.logprobs)
|
||||
return (self.seq_id == other.seq_id
|
||||
and self.parent_seq_id == other.parent_seq_id
|
||||
and self.output_token == other.output_token
|
||||
and self.logprobs == other.logprobs)
|
||||
|
@ -13,8 +13,8 @@ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
|
||||
|
||||
def get_tokenizer(
|
||||
tokenizer_name: str,
|
||||
tokenizer_mode: str = "auto",
|
||||
*args,
|
||||
tokenizer_mode: str = "auto",
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||
@ -73,7 +73,8 @@ def detokenize_incrementally(
|
||||
output_text = tokenizer.convert_tokens_to_string(output_tokens)
|
||||
return new_token, output_text
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
||||
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
||||
# the output_tokens. In Python, running a for loop over a list can be slow
|
||||
# even when the loop body is very simple.
|
||||
|
@ -17,9 +17,9 @@ class Counter:
|
||||
self.counter = start
|
||||
|
||||
def __next__(self) -> int:
|
||||
id = self.counter
|
||||
i = self.counter
|
||||
self.counter += 1
|
||||
return id
|
||||
return i
|
||||
|
||||
def reset(self) -> None:
|
||||
self.counter = 0
|
||||
@ -38,6 +38,7 @@ def get_cpu_memory() -> int:
|
||||
def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def in_wsl() -> bool:
|
||||
# Reference: https://github.com/microsoft/WSL/issues/4071
|
||||
return "microsoft" in " ".join(uname()).lower()
|
||||
|
@ -93,8 +93,8 @@ class CacheEngine:
|
||||
if not pin_memory:
|
||||
# Pinning memory in WSL is not supported.
|
||||
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
|
||||
logger.warn("Using 'pin_memory=False' as WSL is detected. "
|
||||
"This may slow down the performance.")
|
||||
logger.warning("Using 'pin_memory=False' as WSL is detected. "
|
||||
"This may slow down the performance.")
|
||||
for _ in range(self.num_layers):
|
||||
key_blocks = torch.empty(
|
||||
size=(self.num_cpu_blocks, *key_block_shape),
|
||||
@ -120,11 +120,10 @@ class CacheEngine:
|
||||
src_key_cache, src_value_cache = src[i]
|
||||
dst_key_cache, dst_value_cache = dst[i]
|
||||
# Copy the key blocks.
|
||||
cache_ops.swap_blocks(
|
||||
src_key_cache, dst_key_cache, src_to_dst)
|
||||
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
# Copy the value blocks.
|
||||
cache_ops.swap_blocks(
|
||||
src_value_cache, dst_value_cache, src_to_dst)
|
||||
cache_ops.swap_blocks(src_value_cache, dst_value_cache,
|
||||
src_to_dst)
|
||||
event = self.events[i]
|
||||
event.record(stream=self.cache_stream)
|
||||
|
||||
|
@ -73,8 +73,8 @@ class Worker:
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
sampling_params = SamplingParams(top_p=0.99,
|
||||
top_k=self.model.config.vocab_size - 1)
|
||||
vocab_size = self.model.config.vocab_size
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
seqs = []
|
||||
@ -91,7 +91,8 @@ class Worker:
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
input_tokens, input_positions, input_metadata = self._prepare_inputs(seqs)
|
||||
input_tokens, input_positions, input_metadata = self._prepare_inputs(
|
||||
seqs)
|
||||
|
||||
# Execute the model.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
@ -110,8 +111,9 @@ class Worker:
|
||||
total_gpu_memory = get_gpu_memory()
|
||||
cache_block_size = CacheEngine.get_cache_block_size(
|
||||
block_size, self.model_config, self.parallel_config)
|
||||
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
|
||||
- peak_memory) // cache_block_size)
|
||||
num_gpu_blocks = int(
|
||||
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
|
||||
cache_block_size)
|
||||
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
@ -125,8 +127,8 @@ class Worker:
|
||||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||
self.cache_config = cache_config
|
||||
self.block_size = cache_config.block_size
|
||||
self.cache_engine = CacheEngine(
|
||||
self.cache_config, self.model_config, self.parallel_config)
|
||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config)
|
||||
self.cache_events = self.cache_engine.events
|
||||
self.gpu_cache = self.cache_engine.gpu_cache
|
||||
|
||||
@ -202,8 +204,8 @@ class Worker:
|
||||
generation_block_tables.append(block_table)
|
||||
|
||||
max_context_len = max(max_context_len, context_len)
|
||||
max_num_blocks_per_seq = max(
|
||||
max_num_blocks_per_seq, len(block_table))
|
||||
max_num_blocks_per_seq = max(max_num_blocks_per_seq,
|
||||
len(block_table))
|
||||
context_lens.append(context_len)
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
@ -223,7 +225,8 @@ class Worker:
|
||||
context_lens_tensor = torch.cuda.IntTensor(context_lens)
|
||||
padded_block_tables = [
|
||||
_pad_to_max(block_table, max_num_blocks_per_seq)
|
||||
for block_table in generation_block_tables]
|
||||
for block_table in generation_block_tables
|
||||
]
|
||||
block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user