[Experimental] Add multi-LoRA support (#1804)

Co-authored-by: Chen Shen <scv119@gmail.com>
Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com>
Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
Antoni Baum 2024-01-24 00:26:37 +01:00 committed by GitHub
parent 9c1352eb57
commit 9b945daaf1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
52 changed files with 8035 additions and 126 deletions

View File

@ -41,6 +41,9 @@ steps:
- label: Worker Test - label: Worker Test
command: pytest -v -s worker command: pytest -v -s worker
- label: LoRA Test
command: pytest -v -s lora
- label: Benchmarks - label: Benchmarks
working_dir: "/vllm-workspace/.buildkite" working_dir: "/vllm-workspace/.buildkite"
commands: commands:

View File

@ -65,7 +65,9 @@ def main(args: argparse.Namespace):
if args.profile: if args.profile:
profile_dir = args.profile_result_dir profile_dir = args.profile_result_dir
if not profile_dir: if not profile_dir:
profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" profile_dir = Path(
"."
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
print(f"Profiling (results will be saved to '{profile_dir}')...") print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=args.profile_result_dir) run_to_completion(profile_dir=args.profile_result_dir)
return return
@ -123,9 +125,7 @@ if __name__ == '__main__':
'--profile-result-dir', '--profile-result-dir',
type=str, type=str,
default=None, default=None,
help=( help=('path to save the pytorch profiler output. Can be visualized '
'path to save the pytorch profiler output. Can be visualized ' 'with ui.perfetto.dev or Tensorboard.'))
'with ui.perfetto.dev or Tensorboard.'
))
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

217
csrc/punica/LICENSE Normal file
View File

@ -0,0 +1,217 @@
Contains code from https://github.com/punica-ai/punica
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
------------------------------------------------------------------------------------
This product bundles various third-party components under other open source licenses.
This section summarizes those components and their licenses. See licenses/
for text of these licenses.
Apache-2.0
* third_party/nvbench (with LLVM exception)
* third_party/flashinfer
BSD-3-Clause:
* third_party/cutlass

View File

@ -0,0 +1,21 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)

View File

@ -0,0 +1,59 @@
#pragma once
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale);
// clang-format off
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1280) \
f(in_T, out_T, W_T, narrow, 1728) \
f(in_T, out_T, W_T, narrow, 1792) \
f(in_T, out_T, W_T, narrow, 2048) \
f(in_T, out_T, W_T, narrow, 2560) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
f(in_T, out_T, W_T, narrow, 9216) \
f(in_T, out_T, W_T, narrow, 10240) \
f(in_T, out_T, W_T, narrow, 11008) \
f(in_T, out_T, W_T, narrow, 12288) \
f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
f(in_T, out_T, W_T, narrow, 32512) \
f(in_T, out_T, W_T, narrow, 32768) \
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
// clang-format on

View File

@ -0,0 +1,294 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cuda/pipeline>
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
#include "vec_dtypes.cuh"
namespace cg = cooperative_groups;
// nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t j = blockIdx.x;
constexpr size_t num_pipeline_stages = 2;
constexpr size_t tile_size = tx * ty * vec_size;
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
__shared__ float y_warpwise[ty];
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
auto pipe = cuda::make_pipeline();
// pipeline load W/X and compute WX;
pipe.producer_acquire();
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
pipe.producer_commit();
size_t copy_idx, compute_idx;
float y = 0.f;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
++tile_idx) {
copy_idx = tile_idx % num_pipeline_stages;
// pipeline stage: async copy W fragment
pipe.producer_acquire();
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) + tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
}
pipe.producer_commit();
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// pipeline stage: compute WX
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] = sum;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
}
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// final pipeline stage
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] =
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
? sum
: 0.f;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
// write Y;
if (block.thread_rank() == 0) {
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
}
}
// nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
typename in_T, typename out_T, typename W_T>
__global__ void
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t tile_idx = blockIdx.x;
// load X;
vec_t<in_T, vec_size> x_vec;
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
// load W;
vec_t<W_T, vec_size> w_vec;
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
block.thread_rank() * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += g.shfl_down(sum, offset);
}
sum = g.shfl(sum, 0);
if (threadIdx.x == 0) {
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
}
}
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
constexpr size_t vec_size = 8;
constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr (feat_in < feat_out) {
static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size;
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
constexpr int ty = 32 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
constexpr int ty = 16 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else {
constexpr int ty = 8 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
} else {
static_assert(feat_in % (vec_size * 32) == 0 ||
feat_in % (vec_size * 16) == 0 ||
feat_in % (vec_size * 8) == 0);
if constexpr (feat_in % (vec_size * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
vec_size * sizeof(W_T), tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
constexpr int tx = 16;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
}
}
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
template void bgmv_kernel<feat_in, feat_out>( \
out_T * __restrict__ Y, const in_T *__restrict__ X, \
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)

File diff suppressed because it is too large Load Diff

563
csrc/punica/punica_ops.cc Normal file
View File

@ -0,0 +1,563 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <cstdint>
#include "bgmv/bgmv_config.h"
namespace {
//====== utils ======
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
const char *a_name, const char *b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
".size(", i, ")");
}
}
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
return (uint32_t(a) << 16) | uint32_t(b);
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_DIM(d, x) \
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) \
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
//====== bgmv ======
template <typename in_T, typename out_T, typename W_T>
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
const int64_t *lora_indices,
uint16_t in_features, uint16_t out_features,
int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
switch (pack_u16(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u16(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
break;
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
#undef CASE
#undef CASE_ONESIDE
default:
return false;
}
return true;
}
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, float scale) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_INPUT(indicies);
CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(4, w);
CHECK_DIM(1, indicies);
int64_t B = x.size(0);
int64_t h_in = x.size(1);
int64_t h_out = y.size(1);
int64_t num_layers = w.size(1);
CHECK_EQ(w.size(3), h_in);
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
default:
break;
}
}
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
float scale, int64_t h_in, int64_t h_out,
int64_t y_offset) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_INPUT(indicies);
CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(4, w);
CHECK_DIM(1, indicies);
int64_t B = x.size(0);
int64_t num_layers = w.size(1);
int64_t full_y_size = y.size(1);
CHECK_EQ(w.size(3), h_in);
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_bfloat16 *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
default:
break;
}
}
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
} // namespace
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
"dispatch_bgmv_low_level");
}

View File

@ -0,0 +1,117 @@
"""
This example shows how to use the multi-LoRA functionality for offline inference.
Requires HuggingFace credentials for access to Llama2.
"""
from typing import Optional, List, Tuple
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
from vllm.lora.request import LoRARequest
def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
"""Create a list of test prompts with their sampling parameters.
2 requests for base model, 4 requests for the LoRA. We define 2
different LoRA adapters (using the same model for demo purposes).
Since we also set `max_loras=1`, the expectation is that the requests
with the second LoRA adapter will be ran after all requests with the
first adapter have finished.
"""
return [
("A robot may not injure a human being",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128), None),
("To be or not to be,",
SamplingParams(temperature=0.8,
top_k=5,
presence_penalty=0.2,
max_tokens=128), None),
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
SamplingParams(n=3,
best_of=3,
use_beam_search=True,
temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora2", 2, lora_path)),
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
SamplingParams(n=3,
best_of=3,
use_beam_search=True,
temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
]
def process_requests(engine: LLMEngine,
test_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
lora_request=lora_request)
request_id += 1
request_outputs: List[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
print(request_output)
def initialize_engine() -> LLMEngine:
"""Initialize the LLMEngine."""
# max_loras: controls the number of LoRAs that can be used in the same
# batch. Larger numbers will cause higher memory usage, as each LoRA
# slot requires its own preallocated tensor.
# max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
# numbers will cause higher memory usage. If you know that all LoRAs will
# use the same rank, it is recommended to set this as low as possible.
# max_cpu_loras: controls the size of the CPU LoRA cache.
engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
enable_lora=True,
max_loras=1,
max_lora_rank=8,
max_cpu_loras=2,
max_num_seqs=256)
return LLMEngine.from_engine_args(engine_args)
def main():
"""Main function that sets up and runs the prompt processing."""
engine = initialize_engine()
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts)
if __name__ == '__main__':
main()

View File

@ -1,13 +1,16 @@
import contextlib
import io import io
import os import os
import re import re
import subprocess import subprocess
from typing import List, Set
import warnings import warnings
from pathlib import Path
from typing import List, Set
from packaging.version import parse, Version from packaging.version import parse, Version
import setuptools import setuptools
import torch import torch
import torch.utils.cpp_extension as torch_cpp_ext
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
ROOT_DIR = os.path.dirname(__file__) ROOT_DIR = os.path.dirname(__file__)
@ -28,7 +31,7 @@ def _is_neuron() -> bool:
torch_neuronx_installed = True torch_neuronx_installed = True
try: try:
subprocess.run(["neuron-ls"], capture_output=True, check=True) subprocess.run(["neuron-ls"], capture_output=True, check=True)
except FileNotFoundError as e: except FileNotFoundError:
torch_neuronx_installed = False torch_neuronx_installed = False
return torch_neuronx_installed return torch_neuronx_installed
@ -96,10 +99,16 @@ def get_hipcc_rocm_version():
return None return None
def glob(pattern: str):
root = Path(__name__).parent
return [str(p) for p in root.glob(pattern)]
def get_neuronxcc_version(): def get_neuronxcc_version():
import sysconfig import sysconfig
site_dir = sysconfig.get_paths()["purelib"] site_dir = sysconfig.get_paths()["purelib"]
version_file = os.path.join(site_dir, "neuronxcc", "version", "__init__.py") version_file = os.path.join(site_dir, "neuronxcc", "version",
"__init__.py")
# Check if the command was executed successfully # Check if the command was executed successfully
with open(version_file, "rt") as fp: with open(version_file, "rt") as fp:
@ -178,6 +187,8 @@ if _is_cuda() and not compute_capabilities:
"GPUs with compute capability below 7.0 are not supported.") "GPUs with compute capability below 7.0 are not supported.")
compute_capabilities.add(f"{major}.{minor}") compute_capabilities.add(f"{major}.{minor}")
ext_modules = []
if _is_cuda(): if _is_cuda():
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if not compute_capabilities: if not compute_capabilities:
@ -215,6 +226,8 @@ if _is_cuda():
raise RuntimeError( raise RuntimeError(
"CUDA 11.8 or higher is required for compute capability 9.0.") "CUDA 11.8 or higher is required for compute capability 9.0.")
NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy()
# Add target compute capabilities to NVCC flags. # Add target compute capabilities to NVCC flags.
for capability in compute_capabilities: for capability in compute_capabilities:
num = capability[0] + capability[2] num = capability[0] + capability[2]
@ -223,6 +236,14 @@ if _is_cuda():
NVCC_FLAGS += [ NVCC_FLAGS += [
"-gencode", f"arch=compute_{num},code=compute_{num}" "-gencode", f"arch=compute_{num},code=compute_{num}"
] ]
if int(capability[0]) >= 8:
NVCC_FLAGS_PUNICA += [
"-gencode", f"arch=compute_{num},code=sm_{num}"
]
if capability.endswith("+PTX"):
NVCC_FLAGS_PUNICA += [
"-gencode", f"arch=compute_{num},code=compute_{num}"
]
# Use NVCC threads to parallelize the build. # Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"): if nvcc_cuda_version >= Version("11.2"):
@ -230,6 +251,36 @@ if _is_cuda():
num_threads = min(os.cpu_count(), nvcc_threads) num_threads = min(os.cpu_count(), nvcc_threads)
NVCC_FLAGS += ["--threads", str(num_threads)] NVCC_FLAGS += ["--threads", str(num_threads)]
# changes for punica kernels
NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS
REMOVE_NVCC_FLAGS = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
for flag in REMOVE_NVCC_FLAGS:
with contextlib.suppress(ValueError):
torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)
install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "1")))
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
install_punica = False
break
if install_punica:
ext_modules.append(
CUDAExtension(
name="vllm._punica_C",
sources=["csrc/punica/punica_ops.cc"] +
glob("csrc/punica/bgmv/*.cu"),
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS_PUNICA,
},
))
elif _is_hip(): elif _is_hip():
amd_arch = get_amdgpu_offload_arch() amd_arch = get_amdgpu_offload_arch()
if amd_arch not in ROCM_SUPPORTED_ARCHS: if amd_arch not in ROCM_SUPPORTED_ARCHS:
@ -240,8 +291,6 @@ elif _is_hip():
elif _is_neuron(): elif _is_neuron():
neuronxcc_version = get_neuronxcc_version() neuronxcc_version = get_neuronxcc_version()
ext_modules = []
vllm_extension_sources = [ vllm_extension_sources = [
"csrc/cache_kernels.cu", "csrc/cache_kernels.cu",
"csrc/attention/attention_kernels.cu", "csrc/attention/attention_kernels.cu",

View File

@ -25,6 +25,13 @@ class MockEngine:
return [RequestOutput( return [RequestOutput(
request_id=self.request_id)] if self.request_id else [] request_id=self.request_id)] if self.request_id else []
async def encode_request_async(
self,
*args,
**kwargs,
):
return [1]
def generate(self, request_id): def generate(self, request_id):
self.request_id = request_id self.request_id = request_id
@ -35,6 +42,10 @@ class MockEngine:
del kwargs # Unused del kwargs # Unused
self.add_request_calls += 1 self.add_request_calls += 1
async def add_request_async(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
def abort_request(self, request_id): def abort_request(self, request_id):
del request_id # Unused del request_id # Unused
self.abort_request_calls += 1 self.abort_request_calls += 1

0
tests/lora/__init__.py Normal file
View File

143
tests/lora/conftest.py Normal file
View File

@ -0,0 +1,143 @@
import contextlib
import gc
import tempfile
from collections import OrderedDict
from unittest.mock import patch, MagicMock
import pytest
import ray
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
import vllm
from vllm.config import LoRAConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parallel_utils.parallel_state import (
destroy_model_parallel, initialize_model_parallel)
def cleanup():
destroy_model_parallel()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
torch.cuda.empty_cache()
ray.shutdown()
@pytest.fixture(autouse=True)
def cleanup_fixture():
yield
cleanup()
@pytest.fixture
def dist_init():
if not torch.distributed.is_initialized():
temp_file = tempfile.mkstemp()[1]
torch.distributed.init_process_group(
backend="nccl",
world_size=1,
rank=0,
init_method=f"file://{temp_file}",
)
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(1, 1)
yield
cleanup()
@pytest.fixture
def dist_init_torch_only():
if torch.distributed.is_initialized():
return
temp_file = tempfile.mkstemp()[1]
torch.distributed.init_process_group(
backend="nccl",
world_size=1,
rank=0,
init_method=f"file://{temp_file}",
)
@pytest.fixture
def dummy_model() -> nn.Module:
model = nn.Sequential(
OrderedDict([
("dense1", ColumnParallelLinear(764, 100)),
("dense2", RowParallelLinear(100, 50)),
(
"layer1",
nn.Sequential(
OrderedDict([
("dense1", ColumnParallelLinear(100, 10)),
("dense2", RowParallelLinear(10, 50)),
])),
),
("act2", nn.ReLU()),
("output", ColumnParallelLinear(50, 10)),
("outact", nn.Sigmoid()),
# Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)),
("sampler", Sampler(512))
]))
model.config = MagicMock()
return model
@pytest.fixture
def dummy_model_gate_up() -> nn.Module:
model = nn.Sequential(
OrderedDict([
("dense1", ColumnParallelLinear(764, 100)),
("dense2", RowParallelLinear(100, 50)),
(
"layer1",
nn.Sequential(
OrderedDict([
("dense1", ColumnParallelLinear(100, 10)),
("dense2", RowParallelLinear(10, 50)),
])),
),
("act2", nn.ReLU()),
("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
("outact", nn.Sigmoid()),
# Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)),
("sampler", Sampler(512))
]))
model.config = MagicMock()
return model
@pytest.fixture(scope="session")
def sql_lora_files():
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
get_model_old = get_model
def get_model_patched(model_config, lora_config=None):
return get_model_old(model_config,
LoRAConfig(max_loras=4, max_lora_rank=8))
with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
yield engine.llm_engine
del engine
cleanup()
@pytest.fixture
def llama_2_7b_model_extra_embeddings(
llama_2_7b_engine_extra_embeddings) -> nn.Module:
yield llama_2_7b_engine_extra_embeddings.driver_worker.model_runner.model

709
tests/lora/test_layers.py Normal file
View File

@ -0,0 +1,709 @@
import pytest
import random
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Optional, Dict, Tuple
import torch
import torch.nn.functional as F
from vllm.lora.layers import (
ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
QKVParallelLinearWithLora,
VocabParallelEmbeddingWithLoRA,
RowParallelLinearWithLoRA,
SamplerWithLoRA,
LoRAMapping,
BaseLayerWithLoRA,
)
from vllm.lora.models import LoRALayerWeights, convert_mapping, PackedLoRALayerWeights
from vllm.config import LoRAConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
QKVParallelLinear)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
from vllm.model_executor.utils import set_random_seed
from .utils import DummyLoRAManager
TOLERANCES = {
torch.float16: (5e-3, 5e-3),
torch.float32: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
def get_random_id_to_index(num_loras: int,
num_slots: int,
log: bool = True) -> List[Optional[int]]:
"""Creates a random lora_id_to_index mapping.
Args:
num_loras: The number of active loras in the mapping.
num_slots: The number of slots in the mapping. Must be larger
than num_loras.
log: Whether to log the output.
"""
if num_loras > num_slots:
raise ValueError(
f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
"num_loras must be less than or equal to num_slots.")
slots: List[Optional[int]] = [None] * num_slots
random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
slots[slot_idx] = lora_id
if log:
print(f"Created lora_id_to_index mapping: {slots}.")
return slots
def populate_loras(
id_to_index: List[Optional[int]],
layer: BaseLayerWithLoRA,
layer_weights: torch.Tensor,
generate_embeddings_tensor: int = 0,
repeats: int = 1,
) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]:
"""This method populates the lora layers with lora weights.
Args:
id_to_index: a list of lora ids. The index of the lora id
represents which memory slot the lora matrices are
stored in. A None value indicates a free slot.
layer: the LoRAlayer to populate.
layer_weights: the PyTorch tensor containing the layer's
weights.
generate_embeddings_tensor: whether to generate an
embeddings tensor for each LoRA.
repeats: must only be set for column parallel packed
layers. Indicates the number of loras to compose
together to create a single lora layer.
"""
# Dictionary that maps the lora ID to the
# corresponding lora weights.
lora_dict: Dict[int, LoRALayerWeights] = dict()
# Dictionary that maps the lora ID to the
# corresponding subloras. Only useful when
# repeats > 1.
sublora_dict: Dict[int, List[LoRALayerWeights]] = dict()
for slot_idx, lora_id in enumerate(id_to_index):
if lora_id is not None:
subloras = []
sublora_len = layer_weights.shape[0] // repeats
for i in range(repeats):
sublora = DummyLoRAManager().init_random_lora(
module_name=f"fake_{i}",
weight=layer_weights,
generate_embeddings_tensor=generate_embeddings_tensor,
)
sublora.lora_b = sublora.lora_b[:, (sublora_len *
i):(sublora_len * (i + 1))]
sublora.optimize()
subloras.append(sublora)
lora = PackedLoRALayerWeights.pack(
subloras) if repeats > 1 else subloras[0]
layer.set_lora(
slot_idx,
lora_a=lora.lora_a,
lora_b=lora.lora_b,
embeddings_tensor=lora.embeddings_tensor,
)
lora_dict[lora_id] = lora
sublora_dict[lora_id] = subloras
return lora_dict, sublora_dict
def create_random_inputs(
active_lora_ids: List[int],
num_inputs: int,
input_size: Tuple[int, ...],
input_range: Tuple[float, float],
input_type: torch.dtype = torch.int,
) -> Tuple[List[torch.Tensor], List[int], List[int]]:
"""Creates random inputs.
Args:
active_lora_ids: lora IDs of active lora weights.
num_inputs: the number of inputs to create.
input_size: the size of each individual input.
input_range: the range of values to include in the input.
input_range[0] <= possible input values < input_range[1]
input_type: the type of values in the input.
"""
low, high = input_range
inputs, index_mapping, prompt_mapping = [], [], []
for _ in range(num_inputs):
if input_type == torch.int:
inputs.append(
torch.randint(low=int(low),
high=int(high),
size=input_size,
device="cuda"))
else:
inputs.append(
torch.rand(size=input_size, dtype=input_type, device="cuda") *
high + low)
lora_id = random.choice(active_lora_ids)
index_mapping += [lora_id] * input_size[0]
prompt_mapping += [lora_id]
return inputs, index_mapping, prompt_mapping
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
def test_embeddings(dist_init, num_loras) -> None:
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
def create_random_embedding_layer():
embedding = VocabParallelEmbedding(512, 256)
embedding.weight.data = torch.rand_like(embedding.weight.data)
embedding.weight.data[512:, :] = 0
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
lora_embedding.create_lora_weights(max_loras, lora_config)
return embedding, lora_embedding
for i in range(10):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
embedding, lora_embedding = create_random_embedding_layer()
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_embedding,
layer_weights=embedding.weight.T,
)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, 512),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info)
lora_result = lora_embedding(torch.cat(inputs))
expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = embedding(input_)
after_a = F.embedding(
input_,
lora.lora_a,
)
result += (after_a @ lora.lora_b)
expected_results.append(result)
expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
# Check that resetting the lora weights succeeds
for slot_idx in range(max_loras):
lora_embedding.reset_lora(slot_idx)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, 512),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
lora_result = lora_embedding(torch.cat(inputs))
expected_result = embedding(torch.cat(inputs))
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode()
# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None:
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
def create_random_embedding_layer():
embedding = VocabParallelEmbedding(512, 256)
embedding_data = torch.rand_like(embedding.weight.data)
embedding.weight.data = embedding_data
embedding.weight.data[512:, :] = 0
expanded_embedding = VocabParallelEmbedding(
512 + lora_config.lora_extra_vocab_size * max_loras,
256,
org_num_embeddings=512)
expanded_embedding.weight.data[:512, :] = embedding_data
# We need to deepcopy the embedding as it will be modifed
# in place
lora_embedding = VocabParallelEmbeddingWithLoRA(
deepcopy(expanded_embedding))
lora_embedding.create_lora_weights(max_loras, lora_config)
return expanded_embedding, lora_embedding
for i in range(10):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
expanded_embedding, lora_embedding = create_random_embedding_layer()
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_embedding,
layer_weights=torch.zeros(
(256, 512 + lora_config.lora_extra_vocab_size)),
generate_embeddings_tensor=256,
)
# All embeddings tensors have the same shape.
embeddings_tensors = [
lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
]
embeddings_tensor_len = embeddings_tensors[0].shape[0]
# Add empty embeddings_tensors for unoccupied lora slots.
for _ in range(max_loras - len(embeddings_tensors)):
embeddings_tensors.append(
torch.zeros(embeddings_tensors[0].shape, device="cuda"))
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, 512),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
original_inputs = deepcopy(inputs)
# Force some of the inputs to be in the extended embeddings range
# to guarantee that their behavior is tested.
for input_, original_input_, lora_id in zip(inputs, original_inputs,
prompt_mapping):
embedding_id = lora_id - 1
input_[-1] = 512 + (embedding_id * embeddings_tensor_len)
original_input_[-1] = 512
input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1)
original_input_[-2] = 512 + embeddings_tensor_len - 1
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
expanded_embedding.weight[512:512 +
(embeddings_tensor_len *
max_loras)] = torch.cat(embeddings_tensors)
lora_result = lora_embedding(torch.cat(original_inputs))
expected_results = []
for input_, original_input_, lora_id in zip(inputs, original_inputs,
prompt_mapping):
lora = lora_dict[lora_id]
result = expanded_embedding(input_)
after_a = F.embedding(
original_input_,
lora.lora_a,
)
result += (after_a @ lora.lora_b)
expected_results.append(result)
expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
# Check that resetting the lora weights succeeds
for slot_idx in range(max_loras):
lora_embedding.reset_lora(slot_idx)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, 512),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
original_inputs = deepcopy(inputs)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
lora_result = lora_embedding(torch.cat(original_inputs))
expected_result = expanded_embedding(torch.cat(inputs))
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
def test_lm_head_sampler(dist_init, num_loras) -> None:
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
def create_random_sampler_layer():
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
1024, 32000)
linear.weight.data = torch.rand_like(linear.weight.data)
linear.weight.data[:, 32000:] = 0
sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000)
lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype,
linear.weight.device)
lora_sampler.create_lora_weights(max_loras, lora_config)
return linear, sampler, lora_sampler
for i in range(10):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, sampler, lora_sampler = create_random_sampler_layer()
# NOTE: all the generated loras share the same embeddings tensor.
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_sampler,
layer_weights=linear.weight,
generate_embeddings_tensor=1024,
)
embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
embeddings_tensor_len = embeddings_tensor.shape[0]
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=8 * num_loras, # * 3,
input_size=(1, 1024),
input_range=(0, 1),
input_type=torch.float32,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
input_ = torch.rand(20, 1024, device="cuda")
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
max_loras,
32000,
lora_config.lora_extra_vocab_size,
)
lora_sampler.set_mapping(*mapping_info, )
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
embedding=linear.weight,
embedding_bias=None)
original_weight = linear.weight.clone()
linear.weight[sampler.org_vocab_size:sampler.org_vocab_size +
embeddings_tensor_len] = embeddings_tensor
sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size
expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = sampler._get_logits(hidden_states=input_,
embedding=linear.weight,
embedding_bias=None)
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
sampler.org_vocab_size = 32000
# Check that resetting the lora weights succeeds
for slot_idx in range(max_loras):
lora_sampler.reset_lora(slot_idx)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=8 * num_loras * 3,
input_size=(1, 1024),
input_range=(0, 1),
input_type=torch.float32,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
32000,
lora_config.lora_extra_vocab_size)
lora_sampler.set_mapping(*mapping_info, )
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)[:, :32000]
expected_result = sampler._get_logits(hidden_states=torch.cat(inputs),
embedding=original_weight,
embedding_bias=None)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
def test_linear_parallel(dist_init, num_loras, orientation) -> None:
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
def create_random_linear_parallel_layer():
if orientation == "row":
linear = RowParallelLinear(4096, 4096, bias=False)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = RowParallelLinearWithLoRA(linear)
else:
linear = ColumnParallelLinear(4096, 4096, bias=False)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ColumnParallelLinearWithLoRA(linear)
lora_linear.create_lora_weights(max_loras, lora_config)
return linear, lora_linear
for i in range(10):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_parallel_layer()
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_linear,
layer_weights=linear.weight,
)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float32,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
)
lora_linear.set_mapping(*mapping_info, )
lora_result = lora_linear(torch.cat(inputs))[0]
expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
# Check that resetting the lora weights succeeds
for slot_idx in range(max_loras):
lora_linear.reset_lora(slot_idx)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float32,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
lora_linear.set_mapping(*mapping_info, )
lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0]
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [2, 3])
def test_column_parallel_packed(dist_init, num_loras, repeats) -> None:
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
def create_column_parallel_packed_layer():
if repeats == 2:
linear = MergedColumnParallelLinear(4096, [4096] * repeats,
bias=False)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
else:
linear = QKVParallelLinear(4096, 64, 32, bias=False)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = QKVParallelLinearWithLora(linear)
@dataclass
class FakeConfig:
hidden_size = 4096
num_key_value_heads = 32
num_attention_heads = 32
lora_linear.create_lora_weights(max_loras,
lora_config,
model_config=FakeConfig())
return linear, lora_linear
for i in range(10):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_column_parallel_packed_layer()
lora_dict, sublora_dict = populate_loras(
id_to_index,
layer=lora_linear,
layer_weights=linear.weight,
repeats=repeats,
)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float32,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
)
lora_linear.set_mapping(*mapping_info)
lora_result = lora_linear(torch.cat(inputs))[0]
expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping):
result = linear(input_)[0]
subloras = sublora_dict[lora_id]
for i, sublora in enumerate(subloras):
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * (
i + 1
)] += input_ @ sublora.lora_a @ sublora.lora_b * sublora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
for slot_idx in range(max_loras):
lora_linear.reset_lora(slot_idx)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float32,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
)
lora_linear.set_mapping(*mapping_info)
lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0]
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)

144
tests/lora/test_llama.py Normal file
View File

@ -0,0 +1,144 @@
import pytest
import ray
import vllm
from vllm.lora.request import LoRARequest
from .conftest import cleanup
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
def do_sample(llm, lora_path: str, lora_id: int):
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]"
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("tp_size", [1])
def test_llama_lora(sql_lora_files, tp_size):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < tp_size:
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=tp_size)
expected_no_lora_output = [
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]",
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ",
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m",
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ",
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ",
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE",
]
expected_lora_output = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ",
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ",
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ",
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ",
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ",
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' "
]
print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output
print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output
print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output
print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output
print("removing lora")
@pytest.mark.skip("Requires multiple GPUs")
def test_llama_tensor_parallel_equality(sql_lora_files):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
llm_tp1 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1)
del llm_tp1
cleanup()
llm_tp2 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=2)
output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1)
del llm_tp2
cleanup()
assert output_tp1 == output_tp2
llm_tp4 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4)
output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1)
del llm_tp4
cleanup()
assert output_tp1 == output_tp4
def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and is more conservative"""
@ray.remote(num_gpus=1)
def get_num_gpu_blocks_lora():
llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16)
num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks
return num_gpu_blocks_lora_warmup
@ray.remote(num_gpus=1)
def get_num_gpu_blocks_no_lora():
llm = vllm.LLM(MODEL_PATH, max_num_seqs=16)
num_gpu_blocks_no_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks
return num_gpu_blocks_no_lora_warmup
num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote())
num_gpu_blocks_no_lora_warmup = ray.get(
get_num_gpu_blocks_no_lora.remote())
assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, (
"The warmup with lora should be more"
" conservative than without lora, therefore the number of memory blocks for the KV cache should be "
"less when using lora than when not using lora")

224
tests/lora/test_lora.py Normal file
View File

@ -0,0 +1,224 @@
import pytest
import torch
from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice
from .utils import DummyLoRAManager
TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4]
QKV_TENSOR_SIZES = [
(8192, 1024, 1024),
(8192 // 8, 1024 // 8, 1024 // 8),
(4096, 4096, 4096),
(4096 // 2, 4096 // 2, 4096 // 2),
]
BATCH_SIZES = [8, 32, 256]
RANKS = [8]
DTYPES = [torch.float16]
TOLERANCES = {
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
@pytest.mark.parametrize("m", TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora(m, n, k, rank, dtype) -> None:
manager = DummyLoRAManager()
module_name = "module"
weight = torch.rand([m, n], device="cuda", dtype=dtype)
manager.init_random_lora(module_name, weight, rank=rank)
lora = manager.get_module_lora(module_name)
input = torch.rand(k, n, device="cuda", dtype=dtype)
expected = input @ lora.lora_a @ lora.lora_b * lora.scaling
lora_a_stack = torch.zeros(8,
1,
lora.lora_a.shape[1],
lora.lora_a.shape[0],
device="cuda",
dtype=dtype)
lora_b_stack = torch.zeros(8,
1,
lora.lora_b.shape[1],
lora.lora_b.shape[0],
device="cuda",
dtype=dtype)
for i in range(lora_a_stack.shape[0]):
lora_a_stack[i][0] = lora.lora_a.T
lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T
output = torch.zeros(k, m, device="cuda", dtype=dtype)
_apply_lora(
input, lora_a_stack, lora_b_stack,
torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"),
output)
rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
output[:] = 0
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.full((len(input), ), -1, device="cuda"), output)
assert torch.allclose(torch.zeros_like(output), output)
manager.reset_lora()
@pytest.mark.parametrize("m", TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
if m % 2 != 0:
pytest.skip("m must be divisible by 2")
if m // 2 not in TENSOR_SIZES:
pytest.skip("m//2 must be in TENSOR_SIZES")
manager = DummyLoRAManager()
module_name = "module"
weight = torch.rand([m // 2, n], device="cuda", dtype=dtype)
manager.init_random_lora(module_name + "1", weight, rank=rank)
lora_1 = manager.get_module_lora(module_name + "1")
manager.init_random_lora(module_name + "2", weight, rank=rank)
lora_2 = manager.get_module_lora(module_name + "2")
input = torch.rand(k, n, device="cuda", dtype=dtype)
expected = torch.cat([
input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling,
input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling
],
dim=1)
lora_a_stacks = [
torch.zeros(8,
1,
lora_1.lora_a.shape[1],
lora_1.lora_a.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(8,
1,
lora_1.lora_b.shape[1],
lora_1.lora_b.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_1.lora_a.T
lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T
lora_a_stacks[1][i][0] = lora_2.lora_a.T
lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T
output = torch.zeros(k, m, device="cuda", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="cuda"), output, (m // 2, m // 2))
rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="cuda"),
output, (m // 2, m // 2))
assert torch.allclose(torch.zeros_like(output), output)
manager.reset_lora()
@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES)
@pytest.mark.parametrize("n", TENSOR_SIZES)
@pytest.mark.parametrize("k", BATCH_SIZES)
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
manager = DummyLoRAManager()
module_name = "module"
weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype)
weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype)
manager.init_random_lora(module_name + "q", weight_q, rank=rank)
lora_q = manager.get_module_lora(module_name + "q")
manager.init_random_lora(module_name + "k", weight_kv, rank=rank)
lora_k = manager.get_module_lora(module_name + "k")
manager.init_random_lora(module_name + "v", weight_kv, rank=rank)
lora_v = manager.get_module_lora(module_name + "v")
input = torch.rand(k, n, device="cuda", dtype=dtype)
expected = torch.cat([
input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling,
input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling,
input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling
],
dim=1)
lora_a_stacks = [
torch.zeros(8,
1,
lora_q.lora_a.shape[1],
lora_q.lora_a.shape[0],
device="cuda",
dtype=dtype)
] + [
torch.zeros(8,
1,
lora_k.lora_a.shape[1],
lora_k.lora_a.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(8,
1,
lora_q.lora_b.shape[1],
lora_q.lora_b.shape[0],
device="cuda",
dtype=dtype)
] + [
torch.zeros(8,
1,
lora_k.lora_b.shape[1],
lora_k.lora_b.shape[0],
device="cuda",
dtype=dtype) for i in range(2)
]
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_q.lora_a.T
lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T
lora_a_stacks[1][i][0] = lora_k.lora_a.T
lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T
lora_a_stacks[2][i][0] = lora_v.lora_a.T
lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T
output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="cuda"), output, (qkv[0], qkv[1], qkv[2]))
rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="cuda"),
output, (qkv[0], qkv[1], qkv[2]))
assert torch.allclose(torch.zeros_like(output), output)
manager.reset_lora()

View File

@ -0,0 +1,475 @@
import os
from typing import List
import pytest
import torch
from safetensors.torch import load_file
from torch import nn
from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, LoRAMapping)
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear
def test_from_lora_tensors(sql_lora_files):
tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors"))
new_embeddings = load_file(
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
lora_model = LoRAModel.from_lora_tensors(1,
8,
16,
tensors,
"cuda",
embeddings=new_embeddings)
for module_name, lora in lora_model.loras.items():
assert lora.module_name == module_name
assert lora.rank == 8
assert lora.lora_alpha == 16
assert lora.lora_a is not None
assert lora.lora_b is not None
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
assert lora.lora_a.shape[1] == 8
embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name), None)
if embeddings_module:
assert torch.equal(
lora.embeddings_tensor,
new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
device=lora.embeddings_tensor.device))
else:
assert lora.embeddings_tensor is None
def create_lora(lora_id: int, model: nn.Module,
sub_modules: List[str]) -> LoRAModel:
loras = {}
for name in sub_modules:
w = model.get_submodule(name).weight
loras[name] = LoRALayerWeights(
name,
8,
16,
torch.rand([w.shape[1], 8], device="cuda"),
torch.rand([8, w.shape[0]], device="cuda"),
)
return LoRAModel(lora_id, 8, loras)
def create_packed_lora(
lora_id: int,
model: nn.Module,
module_name,
replaced_module_names,
empty_replaced_module_name=None,
) -> LoRAModel:
w = model.get_submodule(module_name).weight
loras = {}
for replaced_module_name in replaced_module_names:
if replaced_module_name == empty_replaced_module_name:
continue
loras[replaced_module_name] = LoRALayerWeights(
replaced_module_name,
8,
16,
torch.rand([w.shape[1], 8], device="cuda"),
torch.rand([8, w.shape[0] // len(replaced_module_names)],
device="cuda"),
)
return LoRAModel(lora_id, 8, loras)
def test_replace_submodules(dist_init, dummy_model):
model = dummy_model
manager = LoRAModelManager(model,
1,
1,
1,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=8,
max_loras=8),
lora_target_modules=["dense1", "layer1.dense2"])
model = manager.model
assert isinstance(model.get_submodule("dense1"),
ColumnParallelLinearWithLoRA)
assert isinstance(model.get_submodule("layer1.dense1"),
ColumnParallelLinearWithLoRA)
assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
assert isinstance(model.get_submodule("layer1.dense2"),
RowParallelLinearWithLoRA)
def test_lora_model_manager(dist_init, dummy_model):
model = dummy_model
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
manager = LoRAModelManager(
model,
2,
2,
2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
lora_target_modules=["dense1", "dense2", "lm_head"])
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_lora(model_lora1)
assert manager.activate_lora(1)
assert manager.lora_index_to_id[0] == 1
assert not manager.add_lora(model_lora1)
assert not manager.activate_lora(1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert not manager.add_lora(model_lora2)
assert not manager.activate_lora(2)
assert manager.add_lora(model_lora3)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
with pytest.raises(ValueError):
assert manager.activate_lora(3)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert manager.remove_lora(model_lora2.id)
assert manager.lora_index_to_id[1] is None
assert not manager.remove_lora(model_lora2.id)
assert manager.remove_lora(model_lora1.id)
assert not manager.remove_lora(model_lora1.id)
assert manager.add_lora(model_lora1)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] is None
assert manager.add_lora(model_lora2)
assert manager.activate_lora(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] is None
assert manager.activate_lora(2)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
def test_lora_lru_cache_model_manager(dist_init, dummy_model):
model = dummy_model
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
manager = LRUCacheLoRAModelManager(
model,
2,
2,
2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
lora_target_modules=["dense1", "dense2", "lm_head"])
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_lora(model_lora1)
assert manager.activate_lora(1)
assert manager.lora_index_to_id[0] == 1
assert not manager.add_lora(model_lora1)
assert not manager.activate_lora(1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert not manager.add_lora(model_lora2)
assert not manager.activate_lora(2)
assert manager.add_lora(model_lora3)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert manager.activate_lora(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
assert manager.remove_lora(model_lora2.id)
assert manager.lora_index_to_id[1] is None
assert not manager.remove_lora(model_lora2.id)
assert manager.remove_lora(model_lora1.id)
assert not manager.remove_lora(model_lora1.id)
assert manager.add_lora(model_lora1)
assert manager.activate_lora(1)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
assert manager.add_lora(model_lora2)
assert manager.deactivate_lora(3)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 1
assert manager.activate_lora(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.activate_lora(3)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
def test_lru_lora_model_manager(dist_init, dummy_model):
# This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager
model = dummy_model
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
manager = LRUCacheLoRAModelManager(
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
["dense1", "dense2", "lm_head"])
assert all(x is None for x in manager.lora_index_to_id)
# Add up to capacity
assert manager.add_lora(model_lora1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(1)
assert manager.activate_lora(2)
assert set(manager.list_loras()) == {1, 2}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
# Add over capacity
assert manager.add_lora(model_lora3)
assert manager.add_lora(model_lora4)
assert manager.activate_lora(3)
assert manager.activate_lora(4)
assert set(manager.list_loras()) == {3, 4}
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 4
# Add 3 again to move it to the top and then add 2
# should return false since it's in already
assert not manager.add_lora(model_lora3)
assert not manager.activate_lora(3)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)
assert set(manager.list_loras()) == {3, 2}
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
# Remove manually
assert manager.remove_lora(3)
assert not manager.remove_lora(3)
assert set(manager.list_loras()) == {2}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 2
assert manager.add_lora(model_lora3)
assert manager.activate_lora(3)
assert manager.add_lora(model_lora4)
assert manager.activate_lora(4)
assert set(manager.list_loras()) == {3, 4}
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 4
assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == {4}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 4
assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == set()
assert all(x is None for x in manager.lora_index_to_id)
assert not manager.remove_oldest_lora()
assert set(manager.list_loras()) == set()
assert all(x is None for x in manager.lora_index_to_id)
def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_lora_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
torch.device("cuda"))
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
mapping = LoRAMapping([], [])
worker_lora_manager.set_active_loras([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
worker_lora_manager.set_active_loras([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("3", 3, sql_lora_files),
LoRARequest("4", 4, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2, 3, 4}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
worker_lora_manager.set_active_loras([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files),
LoRARequest("5", 5, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
worker_lora_manager.set_active_loras([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
worker_lora_manager.set_active_loras([
LoRARequest("6", 6, sql_lora_files),
LoRARequest("7", 7, sql_lora_files),
LoRARequest("8", 8, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 6, 7, 8}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6
# Over capacity
with pytest.raises(RuntimeError):
worker_lora_manager.set_active_loras([
LoRARequest("10", 10, sql_lora_files),
LoRARequest("11", 11, sql_lora_files),
LoRARequest("12", 12, sql_lora_files),
LoRARequest("13", 13, sql_lora_files),
LoRARequest("14", 14, sql_lora_files)
], mapping)
def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_lora_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
torch.device("cuda"))
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
mapping = LoRAMapping([], [])
worker_lora_manager.set_active_loras([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
worker_lora_manager.set_active_loras([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("3", 3, sql_lora_files),
LoRARequest("4", 4, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 3, 4}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4
worker_lora_manager.set_active_loras([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files),
LoRARequest("5", 5, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2, 5}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
worker_lora_manager.set_active_loras([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None
assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None
worker_lora_manager.set_active_loras([
LoRARequest("6", 6, sql_lora_files),
LoRARequest("7", 7, sql_lora_files),
LoRARequest("8", 8, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {6, 7, 8}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7
# Over capacity
with pytest.raises(RuntimeError):
worker_lora_manager.set_active_loras([
LoRARequest("10", 10, sql_lora_files),
LoRARequest("11", 11, sql_lora_files),
LoRARequest("12", 12, sql_lora_files),
LoRARequest("13", 13, sql_lora_files),
LoRARequest("14", 14, sql_lora_files)
], mapping)
def test_packed_loras(dist_init, dummy_model_gate_up):
model = dummy_model_gate_up
model_lora = create_packed_lora(
1,
model,
module_name="gate_up_proj",
replaced_module_names=["gate_proj", "up_proj"])
model_lora1 = create_packed_lora(
2,
model,
module_name="gate_up_proj",
replaced_module_names=["gate_proj", "up_proj"],
empty_replaced_module_name="gate_proj",
)
manager = LoRAModelManager(
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
["gate_up_proj"])
model = manager.model
assert isinstance(model.get_submodule("gate_up_proj"),
MergedColumnParallelLinearWithLoRA)
assert manager.add_lora(model_lora)
assert manager.add_lora(model_lora1)
packed_lora = model_lora.get_lora("gate_up_proj")
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
assert torch.allclose(packed_lora.lora_a[0],
model_lora.get_lora("gate_proj").lora_a)
assert torch.allclose(packed_lora.lora_b[0],
model_lora.get_lora("gate_proj").lora_b)
assert torch.allclose(packed_lora.lora_a[1],
model_lora.get_lora("up_proj").lora_a)
assert torch.allclose(packed_lora.lora_b[1],
model_lora.get_lora("up_proj").lora_b)
packed_lora1 = model_lora1.get_lora("gate_up_proj")
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
assert packed_lora1.lora_a[0] is None
assert packed_lora1.lora_b[0] is None
assert torch.allclose(packed_lora1.lora_a[1],
model_lora1.get_lora("up_proj").lora_a)
assert torch.allclose(packed_lora1.lora_b[1],
model_lora1.get_lora("up_proj").lora_b)

175
tests/lora/test_punica.py Normal file
View File

@ -0,0 +1,175 @@
# Based on code from https://github.com/punica-ai/punica
import pytest
import torch
import vllm.lora.punica as punica
def assert_close(a, b):
rtol, atol = {
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
torch.float32: (None, None),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
def _lora_ref_impl(
y_final: torch.Tensor,
x: torch.Tensor,
wa_T_all: torch.Tensor,
wb_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
):
y_stage_1 = torch.empty(
(x.size(0), wa_T_all.size(-2)),
dtype=torch.float32,
device=x.device,
)
bs = x.shape[0]
s = torch.tensor(scale, dtype=torch.float32, device=x.device)
for i, lora_idx in zip(range(bs), indicies.cpu().tolist()):
xi = x[i].unsqueeze(0).to(torch.float32)
wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
tmp = xi @ wa
y_stage_1[i] = tmp.squeeze(0)
y_final[i] += (tmp @ wb).squeeze(0) * s
return y_final, y_stage_1
H1 = H2 = [
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
5504, 5632, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000,
32256, 32512, 32768, 33024
]
SEED = [0xabcdabcd987]
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2)
@pytest.mark.parametrize("seed", SEED)
@torch.inference_mode()
def test_lora_correctness(dtype_str, h1, h2, seed):
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
r = 8
bs = 32
scale = 0.123
dtype = getattr(torch, dtype_str)
device = torch.device("cuda")
wa_T_all = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wb_T_all = torch.randn(num_loras,
num_layers,
h2,
r,
dtype=dtype,
device=device)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype, device=device)
y = torch.randn(bs, h2, dtype=dtype, device=device)
y_ref = y.clone()
_lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale)
y_our = y.clone()
punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx,
scale)
assert_close(y_ref, y_our)
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2)
@pytest.mark.parametrize("seed", SEED)
@torch.inference_mode()
def test_lora_correctness_slice(dtype_str, h1, h2, seed):
if h2 % 3 != 0 or h2 // 3 not in H1:
pytest.skip("h2 must be divisible by 3 and in supported shapes")
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
r = 8
bs = 32
scale = 0.123
dtype = getattr(torch, dtype_str)
device = torch.device("cuda")
wa_T_all_0 = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wa_T_all_1 = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wa_T_all_2 = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wb_T_all_0 = torch.randn(num_loras,
num_layers,
h2 // 3,
r,
dtype=dtype,
device=device)
wb_T_all_1 = torch.randn(num_loras,
num_layers,
h2 // 3,
r,
dtype=dtype,
device=device)
wb_T_all_2 = torch.randn(num_loras,
num_layers,
h2 // 3,
r,
dtype=dtype,
device=device)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype, device=device)
y = torch.randn(bs, h2, dtype=dtype, device=device)
s = h2 // 3
y_ref = y.clone()
_lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices,
layer_idx, scale)
_lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices,
layer_idx, scale)
_lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices,
layer_idx, scale)
y_our = y.clone()
punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices,
layer_idx, scale, 0, s)
punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices,
layer_idx, scale, s, s)
punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices,
layer_idx, scale, s * 2, s)
assert_close(y_ref[:, :s], y_our[:, :s])
assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2])
assert_close(y_ref[:, s * 2:], y_our[:, s * 2:])

View File

@ -0,0 +1,69 @@
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import TokenizerGroup, get_lora_tokenizer
@pytest.mark.asyncio
async def test_transformers_tokenizer():
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = TokenizerGroup(
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
)
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
request_id="request_id", prompt="prompt", lora_request=None)
assert reference_tokenizer.encode(
"prompt") == await tokenizer.encode_async(request_id="request_id",
prompt="prompt",
lora_request=None)
assert isinstance(tokenizer.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer.get_lora_tokenizer(
None) == await tokenizer.get_lora_tokenizer_async(None)
@pytest.mark.asyncio
async def test_transformers_tokenizer_lora(sql_lora_files):
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
tokenizer = TokenizerGroup(
tokenizer_id="gpt2",
enable_lora=True,
max_num_seqs=1,
max_input_length=None,
)
lora_request = LoRARequest("1", 1, sql_lora_files)
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
request_id="request_id", prompt="prompt", lora_request=lora_request)
assert reference_tokenizer.encode(
"prompt") == await tokenizer.encode_async(request_id="request_id",
prompt="prompt",
lora_request=lora_request)
assert isinstance(tokenizer.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer.get_lora_tokenizer(
None) == await tokenizer.get_lora_tokenizer_async(None)
assert isinstance(tokenizer.get_lora_tokenizer(lora_request),
PreTrainedTokenizerBase)
assert tokenizer.get_lora_tokenizer(
lora_request) != tokenizer.get_lora_tokenizer(None)
assert tokenizer.get_lora_tokenizer(
lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request)
def test_get_lora_tokenizer(sql_lora_files, tmpdir):
lora_request = None
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer
lora_request = LoRARequest("1", 1, sql_lora_files)
tokenizer = get_lora_tokenizer(lora_request)
assert tokenizer.get_added_vocab()
lora_request = LoRARequest("1", 1, str(tmpdir))
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer

172
tests/lora/test_utils.py Normal file
View File

@ -0,0 +1,172 @@
from collections import OrderedDict
from torch import nn
from vllm.utils import LRUCache
from vllm.lora.utils import (parse_fine_tuned_lora_name, replace_submodule)
def test_parse_fine_tuned_lora_name():
fixture = {
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
(
"base_model.model.model.embed_tokens.lora_embedding_A",
"model.embed_tokens",
True,
),
(
"base_model.model.model.embed_tokens.lora_embedding_B",
"model.embed_tokens",
False,
),
(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"model.layers.9.mlp.down_proj",
True,
),
(
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"model.layers.9.mlp.down_proj",
False,
),
}
for name, module_name, is_lora_a in fixture:
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
def test_replace_submodule():
model = nn.Sequential(
OrderedDict([
("dense1", nn.Linear(764, 100)),
("act1", nn.ReLU()),
("dense2", nn.Linear(100, 50)),
(
"seq1",
nn.Sequential(
OrderedDict([
("dense1", nn.Linear(100, 10)),
("dense2", nn.Linear(10, 50)),
])),
),
("act2", nn.ReLU()),
("output", nn.Linear(50, 10)),
("outact", nn.Sigmoid()),
]))
sigmoid = nn.Sigmoid()
replace_submodule(model, "act1", sigmoid)
assert dict(model.named_modules())["act1"] == sigmoid
dense2 = nn.Linear(1, 5)
replace_submodule(model, "seq1.dense2", dense2)
assert dict(model.named_modules())["seq1.dense2"] == dense2
class TestLRUCache(LRUCache):
def _on_remove(self, key, value):
if not hasattr(self, "_remove_counter"):
self._remove_counter = 0
self._remove_counter += 1
def test_lru_cache():
cache = TestLRUCache(3)
cache.put(1, 1)
assert len(cache) == 1
cache.put(1, 1)
assert len(cache) == 1
cache.put(2, 2)
assert len(cache) == 2
cache.put(3, 3)
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}
cache.put(4, 4)
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache.get(2) == 2
cache.put(5, 5)
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2
assert cache.pop(5) == 5
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.get(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.put(6, 6)
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache
cache.remove_oldest()
assert len(cache) == 2
assert set(cache.cache) == {2, 6}
assert cache._remove_counter == 4
cache.clear()
assert len(cache) == 0
assert cache._remove_counter == 6
cache._remove_counter = 0
cache[1] = 1
assert len(cache) == 1
cache[1] = 1
assert len(cache) == 1
cache[2] = 2
assert len(cache) == 2
cache[3] = 3
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}
cache[4] = 4
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache[2] == 2
cache[5] = 5
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2
del cache[5]
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3
cache[6] = 6
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache

61
tests/lora/test_worker.py Normal file
View File

@ -0,0 +1,61 @@
import os
import random
import tempfile
from unittest.mock import patch
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig
from vllm.worker.worker import Worker
@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
worker = Worker(
model_config=ModelConfig(
"meta-llama/Llama-2-7b-hf",
"meta-llama/Llama-2-7b-hf",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
),
parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32, 256),
local_rank=0,
rank=0,
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
max_loras=32),
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
)
worker.init_model()
worker.load_model()
worker.model_runner.set_active_loras([], LoRAMapping([], []))
assert worker.list_loras() == set()
n_loras = 32
lora_requests = [
LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras)
]
worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], []))
assert worker.list_loras() == {
lora_request.lora_int_id
for lora_request in lora_requests
}
for i in range(32):
random.seed(i)
iter_lora_requests = random.choices(lora_requests,
k=random.randint(1, n_loras))
random.shuffle(iter_lora_requests)
iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)]
worker.model_runner.set_active_loras(iter_lora_requests,
LoRAMapping([], []))
assert worker.list_loras().issuperset(
{lora_request.lora_int_id
for lora_request in iter_lora_requests})

88
tests/lora/utils.py Normal file
View File

@ -0,0 +1,88 @@
from typing import List, Optional
import torch
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
class DummyLoRAManager:
def __init__(self):
super().__init__()
self._loras = {}
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
self._loras[module_name] = lora
def get_module_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
return self._loras.get(module_name, None)
def init_random_lora(self,
module_name: str,
weight: torch.Tensor,
rank: int = 8,
generate_embeddings_tensor: int = 0):
lora = LoRALayerWeights(
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([weight.shape[1], rank],
dtype=weight.dtype,
device="cuda"),
lora_b=torch.rand([rank, weight.shape[0]],
dtype=weight.dtype,
device="cuda"),
)
if generate_embeddings_tensor:
lora.embeddings_tensor = torch.rand(5,
generate_embeddings_tensor,
dtype=weight.dtype,
device="cuda")
self.set_module_lora(module_name, lora)
return lora
def init_lora(self,
module_name: str,
input_dim: int,
output_dim: int,
rank=8,
noop=False,
embeddings_tensor=None):
lora = LoRALayerWeights(
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([input_dim, rank], device="cuda"),
lora_b=torch.rand([rank, output_dim], device="cuda"),
embeddings_tensor=embeddings_tensor,
)
self.set_module_lora(module_name, lora)
return lora
def reset_lora(self):
self._loras = {}
def init_packed_lora(
self,
module_name: str,
input_dim: int,
output_dims: List[int],
noop_lora_index: List[int] = None,
rank=8,
):
base_loras = []
noop_lora_index = set(noop_lora_index or [])
for i, out_dim in enumerate(output_dims):
base_lora = self.init_lora(
module_name + "_000_" + str(i),
input_dim,
out_dim,
rank=rank,
noop=i in noop_lora_index,
)
base_loras.append(base_lora)
packed_lora = PackedLoRALayerWeights.pack(base_loras)
self.set_module_lora(module_name, packed_lora)
return packed_lora

View File

@ -19,9 +19,10 @@ class MockLogitsSampler(Sampler):
self.fake_logits = fake_logits self.fake_logits = fake_logits
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
with patch("vllm.model_executor.layers.sampler._prune_hidden_states", with patch(
"vllm.model_executor.layers.sampler._prune_hidden_states",
lambda x, y: x), patch( lambda x, y: x), patch(
"vllm.model_executor.layers.sampler._get_logits", "vllm.model_executor.layers.sampler.Sampler._get_logits",
lambda *args, **kwargs: self.fake_logits): lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
@ -38,7 +39,7 @@ def _prepare_test(
device=input_tensor.device, device=input_tensor.device,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits) sampler = MockLogitsSampler(32000, fake_logits)
model_runner = ModelRunner(None, None, None) model_runner = ModelRunner(None, None, None, None)
return input_tensor, fake_logits, sampler, model_runner return input_tensor, fake_logits, sampler, model_runner
@ -266,7 +267,7 @@ def test_sampler_top_k_top_p(seed: int):
device=input_tensor.device, device=input_tensor.device,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits) sampler = MockLogitsSampler(32000, fake_logits)
model_runner = ModelRunner(None, None, None) model_runner = ModelRunner(None, None, None, None)
generation_model = GenerationMixin() generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k, generation_config = GenerationConfig(top_k=top_k,

View File

@ -83,8 +83,8 @@ def create_worker(cls: type,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
) )
(model_config, cache_config, parallel_config, (model_config, cache_config, parallel_config, scheduler_config,
scheduler_config) = engine_args.create_engine_configs() _) = engine_args.create_engine_configs()
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port()) get_ip(), get_open_port())

View File

@ -6,7 +6,7 @@ from vllm.worker.model_runner import ModelRunner
def test_prepare_prompt(): def test_prepare_prompt():
model_runner = ModelRunner(None, None, None) model_runner = ModelRunner(None, None, None, None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
@ -33,7 +33,7 @@ def test_prepare_prompt():
expected_selected_token_indices.append(selected_token_start_idx + expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1) prompt_len - 1)
selected_token_start_idx += max_seq_len selected_token_start_idx += max_seq_len
input_tokens, input_positions, _, return_prompt_lens, _ = ( input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = (
model_runner._prepare_prompt(seq_group_metadata_list)) model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,

View File

@ -1,4 +1,5 @@
from typing import Optional, Union from typing import Optional, Union, ClassVar
from dataclasses import dataclass
import os import os
import torch import torch
@ -397,6 +398,54 @@ class SchedulerConfig:
f"({self.max_num_seqs}).") f"({self.max_num_seqs}).")
@dataclass
class LoRAConfig:
max_lora_rank: int
max_loras: int
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[torch.dtype] = None
lora_extra_vocab_size: int = 256
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
def __post_init__(self):
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
possible_max_ranks = (8, 16, 32, 64)
possible_lora_extra_vocab_size = (0, 256, 512)
if self.max_lora_rank not in possible_max_ranks:
raise ValueError(
f"max_lora_rank ({self.max_lora_rank}) must be one of "
f"{possible_max_ranks}.")
if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
raise ValueError(
f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
f"must be one of {possible_lora_extra_vocab_size}.")
if self.max_loras < 1:
raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
if self.max_cpu_loras is None:
self.max_cpu_loras = self.max_loras
elif self.max_cpu_loras < self.max_loras:
raise ValueError(
f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
f"max_num_seqs ({self.max_loras})")
def verify_with_model_config(self, model_config: ModelConfig):
if self.lora_dtype in (None, "auto"):
self.lora_dtype = model_config.dtype
elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype)
if model_config.quantization is not None:
raise ValueError(
"LoRA is not supported with quantized models yet.")
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528:
raise ValueError(
"Due to limitations of the custom LoRA CUDA kernel, "
"max_num_batched_tokens must be <= 65528 when "
"LoRA is enabled.")
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16, "half": torch.float16,
"float16": torch.float16, "float16": torch.float16,

View File

@ -1,11 +1,12 @@
from collections import deque from collections import deque
import enum import enum
import time import time
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set
from vllm.config import CacheConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.block_manager import AllocStatus, BlockSpaceManager
from vllm.core.policy import PolicyFactory from vllm.core.policy import PolicyFactory
from vllm.lora.request import LoRARequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus) SequenceGroupMetadata, SequenceStatus)
@ -49,11 +50,25 @@ class SchedulerOutputs:
assert not (blocks_to_swap_in and blocks_to_swap_out) assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups self.ignored_seq_groups = ignored_seq_groups
self.num_loras = len(self.lora_requests)
if self.num_loras > 0:
self._sort_by_lora_ids()
def is_empty(self) -> bool: def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups. # NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy) and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self) -> bool:
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.lora_request.lora_int_id
if g.lora_request else 0, g.request_id))
@property
def lora_requests(self) -> Set[LoRARequest]:
return {g.lora_request for g in self.scheduled_seq_groups}
class Scheduler: class Scheduler:
@ -61,9 +76,14 @@ class Scheduler:
self, self,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
# simple and NOT fair. It can lead to starvation of some
# LoRAs. This should be improved in the future.
self.lora_config = lora_config
self.prompt_limit = min(self.scheduler_config.max_model_len, self.prompt_limit = min(self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
@ -87,6 +107,10 @@ class Scheduler:
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
self.swapped: Deque[SequenceGroup] = deque() self.swapped: Deque[SequenceGroup] = deque()
@property
def lora_enabled(self) -> bool:
return bool(self.lora_config)
def add_seq_group(self, seq_group: SequenceGroup) -> None: def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue. # Add sequence groups to the waiting queue.
self.waiting.append(seq_group) self.waiting.append(seq_group)
@ -150,14 +174,17 @@ class Scheduler:
# requests in the generation phase. # requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs() num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running) for seq_group in self.running)
curr_loras = set(
seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None
seq_lens: List[int] = [] seq_lens: List[int] = []
# Optimization: We do not sort the waiting queue since the preempted # Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups # sequence groups are added to the front and the new sequence groups
# are added to the back. # are added to the back.
leftover_waiting_sequences = deque()
while self.waiting: while self.waiting:
seq_group = self.waiting[0] seq_group = self.waiting[0]
waiting_seqs = seq_group.get_seqs( waiting_seqs = seq_group.get_seqs(
status=SequenceStatus.WAITING) status=SequenceStatus.WAITING)
assert len(waiting_seqs) == 1, ( assert len(waiting_seqs) == 1, (
@ -188,6 +215,17 @@ class Scheduler:
self.waiting.popleft() self.waiting.popleft()
continue continue
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if lora_int_id > 0 and lora_int_id not in curr_loras and len(
curr_loras) >= self.lora_config.max_loras:
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_waiting_sequences.appendleft(seq_group)
self.waiting.popleft()
continue
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
new_seq_lens = seq_lens + [num_prompt_tokens] new_seq_lens = seq_lens + [num_prompt_tokens]
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
@ -207,12 +245,16 @@ class Scheduler:
break break
seq_lens = new_seq_lens seq_lens = new_seq_lens
seq_group = self.waiting.popleft() if lora_int_id > 0:
curr_loras.add(lora_int_id)
self.waiting.popleft()
self._allocate(seq_group) self._allocate(seq_group)
self.running.append(seq_group) self.running.append(seq_group)
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
scheduled.append(seq_group) scheduled.append(seq_group)
self.waiting.extendleft(leftover_waiting_sequences)
if scheduled or ignored_seq_groups: if scheduled or ignored_seq_groups:
scheduler_outputs = SchedulerOutputs( scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled, scheduled_seq_groups=scheduled,
@ -260,9 +302,25 @@ class Scheduler:
if not preempted: if not preempted:
num_curr_seqs = sum(seq_group.get_max_num_running_seqs() num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running) for seq_group in self.running)
curr_loras = set(
seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None
leftover_swapped = deque()
while self.swapped: while self.swapped:
seq_group = self.swapped[0] seq_group = self.swapped[0]
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if lora_int_id > 0 and lora_int_id not in curr_loras and len(
curr_loras) >= self.lora_config.max_loras:
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_swapped.appendleft(seq_group)
self.swapped.popleft()
continue
# If the sequence group cannot be swapped in, stop. # If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group): if not self.block_manager.can_swap_in(seq_group):
break break
@ -274,12 +332,16 @@ class Scheduler:
self.scheduler_config.max_num_seqs): self.scheduler_config.max_num_seqs):
break break
seq_group = self.swapped.popleft() if lora_int_id > 0:
curr_loras.add(lora_int_id)
self.swapped.popleft()
self._swap_in(seq_group, blocks_to_swap_in) self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy) self._append_slot(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
self.running.append(seq_group) self.running.append(seq_group)
self.swapped.extendleft(leftover_swapped)
# Each sequence in the generation phase only takes one token slot. # Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of # Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state. # sequences in the RUNNING state.
@ -320,6 +382,7 @@ class Scheduler:
seq_data=seq_data, seq_data=seq_data,
sampling_params=seq_group.sampling_params, sampling_params=seq_group.sampling_params,
block_tables=block_tables, block_tables=block_tables,
lora_request=seq_group.lora_request,
prefix=seq_group.prefix, prefix=seq_group.prefix,
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig, LoRAConfig)
@dataclass @dataclass
@ -35,6 +35,12 @@ class EngineArgs:
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: bool = False enforce_eager: bool = False
max_context_len_to_capture: int = 8192 max_context_len_to_capture: int = 8192
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
@ -202,6 +208,39 @@ class EngineArgs:
help='maximum context length covered by CUDA ' help='maximum context length covered by CUDA '
'graphs. When a sequence has context length ' 'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.') 'larger than this, we fall back to eager mode.')
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--max-loras',
type=int,
default=EngineArgs.max_loras,
help='Max number of LoRAs in a single batch.')
parser.add_argument('--max-lora-rank',
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument(
'--lora-dtype',
type=str,
default=EngineArgs.lora_dtype,
choices=['auto', 'float16', 'bfloat16', 'float32'],
help=('Data type for LoRA. If auto, will default to '
'base model dtype.'))
parser.add_argument(
'--max-cpu-loras',
type=int,
default=EngineArgs.max_cpu_loras,
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'))
return parser return parser
@classmethod @classmethod
@ -214,7 +253,8 @@ class EngineArgs:
def create_engine_configs( def create_engine_configs(
self, self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
Optional[LoRAConfig]]:
model_config = ModelConfig(self.model, self.tokenizer, model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code, self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format, self.download_dir, self.load_format,
@ -234,7 +274,14 @@ class EngineArgs:
self.max_num_seqs, self.max_num_seqs,
model_config.max_model_len, model_config.max_model_len,
self.max_paddings) self.max_paddings)
return model_config, cache_config, parallel_config, scheduler_config lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
return model_config, cache_config, parallel_config, scheduler_config, lora_config
@dataclass @dataclass

View File

@ -4,6 +4,7 @@ from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator) Union, AsyncIterator)
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
@ -203,6 +204,52 @@ class _AsyncLLMEngine(LLMEngine):
return self._process_model_outputs(output, scheduler_outputs) return self._process_model_outputs(output, scheduler_outputs)
async def encode_request_async(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = await self.tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
async def add_request_async(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = await self.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
return self.add_request(
request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
async def _run_workers_async( async def _run_workers_async(
self, self,
method: str, method: str,
@ -332,7 +379,7 @@ class AsyncLLMEngine:
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.add_request.remote(**new_request) await self.engine.add_request.remote(**new_request)
else: else:
self.engine.add_request(**new_request) await self.engine.add_request_async(**new_request)
if finished_requests: if finished_requests:
await self._engine_abort(finished_requests) await self._engine_abort(finished_requests)
@ -371,6 +418,7 @@ class AsyncLLMEngine:
sampling_params: SamplingParams, sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None, prefix_pos: Optional[int] = None,
) -> AsyncStream: ) -> AsyncStream:
if self.log_requests: if self.log_requests:
@ -386,7 +434,8 @@ class AsyncLLMEngine:
f"prompt: {shortened_prompt!r}, " f"prompt: {shortened_prompt!r}, "
f"prefix_pos: {prefix_pos}," f"prefix_pos: {prefix_pos},"
f"sampling params: {sampling_params}, " f"sampling params: {sampling_params}, "
f"prompt token ids: {shortened_token_ids}.") f"prompt token ids: {shortened_token_ids}, "
f"lora_request: {lora_request}.")
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
@ -398,12 +447,21 @@ class AsyncLLMEngine:
"error that caused the background loop to stop " "error that caused the background loop to stop "
"(AsyncEngineDeadError).") "(AsyncEngineDeadError).")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = await self.engine.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
stream = self._request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
prompt=prompt, prompt=prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos) prefix_pos=prefix_pos)
return stream return stream
@ -414,6 +472,7 @@ class AsyncLLMEngine:
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None, prefix_pos: Optional[int] = None,
) -> AsyncIterator[RequestOutput]: ) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request. """Generate outputs for a request.
@ -429,6 +488,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request. request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
prefix_pos: If not None, we use the given position as the prefix prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix. cache and reuse it for the next request with the same prefix.
@ -487,12 +547,15 @@ class AsyncLLMEngine:
arrival_time = time.monotonic() arrival_time = time.monotonic()
try: try:
stream = await self.add_request(request_id, stream = await self.add_request(
request_id,
prompt, prompt,
sampling_params, sampling_params,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time, arrival_time=arrival_time,
prefix_pos=prefix_pos) lora_request=lora_request,
prefix_pos=prefix_pos,
)
async for request_output in stream: async for request_output in stream:
yield request_output yield request_output

View File

@ -5,8 +5,9 @@ import time
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union) Union)
from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig, LoRAConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import record_metrics from vllm.engine.metrics import record_metrics
@ -17,7 +18,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus) SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer) TokenizerGroup)
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
if ray: if ray:
@ -64,6 +65,7 @@ class LLMEngine:
cache_config: CacheConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
placement_group: Optional["PlacementGroup"], placement_group: Optional["PlacementGroup"],
log_stats: bool, log_stats: bool,
) -> None: ) -> None:
@ -87,17 +89,13 @@ class LLMEngine:
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.log_stats = log_stats self.log_stats = log_stats
self._verify_args() self._verify_args()
self.tokenizer = get_tokenizer( self._init_tokenizer()
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
tokenizer_revision=model_config.tokenizer_revision,
revision=model_config.revision)
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. # Create the parallel GPU workers.
@ -114,7 +112,7 @@ class LLMEngine:
self._init_cache() self._init_cache()
# Create the scheduler. # Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config) self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Logging. # Logging.
self.last_logging_time = 0.0 self.last_logging_time = 0.0
@ -123,6 +121,9 @@ class LLMEngine:
# List of (timestamp, num_tokens) # List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = [] self.num_generation_tokens: List[Tuple[float, int]] = []
def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _init_workers(self): def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers # Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker # before CUDA_VISIBLE_DEVICES is set in the Worker
@ -141,11 +142,24 @@ class LLMEngine:
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
is_driver_worker=True, is_driver_worker=True,
) )
self._run_workers("init_model") self._run_workers("init_model")
self._run_workers("load_model") self._run_workers("load_model")
def _init_tokenizer(self, **tokenizer_init_kwargs):
init_kwargs = dict(
enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: TokenizerGroup = TokenizerGroup(
self.model_config.tokenizer, **init_kwargs)
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1: if self.parallel_config.tensor_parallel_size == 1:
@ -233,6 +247,7 @@ class LLMEngine:
local_rank, local_rank,
rank, rank,
distributed_init_method, distributed_init_method,
lora_config=self.lora_config,
)) ))
driver_rank = 0 driver_rank = 0
@ -244,6 +259,7 @@ class LLMEngine:
driver_local_rank, driver_local_rank,
driver_rank, driver_rank,
distributed_init_method, distributed_init_method,
lora_config=self.lora_config,
is_driver_worker=True, is_driver_worker=True,
) )
@ -257,6 +273,10 @@ class LLMEngine:
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
def _init_cache(self) -> None: def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache. """Profiles the memory usage and initializes the KV cache.
@ -332,6 +352,20 @@ class LLMEngine:
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine
def encode_request(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
@ -339,6 +373,7 @@ class LLMEngine:
sampling_params: SamplingParams, sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None, prefix_pos: Optional[int] = None,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
@ -386,24 +421,31 @@ class LLMEngine:
>>> # continue the request processing >>> # continue the request processing
>>> ... >>> ...
""" """
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None: if arrival_time is None:
arrival_time = time.monotonic() arrival_time = time.monotonic()
if prompt_token_ids is None: prompt_token_ids = self.encode_request(
assert prompt is not None request_id=request_id,
prompt_token_ids = self.tokenizer.encode(prompt) prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
lora_request)
# Check whether the input specifies prefix # Check whether the input specifies prefix
prefix = self.scheduler.prefix_pool.add_or_get_prefix( prefix = self.scheduler.prefix_pool.add_or_get_prefix(
prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None prompt_token_ids[:prefix_pos], lora_request.lora_int_id
if lora_request else 0) if prefix_pos is not None else None
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time, prefix) arrival_time, lora_request, prefix)
# Add the sequence group to the scheduler. # Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group) self.scheduler.add_seq_group(seq_group)
@ -453,11 +495,13 @@ class LLMEngine:
current_worst_score = (current_worst_seq.get_beam_search_score( current_worst_score = (current_worst_seq.get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id)) eos_token_id=self.get_tokenizer_for_seq(
current_worst_seq).eos_token_id))
if early_stopping is False: if early_stopping is False:
highest_attainable_score = (best_running_seq.get_beam_search_score( highest_attainable_score = (best_running_seq.get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id)) eos_token_id=self.get_tokenizer_for_seq(
best_running_seq).eos_token_id))
else: else:
assert early_stopping == "never" assert early_stopping == "never"
if length_penalty > 0.0: if length_penalty > 0.0:
@ -471,7 +515,8 @@ class LLMEngine:
highest_attainable_score = ( highest_attainable_score = (
best_running_seq.get_beam_search_score( best_running_seq.get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id, eos_token_id=self.get_tokenizer_for_seq(
best_running_seq).eos_token_id,
seq_len=max_possible_length)) seq_len=max_possible_length))
else: else:
# Otherwise, beam search will prefer shorter sequences. The # Otherwise, beam search will prefer shorter sequences. The
@ -480,7 +525,8 @@ class LLMEngine:
highest_attainable_score = ( highest_attainable_score = (
best_running_seq.get_beam_search_score( best_running_seq.get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id)) eos_token_id=self.get_tokenizer_for_seq(
best_running_seq).eos_token_id))
return current_worst_score >= highest_attainable_score return current_worst_score >= highest_attainable_score
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
@ -571,7 +617,7 @@ class LLMEngine:
# Sort the finished sequences by their scores. # Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id), eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
reverse=True) reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]: for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new: if is_new:
@ -599,7 +645,7 @@ class LLMEngine:
# Sort the running sequences by their scores. # Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id), eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
reverse=True) reverse=True)
# Check if we can stop the beam search. # Check if we can stop the beam search.
@ -837,7 +883,7 @@ class LLMEngine:
"""Decodes the new token for a sequence.""" """Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset, (new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally( read_offset) = detokenize_incrementally(
self.tokenizer, self.get_tokenizer_for_seq(seq),
all_input_ids=seq.get_token_ids(), all_input_ids=seq.get_token_ids(),
prev_tokens=seq.tokens, prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset, prefix_offset=seq.prefix_offset,
@ -879,11 +925,28 @@ class LLMEngine:
return return
# Check if the sequence has generated the EOS token. # Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos) if ((not sampling_params.ignore_eos) and seq.get_last_token_id()
and seq.get_last_token_id() == self.tokenizer.eos_token_id): == self.get_tokenizer_for_seq(seq).eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def list_loras(self) -> List[int]:
return self._run_workers("list_loras")
def _run_workers( def _run_workers(
self, self,
method: str, method: str,

View File

@ -3,6 +3,7 @@ from typing import List, Optional, Union
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.lora.request import LoRARequest
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
@ -122,6 +123,7 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
prefix_pos: Optional[Union[int, List[int]]] = None, prefix_pos: Optional[Union[int, List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
@ -141,6 +143,7 @@ class LLM:
This is an experimental feature, and may be replaced with This is an experimental feature, and may be replaced with
automatic prefix caching in the future. automatic prefix caching in the future.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
Returns: Returns:
A list of `RequestOutput` objects containing the generated A list of `RequestOutput` objects containing the generated
@ -168,7 +171,11 @@ class LLM:
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[ token_ids = None if prompt_token_ids is None else prompt_token_ids[
i] i]
self._add_request(prompt, sampling_params, token_ids, prefix_pos_i) self._add_request(prompt,
sampling_params,
token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos_i)
return self._run_engine(use_tqdm) return self._run_engine(use_tqdm)
def _add_request( def _add_request(
@ -176,6 +183,7 @@ class LLM:
prompt: Optional[str], prompt: Optional[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]], prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None, prefix_pos: Optional[int] = None,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
@ -183,6 +191,7 @@ class LLM:
prompt, prompt,
sampling_params, sampling_params,
prompt_token_ids, prompt_token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos) prefix_pos=prefix_pos)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:

0
vllm/lora/__init__.py Normal file
View File

975
vllm/lora/layers.py Normal file
View File

@ -0,0 +1,975 @@
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather,
)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear,
QKVParallelLinear,
MergedColumnParallelLinear)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim
if TYPE_CHECKING:
pass
def _apply_lora(
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
indices: torch.Tensor,
output: torch.Tensor,
):
"""Applies lora to each input.
This method applies all loras to each input. It uses the
indices vector to determine which lora yields the
correct output. An index of -1 means no lora should be
applied. This method adds the final lora results to the
output.
Input shapes:
x: (batch_size, hidden_dim)
lora_a_stacked: (num_loras, lora_rank, hidden_dim)
lora_b_stacked: (num_loras, output_dim, lora_rank)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
x = x.view(-1, x.shape[-1])
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
return output.view_as(org_output)
def _apply_lora_packed_nslice(
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
):
"""Applies lora to each input.
This method applies all loras to each input. It uses the
indices vector to determine which lora yields the
correct output. An index of -1 means no lora should be
applied. This method adds the final lora results to the
output.
This method is used for layers that are composed of multiple sublayers
(slices) packed together.
Input shapes:
x: (batch_size, hidden_dim)
lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...), where n is number of slices
"""
org_output = output
x = x.view(-1, x.shape[-1])
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx in range(len(output_slices)):
add_lora_slice(output, x, lora_a_stacked[slice_idx],
lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
output_slices[slice_idx])
offset_left += output_slices[slice_idx]
return output.view_as(org_output)
@dataclass
class LoRAMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
class BaseLayerWithLoRA(nn.Module):
def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
model_config: PretrainedConfig) -> None:
"""Initializes lora matrices."""
...
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
"""Overwrites lora tensors at index."""
...
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
"""Sets the mapping indices."""
...
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
lora_vocab_start_idx = self.base_layer.org_vocab_size
weights_idx = None
if self.base_layer.vocab_end_index > lora_vocab_start_idx:
# We can start adding lora weights
weights_idx = max(
lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
self.embeddings_slice = (self.base_layer.vocab_start_index -
self.base_layer.org_vocab_size +
weights_idx,
self.base_layer.vocab_end_index -
self.base_layer.org_vocab_size)
self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
self.embeddings_weights.fill_(0)
else:
self.embeddings_slice = None
self.embeddings_weights = None
self.embeddings_tensors = torch.zeros(
(
max_loras,
lora_config.lora_extra_vocab_size,
self.base_layer.embedding_dim,
),
dtype=self.base_layer.weight.dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.org_vocab_size +
lora_config.lora_extra_vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.base_layer.embedding_dim,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked_2d = self.lora_a_stacked.view(
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.embeddings_indices = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor.
shape[1]].copy_(embeddings_tensor, non_blocking=True)
if self.embeddings_slice is not None:
# TODO(yard1): Optimize this copy, we don't need to copy
# everything, just the modified part
embeddings = self.embeddings_tensors.view(
self.embeddings_tensors.shape[0] *
self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2]
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.embeddings_indices = embeddings_indices
self.indices_len = indices_len
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
full_lora_a_embeddings = F.embedding(
x + indices,
self.lora_a_stacked_2d,
)
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask))
full_output_org = full_output
if full_output.ndim == 3:
full_output = full_output.view(
full_output.shape[0] * full_output.shape[1], -1)
if full_lora_a_embeddings.ndim == 3:
full_lora_a_embeddings = full_lora_a_embeddings.view(
full_lora_a_embeddings.shape[0] *
full_lora_a_embeddings.shape[1], -1)
bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
return full_output.view_as(full_output_org)
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__()
self.base_layer = base_layer
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_a_stacked = torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
max_loras,
1,
self.base_layer.weight.shape[0],
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[1]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
_apply_lora(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
)
return output
def forward(self, input_):
"""Forward of ColumnParallelLinear
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = (self.base_layer.bias
if not self.base_layer.skip_bias_add else None)
# Matrix multiply.
output_parallel = self.apply_weights(input_, bias)
if self.base_layer.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = (self.base_layer.bias
if self.base_layer.skip_bias_add else None)
return output, output_bias
@property
def linear_weights(self):
return self.base_layer.linear_weights
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
packed together (eg. gate_proj + up_proj -> gate_up_proj).
This means we have 2 LoRAs, each applied to one half of the layer.
Both slices must have the same size.
"""
def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
super().__init__(base_layer)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
n_slices = 2
if not (len(self.base_layer.output_sizes) == n_slices
and self.base_layer.output_sizes[0]
== self.base_layer.output_sizes[1]):
raise ValueError(
"LoRAColumnParallelLinear2Slice requires 2 slices with "
"the same size.")
self.tp_size = get_tensor_model_parallel_world_size()
self.lora_a_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
) for _ in range(n_slices))
self.lora_b_stacked = tuple(
torch.zeros(
max_loras,
1,
self.base_layer.weight.shape[0] // 2,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
) for _ in range(n_slices))
self.indices: Optional[torch.Tensor] = None
self.output_dim = self.lora_b_stacked[0].shape[2]
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
self.lora_a_stacked[1][index] = 0
self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[0][:,
start_idx:end_idx], lora_b[1][:,
start_idx:end_idx]
if lora_a[0] is not None:
self.lora_a_stacked[0][
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
lora_a[0].T, non_blocking=True)
self.lora_b_stacked[0][
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
lora_b[0].T, non_blocking=True)
if lora_a[1] is not None:
self.lora_a_stacked[1][
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
lora_a[1].T, non_blocking=True)
self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
(self.output_dim, self.output_dim),
)
return output
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj).
This means we have 3 LoRAs, each applied to one slice of the layer.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
# q, k, v
self.lora_a_stacked = (
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
)
self.lora_b_stacked = (
torch.zeros(
max_loras,
1,
self.q_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
),
)
self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
self.kv_proj_shard_size)
self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
self.lora_b_stacked[0][index] = 0
self.lora_a_stacked[1][index] = 0
self.lora_b_stacked[1][index] = 0
self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
self.lora_b_stacked[0][
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
lora_b_q.T, non_blocking=True)
if lora_b[1] is not None:
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
self.lora_b_stacked[1][
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
lora_b_k.T, non_blocking=True)
if lora_b[2] is not None:
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
self.lora_b_stacked[2][
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
lora_b_v.T, non_blocking=True)
else:
if lora_b[0] is not None:
self.lora_b_stacked[0][
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
lora_b[0].T, non_blocking=True)
if lora_b[1] is not None:
self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
if lora_b[2] is not None:
self.lora_b_stacked[2][
index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
lora_b[2].T, non_blocking=True)
if lora_a[0] is not None:
self.lora_a_stacked[0][
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
lora_a[0].T, non_blocking=True)
if lora_a[1] is not None:
self.lora_a_stacked[1][
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
lora_a[1].T, non_blocking=True)
if lora_a[2] is not None:
self.lora_a_stacked[2][
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True)
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
self.output_slices,
)
return output
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__()
self.base_layer = base_layer
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_a_stacked = torch.zeros(
(
max_loras,
1,
lora_config.max_lora_rank,
self.base_layer.weight.shape[1],
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.base_layer.weight.shape[0],
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.base_layer.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.base_layer.weight.shape[1]
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
self.indices_len = indices_len
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x)
_apply_lora(
x,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[0]],
output,
)
return output
def forward(self, input_):
"""Forward of RowParallelLinear
Args:
input_: tensor whose last dimension is `input_size`. If
`input_is_parallel` is set, then the last dimension
is `input_size // tp_size`.
Returns:
- output
- bias
"""
# Set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply_weights(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (output_ + self.base_layer.bias
if self.base_layer.bias is not None else output_)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
return output, output_bias
@property
def weight(self):
return self.base_layer.weight
class SamplerWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: Sampler,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
super().__init__()
self.base_layer = base_layer
self.hidden_size = hidden_size
self.dtype = dtype
self.device = device
@property
def vocab_size(self):
return self.base_layer.vocab_size
@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
@property
def include_gpu_probs_tensor(self):
return self.base_layer.include_gpu_probs_tensor
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if 32000 < self.base_layer.vocab_size > 33024:
raise ValueError(
"When using LoRA, vocab size must be 32000 >= vocab_size <= 33024"
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
1,
lora_config.max_lora_rank,
self.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
# Pad for kernel compatibility
math.ceil(self.base_layer.vocab_size /
lora_config.lora_vocab_padding_size) *
lora_config.lora_vocab_padding_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.embeddings_tensors = torch.full(
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
fill_value=float("-inf"),
dtype=self.dtype,
device=self.device,
)
self.indices = None
self.indices_padded = None
self.indices_len = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = float("-inf")
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor.
shape[1], ] = embeddings_tensor
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = sampler_indices
self.indices_padded = sampler_indices_padded
self.indices_len = indices_len
def _get_logits(
self,
hidden_states: torch.Tensor,
embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
if logits is None:
return None
lora_logits = torch.empty(
self.embeddings_tensors.shape[0] + 1,
self.embeddings_tensors.shape[1],
hidden_states.shape[0],
dtype=self.embeddings_tensors.dtype,
device=self.embeddings_tensors.device,
)
torch.matmul(self.embeddings_tensors,
hidden_states.T,
out=lora_logits[:-1])
lora_logits[-1] = float("-inf")
lora_logits = lora_logits.mT
lora_logits = (lora_logits.reshape(
lora_logits.shape[0] * lora_logits.shape[1],
lora_logits.shape[2],
).index_select(0,
self.indices_padded[:self.indices_len[2]]).nan_to_num_(
nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))
logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits
_apply_lora(
hidden_states,
self.lora_a_stacked,
self.lora_b_stacked,
self.indices[:self.indices_len[1]],
logits,
)
# Remove paddings in vocab (if any).
logits = logits[:, :self.base_layer.vocab_size]
return logits
def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs)
def from_layer(
layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
supported_layer_types = {
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLora,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer
def from_layer_sampler(
layer: Sampler,
lm_head: ParallelLMHead,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> SamplerWithLoRA:
ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret

160
vllm/lora/lora.py Normal file
View File

@ -0,0 +1,160 @@
from typing import List, Optional
import torch
from vllm.utils import in_wsl
class LoRALayerWeights:
"""LoRA weights for a layer composed of two low rank matrixes."""
def __init__(
self,
module_name: str,
rank: int,
lora_alpha: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor] = None,
scaling: Optional[float] = None,
) -> None:
self.module_name = module_name
self.rank = rank
self.lora_alpha = lora_alpha
self.lora_a = lora_a
self.lora_b = lora_b
self.embeddings_tensor = embeddings_tensor
if scaling is None:
self.scaling = self.lora_alpha / self.rank
else:
self.scaling = scaling
def optimize(self) -> "LoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
if self.scaling == 1:
return
self.lora_b *= self.scaling
self.scaling = 1
return self
@property
def input_dim(self) -> int:
return self.lora_a.shape[0]
@property
def output_dim(self) -> int:
return self.lora_b.shape[1]
@property
def is_packed(self) -> bool:
return False
@property
def extra_vocab_size(self) -> int:
return self.embeddings_tensor.shape[
0] if self.embeddings_tensor is not None else 0
@classmethod
def create_dummy_lora_weights(
cls,
module_name: str,
input_dim: int,
output_dim: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and not in_wsl()
lora_a = torch.zeros([input_dim, rank],
dtype=dtype,
device=device,
pin_memory=pin_memory)
lora_b = torch.zeros([rank, output_dim],
dtype=dtype,
device=device,
pin_memory=pin_memory)
embeddings_tensor = torch.rand(
10,
embeddings_tensor_dim,
dtype=dtype,
device=device,
pin_memory=pin_memory) if embeddings_tensor_dim else None
return cls(
module_name,
rank=rank,
lora_alpha=1,
lora_a=lora_a,
lora_b=lora_b,
embeddings_tensor=embeddings_tensor,
)
class PackedLoRALayerWeights(LoRALayerWeights):
"""LoRA used for packed layers (eg. qkv_proj)."""
def __init__(
self,
module_name: str,
rank: int,
lora_alphas: List[int],
lora_a: List[torch.Tensor],
lora_b: List[torch.Tensor],
scaling: Optional[List[float]] = None,
) -> None:
super().__init__(
module_name=module_name,
rank=rank,
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
scaling=scaling,
embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:
self.scaling = [
lora_alpha / self.rank for lora_alpha in self.lora_alphas
]
@classmethod
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
"""
first_lora = next(lora for lora in loras if lora is not None)
for lora in loras:
if lora is None:
continue
lora.optimize()
rank = first_lora.rank
module_name = first_lora.module_name
obj = cls(
module_name,
rank,
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
scaling=[1 if lora is not None else None for lora in loras])
return obj
def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None:
continue
self.lora_b[i] *= self.scaling[i]
self.scaling[i] = 1
return self
@property
def input_dim(self) -> int:
raise NotImplementedError()
@property
def output_dim(self) -> int:
raise NotImplementedError()
@property
def is_packed(self) -> bool:
return True

654
vllm/lora/models.py Normal file
View File

@ -0,0 +1,654 @@
import copy
import json
import logging
import math
import os
import re
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type,
Union)
import safetensors.torch
import torch
from torch import nn
from vllm.config import LoRAConfig
from vllm.utils import LRUCache, in_wsl
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
logger = logging.getLogger(__name__)
# TODO: The mappings below should be moved to individual model classes.
PACKED_MODULES_CFG = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
TARGET_MODULES_QKV = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
EMBEDDING_MODULES = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
EMBEDDING_PADDING_MODULES = ["lm_head"]
_GLOBAL_LORA_ID = 0
def convert_mapping(
mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
max_loras: int, vocab_size: int, extra_vocab_size: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
indices_len: List of lengths of the above tensors.
"""
indices = list(mapping.index_mapping).copy()
embedding_indices = indices.copy()
lora_indices = indices.copy()
prompt_mapping = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(indices[i])
if indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if indices[i] > 0 else 0
indices[i] = i
lora_indices[i] = lora_idx
indices = torch.tensor([indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
prompt_mapping = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1])
return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len)
def get_lora_id():
global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1
return _GLOBAL_LORA_ID
class LoRAModel:
"""A LoRA fine-tuned model."""
def __init__(
self,
lora_model_id: int,
rank: int,
loras: Dict[str, LoRALayerWeights],
) -> None:
self.id = lora_model_id
assert (lora_model_id >
0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras
@property
def extra_vocab_size(self) -> int:
return max(lora.extra_vocab_size
for lora in self.loras.values()) if self.loras else 0
def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
"""Get LoRA for a given module by name"""
return self.loras.get(module_name, None)
# (yard1): TODO see if we can derive target_embedding_padding automatically
@classmethod
def from_lora_tensors(
cls,
lora_model_id: int,
rank: int,
lora_alpha: int,
tensors: Dict[str, torch.Tensor],
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and not in_wsl()
loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name),
None)
if embeddings_module:
lora_embeddings_tensor = embeddings[
EMBEDDING_MODULES[embeddings_module]].to(
device=device, dtype=dtype)
if pin_memory:
lora_embeddings_tensor = (
lora_embeddings_tensor.pin_memory())
loras[module_name] = LoRALayerWeights(module_name, rank,
lora_alpha, None, None,
lora_embeddings_tensor)
if is_lora_a:
loras[module_name].lora_a = tensor.to(device=device,
dtype=dtype).t()
if pin_memory:
loras[module_name].lora_a = loras[
module_name].lora_a.pin_memory()
else:
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
if any(name in module_name
for name in EMBEDDING_PADDING_MODULES
) and target_embedding_padding is not None:
lora_b = loras[module_name].lora_b
assert target_embedding_padding >= lora_b.shape[1]
addition = target_embedding_padding - lora_b.shape[1]
loras[module_name].lora_b = torch.nn.functional.pad(
lora_b, (0, addition))
if pin_memory:
loras[module_name].lora_b = loras[
module_name].lora_b.pin_memory()
for lora in loras.values():
lora.optimize()
return cls(lora_model_id, rank, loras)
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
target_embedding_padding: Optional[int] = None) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint."""
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin")
if os.path.isfile(lora_tensor_path):
tensors = safetensors.torch.load_file(lora_tensor_path)
elif os.path.isfile(lora_bin_file_path):
tensors = torch.load(lora_bin_file_path)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
embeddings = None
if os.path.isfile(new_embeddings_tensor_path):
embeddings = safetensors.torch.load_file(
new_embeddings_tensor_path)
elif os.path.isfile(new_embeddings_bin_file_path):
embeddings = torch.load(new_embeddings_bin_file_path)
with open(lora_config_path) as f:
config = json.load(f)
rank = config["r"]
lora_alpha = config["lora_alpha"]
return cls.from_lora_tensors(
lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id,
rank=rank,
lora_alpha=lora_alpha,
tensors=tensors,
device=device,
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
)
class LoRAModelManager:
"""A manager that manages multiple LoRA-fine-tuned models."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
):
"""Create a LoRAModelManager and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
vocab_size: the vocab size of the model.
lora_config: the LoRA configuration.
lora_target_modules: the target modules patterns to be adapted.
Support both single module name and a list of module names.
packed_modules_mapping: the mapping for packed modules. vLLM
packs some modules into one module, e.g., qkv_proj
is packed of q_proj, k_proj, and v_proj. These modules
have a single layer in the original model, but they are split
into multiple layers in the adapted model.
"""
self.lora_config = lora_config
self.max_num_seqs = max_num_seqs
assert self.capacity >= self.lora_slots
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.base_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.embeddings_indices = torch.empty(2,
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.offsets = []
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len = [None] * 4
self.model: nn.Module = model
self.lora_target_modules: List[str] = ([
lora_target_modules
] if isinstance(lora_target_modules, str) else lora_target_modules)
self.lora_target_modules = copy.deepcopy(lora_target_modules)
self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping)
self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {}
self._last_mapping = None
self._create_lora_modules()
self.model.lora_manager = self
@property
def capacity(self) -> int:
return self.lora_config.max_cpu_loras
@property
def lora_slots(self) -> int:
return self.lora_config.max_loras
def __len__(self) -> int:
return len(self._registered_loras)
def activate_lora(
self,
lora_id: int,
) -> bool:
"""Move LoRA into a GPU buffer to be used in the forward pass."""
if lora_id in self._active_loras:
return False
first_free_slot = next(
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
if lora_id is None), None)
if first_free_slot is None:
raise ValueError("No free lora slots")
index, _ = first_free_slot
self._active_loras[lora_id] = None
lora_model = self._registered_loras[lora_id]
logger.debug(
f"Activating LoRA. int id: {lora_model.id}, slot index: {index}")
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name)
if module_lora:
module_lora.optimize()
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
module_lora.embeddings_tensor)
else:
module.reset_lora(index)
return True
def _deactivate_lora(self, lora_id: int):
try:
index = self.lora_index_to_id.index(lora_id)
self.lora_index_to_id[index] = None
except ValueError:
pass
def deactivate_lora(self, lora_id: int) -> bool:
"""Remove a LoRA from a GPU buffer."""
if lora_id in self._active_loras:
self._deactivate_lora(lora_id)
self._active_loras.pop(lora_id)
return True
return False
def _add_lora(self, lora: LoRAModel) -> bool:
self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager CPU cache."""
if lora.id not in self._registered_loras:
if len(self._registered_loras) >= self.capacity:
raise RuntimeError("No free LoRA slots.")
self._add_lora(lora)
return True
return False
def remove_lora(self, lora_id: int) -> bool:
"""Remove a LoRAModel from the manager CPU cache."""
# TODO: should we check active lora?
self.deactivate_lora(lora_id)
return bool(self._registered_loras.pop(lora_id, None))
# TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices,
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
self.lora_slots + 1, self.vocab_size,
self.lora_config.lora_extra_vocab_size)
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded)
self.embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
# Maintain the reference
self.indices_len[:] = indices_len
def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
if self._last_mapping != lora_mapping:
self._set_lora_mapping(lora_mapping)
self._last_mapping = lora_mapping
def list_loras(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_loras)
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None)
def remove_all_loras(self) -> bool:
"""Remove all LoRAModels from the manager."""
self._registered_loras.clear()
self.lora_index_to_id = [None] * self.lora_slots
self._active_loras.clear()
def _create_lora_modules(self):
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name):
continue
new_module = replace_submodule(
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
self.model.config))
# (yard1): TODO make this more robust
if "lm_head" in module_name:
sampler_module = self.model.get_submodule("sampler")
new_module = replace_submodule(
self.model, "sampler",
from_layer_sampler(sampler_module, module, self.lora_slots,
self.lora_config, self.model.config))
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
new_module.set_mapping(self.base_indices, self.sampler_indices,
self.sampler_indices_padded,
self.embeddings_indices, self.indices_len)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA)
self.modules[module_name] = module
def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name) or not isinstance(
module, BaseLayerWithLoRA):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
if parts[-1] in EMBEDDING_MODULES:
input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if
hasattr(module.base_layer, "org_vocab_size")
else module.base_layer.weight.shape[1])
output_dim = module.base_layer.embedding_dim if hasattr(
module.base_layer,
"embedding_dim") else module.base_layer.weight.shape[0]
embeddings_tensor_dim = (module.base_layer.embedding_dim if
hasattr(module.base_layer,
"embedding_dim") else
module.base_layer.weight.shape[1])
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
input_dim,
output_dim,
rank,
module.lora_a_stacked.dtype,
"cpu",
embeddings_tensor_dim=embeddings_tensor_dim)
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.lora_a_stacked.shape[-1],
module.lora_b_stacked.shape[-2],
rank,
module.lora_a_stacked.dtype,
"cpu",
)
lora.optimize()
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
subloras = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
module.lora_a_stacked[i].shape[-1],
module.lora_b_stacked[i].shape[-2],
rank,
module.lora_a_stacked[i].dtype,
"cpu",
)
lora.optimize()
subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora
return model
def _match_target_modules(self, module_name: str):
return any(
re.match(
r".*\.{target_module}$".format(target_module=target_module),
module_name) or target_module == module_name
for target_module in self.lora_target_modules)
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
replacements = self.packed_modules_mapping.get(module_name)
if not replacements:
return
prefix = ".".join(parts[:-1])
self.packed_modules[module_full_name] = [
prefix + "." + r if prefix else r for r in replacements
]
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras = []
has_replacement = False
for r in new_module_names:
lora = lora_model.get_lora(r)
replacement_loras.append(lora)
if lora:
has_replacement = True
if not has_replacement:
continue
for i in range(len(replacement_loras)):
if replacement_loras[i]:
continue
replacement_loras[i] = None
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras)
class LoRALRUCache(LRUCache):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: Hashable, value: Any):
logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
class LRUCacheLoRAModelManager(LoRAModelManager):
"""A model manager that manages multiple LoRAs with LRU cache."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
):
super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config, lora_target_modules,
packed_modules_mapping)
self._registered_loras: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_lora)
self._active_loras: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_lora)
def list_loras(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_loras.cache)
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager."""
if lora.id not in self._registered_loras:
self._add_lora(lora)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_loras.touch(lora.id)
was_added = False
return was_added
def activate_lora(
self,
lora_id: int,
) -> bool:
if lora_id not in self._active_loras and len(
self._active_loras) >= self.lora_slots:
self._active_loras.remove_oldest()
result = super().activate_lora(lora_id)
# We always touch to update the LRU cache order
self._active_loras.touch(lora_id)
return result
def remove_oldest_lora(self) -> bool:
if len(self._registered_loras) > 0:
self._registered_loras.remove_oldest()
return True
return False
def create_lora_manager(
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not getattr(model, "supports_lora", False):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size,
lora_config=lora_config,
lora_target_modules=target_modules,
**kwargs)
return lora_manager

173
vllm/lora/punica.py Normal file
View File

@ -0,0 +1,173 @@
# Based on code from https://github.com/punica-ai/punica
from typing import Optional
import torch
import_exc = None
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
import_exc = e
if import_exc is None:
def bgmv(
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
matrices.
indicies: Shape: `[B]`. Indices of the weight matrices.
layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
"""
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
def add_lora(y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
*,
buffer: Optional[torch.Tensor] = None):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
"""
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical innacuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx,
1.0)
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
scale)
def add_lora_slice(y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
y_offset: int,
y_slice_size: int,
*,
buffer: Optional[torch.Tensor] = None):
"""
Same as `add_lora` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical inaccuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv_low_level(
buffer,
x,
wa_t_all,
indicies,
layer_idx,
1.0,
x.size(1),
buffer.size(1),
0,
)
punica_kernels.dispatch_bgmv_low_level(
y,
buffer,
wb_t_all,
indicies,
layer_idx,
scale,
buffer.size(1),
y_slice_size,
y_offset,
)
else:
def _raise_exc(
*args, # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument
):
if torch.cuda.get_device_capability() < (8, 0):
raise ImportError(
"LoRA kernels require compute capability>=8.0") from import_exc
else:
raise import_exc
bgmv = _raise_exc
add_lora = _raise_exc
add_lora_slice = _raise_exc
__all__ = [
"bgmv",
"add_lora",
"add_lora_slice",
]

32
vllm/lora/request.py Normal file
View File

@ -0,0 +1,32 @@
from dataclasses import dataclass
@dataclass
class LoRARequest:
"""
Request for a LoRA adapter.
Note that this class should be be used internally. For online
serving, it is recommended to not allow users to use this class but
instead provide another layer of abstraction to prevent users from
accessing unauthorized LoRA adapters.
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
"""
lora_name: str
lora_int_id: int
lora_local_path: str
def __post_init__(self):
if self.lora_int_id < 1:
raise ValueError(
f"lora_int_id must be > 0, got {self.lora_int_id}")
def __eq__(self, value: object) -> bool:
return isinstance(
value, LoRARequest) and self.lora_int_id == value.lora_int_id
def __hash__(self) -> int:
return self.lora_int_id

39
vllm/lora/utils.py Normal file
View File

@ -0,0 +1,39 @@
import logging
from typing import Tuple
from torch import nn
logger = logging.getLogger(__name__)
def replace_submodule(model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
"""Parse the name of lora weights.
args:
name: the name of the fine-tuned LoRA, e.g.
base_model.model.dense1.weight
return:
Tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
"""
parts = name.split(".")
assert parts[0] == "base_model"
assert parts[1] == "model"
if parts[-1] == "weight":
assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
raise ValueError(f"{name} is unsupported format")

237
vllm/lora/worker_manager.py Normal file
View File

@ -0,0 +1,237 @@
import logging
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, List, Optional, Set, Type, Union
import torch
from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest
from vllm.lora.layers import LoRAMapping
from vllm.config import LoRAConfig
logger = logging.getLogger(__name__)
class WorkerLoRAManager(ABC):
"""Abstract class for managing LoRA models on the worker side."""
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
vocab_size: int, lora_config: LoRAConfig,
device: torch.device):
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.device = device
self.lora_config = lora_config
@abstractproperty
def is_enabled(self) -> bool:
...
@abstractmethod
def create_lora_manager(
self,
model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
) -> Any:
...
@abstractmethod
def set_active_loras(self, lora_requests: List[LoRARequest],
lora_mapping: LoRAMapping) -> None:
...
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
...
@abstractmethod
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
...
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
...
@abstractmethod
def remove_all_loras(self) -> bool:
...
@abstractmethod
def list_loras(self) -> Set[int]:
...
class WorkerLoRAManager(WorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Every request, the requested LoRAs will be loaded (unless they are already
loaded), and every other LoRA will be unloaded."""
_lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager
def __init__(
self,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
lora_model_cls: Type[LoRAModel] = LoRAModel,
):
self._lora_manager: Optional[LoRAModelManager] = None
self._lora_model_cls = lora_model_cls
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device)
@property
def is_enabled(self) -> bool:
return True
def create_lora_manager(
self,
model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
) -> Any:
lora_manager = create_lora_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
target_modules=target_modules,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls,
)
self._lora_manager: LoRAModelManager = lora_manager
return lora_manager.model
def set_active_loras(self, lora_requests: List[LoRARequest],
lora_mapping: LoRAMapping) -> None:
self._apply_loras(lora_requests)
self._lora_manager.set_lora_mapping(lora_mapping)
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
loras_that_exist = self.list_loras()
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._lora_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).")
new_loras = set(loras_map)
loras_to_add = new_loras - loras_that_exist
loras_to_remove = loras_that_exist - new_loras
for lora_id in loras_to_remove:
self.remove_lora(lora_id)
for lora_id in loras_to_add:
self.add_lora(loras_map[lora_id])
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
try:
lora = self._lora_model_cls.from_local_checkpoint(
lora_request.lora_local_path,
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
target_embedding_padding=self.vocab_size +
self.lora_config.lora_extra_vocab_size,
)
except Exception as e:
raise RuntimeError(
f"Loading lora {lora_request.lora_local_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank:
raise ValueError(
f"LoRA rank {lora.rank} is greater than max_lora_rank "
f"{self.lora_config.max_lora_rank}.")
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError(
f"LoRA added vocab size {lora.extra_vocab_size} is greater than "
f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}."
)
return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras():
return False
return self._lora_manager.add_lora(
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
rank))
def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras():
return False
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
self._lora_manager.activate_lora(lora.id)
return loaded
def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id)
def remove_all_loras(self) -> bool:
self._lora_manager.remove_all_loras()
def list_loras(self) -> Set[int]:
return set(self._lora_manager.list_loras())
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Uses an LRU Cache. Every request, the requested LoRAs will be loaded
(unless they are already loaded) and least recently used LoRAs will
be unloaded if the cache is above capacity."""
_lora_manager_cls: Type[
LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
def create_lora_manager(
self,
model: torch.nn.Module,
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
) -> Any:
lora_manager = create_lora_manager(
model,
target_modules=target_modules,
lora_manager_cls=self._lora_manager_cls,
max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
max_num_batched_tokens=self.max_num_batched_tokens,
)
self._lora_manager: LRUCacheLoRAModelManager = lora_manager
return lora_manager.model
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._lora_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).")
for lora in loras_map.values():
self.add_lora(lora)
def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id not in self.list_loras():
# Remove before we load the new lora to save memory
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
self._lora_manager.remove_oldest_lora()
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
else:
# If the lora is already loaded, just touch it to
# update its position in the caches
loaded = self._lora_manager.get_lora(lora_request.lora_int_id)
self._lora_manager.activate_lora(lora_request.lora_int_id)
return loaded

View File

@ -27,9 +27,25 @@ class Sampler(nn.Module):
parameters (e.g., sampling method, temperature, top-p, top-k, etc.). parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
""" """
def __init__(self, vocab_size: int) -> None: def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def forward( def forward(
self, self,
@ -42,8 +58,7 @@ class Sampler(nn.Module):
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = _get_logits(hidden_states, embedding, embedding_bias, logits = self._get_logits(hidden_states, embedding, embedding_bias)
self.vocab_size)
# Only perform sampling in the driver worker. # Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because # Note: `_get_logits` is still distributed across TP workers because
@ -98,20 +113,6 @@ class Sampler(nn.Module):
prompt_logprobs, sample_logprobs) prompt_logprobs, sample_logprobs)
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor],
vocab_size: int) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :vocab_size]
return logits
def _prune_hidden_states( def _prune_hidden_states(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,

View File

@ -13,8 +13,11 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value.""" """Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to return ((vocab_size + pad_to - 1) // pad_to) * pad_to
@ -43,17 +46,23 @@ class VocabParallelEmbedding(torch.nn.Module):
num_embeddings: vocabulary size. num_embeddings: vocabulary size.
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
params_dtype: type of the parameters. params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
""" """
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
params_dtype: Optional[torch.dtype] = None): params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super().__init__() super().__init__()
# Keep the input dimensions. # Keep the input dimensions.
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings) self.org_vocab_size = org_num_embeddings or num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings,
padding_size)
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
@ -77,7 +86,7 @@ class VocabParallelEmbedding(torch.nn.Module):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.num_embeddings assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
loaded_weight = loaded_weight[self.vocab_start_index:self. loaded_weight = loaded_weight[self.vocab_start_index:self.
vocab_end_index] vocab_end_index]
param[:loaded_weight.shape[0]].data.copy_(loaded_weight) param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
@ -114,14 +123,19 @@ class ParallelLMHead(VocabParallelEmbedding):
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
bias: whether to use bias. bias: whether to use bias.
params_dtype: type of the parameters. params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
""" """
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
bias: bool = False, bias: bool = False,
params_dtype: Optional[torch.dtype] = None): params_dtype: Optional[torch.dtype] = None,
super().__init__(num_embeddings, embedding_dim, params_dtype) org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition,

View File

@ -1,12 +1,12 @@
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib import contextlib
from typing import Type from typing import Optional, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig from vllm.config import ModelConfig, LoRAConfig
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config, from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights) initialize_dummy_weights)
@ -32,7 +32,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
f"Supported architectures: {ModelRegistry.get_supported_archs()}") f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig,
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config) model_class = _get_model_architecture(model_config.hf_config)
# Get the (maybe quantized) linear method. # Get the (maybe quantized) linear method.
@ -62,6 +63,16 @@ def get_model(model_config: ModelConfig) -> nn.Module:
# Create a model instance. # Create a model instance.
# The weights will be initialized as empty tensors. # The weights will be initialized as empty tensors.
with torch.device("cuda"): with torch.device("cuda"):
if getattr(model_class, "supports_lora", False):
model = model_class(model_config.hf_config, linear_method,
lora_config)
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
else:
model = model_class(model_config.hf_config, linear_method) model = model_class(model_config.hf_config, linear_method)
if model_config.load_format == "dummy": if model_config.load_format == "dummy":
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign

View File

@ -38,13 +38,14 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.config import LoRAConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -225,14 +226,19 @@ class LlamaModel(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer(config, linear_method) LlamaDecoderLayer(config, linear_method)
@ -263,18 +269,31 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module):
supports_lora = True
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = LlamaModel(config, linear_method) self.model = LlamaModel(config, linear_method, lora_config=lora_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) unpadded_vocab_size = config.vocab_size
self.sampler = Sampler(config.vocab_size) if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
def forward( def forward(
self, self,

View File

@ -38,13 +38,14 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.config import LoRAConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -220,15 +221,20 @@ class MistralModel(nn.Module):
self, self,
config: MistralConfig, config: MistralConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
MistralDecoderLayer(config, linear_method) MistralDecoderLayer(config, linear_method)
@ -259,18 +265,33 @@ class MistralModel(nn.Module):
class MistralForCausalLM(nn.Module): class MistralForCausalLM(nn.Module):
supports_lora = True
def __init__( def __init__(
self, self,
config: MistralConfig, config: MistralConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = MistralModel(config, linear_method) self.model = MistralModel(config,
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) linear_method,
self.sampler = Sampler(config.vocab_size) lora_config=lora_config)
unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
def forward( def forward(
self, self,

View File

@ -195,10 +195,14 @@ def get_pipeline_model_parallel_prev_rank():
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none and destroy them."""
global _TENSOR_MODEL_PARALLEL_GROUP global _TENSOR_MODEL_PARALLEL_GROUP
if _TENSOR_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_MODEL_PARALLEL_GROUP
if _PIPELINE_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None

View File

@ -2,6 +2,7 @@ from typing import List, Optional
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
SequenceStatus) SequenceStatus)
from vllm.lora.request import LoRARequest
class CompletionOutput: class CompletionOutput:
@ -16,6 +17,7 @@ class CompletionOutput:
logprobs: The log probabilities of the top probability words at each logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested. position if the logprobs are requested.
finish_reason: The reason why the sequence is finished. finish_reason: The reason why the sequence is finished.
lora_request: The LoRA request that was used to generate the output.
""" """
def __init__( def __init__(
@ -26,6 +28,7 @@ class CompletionOutput:
cumulative_logprob: float, cumulative_logprob: float,
logprobs: Optional[SampleLogprobs], logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None, finish_reason: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.index = index self.index = index
self.text = text self.text = text
@ -33,6 +36,7 @@ class CompletionOutput:
self.cumulative_logprob = cumulative_logprob self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs self.logprobs = logprobs
self.finish_reason = finish_reason self.finish_reason = finish_reason
self.lora_request = lora_request
def finished(self) -> bool: def finished(self) -> bool:
return self.finish_reason is not None return self.finish_reason is not None
@ -56,6 +60,7 @@ class RequestOutput:
prompt_logprobs: The log probabilities to return per prompt token. prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request. outputs: The output sequences of the request.
finished: Whether the whole request is finished. finished: Whether the whole request is finished.
lora_request: The LoRA request that was used to generate the output.
""" """
def __init__( def __init__(
@ -66,6 +71,7 @@ class RequestOutput:
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
finished: bool, finished: bool,
lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
@ -73,6 +79,7 @@ class RequestOutput:
self.prompt_logprobs = prompt_logprobs self.prompt_logprobs = prompt_logprobs
self.outputs = outputs self.outputs = outputs
self.finished = finished self.finished = finished
self.lora_request = lora_request
@classmethod @classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
@ -108,8 +115,13 @@ class RequestOutput:
prompt_token_ids = seq_group.prompt_token_ids prompt_token_ids = seq_group.prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished() finished = seq_group.is_finished()
return cls(seq_group.request_id, prompt, prompt_token_ids, return cls(seq_group.request_id,
prompt_logprobs, outputs, finished) prompt,
prompt_token_ids,
prompt_logprobs,
outputs,
finished,
lora_request=seq_group.lora_request)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "
@ -117,4 +129,5 @@ class RequestOutput:
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, "
f"outputs={self.outputs}, " f"outputs={self.outputs}, "
f"finished={self.finished})") f"finished={self.finished}, "
f"lora_request={self.lora_request})")

View File

@ -74,13 +74,14 @@ class PrefixPool:
new_length = len(token_ids) // self.block_size * self.block_size new_length = len(token_ids) // self.block_size * self.block_size
return tuple(token_ids[:new_length]) return tuple(token_ids[:new_length])
def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: def add_or_get_prefix(self, token_ids: Sequence[int],
lora_int_id: int) -> Optional[Prefix]:
token_ids = self._truncate_token_ids(token_ids) token_ids = self._truncate_token_ids(token_ids)
if len(token_ids) == 0: if len(token_ids) == 0:
# Prefix is empty. # Prefix is empty.
return None return None
prefix = Prefix(token_ids, self.block_size) prefix = Prefix(token_ids, self.block_size)
prefix_hash = hash(prefix) prefix_hash = hash((prefix, lora_int_id))
if prefix_hash not in self.prefixes: if prefix_hash not in self.prefixes:
self.prefixes[prefix_hash] = prefix self.prefixes[prefix_hash] = prefix
return self.prefixes[prefix_hash] return self.prefixes[prefix_hash]

View File

@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock from vllm.block import LogicalTokenBlock
from vllm.prefix import Prefix from vllm.prefix import Prefix
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.lora.request import LoRARequest
PromptLogprobs = List[Optional[Dict[int, float]]] PromptLogprobs = List[Optional[Dict[int, float]]]
SampleLogprobs = List[Dict[int, float]] SampleLogprobs = List[Dict[int, float]]
@ -106,6 +107,7 @@ class Sequence:
prompt_token_ids: The token IDs of the prompt. prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine. block size used by the block manager and cache engine.
lora_request: LoRA request.
""" """
def __init__( def __init__(
@ -114,10 +116,12 @@ class Sequence:
prompt: str, prompt: str,
prompt_token_ids: List[int], prompt_token_ids: List[int],
block_size: int, block_size: int,
lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.prompt = prompt self.prompt = prompt
self.block_size = block_size self.block_size = block_size
self.lora_request = lora_request
self.data = SequenceData(prompt_token_ids) self.data = SequenceData(prompt_token_ids)
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
@ -134,6 +138,10 @@ class Sequence:
# Input + output tokens # Input + output tokens
self.tokens: Optional[List[str]] = None self.tokens: Optional[List[str]] = None
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
def _append_logical_block(self) -> None: def _append_logical_block(self) -> None:
block = LogicalTokenBlock( block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks), block_number=len(self.logical_token_blocks),
@ -229,6 +237,7 @@ class SequenceGroup:
seqs: The list of sequences. seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request. arrival_time: The arrival time of the request.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group. prefix: The prefix of the prompt of the sequence group.
""" """
@ -238,12 +247,14 @@ class SequenceGroup:
seqs: List[Sequence], seqs: List[Sequence],
sampling_params: SamplingParams, sampling_params: SamplingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None, prefix: Optional[Prefix] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.arrival_time = arrival_time self.arrival_time = arrival_time
self.lora_request = lora_request
self.prefix: Optional[Prefix] = prefix self.prefix: Optional[Prefix] = prefix
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
@ -259,6 +270,10 @@ class SequenceGroup:
# We use the prompt of an arbitrary sequence. # We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids return next(iter(self.seqs_dict.values())).data.prompt_token_ids
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
def get_max_num_running_seqs(self) -> int: def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining """The maximum number of sequences running in parallel in the remaining
lifetime of the request.""" lifetime of the request."""
@ -338,6 +353,7 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block block_tables: The block tables. (Seq id -> list of physical block
numbers) numbers)
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group. prefix: The prefix of the prompt of the sequence group.
""" """
@ -348,6 +364,7 @@ class SequenceGroupMetadata:
seq_data: Dict[int, SequenceData], seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams, sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], block_tables: Dict[int, List[int]],
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None, prefix: Optional[Prefix] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
@ -355,8 +372,13 @@ class SequenceGroupMetadata:
self.seq_data = seq_data self.seq_data = seq_data
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.block_tables = block_tables self.block_tables = block_tables
self.lora_request = lora_request
self.prefix = prefix self.prefix = prefix
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
class SequenceOutput: class SequenceOutput:
"""The model output associated with a sequence. """The model output associated with a sequence.

View File

@ -4,6 +4,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import make_async, LRUCache
from vllm.transformers_utils.tokenizers import * from vllm.transformers_utils.tokenizers import *
logger = init_logger(__name__) logger = init_logger(__name__)
@ -65,6 +67,84 @@ def get_tokenizer(
return tokenizer return tokenizer
def get_lora_tokenizer(lora_request: LoRARequest, *args,
**kwargs) -> Optional[PreTrainedTokenizer]:
if lora_request is None:
return None
try:
tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
**kwargs)
except OSError as e:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(
f"No tokenizer found in {lora_request.lora_local_path}, "
"using base model tokenizer instead. "
f"(Exception: {str(e)})")
tokenizer = None
return tokenizer
get_lora_tokenizer_async = make_async(get_lora_tokenizer)
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], **tokenizer_config):
self.tokenizer_id = tokenizer_id
self.tokenizer_config = tokenizer_config
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora:
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
else:
self.lora_tokenizers = None
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (await get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
def _convert_tokens_to_string_with_added_encoders( def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str], output_tokens: List[str],

View File

@ -7,6 +7,17 @@ from typing import List
import psutil import psutil
import torch import torch
import asyncio
from functools import partial
from typing import (
Awaitable,
Callable,
TypeVar,
)
from collections import OrderedDict
from typing import Any, Hashable, Optional
T = TypeVar("T")
class Device(enum.Enum): class Device(enum.Enum):
@ -28,6 +39,69 @@ class Counter:
self.counter = 0 self.counter = 0
class LRUCache:
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.capacity = capacity
def __contains__(self, key: Hashable) -> bool:
return key in self.cache
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> Any:
return self.get(key)
def __setitem__(self, key: Hashable, value: Any) -> None:
self.put(key, value)
def __delitem__(self, key: Hashable) -> None:
self.pop(key)
def touch(self, key: Hashable) -> None:
self.cache.move_to_end(key)
def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
return value
def put(self, key: Hashable, value: Any) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: Any):
pass
def remove_oldest(self):
if not self.cache:
return
key, value = self.cache.popitem(last=False)
self._on_remove(key, value)
def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
run_on_remove = key in self.cache
value = self.cache.pop(key, default_value)
if run_on_remove:
self._on_remove(key, value)
return value
def clear(self):
while len(self.cache) > 0:
self.remove_oldest()
self.cache.clear()
def is_hip() -> bool: def is_hip() -> bool:
return torch.version.hip is not None return torch.version.hip is not None
@ -59,6 +133,22 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()
def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
asyncio event loop.
The code in this function needs to be thread safe.
"""
def _async_wrapper(*args, **kwargs) -> asyncio.Future:
loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=None, func=p_func)
return _async_wrapper
def get_ip() -> str: def get_ip() -> str:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable

View File

@ -1,23 +1,27 @@
import time import time
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Set, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict) broadcast_tensor_dict)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.utils import in_wsl from vllm.utils import in_wsl
logger = init_logger(__name__) logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
_PAD_SLOT_ID = -1 _PAD_SLOT_ID = -1
LORA_WARMUP_RANK = 8
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
@ -30,19 +34,23 @@ class ModelRunner:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
is_driver_worker: bool = False, is_driver_worker: bool = False,
): ):
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py. # model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this. # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self.sliding_window = (model_config.get_sliding_window() self.sliding_window = (model_config.get_sliding_window()
if model_config is not None else None) if model_config is not None else None)
self.device = torch.device(torch.cuda.current_device())
self.model = None self.model = None
self.block_size = None # Set after initial profiling. self.block_size = None # Set after initial profiling.
self.lora_manager = None
self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture. self.graph_memory_pool = None # Set during graph capture.
@ -61,7 +69,17 @@ class ModelRunner:
self.in_wsl = in_wsl() self.in_wsl = in_wsl()
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(self.model_config) self.model = get_model(self.model_config, self.lora_config)
vocab_size = self.model.config.vocab_size
if self.lora_config:
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens +
self.scheduler_config.max_paddings, vocab_size,
self.lora_config, self.device)
self.model = self.lora_manager.create_lora_manager(self.model)
def set_block_size(self, block_size: int) -> None: def set_block_size(self, block_size: int) -> None:
self.block_size = block_size self.block_size = block_size
@ -74,12 +92,15 @@ class ModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
List[int]]: List[int], List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[List[int]] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
prompt_lens: List[int] = [] prompt_lens: List[int] = []
context_lens: List[int] = [] context_lens: List[int] = []
@ -113,6 +134,17 @@ class ModelRunner:
input_positions.append( input_positions.append(
list(range(prefix_len, prefix_len + len(prompt_tokens)))) list(range(prefix_len, prefix_len + len(prompt_tokens))))
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping.append([lora_id] * prompt_len)
lora_prompt_mapping.extend(
[lora_id] *
(prompt_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping. # yet. In this case, we just use a dummy slot mapping.
@ -156,6 +188,10 @@ class ModelRunner:
max_prompt_len, max_prompt_len,
pad=_PAD_SLOT_ID, pad=_PAD_SLOT_ID,
dtype=torch.long) dtype=torch.long)
lora_index_mapping = [
_pad_to_max(mapping, max_prompt_len, pad=0)
for mapping in lora_index_mapping
]
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device='cuda') device='cuda')
@ -188,23 +224,33 @@ class ModelRunner:
use_cuda_graph=False, use_cuda_graph=False,
) )
return (input_tokens, input_positions, input_metadata, prompt_lens, return (input_tokens, input_positions, input_metadata, prompt_lens,
subquery_lens) subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests)
def _prepare_decode( def _prepare_decode(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[List[int]] = []
context_lens: List[int] = [] context_lens: List[int] = []
block_tables: List[List[int]] = [] block_tables: List[List[int]] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id() generation_token = seq_data.get_last_token_id()
@ -223,6 +269,8 @@ class ModelRunner:
block_offset = position % self.block_size block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append([slot]) slot_mapping.append([slot])
lora_index_mapping.append([lora_id])
lora_prompt_mapping.append(lora_id)
if self.sliding_window is not None: if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window // sliding_window_blocks = (self.sliding_window //
@ -287,6 +335,10 @@ class ModelRunner:
device="cuda", device="cuda",
) )
lora_index_mapping = [
_pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
]
input_metadata = InputMetadata( input_metadata = InputMetadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
@ -298,7 +350,7 @@ class ModelRunner:
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
) )
return input_tokens, input_positions, input_metadata return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests
def _prepare_sample( def _prepare_sample(
self, self,
@ -375,7 +427,8 @@ class ModelRunner:
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
Set[int], LoRAMapping]:
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
@ -383,16 +436,29 @@ class ModelRunner:
# Prepare input tensors. # Prepare input tensors.
if is_prompt: if is_prompt:
(input_tokens, input_positions, input_metadata, prompt_lens, (input_tokens, input_positions, input_metadata, prompt_lens,
subquery_lens) = self._prepare_prompt(seq_group_metadata_list) subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, input_metadata (input_tokens, input_positions, input_metadata,
) = self._prepare_decode(seq_group_metadata_list) lora_index_mapping, lora_prompt_mapping,
subquery_lens = None lora_requests) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = [] prompt_lens = []
subquery_lens = None
sampling_metadata = self._prepare_sample(seq_group_metadata_list, sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens, prompt_lens,
subquery_lens) subquery_lens)
if self.lora_config:
flat_lora_index_mapping = [
item for sublist in lora_index_mapping for item in sublist
]
lora_mapping = LoRAMapping(
flat_lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
# Broadcast the metadata. # Broadcast the metadata.
metadata_dict = { metadata_dict = {
"input_tokens": input_tokens, "input_tokens": input_tokens,
@ -408,12 +474,16 @@ class ModelRunner:
"use_cuda_graph": input_metadata.use_cuda_graph, "use_cuda_graph": input_metadata.use_cuda_graph,
"selected_token_indices": "selected_token_indices":
sampling_metadata.selected_token_indices, sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
} }
broadcast_tensor_dict(metadata_dict, src=0) broadcast_tensor_dict(metadata_dict, src=0)
else: else:
metadata_dict = broadcast_tensor_dict(src=0) metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict["input_tokens"] input_tokens = metadata_dict["input_tokens"]
input_positions = metadata_dict["input_positions"] input_positions = metadata_dict["input_positions"]
lora_mapping = metadata_dict["lora_mapping"]
lora_requests = metadata_dict["lora_requests"]
input_metadata = InputMetadata( input_metadata = InputMetadata(
is_prompt=metadata_dict["is_prompt"], is_prompt=metadata_dict["is_prompt"],
slot_mapping=metadata_dict["slot_mapping"], slot_mapping=metadata_dict["slot_mapping"],
@ -434,7 +504,7 @@ class ModelRunner:
perform_sampling=False, perform_sampling=False,
) )
return input_tokens, input_positions, input_metadata, sampling_metadata return input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
@ -442,8 +512,12 @@ class ModelRunner:
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
input_tokens, input_positions, input_metadata, sampling_metadata = ( input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping = (
self.prepare_input_tensors(seq_group_metadata_list)) self.prepare_input_tensors(seq_group_metadata_list))
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
# Execute the model. # Execute the model.
if input_metadata.use_cuda_graph: if input_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0] graph_batch_size = input_tokens.shape[0]
@ -472,6 +546,28 @@ class ModelRunner:
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
dummy_lora_requests = []
dummy_lora_requests_per_seq = []
if self.lora_config:
for idx in range(self.lora_config.max_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the total # Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = [] seqs: List[SequenceGroupMetadata] = []
@ -485,6 +581,8 @@ class ModelRunner:
seq_data={group_id: seq_data}, seq_data={group_id: seq_data},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables=None, block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
) )
seqs.append(seq) seqs.append(seq)
@ -495,6 +593,32 @@ class ModelRunner:
torch.cuda.synchronize() torch.cuda.synchronize()
return return
def remove_all_loras(self) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_all_loras()
def set_active_loras(self, lora_requests: List[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_loras(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_loras()
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[KVCache]) -> None: def capture_model(self, kv_caches: List[KVCache]) -> None:
assert not self.model_config.enforce_eager assert not self.model_config.enforce_eager
@ -541,6 +665,13 @@ class ModelRunner:
use_cuda_graph=True, use_cuda_graph=True,
) )
if self.lora_config:
lora_mapping = LoRAMapping(
[0] * batch_size,
[0] * batch_size,
)
self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model) graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture( graph_runner.capture(
input_tokens[:batch_size], input_tokens[:batch_size],

View File

@ -1,12 +1,13 @@
"""A GPU worker class.""" """A GPU worker class."""
import gc
import os import os
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Tuple, Set, Optional
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig, LoRAConfig)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict) broadcast_tensor_dict)
@ -15,6 +16,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
from vllm.lora.request import LoRARequest
class Worker: class Worker:
@ -33,6 +35,7 @@ class Worker:
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
@ -41,12 +44,16 @@ class Worker:
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
self.model_runner = ModelRunner(model_config, parallel_config, self.model_runner = ModelRunner(model_config,
scheduler_config, is_driver_worker) parallel_config,
scheduler_config,
lora_config=self.lora_config,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# self.init_cache_engine(). # self.init_cache_engine().
self.cache_config = None self.cache_config = None
@ -117,6 +124,9 @@ class Worker:
num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_cpu_blocks = int(cpu_swap_space // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0) num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks return num_gpu_blocks, num_cpu_blocks
@ -199,6 +209,15 @@ class Worker:
self.gpu_cache) self.gpu_cache)
return output return output
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def _init_distributed_environment( def _init_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,