[Experimental] Add multi-LoRA support (#1804)
Co-authored-by: Chen Shen <scv119@gmail.com> Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com> Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
parent
9c1352eb57
commit
9b945daaf1
@ -41,6 +41,9 @@ steps:
|
|||||||
- label: Worker Test
|
- label: Worker Test
|
||||||
command: pytest -v -s worker
|
command: pytest -v -s worker
|
||||||
|
|
||||||
|
- label: LoRA Test
|
||||||
|
command: pytest -v -s lora
|
||||||
|
|
||||||
- label: Benchmarks
|
- label: Benchmarks
|
||||||
working_dir: "/vllm-workspace/.buildkite"
|
working_dir: "/vllm-workspace/.buildkite"
|
||||||
commands:
|
commands:
|
||||||
|
@ -65,7 +65,9 @@ def main(args: argparse.Namespace):
|
|||||||
if args.profile:
|
if args.profile:
|
||||||
profile_dir = args.profile_result_dir
|
profile_dir = args.profile_result_dir
|
||||||
if not profile_dir:
|
if not profile_dir:
|
||||||
profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
profile_dir = Path(
|
||||||
|
"."
|
||||||
|
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
||||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||||
run_to_completion(profile_dir=args.profile_result_dir)
|
run_to_completion(profile_dir=args.profile_result_dir)
|
||||||
return
|
return
|
||||||
@ -123,9 +125,7 @@ if __name__ == '__main__':
|
|||||||
'--profile-result-dir',
|
'--profile-result-dir',
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help=(
|
help=('path to save the pytorch profiler output. Can be visualized '
|
||||||
'path to save the pytorch profiler output. Can be visualized '
|
'with ui.perfetto.dev or Tensorboard.'))
|
||||||
'with ui.perfetto.dev or Tensorboard.'
|
|
||||||
))
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
217
csrc/punica/LICENSE
Normal file
217
csrc/punica/LICENSE
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
Contains code from https://github.com/punica-ai/punica
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright {yyyy} {name of copyright owner}
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
This product bundles various third-party components under other open source licenses.
|
||||||
|
This section summarizes those components and their licenses. See licenses/
|
||||||
|
for text of these licenses.
|
||||||
|
|
||||||
|
|
||||||
|
Apache-2.0
|
||||||
|
* third_party/nvbench (with LLVM exception)
|
||||||
|
* third_party/flashinfer
|
||||||
|
|
||||||
|
BSD-3-Clause:
|
||||||
|
* third_party/cutlass
|
21
csrc/punica/bgmv/bgmv_all.cu
Normal file
21
csrc/punica/bgmv/bgmv_all.cu
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#include "bgmv_config.h"
|
||||||
|
#include "bgmv_impl.cuh"
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
|
||||||
|
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
|
59
csrc/punica/bgmv/bgmv_config.h
Normal file
59
csrc/punica/bgmv/bgmv_config.h
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||||
|
typename W_T>
|
||||||
|
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||||
|
int64_t layer_idx, float scale);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
|
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 128) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 256) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 512) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 1024) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 1280) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 1728) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 1792) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 2048) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 2560) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 2752) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 3072) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 3456) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 3584) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 4096) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 5120) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 5504) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 5632) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 6912) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 7168) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 8192) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 9216) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 10240) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 11008) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 12288) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 13824) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 14336) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 16384) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 20480) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 28672) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 32000) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 32256) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 32512) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 32768) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 33024) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 36864) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 49152) \
|
||||||
|
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
||||||
|
|
||||||
|
// Keep this in sync with vllm/config::LoRAConfig
|
||||||
|
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||||
|
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
|
||||||
|
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
|
||||||
|
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
|
||||||
|
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
|
||||||
|
|
||||||
|
// clang-format on
|
294
csrc/punica/bgmv/bgmv_impl.cuh
Normal file
294
csrc/punica/bgmv/bgmv_impl.cuh
Normal file
@ -0,0 +1,294 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cuda/pipeline>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "vec_dtypes.cuh"
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
|
// nthrs = (32, 4)
|
||||||
|
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||||
|
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||||
|
typename out_T, typename W_T>
|
||||||
|
__global__ void
|
||||||
|
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||||
|
float scale) {
|
||||||
|
size_t batch_idx = blockIdx.y;
|
||||||
|
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||||
|
if (idx < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
size_t j = blockIdx.x;
|
||||||
|
constexpr size_t num_pipeline_stages = 2;
|
||||||
|
constexpr size_t tile_size = tx * ty * vec_size;
|
||||||
|
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
|
||||||
|
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
|
||||||
|
__shared__ float y_warpwise[ty];
|
||||||
|
|
||||||
|
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||||
|
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||||
|
auto pipe = cuda::make_pipeline();
|
||||||
|
|
||||||
|
// pipeline load W/X and compute WX;
|
||||||
|
pipe.producer_acquire();
|
||||||
|
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
W + (idx * feat_out + j) * feat_in +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||||
|
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
X + (batch_idx * feat_in) +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||||
|
pipe.producer_commit();
|
||||||
|
size_t copy_idx, compute_idx;
|
||||||
|
float y = 0.f;
|
||||||
|
vec_t<in_T, vec_size> x_vec;
|
||||||
|
vec_t<W_T, vec_size> w_vec;
|
||||||
|
size_t tile_idx;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
|
||||||
|
++tile_idx) {
|
||||||
|
copy_idx = tile_idx % num_pipeline_stages;
|
||||||
|
// pipeline stage: async copy W fragment
|
||||||
|
pipe.producer_acquire();
|
||||||
|
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
|
||||||
|
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
W + (idx * feat_out + j) * feat_in +
|
||||||
|
tile_idx * tile_size +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||||
|
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
X + (batch_idx * feat_in) + tile_idx * tile_size +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||||
|
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||||
|
}
|
||||||
|
pipe.producer_commit();
|
||||||
|
|
||||||
|
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||||
|
// pipeline stage: compute WX
|
||||||
|
pipe.consumer_wait();
|
||||||
|
block.sync();
|
||||||
|
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
float sum = 0.f;
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < vec_size; ++i) {
|
||||||
|
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||||
|
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||||
|
}
|
||||||
|
y_warpwise[threadIdx.y] = sum;
|
||||||
|
block.sync();
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < ty; ++i) {
|
||||||
|
y += y_warpwise[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
block.sync();
|
||||||
|
pipe.consumer_release();
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||||
|
// final pipeline stage
|
||||||
|
pipe.consumer_wait();
|
||||||
|
block.sync();
|
||||||
|
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||||
|
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||||
|
float sum = 0.f;
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < vec_size; ++i) {
|
||||||
|
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||||
|
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||||
|
}
|
||||||
|
y_warpwise[threadIdx.y] =
|
||||||
|
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
|
||||||
|
? sum
|
||||||
|
: 0.f;
|
||||||
|
block.sync();
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < ty; ++i) {
|
||||||
|
y += y_warpwise[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
block.sync();
|
||||||
|
pipe.consumer_release();
|
||||||
|
|
||||||
|
// write Y;
|
||||||
|
if (block.thread_rank() == 0) {
|
||||||
|
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// nthrs = (2, 16, 4)
|
||||||
|
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
||||||
|
typename in_T, typename out_T, typename W_T>
|
||||||
|
__global__ void
|
||||||
|
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||||
|
float scale) {
|
||||||
|
size_t batch_idx = blockIdx.y;
|
||||||
|
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||||
|
|
||||||
|
if (idx < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto block = cg::this_thread_block();
|
||||||
|
size_t tile_idx = blockIdx.x;
|
||||||
|
|
||||||
|
// load X;
|
||||||
|
vec_t<in_T, vec_size> x_vec;
|
||||||
|
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
|
||||||
|
|
||||||
|
// load W;
|
||||||
|
vec_t<W_T, vec_size> w_vec;
|
||||||
|
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
|
||||||
|
block.thread_rank() * vec_size);
|
||||||
|
|
||||||
|
float sum = 0.f;
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t i = 0; i < vec_size; ++i) {
|
||||||
|
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
||||||
|
#pragma unroll
|
||||||
|
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||||
|
sum += g.shfl_down(sum, offset);
|
||||||
|
}
|
||||||
|
sum = g.shfl(sum, 0);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||||
|
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||||
|
typename W_T>
|
||||||
|
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||||
|
const W_T *__restrict__ W,
|
||||||
|
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||||
|
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||||
|
int64_t layer_idx, float scale) {
|
||||||
|
constexpr size_t vec_size = 8;
|
||||||
|
constexpr int tz = 4;
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
if constexpr (feat_in < feat_out) {
|
||||||
|
static_assert(feat_in % vec_size == 0);
|
||||||
|
constexpr int tx = feat_in / vec_size;
|
||||||
|
|
||||||
|
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
|
||||||
|
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
|
||||||
|
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
|
||||||
|
|
||||||
|
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
|
||||||
|
constexpr int ty = 32 / tx;
|
||||||
|
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||||
|
dim3 nthrs(tx, ty, tz);
|
||||||
|
|
||||||
|
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
|
||||||
|
constexpr int ty = 16 / tx;
|
||||||
|
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||||
|
dim3 nthrs(tx, ty, tz);
|
||||||
|
|
||||||
|
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
} else {
|
||||||
|
constexpr int ty = 8 / tx;
|
||||||
|
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||||
|
dim3 nthrs(tx, ty, tz);
|
||||||
|
|
||||||
|
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static_assert(feat_in % (vec_size * 32) == 0 ||
|
||||||
|
feat_in % (vec_size * 16) == 0 ||
|
||||||
|
feat_in % (vec_size * 8) == 0);
|
||||||
|
|
||||||
|
if constexpr (feat_in % (vec_size * 32) == 0) {
|
||||||
|
constexpr int tx = 32;
|
||||||
|
constexpr int ty = 4;
|
||||||
|
|
||||||
|
dim3 nblks(feat_out, batch_size);
|
||||||
|
dim3 nthrs(tx, ty);
|
||||||
|
|
||||||
|
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
|
||||||
|
vec_size * sizeof(W_T), tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
|
||||||
|
constexpr int tx = 32;
|
||||||
|
constexpr int ty = 4;
|
||||||
|
|
||||||
|
dim3 nblks(feat_out, batch_size);
|
||||||
|
dim3 nthrs(tx, ty);
|
||||||
|
|
||||||
|
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||||
|
vec_size * sizeof(in_T) / 2,
|
||||||
|
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
|
||||||
|
constexpr int tx = 16;
|
||||||
|
constexpr int ty = 4;
|
||||||
|
|
||||||
|
dim3 nblks(feat_out, batch_size);
|
||||||
|
dim3 nthrs(tx, ty);
|
||||||
|
|
||||||
|
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||||
|
vec_size * sizeof(in_T) / 2,
|
||||||
|
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||||
|
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||||
|
full_y_size, num_layers, layer_idx,
|
||||||
|
scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
|
||||||
|
template void bgmv_kernel<feat_in, feat_out>( \
|
||||||
|
out_T * __restrict__ Y, const in_T *__restrict__ X, \
|
||||||
|
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
|
||||||
|
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
|
||||||
|
int64_t num_layers, int64_t layer_idx, float scale);
|
||||||
|
|
||||||
|
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
|
||||||
|
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
|
||||||
|
INST_BGMV(wide, narrow, in_T, out_T, W_T)
|
1324
csrc/punica/bgmv/vec_dtypes.cuh
Normal file
1324
csrc/punica/bgmv/vec_dtypes.cuh
Normal file
File diff suppressed because it is too large
Load Diff
563
csrc/punica/punica_ops.cc
Normal file
563
csrc/punica/punica_ops.cc
Normal file
@ -0,0 +1,563 @@
|
|||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "bgmv/bgmv_config.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
//====== utils ======
|
||||||
|
|
||||||
|
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
|
||||||
|
const char *a_name, const char *b_name) {
|
||||||
|
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
|
||||||
|
a.dim(), " vs ", b.dim());
|
||||||
|
for (int i = 0; i < a.dim(); ++i) {
|
||||||
|
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
|
||||||
|
".size(", i, ")");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||||
|
return (uint32_t(a) << 16) | uint32_t(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||||
|
|
||||||
|
#define CHECK_CONTIGUOUS(x) \
|
||||||
|
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
|
||||||
|
#define CHECK_INPUT(x) \
|
||||||
|
CHECK_CUDA(x); \
|
||||||
|
CHECK_CONTIGUOUS(x)
|
||||||
|
|
||||||
|
#define CHECK_DIM(d, x) \
|
||||||
|
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||||
|
|
||||||
|
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
|
||||||
|
|
||||||
|
#define CHECK_EQ(a, b) \
|
||||||
|
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||||
|
|
||||||
|
//====== bgmv ======
|
||||||
|
|
||||||
|
template <typename in_T, typename out_T, typename W_T>
|
||||||
|
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||||
|
const int64_t *lora_indices,
|
||||||
|
uint16_t in_features, uint16_t out_features,
|
||||||
|
int64_t y_offset, int64_t full_y_size,
|
||||||
|
int64_t batch_size, int64_t num_layers,
|
||||||
|
int64_t layer_idx, float scale) {
|
||||||
|
switch (pack_u16(in_features, out_features)) {
|
||||||
|
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||||
|
case pack_u16(feat_in, feat_out): \
|
||||||
|
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
||||||
|
full_y_size, batch_size, num_layers, \
|
||||||
|
layer_idx, scale); \
|
||||||
|
break;
|
||||||
|
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
|
||||||
|
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
|
||||||
|
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
|
||||||
|
|
||||||
|
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
|
||||||
|
#undef CASE
|
||||||
|
#undef CASE_ONESIDE
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
|
torch::Tensor indicies, int64_t layer_idx, float scale) {
|
||||||
|
CHECK_INPUT(y);
|
||||||
|
CHECK_INPUT(x);
|
||||||
|
CHECK_INPUT(w);
|
||||||
|
CHECK_INPUT(indicies);
|
||||||
|
|
||||||
|
CHECK_DIM(2, y);
|
||||||
|
CHECK_DIM(2, x);
|
||||||
|
CHECK_DIM(4, w);
|
||||||
|
CHECK_DIM(1, indicies);
|
||||||
|
|
||||||
|
int64_t B = x.size(0);
|
||||||
|
int64_t h_in = x.size(1);
|
||||||
|
int64_t h_out = y.size(1);
|
||||||
|
int64_t num_layers = w.size(1);
|
||||||
|
CHECK_EQ(w.size(3), h_in);
|
||||||
|
CHECK_EQ(w.size(2), h_out);
|
||||||
|
CHECK_EQ(indicies.size(0), x.size(0));
|
||||||
|
CHECK_EQ(y.size(0), x.size(0));
|
||||||
|
bool ok = false;
|
||||||
|
if (h_in < 65536 && h_out < 65536) {
|
||||||
|
// TODO: See if we can get rid of this massive nested switch
|
||||||
|
switch (x.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||||
|
h_out, B, num_layers, layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||||
|
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||||
|
torch::Tensor indicies, int64_t layer_idx,
|
||||||
|
float scale, int64_t h_in, int64_t h_out,
|
||||||
|
int64_t y_offset) {
|
||||||
|
CHECK_INPUT(y);
|
||||||
|
CHECK_INPUT(x);
|
||||||
|
CHECK_INPUT(w);
|
||||||
|
CHECK_INPUT(indicies);
|
||||||
|
|
||||||
|
CHECK_DIM(2, y);
|
||||||
|
CHECK_DIM(2, x);
|
||||||
|
CHECK_DIM(4, w);
|
||||||
|
CHECK_DIM(1, indicies);
|
||||||
|
|
||||||
|
int64_t B = x.size(0);
|
||||||
|
int64_t num_layers = w.size(1);
|
||||||
|
int64_t full_y_size = y.size(1);
|
||||||
|
CHECK_EQ(w.size(3), h_in);
|
||||||
|
CHECK_EQ(w.size(2), h_out);
|
||||||
|
CHECK_EQ(indicies.size(0), x.size(0));
|
||||||
|
CHECK_EQ(y.size(0), x.size(0));
|
||||||
|
bool ok = false;
|
||||||
|
if (h_in < 65536 && h_out < 65536) {
|
||||||
|
// TODO: See if we can get rid of this massive nested switch
|
||||||
|
switch (x.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_half *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (y.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case at::ScalarType::Float:
|
||||||
|
switch (w.scalar_type()) {
|
||||||
|
case at::ScalarType::Half:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_half *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
case at::ScalarType::BFloat16:
|
||||||
|
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||||
|
static_cast<float *>(x.data_ptr()),
|
||||||
|
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||||
|
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||||
|
y_offset, full_y_size, B, num_layers,
|
||||||
|
layer_idx, scale);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||||
|
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//====== pybind ======
|
||||||
|
|
||||||
|
#define DEFINE_pybind(name) m.def(#name, &name, #name);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
|
||||||
|
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
|
||||||
|
"dispatch_bgmv_low_level");
|
||||||
|
}
|
117
examples/multilora_inference.py
Normal file
117
examples/multilora_inference.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
"""
|
||||||
|
This example shows how to use the multi-LoRA functionality for offline inference.
|
||||||
|
|
||||||
|
Requires HuggingFace credentials for access to Llama2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
|
||||||
|
"""Create a list of test prompts with their sampling parameters.
|
||||||
|
|
||||||
|
2 requests for base model, 4 requests for the LoRA. We define 2
|
||||||
|
different LoRA adapters (using the same model for demo purposes).
|
||||||
|
Since we also set `max_loras=1`, the expectation is that the requests
|
||||||
|
with the second LoRA adapter will be ran after all requests with the
|
||||||
|
first adapter have finished.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
("A robot may not injure a human being",
|
||||||
|
SamplingParams(temperature=0.0,
|
||||||
|
logprobs=1,
|
||||||
|
prompt_logprobs=1,
|
||||||
|
max_tokens=128), None),
|
||||||
|
("To be or not to be,",
|
||||||
|
SamplingParams(temperature=0.8,
|
||||||
|
top_k=5,
|
||||||
|
presence_penalty=0.2,
|
||||||
|
max_tokens=128), None),
|
||||||
|
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
|
||||||
|
SamplingParams(temperature=0.0,
|
||||||
|
logprobs=1,
|
||||||
|
prompt_logprobs=1,
|
||||||
|
max_tokens=128,
|
||||||
|
stop_token_ids=[32003]),
|
||||||
|
LoRARequest("sql-lora", 1, lora_path)),
|
||||||
|
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
|
||||||
|
SamplingParams(n=3,
|
||||||
|
best_of=3,
|
||||||
|
use_beam_search=True,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=128,
|
||||||
|
stop_token_ids=[32003]),
|
||||||
|
LoRARequest("sql-lora", 1, lora_path)),
|
||||||
|
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
|
||||||
|
SamplingParams(temperature=0.0,
|
||||||
|
logprobs=1,
|
||||||
|
prompt_logprobs=1,
|
||||||
|
max_tokens=128,
|
||||||
|
stop_token_ids=[32003]),
|
||||||
|
LoRARequest("sql-lora2", 2, lora_path)),
|
||||||
|
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
|
||||||
|
SamplingParams(n=3,
|
||||||
|
best_of=3,
|
||||||
|
use_beam_search=True,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=128,
|
||||||
|
stop_token_ids=[32003]),
|
||||||
|
LoRARequest("sql-lora", 1, lora_path)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def process_requests(engine: LLMEngine,
|
||||||
|
test_prompts: List[Tuple[str, SamplingParams,
|
||||||
|
Optional[LoRARequest]]]):
|
||||||
|
"""Continuously process a list of prompts and handle the outputs."""
|
||||||
|
request_id = 0
|
||||||
|
|
||||||
|
while test_prompts or engine.has_unfinished_requests():
|
||||||
|
if test_prompts:
|
||||||
|
prompt, sampling_params, lora_request = test_prompts.pop(0)
|
||||||
|
engine.add_request(str(request_id),
|
||||||
|
prompt,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=lora_request)
|
||||||
|
request_id += 1
|
||||||
|
|
||||||
|
request_outputs: List[RequestOutput] = engine.step()
|
||||||
|
|
||||||
|
for request_output in request_outputs:
|
||||||
|
if request_output.finished:
|
||||||
|
print(request_output)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_engine() -> LLMEngine:
|
||||||
|
"""Initialize the LLMEngine."""
|
||||||
|
# max_loras: controls the number of LoRAs that can be used in the same
|
||||||
|
# batch. Larger numbers will cause higher memory usage, as each LoRA
|
||||||
|
# slot requires its own preallocated tensor.
|
||||||
|
# max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
|
||||||
|
# numbers will cause higher memory usage. If you know that all LoRAs will
|
||||||
|
# use the same rank, it is recommended to set this as low as possible.
|
||||||
|
# max_cpu_loras: controls the size of the CPU LoRA cache.
|
||||||
|
engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=1,
|
||||||
|
max_lora_rank=8,
|
||||||
|
max_cpu_loras=2,
|
||||||
|
max_num_seqs=256)
|
||||||
|
return LLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function that sets up and runs the prompt processing."""
|
||||||
|
engine = initialize_engine()
|
||||||
|
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||||
|
test_prompts = create_test_prompts(lora_path)
|
||||||
|
process_requests(engine, test_prompts)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
59
setup.py
59
setup.py
@ -1,13 +1,16 @@
|
|||||||
|
import contextlib
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import List, Set
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Set
|
||||||
|
|
||||||
from packaging.version import parse, Version
|
from packaging.version import parse, Version
|
||||||
import setuptools
|
import setuptools
|
||||||
import torch
|
import torch
|
||||||
|
import torch.utils.cpp_extension as torch_cpp_ext
|
||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
|
||||||
|
|
||||||
ROOT_DIR = os.path.dirname(__file__)
|
ROOT_DIR = os.path.dirname(__file__)
|
||||||
@ -28,7 +31,7 @@ def _is_neuron() -> bool:
|
|||||||
torch_neuronx_installed = True
|
torch_neuronx_installed = True
|
||||||
try:
|
try:
|
||||||
subprocess.run(["neuron-ls"], capture_output=True, check=True)
|
subprocess.run(["neuron-ls"], capture_output=True, check=True)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError:
|
||||||
torch_neuronx_installed = False
|
torch_neuronx_installed = False
|
||||||
return torch_neuronx_installed
|
return torch_neuronx_installed
|
||||||
|
|
||||||
@ -96,10 +99,16 @@ def get_hipcc_rocm_version():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def glob(pattern: str):
|
||||||
|
root = Path(__name__).parent
|
||||||
|
return [str(p) for p in root.glob(pattern)]
|
||||||
|
|
||||||
|
|
||||||
def get_neuronxcc_version():
|
def get_neuronxcc_version():
|
||||||
import sysconfig
|
import sysconfig
|
||||||
site_dir = sysconfig.get_paths()["purelib"]
|
site_dir = sysconfig.get_paths()["purelib"]
|
||||||
version_file = os.path.join(site_dir, "neuronxcc", "version", "__init__.py")
|
version_file = os.path.join(site_dir, "neuronxcc", "version",
|
||||||
|
"__init__.py")
|
||||||
|
|
||||||
# Check if the command was executed successfully
|
# Check if the command was executed successfully
|
||||||
with open(version_file, "rt") as fp:
|
with open(version_file, "rt") as fp:
|
||||||
@ -178,6 +187,8 @@ if _is_cuda() and not compute_capabilities:
|
|||||||
"GPUs with compute capability below 7.0 are not supported.")
|
"GPUs with compute capability below 7.0 are not supported.")
|
||||||
compute_capabilities.add(f"{major}.{minor}")
|
compute_capabilities.add(f"{major}.{minor}")
|
||||||
|
|
||||||
|
ext_modules = []
|
||||||
|
|
||||||
if _is_cuda():
|
if _is_cuda():
|
||||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||||
if not compute_capabilities:
|
if not compute_capabilities:
|
||||||
@ -215,6 +226,8 @@ if _is_cuda():
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
||||||
|
|
||||||
|
NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy()
|
||||||
|
|
||||||
# Add target compute capabilities to NVCC flags.
|
# Add target compute capabilities to NVCC flags.
|
||||||
for capability in compute_capabilities:
|
for capability in compute_capabilities:
|
||||||
num = capability[0] + capability[2]
|
num = capability[0] + capability[2]
|
||||||
@ -223,6 +236,14 @@ if _is_cuda():
|
|||||||
NVCC_FLAGS += [
|
NVCC_FLAGS += [
|
||||||
"-gencode", f"arch=compute_{num},code=compute_{num}"
|
"-gencode", f"arch=compute_{num},code=compute_{num}"
|
||||||
]
|
]
|
||||||
|
if int(capability[0]) >= 8:
|
||||||
|
NVCC_FLAGS_PUNICA += [
|
||||||
|
"-gencode", f"arch=compute_{num},code=sm_{num}"
|
||||||
|
]
|
||||||
|
if capability.endswith("+PTX"):
|
||||||
|
NVCC_FLAGS_PUNICA += [
|
||||||
|
"-gencode", f"arch=compute_{num},code=compute_{num}"
|
||||||
|
]
|
||||||
|
|
||||||
# Use NVCC threads to parallelize the build.
|
# Use NVCC threads to parallelize the build.
|
||||||
if nvcc_cuda_version >= Version("11.2"):
|
if nvcc_cuda_version >= Version("11.2"):
|
||||||
@ -230,6 +251,36 @@ if _is_cuda():
|
|||||||
num_threads = min(os.cpu_count(), nvcc_threads)
|
num_threads = min(os.cpu_count(), nvcc_threads)
|
||||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||||
|
|
||||||
|
# changes for punica kernels
|
||||||
|
NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS
|
||||||
|
REMOVE_NVCC_FLAGS = [
|
||||||
|
'-D__CUDA_NO_HALF_OPERATORS__',
|
||||||
|
'-D__CUDA_NO_HALF_CONVERSIONS__',
|
||||||
|
'-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
|
||||||
|
'-D__CUDA_NO_HALF2_OPERATORS__',
|
||||||
|
]
|
||||||
|
for flag in REMOVE_NVCC_FLAGS:
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)
|
||||||
|
|
||||||
|
install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "1")))
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
for i in range(device_count):
|
||||||
|
major, minor = torch.cuda.get_device_capability(i)
|
||||||
|
if major < 8:
|
||||||
|
install_punica = False
|
||||||
|
break
|
||||||
|
if install_punica:
|
||||||
|
ext_modules.append(
|
||||||
|
CUDAExtension(
|
||||||
|
name="vllm._punica_C",
|
||||||
|
sources=["csrc/punica/punica_ops.cc"] +
|
||||||
|
glob("csrc/punica/bgmv/*.cu"),
|
||||||
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS_PUNICA,
|
||||||
|
},
|
||||||
|
))
|
||||||
elif _is_hip():
|
elif _is_hip():
|
||||||
amd_arch = get_amdgpu_offload_arch()
|
amd_arch = get_amdgpu_offload_arch()
|
||||||
if amd_arch not in ROCM_SUPPORTED_ARCHS:
|
if amd_arch not in ROCM_SUPPORTED_ARCHS:
|
||||||
@ -240,8 +291,6 @@ elif _is_hip():
|
|||||||
elif _is_neuron():
|
elif _is_neuron():
|
||||||
neuronxcc_version = get_neuronxcc_version()
|
neuronxcc_version = get_neuronxcc_version()
|
||||||
|
|
||||||
ext_modules = []
|
|
||||||
|
|
||||||
vllm_extension_sources = [
|
vllm_extension_sources = [
|
||||||
"csrc/cache_kernels.cu",
|
"csrc/cache_kernels.cu",
|
||||||
"csrc/attention/attention_kernels.cu",
|
"csrc/attention/attention_kernels.cu",
|
||||||
|
@ -25,6 +25,13 @@ class MockEngine:
|
|||||||
return [RequestOutput(
|
return [RequestOutput(
|
||||||
request_id=self.request_id)] if self.request_id else []
|
request_id=self.request_id)] if self.request_id else []
|
||||||
|
|
||||||
|
async def encode_request_async(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return [1]
|
||||||
|
|
||||||
def generate(self, request_id):
|
def generate(self, request_id):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
|
||||||
@ -35,6 +42,10 @@ class MockEngine:
|
|||||||
del kwargs # Unused
|
del kwargs # Unused
|
||||||
self.add_request_calls += 1
|
self.add_request_calls += 1
|
||||||
|
|
||||||
|
async def add_request_async(self, **kwargs):
|
||||||
|
del kwargs # Unused
|
||||||
|
self.add_request_calls += 1
|
||||||
|
|
||||||
def abort_request(self, request_id):
|
def abort_request(self, request_id):
|
||||||
del request_id # Unused
|
del request_id # Unused
|
||||||
self.abort_request_calls += 1
|
self.abort_request_calls += 1
|
||||||
|
0
tests/lora/__init__.py
Normal file
0
tests/lora/__init__.py
Normal file
143
tests/lora/conftest.py
Normal file
143
tests/lora/conftest.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
import contextlib
|
||||||
|
import gc
|
||||||
|
import tempfile
|
||||||
|
from collections import OrderedDict
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
destroy_model_parallel, initialize_model_parallel)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
destroy_model_parallel()
|
||||||
|
with contextlib.suppress(AssertionError):
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def cleanup_fixture():
|
||||||
|
yield
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dist_init():
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
temp_file = tempfile.mkstemp()[1]
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
world_size=1,
|
||||||
|
rank=0,
|
||||||
|
init_method=f"file://{temp_file}",
|
||||||
|
)
|
||||||
|
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||||
|
initialize_model_parallel(1, 1)
|
||||||
|
yield
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dist_init_torch_only():
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
return
|
||||||
|
temp_file = tempfile.mkstemp()[1]
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
world_size=1,
|
||||||
|
rank=0,
|
||||||
|
init_method=f"file://{temp_file}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_model() -> nn.Module:
|
||||||
|
model = nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
("dense1", ColumnParallelLinear(764, 100)),
|
||||||
|
("dense2", RowParallelLinear(100, 50)),
|
||||||
|
(
|
||||||
|
"layer1",
|
||||||
|
nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
("dense1", ColumnParallelLinear(100, 10)),
|
||||||
|
("dense2", RowParallelLinear(10, 50)),
|
||||||
|
])),
|
||||||
|
),
|
||||||
|
("act2", nn.ReLU()),
|
||||||
|
("output", ColumnParallelLinear(50, 10)),
|
||||||
|
("outact", nn.Sigmoid()),
|
||||||
|
# Special handling for lm_head & sampler
|
||||||
|
("lm_head", ParallelLMHead(512, 10)),
|
||||||
|
("sampler", Sampler(512))
|
||||||
|
]))
|
||||||
|
model.config = MagicMock()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_model_gate_up() -> nn.Module:
|
||||||
|
model = nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
("dense1", ColumnParallelLinear(764, 100)),
|
||||||
|
("dense2", RowParallelLinear(100, 50)),
|
||||||
|
(
|
||||||
|
"layer1",
|
||||||
|
nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
("dense1", ColumnParallelLinear(100, 10)),
|
||||||
|
("dense2", RowParallelLinear(10, 50)),
|
||||||
|
])),
|
||||||
|
),
|
||||||
|
("act2", nn.ReLU()),
|
||||||
|
("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
|
||||||
|
("outact", nn.Sigmoid()),
|
||||||
|
# Special handling for lm_head & sampler
|
||||||
|
("lm_head", ParallelLMHead(512, 10)),
|
||||||
|
("sampler", Sampler(512))
|
||||||
|
]))
|
||||||
|
model.config = MagicMock()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def sql_lora_files():
|
||||||
|
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||||
|
cleanup()
|
||||||
|
get_model_old = get_model
|
||||||
|
|
||||||
|
def get_model_patched(model_config, lora_config=None):
|
||||||
|
return get_model_old(model_config,
|
||||||
|
LoRAConfig(max_loras=4, max_lora_rank=8))
|
||||||
|
|
||||||
|
with patch("vllm.worker.model_runner.get_model", get_model_patched):
|
||||||
|
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
|
||||||
|
yield engine.llm_engine
|
||||||
|
del engine
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llama_2_7b_model_extra_embeddings(
|
||||||
|
llama_2_7b_engine_extra_embeddings) -> nn.Module:
|
||||||
|
yield llama_2_7b_engine_extra_embeddings.driver_worker.model_runner.model
|
709
tests/lora/test_layers.py
Normal file
709
tests/lora/test_layers.py
Normal file
@ -0,0 +1,709 @@
|
|||||||
|
import pytest
|
||||||
|
import random
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional, Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm.lora.layers import (
|
||||||
|
ColumnParallelLinearWithLoRA,
|
||||||
|
MergedColumnParallelLinearWithLoRA,
|
||||||
|
QKVParallelLinearWithLora,
|
||||||
|
VocabParallelEmbeddingWithLoRA,
|
||||||
|
RowParallelLinearWithLoRA,
|
||||||
|
SamplerWithLoRA,
|
||||||
|
LoRAMapping,
|
||||||
|
BaseLayerWithLoRA,
|
||||||
|
)
|
||||||
|
from vllm.lora.models import LoRALayerWeights, convert_mapping, PackedLoRALayerWeights
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
QKVParallelLinear)
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
|
||||||
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
|
||||||
|
from .utils import DummyLoRAManager
|
||||||
|
|
||||||
|
TOLERANCES = {
|
||||||
|
torch.float16: (5e-3, 5e-3),
|
||||||
|
torch.float32: (5e-3, 5e-3),
|
||||||
|
torch.bfloat16: (3e-2, 2e-2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_id_to_index(num_loras: int,
|
||||||
|
num_slots: int,
|
||||||
|
log: bool = True) -> List[Optional[int]]:
|
||||||
|
"""Creates a random lora_id_to_index mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_loras: The number of active loras in the mapping.
|
||||||
|
num_slots: The number of slots in the mapping. Must be larger
|
||||||
|
than num_loras.
|
||||||
|
log: Whether to log the output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if num_loras > num_slots:
|
||||||
|
raise ValueError(
|
||||||
|
f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
|
||||||
|
"num_loras must be less than or equal to num_slots.")
|
||||||
|
|
||||||
|
slots: List[Optional[int]] = [None] * num_slots
|
||||||
|
random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
|
||||||
|
for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
|
||||||
|
slots[slot_idx] = lora_id
|
||||||
|
|
||||||
|
if log:
|
||||||
|
print(f"Created lora_id_to_index mapping: {slots}.")
|
||||||
|
|
||||||
|
return slots
|
||||||
|
|
||||||
|
|
||||||
|
def populate_loras(
|
||||||
|
id_to_index: List[Optional[int]],
|
||||||
|
layer: BaseLayerWithLoRA,
|
||||||
|
layer_weights: torch.Tensor,
|
||||||
|
generate_embeddings_tensor: int = 0,
|
||||||
|
repeats: int = 1,
|
||||||
|
) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]:
|
||||||
|
"""This method populates the lora layers with lora weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id_to_index: a list of lora ids. The index of the lora id
|
||||||
|
represents which memory slot the lora matrices are
|
||||||
|
stored in. A None value indicates a free slot.
|
||||||
|
layer: the LoRAlayer to populate.
|
||||||
|
layer_weights: the PyTorch tensor containing the layer's
|
||||||
|
weights.
|
||||||
|
generate_embeddings_tensor: whether to generate an
|
||||||
|
embeddings tensor for each LoRA.
|
||||||
|
repeats: must only be set for column parallel packed
|
||||||
|
layers. Indicates the number of loras to compose
|
||||||
|
together to create a single lora layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Dictionary that maps the lora ID to the
|
||||||
|
# corresponding lora weights.
|
||||||
|
lora_dict: Dict[int, LoRALayerWeights] = dict()
|
||||||
|
|
||||||
|
# Dictionary that maps the lora ID to the
|
||||||
|
# corresponding subloras. Only useful when
|
||||||
|
# repeats > 1.
|
||||||
|
sublora_dict: Dict[int, List[LoRALayerWeights]] = dict()
|
||||||
|
|
||||||
|
for slot_idx, lora_id in enumerate(id_to_index):
|
||||||
|
if lora_id is not None:
|
||||||
|
subloras = []
|
||||||
|
sublora_len = layer_weights.shape[0] // repeats
|
||||||
|
for i in range(repeats):
|
||||||
|
sublora = DummyLoRAManager().init_random_lora(
|
||||||
|
module_name=f"fake_{i}",
|
||||||
|
weight=layer_weights,
|
||||||
|
generate_embeddings_tensor=generate_embeddings_tensor,
|
||||||
|
)
|
||||||
|
sublora.lora_b = sublora.lora_b[:, (sublora_len *
|
||||||
|
i):(sublora_len * (i + 1))]
|
||||||
|
sublora.optimize()
|
||||||
|
subloras.append(sublora)
|
||||||
|
|
||||||
|
lora = PackedLoRALayerWeights.pack(
|
||||||
|
subloras) if repeats > 1 else subloras[0]
|
||||||
|
|
||||||
|
layer.set_lora(
|
||||||
|
slot_idx,
|
||||||
|
lora_a=lora.lora_a,
|
||||||
|
lora_b=lora.lora_b,
|
||||||
|
embeddings_tensor=lora.embeddings_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_dict[lora_id] = lora
|
||||||
|
sublora_dict[lora_id] = subloras
|
||||||
|
|
||||||
|
return lora_dict, sublora_dict
|
||||||
|
|
||||||
|
|
||||||
|
def create_random_inputs(
|
||||||
|
active_lora_ids: List[int],
|
||||||
|
num_inputs: int,
|
||||||
|
input_size: Tuple[int, ...],
|
||||||
|
input_range: Tuple[float, float],
|
||||||
|
input_type: torch.dtype = torch.int,
|
||||||
|
) -> Tuple[List[torch.Tensor], List[int], List[int]]:
|
||||||
|
"""Creates random inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
active_lora_ids: lora IDs of active lora weights.
|
||||||
|
num_inputs: the number of inputs to create.
|
||||||
|
input_size: the size of each individual input.
|
||||||
|
input_range: the range of values to include in the input.
|
||||||
|
input_range[0] <= possible input values < input_range[1]
|
||||||
|
input_type: the type of values in the input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
low, high = input_range
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = [], [], []
|
||||||
|
for _ in range(num_inputs):
|
||||||
|
if input_type == torch.int:
|
||||||
|
inputs.append(
|
||||||
|
torch.randint(low=int(low),
|
||||||
|
high=int(high),
|
||||||
|
size=input_size,
|
||||||
|
device="cuda"))
|
||||||
|
else:
|
||||||
|
inputs.append(
|
||||||
|
torch.rand(size=input_size, dtype=input_type, device="cuda") *
|
||||||
|
high + low)
|
||||||
|
|
||||||
|
lora_id = random.choice(active_lora_ids)
|
||||||
|
index_mapping += [lora_id] * input_size[0]
|
||||||
|
prompt_mapping += [lora_id]
|
||||||
|
|
||||||
|
return inputs, index_mapping, prompt_mapping
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
|
def test_embeddings(dist_init, num_loras) -> None:
|
||||||
|
|
||||||
|
max_loras = 8
|
||||||
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
|
max_lora_rank=8,
|
||||||
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
|
def create_random_embedding_layer():
|
||||||
|
embedding = VocabParallelEmbedding(512, 256)
|
||||||
|
embedding.weight.data = torch.rand_like(embedding.weight.data)
|
||||||
|
embedding.weight.data[512:, :] = 0
|
||||||
|
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
|
||||||
|
lora_embedding.create_lora_weights(max_loras, lora_config)
|
||||||
|
|
||||||
|
return embedding, lora_embedding
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
set_random_seed(i)
|
||||||
|
|
||||||
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
|
embedding, lora_embedding = create_random_embedding_layer()
|
||||||
|
|
||||||
|
lora_dict, _ = populate_loras(
|
||||||
|
id_to_index,
|
||||||
|
layer=lora_embedding,
|
||||||
|
layer_weights=embedding.weight.T,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=list(lora_dict.keys()),
|
||||||
|
num_inputs=num_loras * 3,
|
||||||
|
input_size=(200, ),
|
||||||
|
input_range=(1, 512),
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
|
512, lora_config.lora_extra_vocab_size)
|
||||||
|
lora_embedding.set_mapping(*mapping_info)
|
||||||
|
|
||||||
|
lora_result = lora_embedding(torch.cat(inputs))
|
||||||
|
|
||||||
|
expected_results = []
|
||||||
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||||
|
lora = lora_dict[lora_id]
|
||||||
|
result = embedding(input_)
|
||||||
|
after_a = F.embedding(
|
||||||
|
input_,
|
||||||
|
lora.lora_a,
|
||||||
|
)
|
||||||
|
result += (after_a @ lora.lora_b)
|
||||||
|
expected_results.append(result)
|
||||||
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
# Check that resetting the lora weights succeeds
|
||||||
|
|
||||||
|
for slot_idx in range(max_loras):
|
||||||
|
lora_embedding.reset_lora(slot_idx)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=[0],
|
||||||
|
num_inputs=num_loras * 3,
|
||||||
|
input_size=(200, ),
|
||||||
|
input_range=(1, 512),
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
|
512, lora_config.lora_extra_vocab_size)
|
||||||
|
lora_embedding.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
lora_result = lora_embedding(torch.cat(inputs))
|
||||||
|
expected_result = embedding(torch.cat(inputs))
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.")
|
||||||
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
|
def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None:
|
||||||
|
|
||||||
|
max_loras = 8
|
||||||
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
|
max_lora_rank=8,
|
||||||
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
|
def create_random_embedding_layer():
|
||||||
|
embedding = VocabParallelEmbedding(512, 256)
|
||||||
|
embedding_data = torch.rand_like(embedding.weight.data)
|
||||||
|
embedding.weight.data = embedding_data
|
||||||
|
embedding.weight.data[512:, :] = 0
|
||||||
|
expanded_embedding = VocabParallelEmbedding(
|
||||||
|
512 + lora_config.lora_extra_vocab_size * max_loras,
|
||||||
|
256,
|
||||||
|
org_num_embeddings=512)
|
||||||
|
expanded_embedding.weight.data[:512, :] = embedding_data
|
||||||
|
# We need to deepcopy the embedding as it will be modifed
|
||||||
|
# in place
|
||||||
|
lora_embedding = VocabParallelEmbeddingWithLoRA(
|
||||||
|
deepcopy(expanded_embedding))
|
||||||
|
lora_embedding.create_lora_weights(max_loras, lora_config)
|
||||||
|
|
||||||
|
return expanded_embedding, lora_embedding
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
set_random_seed(i)
|
||||||
|
|
||||||
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
|
expanded_embedding, lora_embedding = create_random_embedding_layer()
|
||||||
|
lora_dict, _ = populate_loras(
|
||||||
|
id_to_index,
|
||||||
|
layer=lora_embedding,
|
||||||
|
layer_weights=torch.zeros(
|
||||||
|
(256, 512 + lora_config.lora_extra_vocab_size)),
|
||||||
|
generate_embeddings_tensor=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All embeddings tensors have the same shape.
|
||||||
|
embeddings_tensors = [
|
||||||
|
lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
|
||||||
|
]
|
||||||
|
embeddings_tensor_len = embeddings_tensors[0].shape[0]
|
||||||
|
|
||||||
|
# Add empty embeddings_tensors for unoccupied lora slots.
|
||||||
|
for _ in range(max_loras - len(embeddings_tensors)):
|
||||||
|
embeddings_tensors.append(
|
||||||
|
torch.zeros(embeddings_tensors[0].shape, device="cuda"))
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=list(lora_dict.keys()),
|
||||||
|
num_inputs=num_loras * 3,
|
||||||
|
input_size=(200, ),
|
||||||
|
input_range=(1, 512),
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
original_inputs = deepcopy(inputs)
|
||||||
|
|
||||||
|
# Force some of the inputs to be in the extended embeddings range
|
||||||
|
# to guarantee that their behavior is tested.
|
||||||
|
for input_, original_input_, lora_id in zip(inputs, original_inputs,
|
||||||
|
prompt_mapping):
|
||||||
|
embedding_id = lora_id - 1
|
||||||
|
input_[-1] = 512 + (embedding_id * embeddings_tensor_len)
|
||||||
|
original_input_[-1] = 512
|
||||||
|
input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1)
|
||||||
|
original_input_[-2] = 512 + embeddings_tensor_len - 1
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
|
512, lora_config.lora_extra_vocab_size)
|
||||||
|
lora_embedding.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
expanded_embedding.weight[512:512 +
|
||||||
|
(embeddings_tensor_len *
|
||||||
|
max_loras)] = torch.cat(embeddings_tensors)
|
||||||
|
|
||||||
|
lora_result = lora_embedding(torch.cat(original_inputs))
|
||||||
|
|
||||||
|
expected_results = []
|
||||||
|
for input_, original_input_, lora_id in zip(inputs, original_inputs,
|
||||||
|
prompt_mapping):
|
||||||
|
lora = lora_dict[lora_id]
|
||||||
|
result = expanded_embedding(input_)
|
||||||
|
after_a = F.embedding(
|
||||||
|
original_input_,
|
||||||
|
lora.lora_a,
|
||||||
|
)
|
||||||
|
result += (after_a @ lora.lora_b)
|
||||||
|
expected_results.append(result)
|
||||||
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
# Check that resetting the lora weights succeeds
|
||||||
|
|
||||||
|
for slot_idx in range(max_loras):
|
||||||
|
lora_embedding.reset_lora(slot_idx)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=[0],
|
||||||
|
num_inputs=num_loras * 3,
|
||||||
|
input_size=(200, ),
|
||||||
|
input_range=(1, 512),
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
original_inputs = deepcopy(inputs)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
|
512, lora_config.lora_extra_vocab_size)
|
||||||
|
lora_embedding.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
lora_result = lora_embedding(torch.cat(original_inputs))
|
||||||
|
expected_result = expanded_embedding(torch.cat(inputs))
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
|
def test_lm_head_sampler(dist_init, num_loras) -> None:
|
||||||
|
|
||||||
|
max_loras = 8
|
||||||
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
|
max_lora_rank=8,
|
||||||
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
|
def create_random_sampler_layer():
|
||||||
|
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
|
||||||
|
1024, 32000)
|
||||||
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
|
linear.weight.data[:, 32000:] = 0
|
||||||
|
sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000)
|
||||||
|
lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype,
|
||||||
|
linear.weight.device)
|
||||||
|
lora_sampler.create_lora_weights(max_loras, lora_config)
|
||||||
|
|
||||||
|
return linear, sampler, lora_sampler
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
set_random_seed(i)
|
||||||
|
|
||||||
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
|
linear, sampler, lora_sampler = create_random_sampler_layer()
|
||||||
|
|
||||||
|
# NOTE: all the generated loras share the same embeddings tensor.
|
||||||
|
lora_dict, _ = populate_loras(
|
||||||
|
id_to_index,
|
||||||
|
layer=lora_sampler,
|
||||||
|
layer_weights=linear.weight,
|
||||||
|
generate_embeddings_tensor=1024,
|
||||||
|
)
|
||||||
|
embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor
|
||||||
|
embeddings_tensor_len = embeddings_tensor.shape[0]
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=list(lora_dict.keys()),
|
||||||
|
num_inputs=8 * num_loras, # * 3,
|
||||||
|
input_size=(1, 1024),
|
||||||
|
input_range=(0, 1),
|
||||||
|
input_type=torch.float32,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
input_ = torch.rand(20, 1024, device="cuda")
|
||||||
|
mapping_info = convert_mapping(
|
||||||
|
lora_mapping,
|
||||||
|
id_to_index,
|
||||||
|
max_loras,
|
||||||
|
32000,
|
||||||
|
lora_config.lora_extra_vocab_size,
|
||||||
|
)
|
||||||
|
lora_sampler.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
|
||||||
|
embedding=linear.weight,
|
||||||
|
embedding_bias=None)
|
||||||
|
|
||||||
|
original_weight = linear.weight.clone()
|
||||||
|
|
||||||
|
linear.weight[sampler.org_vocab_size:sampler.org_vocab_size +
|
||||||
|
embeddings_tensor_len] = embeddings_tensor
|
||||||
|
|
||||||
|
sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size
|
||||||
|
expected_results = []
|
||||||
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||||
|
lora = lora_dict[lora_id]
|
||||||
|
result = sampler._get_logits(hidden_states=input_,
|
||||||
|
embedding=linear.weight,
|
||||||
|
embedding_bias=None)
|
||||||
|
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
|
||||||
|
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||||
|
expected_results.append(result)
|
||||||
|
expected_result = torch.cat(expected_results)
|
||||||
|
sampler.org_vocab_size = 32000
|
||||||
|
|
||||||
|
# Check that resetting the lora weights succeeds
|
||||||
|
|
||||||
|
for slot_idx in range(max_loras):
|
||||||
|
lora_sampler.reset_lora(slot_idx)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=[0],
|
||||||
|
num_inputs=8 * num_loras * 3,
|
||||||
|
input_size=(1, 1024),
|
||||||
|
input_range=(0, 1),
|
||||||
|
input_type=torch.float32,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
|
32000,
|
||||||
|
lora_config.lora_extra_vocab_size)
|
||||||
|
lora_sampler.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs),
|
||||||
|
embedding=original_weight,
|
||||||
|
embedding_bias=None)[:, :32000]
|
||||||
|
expected_result = sampler._get_logits(hidden_states=torch.cat(inputs),
|
||||||
|
embedding=original_weight,
|
||||||
|
embedding_bias=None)
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
|
@pytest.mark.parametrize("orientation", ["row", "column"])
|
||||||
|
def test_linear_parallel(dist_init, num_loras, orientation) -> None:
|
||||||
|
|
||||||
|
max_loras = 8
|
||||||
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
|
max_lora_rank=8,
|
||||||
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
|
def create_random_linear_parallel_layer():
|
||||||
|
if orientation == "row":
|
||||||
|
linear = RowParallelLinear(4096, 4096, bias=False)
|
||||||
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
|
lora_linear = RowParallelLinearWithLoRA(linear)
|
||||||
|
else:
|
||||||
|
linear = ColumnParallelLinear(4096, 4096, bias=False)
|
||||||
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
|
lora_linear = ColumnParallelLinearWithLoRA(linear)
|
||||||
|
lora_linear.create_lora_weights(max_loras, lora_config)
|
||||||
|
|
||||||
|
return linear, lora_linear
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
set_random_seed(i)
|
||||||
|
|
||||||
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
|
linear, lora_linear = create_random_linear_parallel_layer()
|
||||||
|
|
||||||
|
lora_dict, _ = populate_loras(
|
||||||
|
id_to_index,
|
||||||
|
layer=lora_linear,
|
||||||
|
layer_weights=linear.weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=list(lora_dict.keys()),
|
||||||
|
num_inputs=32 * num_loras,
|
||||||
|
input_size=(1, 4096),
|
||||||
|
input_range=(0, 1),
|
||||||
|
input_type=torch.float32,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(
|
||||||
|
lora_mapping,
|
||||||
|
id_to_index,
|
||||||
|
max_loras,
|
||||||
|
512,
|
||||||
|
lora_config.lora_extra_vocab_size,
|
||||||
|
)
|
||||||
|
lora_linear.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||||
|
|
||||||
|
expected_results = []
|
||||||
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||||
|
lora = lora_dict[lora_id]
|
||||||
|
result = linear(input_)[0]
|
||||||
|
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||||
|
expected_results.append(result)
|
||||||
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
# Check that resetting the lora weights succeeds
|
||||||
|
|
||||||
|
for slot_idx in range(max_loras):
|
||||||
|
lora_linear.reset_lora(slot_idx)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=[0],
|
||||||
|
num_inputs=32 * num_loras,
|
||||||
|
input_size=(1, 4096),
|
||||||
|
input_range=(0, 1),
|
||||||
|
input_type=torch.float32,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||||
|
512, lora_config.lora_extra_vocab_size)
|
||||||
|
lora_linear.set_mapping(*mapping_info, )
|
||||||
|
|
||||||
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||||
|
expected_result = linear(torch.cat(inputs))[0]
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
|
@pytest.mark.parametrize("repeats", [2, 3])
|
||||||
|
def test_column_parallel_packed(dist_init, num_loras, repeats) -> None:
|
||||||
|
|
||||||
|
max_loras = 8
|
||||||
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
|
max_lora_rank=8,
|
||||||
|
lora_dtype=torch.float16)
|
||||||
|
|
||||||
|
def create_column_parallel_packed_layer():
|
||||||
|
if repeats == 2:
|
||||||
|
linear = MergedColumnParallelLinear(4096, [4096] * repeats,
|
||||||
|
bias=False)
|
||||||
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
|
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
|
||||||
|
else:
|
||||||
|
linear = QKVParallelLinear(4096, 64, 32, bias=False)
|
||||||
|
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||||
|
lora_linear = QKVParallelLinearWithLora(linear)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FakeConfig:
|
||||||
|
hidden_size = 4096
|
||||||
|
num_key_value_heads = 32
|
||||||
|
num_attention_heads = 32
|
||||||
|
|
||||||
|
lora_linear.create_lora_weights(max_loras,
|
||||||
|
lora_config,
|
||||||
|
model_config=FakeConfig())
|
||||||
|
|
||||||
|
return linear, lora_linear
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
set_random_seed(i)
|
||||||
|
|
||||||
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
|
|
||||||
|
linear, lora_linear = create_column_parallel_packed_layer()
|
||||||
|
|
||||||
|
lora_dict, sublora_dict = populate_loras(
|
||||||
|
id_to_index,
|
||||||
|
layer=lora_linear,
|
||||||
|
layer_weights=linear.weight,
|
||||||
|
repeats=repeats,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=list(lora_dict.keys()),
|
||||||
|
num_inputs=32 * num_loras,
|
||||||
|
input_size=(1, 4096),
|
||||||
|
input_range=(0, 1),
|
||||||
|
input_type=torch.float32,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(
|
||||||
|
lora_mapping,
|
||||||
|
id_to_index,
|
||||||
|
max_loras,
|
||||||
|
512,
|
||||||
|
lora_config.lora_extra_vocab_size,
|
||||||
|
)
|
||||||
|
lora_linear.set_mapping(*mapping_info)
|
||||||
|
|
||||||
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||||
|
|
||||||
|
expected_results = []
|
||||||
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||||
|
result = linear(input_)[0]
|
||||||
|
subloras = sublora_dict[lora_id]
|
||||||
|
for i, sublora in enumerate(subloras):
|
||||||
|
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * (
|
||||||
|
i + 1
|
||||||
|
)] += input_ @ sublora.lora_a @ sublora.lora_b * sublora.scaling
|
||||||
|
expected_results.append(result)
|
||||||
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
for slot_idx in range(max_loras):
|
||||||
|
lora_linear.reset_lora(slot_idx)
|
||||||
|
|
||||||
|
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=[0],
|
||||||
|
num_inputs=32 * num_loras,
|
||||||
|
input_size=(1, 4096),
|
||||||
|
input_range=(0, 1),
|
||||||
|
input_type=torch.float32,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
|
||||||
|
mapping_info = convert_mapping(
|
||||||
|
lora_mapping,
|
||||||
|
id_to_index,
|
||||||
|
max_loras,
|
||||||
|
512,
|
||||||
|
lora_config.lora_extra_vocab_size,
|
||||||
|
)
|
||||||
|
lora_linear.set_mapping(*mapping_info)
|
||||||
|
|
||||||
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||||
|
expected_result = linear(torch.cat(inputs))[0]
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
|
assert torch.allclose(lora_result,
|
||||||
|
expected_result,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
144
tests/lora/test_llama.py
Normal file
144
tests/lora/test_llama.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
import pytest
|
||||||
|
import ray
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from .conftest import cleanup
|
||||||
|
|
||||||
|
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
|
||||||
|
|
||||||
|
|
||||||
|
def do_sample(llm, lora_path: str, lora_id: int):
|
||||||
|
prompts = [
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]",
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]",
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]",
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]"
|
||||||
|
]
|
||||||
|
sampling_params = vllm.SamplingParams(temperature=0,
|
||||||
|
max_tokens=256,
|
||||||
|
stop=["[/assistant]"])
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||||
|
if lora_id else None)
|
||||||
|
# Print the outputs.
|
||||||
|
generated_texts = []
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
generated_texts.append(generated_text)
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
return generated_texts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tp_size", [1])
|
||||||
|
def test_llama_lora(sql_lora_files, tp_size):
|
||||||
|
# Cannot use as it will initialize torch.cuda too early...
|
||||||
|
# if torch.cuda.device_count() < tp_size:
|
||||||
|
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||||
|
|
||||||
|
llm = vllm.LLM(MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_loras=4,
|
||||||
|
tensor_parallel_size=tp_size)
|
||||||
|
|
||||||
|
expected_no_lora_output = [
|
||||||
|
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]",
|
||||||
|
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ",
|
||||||
|
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m",
|
||||||
|
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ",
|
||||||
|
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ",
|
||||||
|
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE",
|
||||||
|
]
|
||||||
|
expected_lora_output = [
|
||||||
|
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ",
|
||||||
|
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ",
|
||||||
|
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ",
|
||||||
|
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ",
|
||||||
|
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ",
|
||||||
|
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' "
|
||||||
|
]
|
||||||
|
|
||||||
|
print("lora adapter created")
|
||||||
|
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output
|
||||||
|
|
||||||
|
print("lora 1")
|
||||||
|
assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output
|
||||||
|
|
||||||
|
print("no lora")
|
||||||
|
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output
|
||||||
|
|
||||||
|
print("lora 2")
|
||||||
|
assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output
|
||||||
|
|
||||||
|
print("removing lora")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("Requires multiple GPUs")
|
||||||
|
def test_llama_tensor_parallel_equality(sql_lora_files):
|
||||||
|
# Cannot use as it will initialize torch.cuda too early...
|
||||||
|
# if torch.cuda.device_count() < 4:
|
||||||
|
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
|
||||||
|
|
||||||
|
llm_tp1 = vllm.LLM(MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_loras=4,
|
||||||
|
tensor_parallel_size=1)
|
||||||
|
output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1)
|
||||||
|
|
||||||
|
del llm_tp1
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
llm_tp2 = vllm.LLM(MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_loras=4,
|
||||||
|
tensor_parallel_size=2)
|
||||||
|
output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1)
|
||||||
|
|
||||||
|
del llm_tp2
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
assert output_tp1 == output_tp2
|
||||||
|
|
||||||
|
llm_tp4 = vllm.LLM(MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_loras=4,
|
||||||
|
tensor_parallel_size=4)
|
||||||
|
output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1)
|
||||||
|
|
||||||
|
del llm_tp4
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
assert output_tp1 == output_tp4
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama_lora_warmup(sql_lora_files):
|
||||||
|
"""Test that the LLM initialization works with a warmup LORA path and is more conservative"""
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
def get_num_gpu_blocks_lora():
|
||||||
|
llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16)
|
||||||
|
num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks
|
||||||
|
return num_gpu_blocks_lora_warmup
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
def get_num_gpu_blocks_no_lora():
|
||||||
|
llm = vllm.LLM(MODEL_PATH, max_num_seqs=16)
|
||||||
|
num_gpu_blocks_no_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks
|
||||||
|
return num_gpu_blocks_no_lora_warmup
|
||||||
|
|
||||||
|
num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote())
|
||||||
|
num_gpu_blocks_no_lora_warmup = ray.get(
|
||||||
|
get_num_gpu_blocks_no_lora.remote())
|
||||||
|
assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, (
|
||||||
|
"The warmup with lora should be more"
|
||||||
|
" conservative than without lora, therefore the number of memory blocks for the KV cache should be "
|
||||||
|
"less when using lora than when not using lora")
|
224
tests/lora/test_lora.py
Normal file
224
tests/lora/test_lora.py
Normal file
@ -0,0 +1,224 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice
|
||||||
|
|
||||||
|
from .utils import DummyLoRAManager
|
||||||
|
|
||||||
|
TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4]
|
||||||
|
QKV_TENSOR_SIZES = [
|
||||||
|
(8192, 1024, 1024),
|
||||||
|
(8192 // 8, 1024 // 8, 1024 // 8),
|
||||||
|
(4096, 4096, 4096),
|
||||||
|
(4096 // 2, 4096 // 2, 4096 // 2),
|
||||||
|
]
|
||||||
|
BATCH_SIZES = [8, 32, 256]
|
||||||
|
RANKS = [8]
|
||||||
|
DTYPES = [torch.float16]
|
||||||
|
TOLERANCES = {
|
||||||
|
torch.float16: (5e-3, 5e-3),
|
||||||
|
torch.bfloat16: (3e-2, 2e-2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", TENSOR_SIZES)
|
||||||
|
@pytest.mark.parametrize("n", TENSOR_SIZES)
|
||||||
|
@pytest.mark.parametrize("k", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("rank", RANKS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
def test_apply_lora(m, n, k, rank, dtype) -> None:
|
||||||
|
manager = DummyLoRAManager()
|
||||||
|
|
||||||
|
module_name = "module"
|
||||||
|
weight = torch.rand([m, n], device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
manager.init_random_lora(module_name, weight, rank=rank)
|
||||||
|
lora = manager.get_module_lora(module_name)
|
||||||
|
|
||||||
|
input = torch.rand(k, n, device="cuda", dtype=dtype)
|
||||||
|
expected = input @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||||
|
|
||||||
|
lora_a_stack = torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora.lora_a.shape[1],
|
||||||
|
lora.lora_a.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype)
|
||||||
|
lora_b_stack = torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora.lora_b.shape[1],
|
||||||
|
lora.lora_b.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype)
|
||||||
|
for i in range(lora_a_stack.shape[0]):
|
||||||
|
lora_a_stack[i][0] = lora.lora_a.T
|
||||||
|
lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T
|
||||||
|
|
||||||
|
output = torch.zeros(k, m, device="cuda", dtype=dtype)
|
||||||
|
_apply_lora(
|
||||||
|
input, lora_a_stack, lora_b_stack,
|
||||||
|
torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"),
|
||||||
|
output)
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[dtype]
|
||||||
|
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
output[:] = 0
|
||||||
|
_apply_lora(input, lora_a_stack, lora_b_stack,
|
||||||
|
torch.full((len(input), ), -1, device="cuda"), output)
|
||||||
|
assert torch.allclose(torch.zeros_like(output), output)
|
||||||
|
|
||||||
|
manager.reset_lora()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", TENSOR_SIZES)
|
||||||
|
@pytest.mark.parametrize("n", TENSOR_SIZES)
|
||||||
|
@pytest.mark.parametrize("k", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("rank", RANKS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
|
||||||
|
if m % 2 != 0:
|
||||||
|
pytest.skip("m must be divisible by 2")
|
||||||
|
if m // 2 not in TENSOR_SIZES:
|
||||||
|
pytest.skip("m//2 must be in TENSOR_SIZES")
|
||||||
|
|
||||||
|
manager = DummyLoRAManager()
|
||||||
|
|
||||||
|
module_name = "module"
|
||||||
|
weight = torch.rand([m // 2, n], device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
manager.init_random_lora(module_name + "1", weight, rank=rank)
|
||||||
|
lora_1 = manager.get_module_lora(module_name + "1")
|
||||||
|
manager.init_random_lora(module_name + "2", weight, rank=rank)
|
||||||
|
lora_2 = manager.get_module_lora(module_name + "2")
|
||||||
|
|
||||||
|
input = torch.rand(k, n, device="cuda", dtype=dtype)
|
||||||
|
expected = torch.cat([
|
||||||
|
input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling,
|
||||||
|
input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling
|
||||||
|
],
|
||||||
|
dim=1)
|
||||||
|
|
||||||
|
lora_a_stacks = [
|
||||||
|
torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora_1.lora_a.shape[1],
|
||||||
|
lora_1.lora_a.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype) for i in range(2)
|
||||||
|
]
|
||||||
|
lora_b_stacks = [
|
||||||
|
torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora_1.lora_b.shape[1],
|
||||||
|
lora_1.lora_b.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype) for i in range(2)
|
||||||
|
]
|
||||||
|
for i in range(lora_a_stacks[0].shape[0]):
|
||||||
|
lora_a_stacks[0][i][0] = lora_1.lora_a.T
|
||||||
|
lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T
|
||||||
|
lora_a_stacks[1][i][0] = lora_2.lora_a.T
|
||||||
|
lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T
|
||||||
|
|
||||||
|
output = torch.zeros(k, m, device="cuda", dtype=dtype)
|
||||||
|
_apply_lora_packed_nslice(
|
||||||
|
input, lora_a_stacks, lora_b_stacks,
|
||||||
|
torch.randint(0,
|
||||||
|
lora_a_stacks[0].shape[0], (len(input), ),
|
||||||
|
device="cuda"), output, (m // 2, m // 2))
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[dtype]
|
||||||
|
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
output[:] = 0
|
||||||
|
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
|
||||||
|
torch.full((len(input), ), -1, device="cuda"),
|
||||||
|
output, (m // 2, m // 2))
|
||||||
|
assert torch.allclose(torch.zeros_like(output), output)
|
||||||
|
|
||||||
|
manager.reset_lora()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES)
|
||||||
|
@pytest.mark.parametrize("n", TENSOR_SIZES)
|
||||||
|
@pytest.mark.parametrize("k", BATCH_SIZES)
|
||||||
|
@pytest.mark.parametrize("rank", RANKS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
|
||||||
|
manager = DummyLoRAManager()
|
||||||
|
|
||||||
|
module_name = "module"
|
||||||
|
weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype)
|
||||||
|
weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
manager.init_random_lora(module_name + "q", weight_q, rank=rank)
|
||||||
|
lora_q = manager.get_module_lora(module_name + "q")
|
||||||
|
manager.init_random_lora(module_name + "k", weight_kv, rank=rank)
|
||||||
|
lora_k = manager.get_module_lora(module_name + "k")
|
||||||
|
manager.init_random_lora(module_name + "v", weight_kv, rank=rank)
|
||||||
|
lora_v = manager.get_module_lora(module_name + "v")
|
||||||
|
|
||||||
|
input = torch.rand(k, n, device="cuda", dtype=dtype)
|
||||||
|
expected = torch.cat([
|
||||||
|
input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling,
|
||||||
|
input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling,
|
||||||
|
input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling
|
||||||
|
],
|
||||||
|
dim=1)
|
||||||
|
|
||||||
|
lora_a_stacks = [
|
||||||
|
torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora_q.lora_a.shape[1],
|
||||||
|
lora_q.lora_a.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype)
|
||||||
|
] + [
|
||||||
|
torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora_k.lora_a.shape[1],
|
||||||
|
lora_k.lora_a.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype) for i in range(2)
|
||||||
|
]
|
||||||
|
lora_b_stacks = [
|
||||||
|
torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora_q.lora_b.shape[1],
|
||||||
|
lora_q.lora_b.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype)
|
||||||
|
] + [
|
||||||
|
torch.zeros(8,
|
||||||
|
1,
|
||||||
|
lora_k.lora_b.shape[1],
|
||||||
|
lora_k.lora_b.shape[0],
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype) for i in range(2)
|
||||||
|
]
|
||||||
|
for i in range(lora_a_stacks[0].shape[0]):
|
||||||
|
lora_a_stacks[0][i][0] = lora_q.lora_a.T
|
||||||
|
lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T
|
||||||
|
lora_a_stacks[1][i][0] = lora_k.lora_a.T
|
||||||
|
lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T
|
||||||
|
lora_a_stacks[2][i][0] = lora_v.lora_a.T
|
||||||
|
lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T
|
||||||
|
|
||||||
|
output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype)
|
||||||
|
_apply_lora_packed_nslice(
|
||||||
|
input, lora_a_stacks, lora_b_stacks,
|
||||||
|
torch.randint(0,
|
||||||
|
lora_a_stacks[0].shape[0], (len(input), ),
|
||||||
|
device="cuda"), output, (qkv[0], qkv[1], qkv[2]))
|
||||||
|
|
||||||
|
rtol, atol = TOLERANCES[dtype]
|
||||||
|
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
output[:] = 0
|
||||||
|
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
|
||||||
|
torch.full((len(input), ), -1, device="cuda"),
|
||||||
|
output, (qkv[0], qkv[1], qkv[2]))
|
||||||
|
assert torch.allclose(torch.zeros_like(output), output)
|
||||||
|
|
||||||
|
manager.reset_lora()
|
475
tests/lora/test_lora_manager.py
Normal file
475
tests/lora/test_lora_manager.py
Normal file
@ -0,0 +1,475 @@
|
|||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
|
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||||
|
RowParallelLinearWithLoRA,
|
||||||
|
MergedColumnParallelLinearWithLoRA)
|
||||||
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
|
from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager,
|
||||||
|
LRUCacheLoRAModelManager, LoRAMapping)
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||||
|
WorkerLoRAManager)
|
||||||
|
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_lora_tensors(sql_lora_files):
|
||||||
|
tensors = load_file(
|
||||||
|
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||||
|
new_embeddings = load_file(
|
||||||
|
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
|
||||||
|
lora_model = LoRAModel.from_lora_tensors(1,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
tensors,
|
||||||
|
"cuda",
|
||||||
|
embeddings=new_embeddings)
|
||||||
|
for module_name, lora in lora_model.loras.items():
|
||||||
|
assert lora.module_name == module_name
|
||||||
|
assert lora.rank == 8
|
||||||
|
assert lora.lora_alpha == 16
|
||||||
|
assert lora.lora_a is not None
|
||||||
|
assert lora.lora_b is not None
|
||||||
|
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
|
||||||
|
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
|
||||||
|
assert lora.lora_a.shape[1] == 8
|
||||||
|
embeddings_module = next(
|
||||||
|
(k for k in EMBEDDING_MODULES if k in module_name), None)
|
||||||
|
if embeddings_module:
|
||||||
|
assert torch.equal(
|
||||||
|
lora.embeddings_tensor,
|
||||||
|
new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
|
||||||
|
device=lora.embeddings_tensor.device))
|
||||||
|
else:
|
||||||
|
assert lora.embeddings_tensor is None
|
||||||
|
|
||||||
|
|
||||||
|
def create_lora(lora_id: int, model: nn.Module,
|
||||||
|
sub_modules: List[str]) -> LoRAModel:
|
||||||
|
loras = {}
|
||||||
|
for name in sub_modules:
|
||||||
|
w = model.get_submodule(name).weight
|
||||||
|
loras[name] = LoRALayerWeights(
|
||||||
|
name,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
torch.rand([w.shape[1], 8], device="cuda"),
|
||||||
|
torch.rand([8, w.shape[0]], device="cuda"),
|
||||||
|
)
|
||||||
|
return LoRAModel(lora_id, 8, loras)
|
||||||
|
|
||||||
|
|
||||||
|
def create_packed_lora(
|
||||||
|
lora_id: int,
|
||||||
|
model: nn.Module,
|
||||||
|
module_name,
|
||||||
|
replaced_module_names,
|
||||||
|
empty_replaced_module_name=None,
|
||||||
|
) -> LoRAModel:
|
||||||
|
w = model.get_submodule(module_name).weight
|
||||||
|
loras = {}
|
||||||
|
for replaced_module_name in replaced_module_names:
|
||||||
|
if replaced_module_name == empty_replaced_module_name:
|
||||||
|
continue
|
||||||
|
loras[replaced_module_name] = LoRALayerWeights(
|
||||||
|
replaced_module_name,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
torch.rand([w.shape[1], 8], device="cuda"),
|
||||||
|
torch.rand([8, w.shape[0] // len(replaced_module_names)],
|
||||||
|
device="cuda"),
|
||||||
|
)
|
||||||
|
return LoRAModel(lora_id, 8, loras)
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_submodules(dist_init, dummy_model):
|
||||||
|
model = dummy_model
|
||||||
|
manager = LoRAModelManager(model,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
LoRAConfig(max_lora_rank=8,
|
||||||
|
max_cpu_loras=8,
|
||||||
|
max_loras=8),
|
||||||
|
lora_target_modules=["dense1", "layer1.dense2"])
|
||||||
|
model = manager.model
|
||||||
|
|
||||||
|
assert isinstance(model.get_submodule("dense1"),
|
||||||
|
ColumnParallelLinearWithLoRA)
|
||||||
|
assert isinstance(model.get_submodule("layer1.dense1"),
|
||||||
|
ColumnParallelLinearWithLoRA)
|
||||||
|
assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
|
||||||
|
assert isinstance(model.get_submodule("layer1.dense2"),
|
||||||
|
RowParallelLinearWithLoRA)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_model_manager(dist_init, dummy_model):
|
||||||
|
model = dummy_model
|
||||||
|
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||||
|
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
manager = LoRAModelManager(
|
||||||
|
model,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
|
||||||
|
lora_target_modules=["dense1", "dense2", "lm_head"])
|
||||||
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
assert manager.activate_lora(1)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert not manager.add_lora(model_lora1)
|
||||||
|
assert not manager.activate_lora(1)
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
assert not manager.add_lora(model_lora2)
|
||||||
|
assert not manager.activate_lora(2)
|
||||||
|
assert manager.add_lora(model_lora3)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
assert manager.remove_lora(model_lora2.id)
|
||||||
|
assert manager.lora_index_to_id[1] is None
|
||||||
|
assert not manager.remove_lora(model_lora2.id)
|
||||||
|
assert manager.remove_lora(model_lora1.id)
|
||||||
|
assert not manager.remove_lora(model_lora1.id)
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
assert manager.lora_index_to_id[0] is None
|
||||||
|
assert manager.lora_index_to_id[1] is None
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] is None
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
||||||
|
model = dummy_model
|
||||||
|
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||||
|
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
manager = LRUCacheLoRAModelManager(
|
||||||
|
model,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
|
||||||
|
lora_target_modules=["dense1", "dense2", "lm_head"])
|
||||||
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
assert manager.activate_lora(1)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert not manager.add_lora(model_lora1)
|
||||||
|
assert not manager.activate_lora(1)
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
assert not manager.add_lora(model_lora2)
|
||||||
|
assert not manager.activate_lora(2)
|
||||||
|
assert manager.add_lora(model_lora3)
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
assert manager.remove_lora(model_lora2.id)
|
||||||
|
assert manager.lora_index_to_id[1] is None
|
||||||
|
assert not manager.remove_lora(model_lora2.id)
|
||||||
|
assert manager.remove_lora(model_lora1.id)
|
||||||
|
assert not manager.remove_lora(model_lora1.id)
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
assert manager.activate_lora(1)
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 1
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.deactivate_lora(3)
|
||||||
|
assert manager.lora_index_to_id[0] is None
|
||||||
|
assert manager.lora_index_to_id[1] == 1
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
assert manager.lora_index_to_id[0] == 2
|
||||||
|
assert manager.lora_index_to_id[1] == 1
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.lora_index_to_id[0] == 2
|
||||||
|
assert manager.lora_index_to_id[1] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_lru_lora_model_manager(dist_init, dummy_model):
|
||||||
|
# This tests just the LRU cache functionality, everything else is
|
||||||
|
# tested in test_lora_model_manager
|
||||||
|
model = dummy_model
|
||||||
|
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||||
|
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
|
||||||
|
manager = LRUCacheLoRAModelManager(
|
||||||
|
model, 2, 2, 2,
|
||||||
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
|
||||||
|
["dense1", "dense2", "lm_head"])
|
||||||
|
|
||||||
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
|
||||||
|
# Add up to capacity
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.activate_lora(1)
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {1, 2}
|
||||||
|
assert manager.lora_index_to_id[0] == 1
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
# Add over capacity
|
||||||
|
assert manager.add_lora(model_lora3)
|
||||||
|
assert manager.add_lora(model_lora4)
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.activate_lora(4)
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {3, 4}
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 4
|
||||||
|
|
||||||
|
# Add 3 again to move it to the top and then add 2
|
||||||
|
# should return false since it's in already
|
||||||
|
assert not manager.add_lora(model_lora3)
|
||||||
|
assert not manager.activate_lora(3)
|
||||||
|
assert manager.add_lora(model_lora2)
|
||||||
|
assert manager.activate_lora(2)
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {3, 2}
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
# Remove manually
|
||||||
|
assert manager.remove_lora(3)
|
||||||
|
assert not manager.remove_lora(3)
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {2}
|
||||||
|
assert manager.lora_index_to_id[0] is None
|
||||||
|
assert manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
assert manager.add_lora(model_lora3)
|
||||||
|
assert manager.activate_lora(3)
|
||||||
|
assert manager.add_lora(model_lora4)
|
||||||
|
assert manager.activate_lora(4)
|
||||||
|
|
||||||
|
assert set(manager.list_loras()) == {3, 4}
|
||||||
|
assert manager.lora_index_to_id[0] == 3
|
||||||
|
assert manager.lora_index_to_id[1] == 4
|
||||||
|
|
||||||
|
assert manager.remove_oldest_lora()
|
||||||
|
assert set(manager.list_loras()) == {4}
|
||||||
|
assert manager.lora_index_to_id[0] is None
|
||||||
|
assert manager.lora_index_to_id[1] == 4
|
||||||
|
|
||||||
|
assert manager.remove_oldest_lora()
|
||||||
|
assert set(manager.list_loras()) == set()
|
||||||
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
|
||||||
|
assert not manager.remove_oldest_lora()
|
||||||
|
assert set(manager.list_loras()) == set()
|
||||||
|
assert all(x is None for x in manager.lora_index_to_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
||||||
|
sql_lora_files):
|
||||||
|
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||||
|
worker_lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
|
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
|
||||||
|
torch.device("cuda"))
|
||||||
|
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
||||||
|
|
||||||
|
mapping = LoRAMapping([], [])
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("2", 2, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 2}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("3", 3, sql_lora_files),
|
||||||
|
LoRARequest("4", 4, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 2, 3, 4}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("2", 2, sql_lora_files),
|
||||||
|
LoRARequest("5", 5, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("1", 1, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("6", 6, sql_lora_files),
|
||||||
|
LoRARequest("7", 7, sql_lora_files),
|
||||||
|
LoRARequest("8", 8, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 6, 7, 8}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6
|
||||||
|
|
||||||
|
# Over capacity
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("10", 10, sql_lora_files),
|
||||||
|
LoRARequest("11", 11, sql_lora_files),
|
||||||
|
LoRARequest("12", 12, sql_lora_files),
|
||||||
|
LoRARequest("13", 13, sql_lora_files),
|
||||||
|
LoRARequest("14", 14, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
|
||||||
|
|
||||||
|
def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
||||||
|
sql_lora_files):
|
||||||
|
# Should remove every LoRA not specified in the request.
|
||||||
|
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||||
|
worker_lora_manager = WorkerLoRAManager(
|
||||||
|
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
|
||||||
|
torch.device("cuda"))
|
||||||
|
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
||||||
|
|
||||||
|
mapping = LoRAMapping([], [])
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("2", 2, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 2}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("3", 3, sql_lora_files),
|
||||||
|
LoRARequest("4", 4, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 3, 4}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("2", 2, sql_lora_files),
|
||||||
|
LoRARequest("5", 5, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1, 2, 5}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("1", 1, sql_lora_files),
|
||||||
|
LoRARequest("1", 1, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {1}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None
|
||||||
|
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("6", 6, sql_lora_files),
|
||||||
|
LoRARequest("7", 7, sql_lora_files),
|
||||||
|
LoRARequest("8", 8, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
assert worker_lora_manager.list_loras() == {6, 7, 8}
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6
|
||||||
|
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7
|
||||||
|
|
||||||
|
# Over capacity
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
worker_lora_manager.set_active_loras([
|
||||||
|
LoRARequest("10", 10, sql_lora_files),
|
||||||
|
LoRARequest("11", 11, sql_lora_files),
|
||||||
|
LoRARequest("12", 12, sql_lora_files),
|
||||||
|
LoRARequest("13", 13, sql_lora_files),
|
||||||
|
LoRARequest("14", 14, sql_lora_files)
|
||||||
|
], mapping)
|
||||||
|
|
||||||
|
|
||||||
|
def test_packed_loras(dist_init, dummy_model_gate_up):
|
||||||
|
model = dummy_model_gate_up
|
||||||
|
model_lora = create_packed_lora(
|
||||||
|
1,
|
||||||
|
model,
|
||||||
|
module_name="gate_up_proj",
|
||||||
|
replaced_module_names=["gate_proj", "up_proj"])
|
||||||
|
model_lora1 = create_packed_lora(
|
||||||
|
2,
|
||||||
|
model,
|
||||||
|
module_name="gate_up_proj",
|
||||||
|
replaced_module_names=["gate_proj", "up_proj"],
|
||||||
|
empty_replaced_module_name="gate_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = LoRAModelManager(
|
||||||
|
model, 2, 2, 2,
|
||||||
|
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
|
||||||
|
["gate_up_proj"])
|
||||||
|
model = manager.model
|
||||||
|
|
||||||
|
assert isinstance(model.get_submodule("gate_up_proj"),
|
||||||
|
MergedColumnParallelLinearWithLoRA)
|
||||||
|
assert manager.add_lora(model_lora)
|
||||||
|
assert manager.add_lora(model_lora1)
|
||||||
|
|
||||||
|
packed_lora = model_lora.get_lora("gate_up_proj")
|
||||||
|
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
|
||||||
|
|
||||||
|
assert torch.allclose(packed_lora.lora_a[0],
|
||||||
|
model_lora.get_lora("gate_proj").lora_a)
|
||||||
|
assert torch.allclose(packed_lora.lora_b[0],
|
||||||
|
model_lora.get_lora("gate_proj").lora_b)
|
||||||
|
assert torch.allclose(packed_lora.lora_a[1],
|
||||||
|
model_lora.get_lora("up_proj").lora_a)
|
||||||
|
assert torch.allclose(packed_lora.lora_b[1],
|
||||||
|
model_lora.get_lora("up_proj").lora_b)
|
||||||
|
|
||||||
|
packed_lora1 = model_lora1.get_lora("gate_up_proj")
|
||||||
|
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
|
||||||
|
|
||||||
|
assert packed_lora1.lora_a[0] is None
|
||||||
|
assert packed_lora1.lora_b[0] is None
|
||||||
|
assert torch.allclose(packed_lora1.lora_a[1],
|
||||||
|
model_lora1.get_lora("up_proj").lora_a)
|
||||||
|
assert torch.allclose(packed_lora1.lora_b[1],
|
||||||
|
model_lora1.get_lora("up_proj").lora_b)
|
175
tests/lora/test_punica.py
Normal file
175
tests/lora/test_punica.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
# Based on code from https://github.com/punica-ai/punica
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.lora.punica as punica
|
||||||
|
|
||||||
|
|
||||||
|
def assert_close(a, b):
|
||||||
|
rtol, atol = {
|
||||||
|
torch.float16: (5e-3, 5e-3),
|
||||||
|
torch.bfloat16: (3e-2, 2e-2),
|
||||||
|
torch.float32: (None, None),
|
||||||
|
}[a.dtype]
|
||||||
|
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
def _lora_ref_impl(
|
||||||
|
y_final: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
wa_T_all: torch.Tensor,
|
||||||
|
wb_T_all: torch.Tensor,
|
||||||
|
indicies: torch.LongTensor,
|
||||||
|
layer_idx: int,
|
||||||
|
scale: float,
|
||||||
|
):
|
||||||
|
y_stage_1 = torch.empty(
|
||||||
|
(x.size(0), wa_T_all.size(-2)),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
bs = x.shape[0]
|
||||||
|
s = torch.tensor(scale, dtype=torch.float32, device=x.device)
|
||||||
|
for i, lora_idx in zip(range(bs), indicies.cpu().tolist()):
|
||||||
|
xi = x[i].unsqueeze(0).to(torch.float32)
|
||||||
|
wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
|
||||||
|
wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
|
||||||
|
|
||||||
|
tmp = xi @ wa
|
||||||
|
y_stage_1[i] = tmp.squeeze(0)
|
||||||
|
y_final[i] += (tmp @ wb).squeeze(0) * s
|
||||||
|
return y_final, y_stage_1
|
||||||
|
|
||||||
|
|
||||||
|
H1 = H2 = [
|
||||||
|
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
|
||||||
|
5504, 5632, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000,
|
||||||
|
32256, 32512, 32768, 33024
|
||||||
|
]
|
||||||
|
SEED = [0xabcdabcd987]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
|
||||||
|
@pytest.mark.parametrize("h1", H1)
|
||||||
|
@pytest.mark.parametrize("h2", H2)
|
||||||
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_lora_correctness(dtype_str, h1, h2, seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
num_loras = 4
|
||||||
|
num_layers = 1
|
||||||
|
r = 8
|
||||||
|
bs = 32
|
||||||
|
scale = 0.123
|
||||||
|
dtype = getattr(torch, dtype_str)
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
wa_T_all = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
r,
|
||||||
|
h1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
wb_T_all = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
h2,
|
||||||
|
r,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
for layer_idx in range(num_layers):
|
||||||
|
x = torch.randn(bs, h1, dtype=dtype, device=device)
|
||||||
|
y = torch.randn(bs, h2, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
y_ref = y.clone()
|
||||||
|
_lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale)
|
||||||
|
|
||||||
|
y_our = y.clone()
|
||||||
|
punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx,
|
||||||
|
scale)
|
||||||
|
|
||||||
|
assert_close(y_ref, y_our)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
|
||||||
|
@pytest.mark.parametrize("h1", H1)
|
||||||
|
@pytest.mark.parametrize("h2", H2)
|
||||||
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_lora_correctness_slice(dtype_str, h1, h2, seed):
|
||||||
|
if h2 % 3 != 0 or h2 // 3 not in H1:
|
||||||
|
pytest.skip("h2 must be divisible by 3 and in supported shapes")
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
num_loras = 4
|
||||||
|
num_layers = 1
|
||||||
|
r = 8
|
||||||
|
bs = 32
|
||||||
|
scale = 0.123
|
||||||
|
dtype = getattr(torch, dtype_str)
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
wa_T_all_0 = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
r,
|
||||||
|
h1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
wa_T_all_1 = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
r,
|
||||||
|
h1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
wa_T_all_2 = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
r,
|
||||||
|
h1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
wb_T_all_0 = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
h2 // 3,
|
||||||
|
r,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
wb_T_all_1 = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
h2 // 3,
|
||||||
|
r,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
wb_T_all_2 = torch.randn(num_loras,
|
||||||
|
num_layers,
|
||||||
|
h2 // 3,
|
||||||
|
r,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
for layer_idx in range(num_layers):
|
||||||
|
x = torch.randn(bs, h1, dtype=dtype, device=device)
|
||||||
|
y = torch.randn(bs, h2, dtype=dtype, device=device)
|
||||||
|
s = h2 // 3
|
||||||
|
|
||||||
|
y_ref = y.clone()
|
||||||
|
_lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices,
|
||||||
|
layer_idx, scale)
|
||||||
|
_lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices,
|
||||||
|
layer_idx, scale)
|
||||||
|
_lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices,
|
||||||
|
layer_idx, scale)
|
||||||
|
|
||||||
|
y_our = y.clone()
|
||||||
|
punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices,
|
||||||
|
layer_idx, scale, 0, s)
|
||||||
|
punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices,
|
||||||
|
layer_idx, scale, s, s)
|
||||||
|
punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices,
|
||||||
|
layer_idx, scale, s * 2, s)
|
||||||
|
|
||||||
|
assert_close(y_ref[:, :s], y_our[:, :s])
|
||||||
|
assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2])
|
||||||
|
assert_close(y_ref[:, s * 2:], y_our[:, s * 2:])
|
69
tests/lora/test_tokenizer.py
Normal file
69
tests/lora/test_tokenizer.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.transformers_utils.tokenizer import TokenizerGroup, get_lora_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transformers_tokenizer():
|
||||||
|
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
tokenizer = TokenizerGroup(
|
||||||
|
tokenizer_id="gpt2",
|
||||||
|
enable_lora=False,
|
||||||
|
max_num_seqs=1,
|
||||||
|
max_input_length=None,
|
||||||
|
)
|
||||||
|
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
|
||||||
|
request_id="request_id", prompt="prompt", lora_request=None)
|
||||||
|
assert reference_tokenizer.encode(
|
||||||
|
"prompt") == await tokenizer.encode_async(request_id="request_id",
|
||||||
|
prompt="prompt",
|
||||||
|
lora_request=None)
|
||||||
|
assert isinstance(tokenizer.get_lora_tokenizer(None),
|
||||||
|
PreTrainedTokenizerBase)
|
||||||
|
assert tokenizer.get_lora_tokenizer(
|
||||||
|
None) == await tokenizer.get_lora_tokenizer_async(None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transformers_tokenizer_lora(sql_lora_files):
|
||||||
|
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
|
||||||
|
tokenizer = TokenizerGroup(
|
||||||
|
tokenizer_id="gpt2",
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=1,
|
||||||
|
max_input_length=None,
|
||||||
|
)
|
||||||
|
lora_request = LoRARequest("1", 1, sql_lora_files)
|
||||||
|
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
|
||||||
|
request_id="request_id", prompt="prompt", lora_request=lora_request)
|
||||||
|
assert reference_tokenizer.encode(
|
||||||
|
"prompt") == await tokenizer.encode_async(request_id="request_id",
|
||||||
|
prompt="prompt",
|
||||||
|
lora_request=lora_request)
|
||||||
|
assert isinstance(tokenizer.get_lora_tokenizer(None),
|
||||||
|
PreTrainedTokenizerBase)
|
||||||
|
assert tokenizer.get_lora_tokenizer(
|
||||||
|
None) == await tokenizer.get_lora_tokenizer_async(None)
|
||||||
|
|
||||||
|
assert isinstance(tokenizer.get_lora_tokenizer(lora_request),
|
||||||
|
PreTrainedTokenizerBase)
|
||||||
|
assert tokenizer.get_lora_tokenizer(
|
||||||
|
lora_request) != tokenizer.get_lora_tokenizer(None)
|
||||||
|
assert tokenizer.get_lora_tokenizer(
|
||||||
|
lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lora_tokenizer(sql_lora_files, tmpdir):
|
||||||
|
lora_request = None
|
||||||
|
tokenizer = get_lora_tokenizer(lora_request)
|
||||||
|
assert not tokenizer
|
||||||
|
|
||||||
|
lora_request = LoRARequest("1", 1, sql_lora_files)
|
||||||
|
tokenizer = get_lora_tokenizer(lora_request)
|
||||||
|
assert tokenizer.get_added_vocab()
|
||||||
|
|
||||||
|
lora_request = LoRARequest("1", 1, str(tmpdir))
|
||||||
|
tokenizer = get_lora_tokenizer(lora_request)
|
||||||
|
assert not tokenizer
|
172
tests/lora/test_utils.py
Normal file
172
tests/lora/test_utils.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.utils import LRUCache
|
||||||
|
from vllm.lora.utils import (parse_fine_tuned_lora_name, replace_submodule)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_fine_tuned_lora_name():
|
||||||
|
fixture = {
|
||||||
|
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
|
||||||
|
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
|
||||||
|
(
|
||||||
|
"base_model.model.model.embed_tokens.lora_embedding_A",
|
||||||
|
"model.embed_tokens",
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"base_model.model.model.embed_tokens.lora_embedding_B",
|
||||||
|
"model.embed_tokens",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
||||||
|
"model.layers.9.mlp.down_proj",
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
||||||
|
"model.layers.9.mlp.down_proj",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for name, module_name, is_lora_a in fixture:
|
||||||
|
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_submodule():
|
||||||
|
model = nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
("dense1", nn.Linear(764, 100)),
|
||||||
|
("act1", nn.ReLU()),
|
||||||
|
("dense2", nn.Linear(100, 50)),
|
||||||
|
(
|
||||||
|
"seq1",
|
||||||
|
nn.Sequential(
|
||||||
|
OrderedDict([
|
||||||
|
("dense1", nn.Linear(100, 10)),
|
||||||
|
("dense2", nn.Linear(10, 50)),
|
||||||
|
])),
|
||||||
|
),
|
||||||
|
("act2", nn.ReLU()),
|
||||||
|
("output", nn.Linear(50, 10)),
|
||||||
|
("outact", nn.Sigmoid()),
|
||||||
|
]))
|
||||||
|
|
||||||
|
sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
replace_submodule(model, "act1", sigmoid)
|
||||||
|
assert dict(model.named_modules())["act1"] == sigmoid
|
||||||
|
|
||||||
|
dense2 = nn.Linear(1, 5)
|
||||||
|
replace_submodule(model, "seq1.dense2", dense2)
|
||||||
|
assert dict(model.named_modules())["seq1.dense2"] == dense2
|
||||||
|
|
||||||
|
|
||||||
|
class TestLRUCache(LRUCache):
|
||||||
|
|
||||||
|
def _on_remove(self, key, value):
|
||||||
|
if not hasattr(self, "_remove_counter"):
|
||||||
|
self._remove_counter = 0
|
||||||
|
self._remove_counter += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_lru_cache():
|
||||||
|
cache = TestLRUCache(3)
|
||||||
|
|
||||||
|
cache.put(1, 1)
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache.put(1, 1)
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache.put(2, 2)
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
cache.put(3, 3)
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {1, 2, 3}
|
||||||
|
|
||||||
|
cache.put(4, 4)
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 3, 4}
|
||||||
|
assert cache._remove_counter == 1
|
||||||
|
assert cache.get(2) == 2
|
||||||
|
|
||||||
|
cache.put(5, 5)
|
||||||
|
assert set(cache.cache) == {2, 4, 5}
|
||||||
|
assert cache._remove_counter == 2
|
||||||
|
|
||||||
|
assert cache.pop(5) == 5
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.pop(10)
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.get(10)
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.put(6, 6)
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 4, 6}
|
||||||
|
assert 2 in cache
|
||||||
|
assert 4 in cache
|
||||||
|
assert 6 in cache
|
||||||
|
|
||||||
|
cache.remove_oldest()
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 6}
|
||||||
|
assert cache._remove_counter == 4
|
||||||
|
|
||||||
|
cache.clear()
|
||||||
|
assert len(cache) == 0
|
||||||
|
assert cache._remove_counter == 6
|
||||||
|
|
||||||
|
cache._remove_counter = 0
|
||||||
|
|
||||||
|
cache[1] = 1
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache[1] = 1
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache[2] = 2
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
cache[3] = 3
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {1, 2, 3}
|
||||||
|
|
||||||
|
cache[4] = 4
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 3, 4}
|
||||||
|
assert cache._remove_counter == 1
|
||||||
|
assert cache[2] == 2
|
||||||
|
|
||||||
|
cache[5] = 5
|
||||||
|
assert set(cache.cache) == {2, 4, 5}
|
||||||
|
assert cache._remove_counter == 2
|
||||||
|
|
||||||
|
del cache[5]
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.pop(10)
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache[6] = 6
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 4, 6}
|
||||||
|
assert 2 in cache
|
||||||
|
assert 4 in cache
|
||||||
|
assert 6 in cache
|
61
tests/lora/test_worker.py
Normal file
61
tests/lora/test_worker.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from vllm.lora.models import LoRAMapping
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"RANK": "0"})
|
||||||
|
def test_worker_apply_lora(sql_lora_files):
|
||||||
|
worker = Worker(
|
||||||
|
model_config=ModelConfig(
|
||||||
|
"meta-llama/Llama-2-7b-hf",
|
||||||
|
"meta-llama/Llama-2-7b-hf",
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
download_dir=None,
|
||||||
|
load_format="dummy",
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None,
|
||||||
|
),
|
||||||
|
parallel_config=ParallelConfig(1, 1, False),
|
||||||
|
scheduler_config=SchedulerConfig(32, 32, 32, 256),
|
||||||
|
local_rank=0,
|
||||||
|
rank=0,
|
||||||
|
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
|
||||||
|
max_loras=32),
|
||||||
|
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
|
||||||
|
)
|
||||||
|
worker.init_model()
|
||||||
|
worker.load_model()
|
||||||
|
|
||||||
|
worker.model_runner.set_active_loras([], LoRAMapping([], []))
|
||||||
|
assert worker.list_loras() == set()
|
||||||
|
|
||||||
|
n_loras = 32
|
||||||
|
lora_requests = [
|
||||||
|
LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras)
|
||||||
|
]
|
||||||
|
|
||||||
|
worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], []))
|
||||||
|
assert worker.list_loras() == {
|
||||||
|
lora_request.lora_int_id
|
||||||
|
for lora_request in lora_requests
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(32):
|
||||||
|
random.seed(i)
|
||||||
|
iter_lora_requests = random.choices(lora_requests,
|
||||||
|
k=random.randint(1, n_loras))
|
||||||
|
random.shuffle(iter_lora_requests)
|
||||||
|
iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)]
|
||||||
|
worker.model_runner.set_active_loras(iter_lora_requests,
|
||||||
|
LoRAMapping([], []))
|
||||||
|
assert worker.list_loras().issuperset(
|
||||||
|
{lora_request.lora_int_id
|
||||||
|
for lora_request in iter_lora_requests})
|
88
tests/lora/utils.py
Normal file
88
tests/lora/utils.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
|
|
||||||
|
|
||||||
|
class DummyLoRAManager:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._loras = {}
|
||||||
|
|
||||||
|
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
|
||||||
|
self._loras[module_name] = lora
|
||||||
|
|
||||||
|
def get_module_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
|
||||||
|
return self._loras.get(module_name, None)
|
||||||
|
|
||||||
|
def init_random_lora(self,
|
||||||
|
module_name: str,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
rank: int = 8,
|
||||||
|
generate_embeddings_tensor: int = 0):
|
||||||
|
lora = LoRALayerWeights(
|
||||||
|
module_name,
|
||||||
|
rank=rank,
|
||||||
|
lora_alpha=1,
|
||||||
|
lora_a=torch.rand([weight.shape[1], rank],
|
||||||
|
dtype=weight.dtype,
|
||||||
|
device="cuda"),
|
||||||
|
lora_b=torch.rand([rank, weight.shape[0]],
|
||||||
|
dtype=weight.dtype,
|
||||||
|
device="cuda"),
|
||||||
|
)
|
||||||
|
if generate_embeddings_tensor:
|
||||||
|
lora.embeddings_tensor = torch.rand(5,
|
||||||
|
generate_embeddings_tensor,
|
||||||
|
dtype=weight.dtype,
|
||||||
|
device="cuda")
|
||||||
|
self.set_module_lora(module_name, lora)
|
||||||
|
|
||||||
|
return lora
|
||||||
|
|
||||||
|
def init_lora(self,
|
||||||
|
module_name: str,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
rank=8,
|
||||||
|
noop=False,
|
||||||
|
embeddings_tensor=None):
|
||||||
|
lora = LoRALayerWeights(
|
||||||
|
module_name,
|
||||||
|
rank=rank,
|
||||||
|
lora_alpha=1,
|
||||||
|
lora_a=torch.rand([input_dim, rank], device="cuda"),
|
||||||
|
lora_b=torch.rand([rank, output_dim], device="cuda"),
|
||||||
|
embeddings_tensor=embeddings_tensor,
|
||||||
|
)
|
||||||
|
self.set_module_lora(module_name, lora)
|
||||||
|
return lora
|
||||||
|
|
||||||
|
def reset_lora(self):
|
||||||
|
self._loras = {}
|
||||||
|
|
||||||
|
def init_packed_lora(
|
||||||
|
self,
|
||||||
|
module_name: str,
|
||||||
|
input_dim: int,
|
||||||
|
output_dims: List[int],
|
||||||
|
noop_lora_index: List[int] = None,
|
||||||
|
rank=8,
|
||||||
|
):
|
||||||
|
base_loras = []
|
||||||
|
noop_lora_index = set(noop_lora_index or [])
|
||||||
|
|
||||||
|
for i, out_dim in enumerate(output_dims):
|
||||||
|
base_lora = self.init_lora(
|
||||||
|
module_name + "_000_" + str(i),
|
||||||
|
input_dim,
|
||||||
|
out_dim,
|
||||||
|
rank=rank,
|
||||||
|
noop=i in noop_lora_index,
|
||||||
|
)
|
||||||
|
base_loras.append(base_lora)
|
||||||
|
packed_lora = PackedLoRALayerWeights.pack(base_loras)
|
||||||
|
self.set_module_lora(module_name, packed_lora)
|
||||||
|
return packed_lora
|
@ -19,10 +19,11 @@ class MockLogitsSampler(Sampler):
|
|||||||
self.fake_logits = fake_logits
|
self.fake_logits = fake_logits
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
|
with patch(
|
||||||
lambda x, y: x), patch(
|
"vllm.model_executor.layers.sampler._prune_hidden_states",
|
||||||
"vllm.model_executor.layers.sampler._get_logits",
|
lambda x, y: x), patch(
|
||||||
lambda *args, **kwargs: self.fake_logits):
|
"vllm.model_executor.layers.sampler.Sampler._get_logits",
|
||||||
|
lambda *args, **kwargs: self.fake_logits):
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -38,7 +39,7 @@ def _prepare_test(
|
|||||||
device=input_tensor.device,
|
device=input_tensor.device,
|
||||||
dtype=input_tensor.dtype)
|
dtype=input_tensor.dtype)
|
||||||
sampler = MockLogitsSampler(32000, fake_logits)
|
sampler = MockLogitsSampler(32000, fake_logits)
|
||||||
model_runner = ModelRunner(None, None, None)
|
model_runner = ModelRunner(None, None, None, None)
|
||||||
return input_tensor, fake_logits, sampler, model_runner
|
return input_tensor, fake_logits, sampler, model_runner
|
||||||
|
|
||||||
|
|
||||||
@ -266,7 +267,7 @@ def test_sampler_top_k_top_p(seed: int):
|
|||||||
device=input_tensor.device,
|
device=input_tensor.device,
|
||||||
dtype=input_tensor.dtype)
|
dtype=input_tensor.dtype)
|
||||||
sampler = MockLogitsSampler(32000, fake_logits)
|
sampler = MockLogitsSampler(32000, fake_logits)
|
||||||
model_runner = ModelRunner(None, None, None)
|
model_runner = ModelRunner(None, None, None, None)
|
||||||
|
|
||||||
generation_model = GenerationMixin()
|
generation_model = GenerationMixin()
|
||||||
generation_config = GenerationConfig(top_k=top_k,
|
generation_config = GenerationConfig(top_k=top_k,
|
||||||
|
@ -83,8 +83,8 @@ def create_worker(cls: type,
|
|||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
)
|
)
|
||||||
|
|
||||||
(model_config, cache_config, parallel_config,
|
(model_config, cache_config, parallel_config, scheduler_config,
|
||||||
scheduler_config) = engine_args.create_engine_configs()
|
_) = engine_args.create_engine_configs()
|
||||||
|
|
||||||
distributed_init_method = get_distributed_init_method(
|
distributed_init_method = get_distributed_init_method(
|
||||||
get_ip(), get_open_port())
|
get_ip(), get_open_port())
|
||||||
|
@ -6,7 +6,7 @@ from vllm.worker.model_runner import ModelRunner
|
|||||||
|
|
||||||
|
|
||||||
def test_prepare_prompt():
|
def test_prepare_prompt():
|
||||||
model_runner = ModelRunner(None, None, None)
|
model_runner = ModelRunner(None, None, None, None)
|
||||||
model_runner.set_block_size(16)
|
model_runner.set_block_size(16)
|
||||||
|
|
||||||
batch_size = random.randint(1, 256)
|
batch_size = random.randint(1, 256)
|
||||||
@ -33,7 +33,7 @@ def test_prepare_prompt():
|
|||||||
expected_selected_token_indices.append(selected_token_start_idx +
|
expected_selected_token_indices.append(selected_token_start_idx +
|
||||||
prompt_len - 1)
|
prompt_len - 1)
|
||||||
selected_token_start_idx += max_seq_len
|
selected_token_start_idx += max_seq_len
|
||||||
input_tokens, input_positions, _, return_prompt_lens, _ = (
|
input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = (
|
||||||
model_runner._prepare_prompt(seq_group_metadata_list))
|
model_runner._prepare_prompt(seq_group_metadata_list))
|
||||||
assert return_prompt_lens == prompt_lens
|
assert return_prompt_lens == prompt_lens
|
||||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union, ClassVar
|
||||||
|
from dataclasses import dataclass
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -397,6 +398,54 @@ class SchedulerConfig:
|
|||||||
f"({self.max_num_seqs}).")
|
f"({self.max_num_seqs}).")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAConfig:
|
||||||
|
max_lora_rank: int
|
||||||
|
max_loras: int
|
||||||
|
max_cpu_loras: Optional[int] = None
|
||||||
|
lora_dtype: Optional[torch.dtype] = None
|
||||||
|
lora_extra_vocab_size: int = 256
|
||||||
|
# This is a constant.
|
||||||
|
lora_vocab_padding_size: ClassVar[int] = 256
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
|
||||||
|
possible_max_ranks = (8, 16, 32, 64)
|
||||||
|
possible_lora_extra_vocab_size = (0, 256, 512)
|
||||||
|
if self.max_lora_rank not in possible_max_ranks:
|
||||||
|
raise ValueError(
|
||||||
|
f"max_lora_rank ({self.max_lora_rank}) must be one of "
|
||||||
|
f"{possible_max_ranks}.")
|
||||||
|
if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
|
||||||
|
f"must be one of {possible_lora_extra_vocab_size}.")
|
||||||
|
if self.max_loras < 1:
|
||||||
|
raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
|
||||||
|
if self.max_cpu_loras is None:
|
||||||
|
self.max_cpu_loras = self.max_loras
|
||||||
|
elif self.max_cpu_loras < self.max_loras:
|
||||||
|
raise ValueError(
|
||||||
|
f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
|
||||||
|
f"max_num_seqs ({self.max_loras})")
|
||||||
|
|
||||||
|
def verify_with_model_config(self, model_config: ModelConfig):
|
||||||
|
if self.lora_dtype in (None, "auto"):
|
||||||
|
self.lora_dtype = model_config.dtype
|
||||||
|
elif isinstance(self.lora_dtype, str):
|
||||||
|
self.lora_dtype = getattr(torch, self.lora_dtype)
|
||||||
|
if model_config.quantization is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"LoRA is not supported with quantized models yet.")
|
||||||
|
|
||||||
|
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
|
||||||
|
if scheduler_config.max_num_batched_tokens > 65528:
|
||||||
|
raise ValueError(
|
||||||
|
"Due to limitations of the custom LoRA CUDA kernel, "
|
||||||
|
"max_num_batched_tokens must be <= 65528 when "
|
||||||
|
"LoRA is enabled.")
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
"half": torch.float16,
|
"half": torch.float16,
|
||||||
"float16": torch.float16,
|
"float16": torch.float16,
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
import enum
|
import enum
|
||||||
import time
|
import time
|
||||||
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set
|
||||||
|
|
||||||
from vllm.config import CacheConfig, SchedulerConfig
|
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||||
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
|
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
|
||||||
from vllm.core.policy import PolicyFactory
|
from vllm.core.policy import PolicyFactory
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||||
SequenceGroupMetadata, SequenceStatus)
|
SequenceGroupMetadata, SequenceStatus)
|
||||||
@ -49,11 +50,25 @@ class SchedulerOutputs:
|
|||||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||||
self.ignored_seq_groups = ignored_seq_groups
|
self.ignored_seq_groups = ignored_seq_groups
|
||||||
|
|
||||||
|
self.num_loras = len(self.lora_requests)
|
||||||
|
if self.num_loras > 0:
|
||||||
|
self._sort_by_lora_ids()
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
# NOTE: We do not consider the ignored sequence groups.
|
# NOTE: We do not consider the ignored sequence groups.
|
||||||
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
|
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
|
||||||
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
||||||
|
|
||||||
|
def _sort_by_lora_ids(self) -> bool:
|
||||||
|
self.scheduled_seq_groups = sorted(
|
||||||
|
self.scheduled_seq_groups,
|
||||||
|
key=lambda g: (g.lora_request.lora_int_id
|
||||||
|
if g.lora_request else 0, g.request_id))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lora_requests(self) -> Set[LoRARequest]:
|
||||||
|
return {g.lora_request for g in self.scheduled_seq_groups}
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
|
|
||||||
@ -61,9 +76,14 @@ class Scheduler:
|
|||||||
self,
|
self,
|
||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
cache_config: CacheConfig,
|
cache_config: CacheConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
|
# Note for LoRA scheduling: the current policy is extremely
|
||||||
|
# simple and NOT fair. It can lead to starvation of some
|
||||||
|
# LoRAs. This should be improved in the future.
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
self.prompt_limit = min(self.scheduler_config.max_model_len,
|
self.prompt_limit = min(self.scheduler_config.max_model_len,
|
||||||
self.scheduler_config.max_num_batched_tokens)
|
self.scheduler_config.max_num_batched_tokens)
|
||||||
@ -87,6 +107,10 @@ class Scheduler:
|
|||||||
# Sequence groups in the SWAPPED state.
|
# Sequence groups in the SWAPPED state.
|
||||||
self.swapped: Deque[SequenceGroup] = deque()
|
self.swapped: Deque[SequenceGroup] = deque()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lora_enabled(self) -> bool:
|
||||||
|
return bool(self.lora_config)
|
||||||
|
|
||||||
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||||
# Add sequence groups to the waiting queue.
|
# Add sequence groups to the waiting queue.
|
||||||
self.waiting.append(seq_group)
|
self.waiting.append(seq_group)
|
||||||
@ -150,14 +174,17 @@ class Scheduler:
|
|||||||
# requests in the generation phase.
|
# requests in the generation phase.
|
||||||
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||||
for seq_group in self.running)
|
for seq_group in self.running)
|
||||||
|
curr_loras = set(
|
||||||
|
seq_group.lora_int_id
|
||||||
|
for seq_group in self.running) if self.lora_enabled else None
|
||||||
seq_lens: List[int] = []
|
seq_lens: List[int] = []
|
||||||
|
|
||||||
# Optimization: We do not sort the waiting queue since the preempted
|
# Optimization: We do not sort the waiting queue since the preempted
|
||||||
# sequence groups are added to the front and the new sequence groups
|
# sequence groups are added to the front and the new sequence groups
|
||||||
# are added to the back.
|
# are added to the back.
|
||||||
|
leftover_waiting_sequences = deque()
|
||||||
while self.waiting:
|
while self.waiting:
|
||||||
seq_group = self.waiting[0]
|
seq_group = self.waiting[0]
|
||||||
|
|
||||||
waiting_seqs = seq_group.get_seqs(
|
waiting_seqs = seq_group.get_seqs(
|
||||||
status=SequenceStatus.WAITING)
|
status=SequenceStatus.WAITING)
|
||||||
assert len(waiting_seqs) == 1, (
|
assert len(waiting_seqs) == 1, (
|
||||||
@ -188,6 +215,17 @@ class Scheduler:
|
|||||||
self.waiting.popleft()
|
self.waiting.popleft()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
lora_int_id = 0
|
||||||
|
if self.lora_enabled:
|
||||||
|
lora_int_id = seq_group.lora_int_id
|
||||||
|
if lora_int_id > 0 and lora_int_id not in curr_loras and len(
|
||||||
|
curr_loras) >= self.lora_config.max_loras:
|
||||||
|
# We don't have a space for another LoRA, so
|
||||||
|
# we ignore this request for now.
|
||||||
|
leftover_waiting_sequences.appendleft(seq_group)
|
||||||
|
self.waiting.popleft()
|
||||||
|
continue
|
||||||
|
|
||||||
# If the number of batched tokens exceeds the limit, stop.
|
# If the number of batched tokens exceeds the limit, stop.
|
||||||
new_seq_lens = seq_lens + [num_prompt_tokens]
|
new_seq_lens = seq_lens + [num_prompt_tokens]
|
||||||
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
|
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
|
||||||
@ -207,12 +245,16 @@ class Scheduler:
|
|||||||
break
|
break
|
||||||
seq_lens = new_seq_lens
|
seq_lens = new_seq_lens
|
||||||
|
|
||||||
seq_group = self.waiting.popleft()
|
if lora_int_id > 0:
|
||||||
|
curr_loras.add(lora_int_id)
|
||||||
|
self.waiting.popleft()
|
||||||
self._allocate(seq_group)
|
self._allocate(seq_group)
|
||||||
self.running.append(seq_group)
|
self.running.append(seq_group)
|
||||||
num_curr_seqs += num_new_seqs
|
num_curr_seqs += num_new_seqs
|
||||||
scheduled.append(seq_group)
|
scheduled.append(seq_group)
|
||||||
|
|
||||||
|
self.waiting.extendleft(leftover_waiting_sequences)
|
||||||
|
|
||||||
if scheduled or ignored_seq_groups:
|
if scheduled or ignored_seq_groups:
|
||||||
scheduler_outputs = SchedulerOutputs(
|
scheduler_outputs = SchedulerOutputs(
|
||||||
scheduled_seq_groups=scheduled,
|
scheduled_seq_groups=scheduled,
|
||||||
@ -260,9 +302,25 @@ class Scheduler:
|
|||||||
if not preempted:
|
if not preempted:
|
||||||
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||||
for seq_group in self.running)
|
for seq_group in self.running)
|
||||||
|
curr_loras = set(
|
||||||
|
seq_group.lora_int_id
|
||||||
|
for seq_group in self.running) if self.lora_enabled else None
|
||||||
|
|
||||||
|
leftover_swapped = deque()
|
||||||
|
|
||||||
while self.swapped:
|
while self.swapped:
|
||||||
seq_group = self.swapped[0]
|
seq_group = self.swapped[0]
|
||||||
|
lora_int_id = 0
|
||||||
|
if self.lora_enabled:
|
||||||
|
lora_int_id = seq_group.lora_int_id
|
||||||
|
if lora_int_id > 0 and lora_int_id not in curr_loras and len(
|
||||||
|
curr_loras) >= self.lora_config.max_loras:
|
||||||
|
# We don't have a space for another LoRA, so
|
||||||
|
# we ignore this request for now.
|
||||||
|
leftover_swapped.appendleft(seq_group)
|
||||||
|
self.swapped.popleft()
|
||||||
|
continue
|
||||||
|
|
||||||
# If the sequence group cannot be swapped in, stop.
|
# If the sequence group cannot be swapped in, stop.
|
||||||
if not self.block_manager.can_swap_in(seq_group):
|
if not self.block_manager.can_swap_in(seq_group):
|
||||||
break
|
break
|
||||||
@ -274,12 +332,16 @@ class Scheduler:
|
|||||||
self.scheduler_config.max_num_seqs):
|
self.scheduler_config.max_num_seqs):
|
||||||
break
|
break
|
||||||
|
|
||||||
seq_group = self.swapped.popleft()
|
if lora_int_id > 0:
|
||||||
|
curr_loras.add(lora_int_id)
|
||||||
|
self.swapped.popleft()
|
||||||
self._swap_in(seq_group, blocks_to_swap_in)
|
self._swap_in(seq_group, blocks_to_swap_in)
|
||||||
self._append_slot(seq_group, blocks_to_copy)
|
self._append_slot(seq_group, blocks_to_copy)
|
||||||
num_curr_seqs += num_new_seqs
|
num_curr_seqs += num_new_seqs
|
||||||
self.running.append(seq_group)
|
self.running.append(seq_group)
|
||||||
|
|
||||||
|
self.swapped.extendleft(leftover_swapped)
|
||||||
|
|
||||||
# Each sequence in the generation phase only takes one token slot.
|
# Each sequence in the generation phase only takes one token slot.
|
||||||
# Therefore, the number of batched tokens is equal to the number of
|
# Therefore, the number of batched tokens is equal to the number of
|
||||||
# sequences in the RUNNING state.
|
# sequences in the RUNNING state.
|
||||||
@ -320,6 +382,7 @@ class Scheduler:
|
|||||||
seq_data=seq_data,
|
seq_data=seq_data,
|
||||||
sampling_params=seq_group.sampling_params,
|
sampling_params=seq_group.sampling_params,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
|
lora_request=seq_group.lora_request,
|
||||||
prefix=seq_group.prefix,
|
prefix=seq_group.prefix,
|
||||||
)
|
)
|
||||||
seq_group_metadata_list.append(seq_group_metadata)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig, LoRAConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -35,6 +35,12 @@ class EngineArgs:
|
|||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
enforce_eager: bool = False
|
enforce_eager: bool = False
|
||||||
max_context_len_to_capture: int = 8192
|
max_context_len_to_capture: int = 8192
|
||||||
|
enable_lora: bool = False
|
||||||
|
max_loras: int = 1
|
||||||
|
max_lora_rank: int = 16
|
||||||
|
lora_extra_vocab_size: int = 256
|
||||||
|
lora_dtype = 'auto'
|
||||||
|
max_cpu_loras: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
@ -202,6 +208,39 @@ class EngineArgs:
|
|||||||
help='maximum context length covered by CUDA '
|
help='maximum context length covered by CUDA '
|
||||||
'graphs. When a sequence has context length '
|
'graphs. When a sequence has context length '
|
||||||
'larger than this, we fall back to eager mode.')
|
'larger than this, we fall back to eager mode.')
|
||||||
|
# LoRA related configs
|
||||||
|
parser.add_argument('--enable-lora',
|
||||||
|
action='store_true',
|
||||||
|
help='If True, enable handling of LoRA adapters.')
|
||||||
|
parser.add_argument('--max-loras',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.max_loras,
|
||||||
|
help='Max number of LoRAs in a single batch.')
|
||||||
|
parser.add_argument('--max-lora-rank',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.max_lora_rank,
|
||||||
|
help='Max LoRA rank.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--lora-extra-vocab-size',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.lora_extra_vocab_size,
|
||||||
|
help=('Maximum size of extra vocabulary that can be '
|
||||||
|
'present in a LoRA adapter (added to the base '
|
||||||
|
'model vocabulary).'))
|
||||||
|
parser.add_argument(
|
||||||
|
'--lora-dtype',
|
||||||
|
type=str,
|
||||||
|
default=EngineArgs.lora_dtype,
|
||||||
|
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
||||||
|
help=('Data type for LoRA. If auto, will default to '
|
||||||
|
'base model dtype.'))
|
||||||
|
parser.add_argument(
|
||||||
|
'--max-cpu-loras',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.max_cpu_loras,
|
||||||
|
help=('Maximum number of LoRAs to store in CPU memory. '
|
||||||
|
'Must be >= than max_num_seqs. '
|
||||||
|
'Defaults to max_num_seqs.'))
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -214,7 +253,8 @@ class EngineArgs:
|
|||||||
|
|
||||||
def create_engine_configs(
|
def create_engine_configs(
|
||||||
self,
|
self,
|
||||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
|
||||||
|
Optional[LoRAConfig]]:
|
||||||
model_config = ModelConfig(self.model, self.tokenizer,
|
model_config = ModelConfig(self.model, self.tokenizer,
|
||||||
self.tokenizer_mode, self.trust_remote_code,
|
self.tokenizer_mode, self.trust_remote_code,
|
||||||
self.download_dir, self.load_format,
|
self.download_dir, self.load_format,
|
||||||
@ -234,7 +274,14 @@ class EngineArgs:
|
|||||||
self.max_num_seqs,
|
self.max_num_seqs,
|
||||||
model_config.max_model_len,
|
model_config.max_model_len,
|
||||||
self.max_paddings)
|
self.max_paddings)
|
||||||
return model_config, cache_config, parallel_config, scheduler_config
|
lora_config = LoRAConfig(
|
||||||
|
max_lora_rank=self.max_lora_rank,
|
||||||
|
max_loras=self.max_loras,
|
||||||
|
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||||
|
lora_dtype=self.lora_dtype,
|
||||||
|
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||||
|
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||||
|
return model_config, cache_config, parallel_config, scheduler_config, lora_config
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -4,6 +4,7 @@ from functools import partial
|
|||||||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||||
Union, AsyncIterator)
|
Union, AsyncIterator)
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
@ -203,6 +204,52 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
|
|
||||||
return self._process_model_outputs(output, scheduler_outputs)
|
return self._process_model_outputs(output, scheduler_outputs)
|
||||||
|
|
||||||
|
async def encode_request_async(
|
||||||
|
self,
|
||||||
|
request_id: str, # pylint: disable=unused-argument
|
||||||
|
prompt: Optional[str],
|
||||||
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
):
|
||||||
|
if prompt_token_ids is None:
|
||||||
|
assert prompt is not None
|
||||||
|
prompt_token_ids = await self.tokenizer.encode_async(
|
||||||
|
request_id=request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
lora_request=lora_request)
|
||||||
|
return prompt_token_ids
|
||||||
|
|
||||||
|
async def add_request_async(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: Optional[str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
prefix_pos: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
if lora_request is not None and not self.lora_config:
|
||||||
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||||
|
"not enabled!")
|
||||||
|
if arrival_time is None:
|
||||||
|
arrival_time = time.time()
|
||||||
|
prompt_token_ids = await self.encode_request_async(
|
||||||
|
request_id=request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
lora_request=lora_request)
|
||||||
|
|
||||||
|
return self.add_request(
|
||||||
|
request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
arrival_time=arrival_time,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prefix_pos=prefix_pos,
|
||||||
|
)
|
||||||
|
|
||||||
async def _run_workers_async(
|
async def _run_workers_async(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
@ -332,7 +379,7 @@ class AsyncLLMEngine:
|
|||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
await self.engine.add_request.remote(**new_request)
|
await self.engine.add_request.remote(**new_request)
|
||||||
else:
|
else:
|
||||||
self.engine.add_request(**new_request)
|
await self.engine.add_request_async(**new_request)
|
||||||
|
|
||||||
if finished_requests:
|
if finished_requests:
|
||||||
await self._engine_abort(finished_requests)
|
await self._engine_abort(finished_requests)
|
||||||
@ -371,6 +418,7 @@ class AsyncLLMEngine:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix_pos: Optional[int] = None,
|
prefix_pos: Optional[int] = None,
|
||||||
) -> AsyncStream:
|
) -> AsyncStream:
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
@ -386,7 +434,8 @@ class AsyncLLMEngine:
|
|||||||
f"prompt: {shortened_prompt!r}, "
|
f"prompt: {shortened_prompt!r}, "
|
||||||
f"prefix_pos: {prefix_pos},"
|
f"prefix_pos: {prefix_pos},"
|
||||||
f"sampling params: {sampling_params}, "
|
f"sampling params: {sampling_params}, "
|
||||||
f"prompt token ids: {shortened_token_ids}.")
|
f"prompt token ids: {shortened_token_ids}, "
|
||||||
|
f"lora_request: {lora_request}.")
|
||||||
|
|
||||||
if not self.is_running:
|
if not self.is_running:
|
||||||
if self.start_engine_loop:
|
if self.start_engine_loop:
|
||||||
@ -398,12 +447,21 @@ class AsyncLLMEngine:
|
|||||||
"error that caused the background loop to stop "
|
"error that caused the background loop to stop "
|
||||||
"(AsyncEngineDeadError).")
|
"(AsyncEngineDeadError).")
|
||||||
|
|
||||||
|
if arrival_time is None:
|
||||||
|
arrival_time = time.time()
|
||||||
|
prompt_token_ids = await self.engine.encode_request_async(
|
||||||
|
request_id=request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
lora_request=lora_request)
|
||||||
|
|
||||||
stream = self._request_tracker.add_request(
|
stream = self._request_tracker.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
|
lora_request=lora_request,
|
||||||
prefix_pos=prefix_pos)
|
prefix_pos=prefix_pos)
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
@ -414,6 +472,7 @@ class AsyncLLMEngine:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix_pos: Optional[int] = None,
|
prefix_pos: Optional[int] = None,
|
||||||
) -> AsyncIterator[RequestOutput]:
|
) -> AsyncIterator[RequestOutput]:
|
||||||
"""Generate outputs for a request.
|
"""Generate outputs for a request.
|
||||||
@ -429,6 +488,7 @@ class AsyncLLMEngine:
|
|||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||||
use the tokenizer to convert the prompts to token IDs.
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
|
lora_request: LoRA request to use for generation, if any.
|
||||||
prefix_pos: If not None, we use the given position as the prefix
|
prefix_pos: If not None, we use the given position as the prefix
|
||||||
position for each prompt. We will cache the prefix's KV
|
position for each prompt. We will cache the prefix's KV
|
||||||
cache and reuse it for the next request with the same prefix.
|
cache and reuse it for the next request with the same prefix.
|
||||||
@ -487,12 +547,15 @@ class AsyncLLMEngine:
|
|||||||
arrival_time = time.monotonic()
|
arrival_time = time.monotonic()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream = await self.add_request(request_id,
|
stream = await self.add_request(
|
||||||
prompt,
|
request_id,
|
||||||
sampling_params,
|
prompt,
|
||||||
prompt_token_ids=prompt_token_ids,
|
sampling_params,
|
||||||
arrival_time=arrival_time,
|
prompt_token_ids=prompt_token_ids,
|
||||||
prefix_pos=prefix_pos)
|
arrival_time=arrival_time,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prefix_pos=prefix_pos,
|
||||||
|
)
|
||||||
|
|
||||||
async for request_output in stream:
|
async for request_output in stream:
|
||||||
yield request_output
|
yield request_output
|
||||||
|
@ -5,8 +5,9 @@ import time
|
|||||||
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
||||||
Union)
|
Union)
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig, LoRAConfig)
|
||||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.metrics import record_metrics
|
from vllm.engine.metrics import record_metrics
|
||||||
@ -17,7 +18,7 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||||
get_tokenizer)
|
TokenizerGroup)
|
||||||
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
|
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
|
||||||
|
|
||||||
if ray:
|
if ray:
|
||||||
@ -64,6 +65,7 @@ class LLMEngine:
|
|||||||
cache_config: CacheConfig,
|
cache_config: CacheConfig,
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
placement_group: Optional["PlacementGroup"],
|
placement_group: Optional["PlacementGroup"],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -87,17 +89,13 @@ class LLMEngine:
|
|||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
|
self.lora_config = lora_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
self.tokenizer = get_tokenizer(
|
self._init_tokenizer()
|
||||||
model_config.tokenizer,
|
|
||||||
tokenizer_mode=model_config.tokenizer_mode,
|
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
|
||||||
tokenizer_revision=model_config.tokenizer_revision,
|
|
||||||
revision=model_config.revision)
|
|
||||||
self.seq_counter = Counter()
|
self.seq_counter = Counter()
|
||||||
|
|
||||||
# Create the parallel GPU workers.
|
# Create the parallel GPU workers.
|
||||||
@ -114,7 +112,7 @@ class LLMEngine:
|
|||||||
self._init_cache()
|
self._init_cache()
|
||||||
|
|
||||||
# Create the scheduler.
|
# Create the scheduler.
|
||||||
self.scheduler = Scheduler(scheduler_config, cache_config)
|
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||||
|
|
||||||
# Logging.
|
# Logging.
|
||||||
self.last_logging_time = 0.0
|
self.last_logging_time = 0.0
|
||||||
@ -123,6 +121,9 @@ class LLMEngine:
|
|||||||
# List of (timestamp, num_tokens)
|
# List of (timestamp, num_tokens)
|
||||||
self.num_generation_tokens: List[Tuple[float, int]] = []
|
self.num_generation_tokens: List[Tuple[float, int]] = []
|
||||||
|
|
||||||
|
def get_tokenizer_for_seq(self, sequence: Sequence):
|
||||||
|
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||||
|
|
||||||
def _init_workers(self):
|
def _init_workers(self):
|
||||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
@ -141,11 +142,24 @@ class LLMEngine:
|
|||||||
local_rank=0,
|
local_rank=0,
|
||||||
rank=0,
|
rank=0,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
|
lora_config=self.lora_config,
|
||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
self._run_workers("init_model")
|
self._run_workers("init_model")
|
||||||
self._run_workers("load_model")
|
self._run_workers("load_model")
|
||||||
|
|
||||||
|
def _init_tokenizer(self, **tokenizer_init_kwargs):
|
||||||
|
init_kwargs = dict(
|
||||||
|
enable_lora=bool(self.lora_config),
|
||||||
|
max_num_seqs=self.scheduler_config.max_num_seqs,
|
||||||
|
max_input_length=None,
|
||||||
|
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||||
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
|
revision=self.model_config.tokenizer_revision)
|
||||||
|
init_kwargs.update(tokenizer_init_kwargs)
|
||||||
|
self.tokenizer: TokenizerGroup = TokenizerGroup(
|
||||||
|
self.model_config.tokenizer, **init_kwargs)
|
||||||
|
|
||||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||||
**ray_remote_kwargs):
|
**ray_remote_kwargs):
|
||||||
if self.parallel_config.tensor_parallel_size == 1:
|
if self.parallel_config.tensor_parallel_size == 1:
|
||||||
@ -233,6 +247,7 @@ class LLMEngine:
|
|||||||
local_rank,
|
local_rank,
|
||||||
rank,
|
rank,
|
||||||
distributed_init_method,
|
distributed_init_method,
|
||||||
|
lora_config=self.lora_config,
|
||||||
))
|
))
|
||||||
|
|
||||||
driver_rank = 0
|
driver_rank = 0
|
||||||
@ -244,6 +259,7 @@ class LLMEngine:
|
|||||||
driver_local_rank,
|
driver_local_rank,
|
||||||
driver_rank,
|
driver_rank,
|
||||||
distributed_init_method,
|
distributed_init_method,
|
||||||
|
lora_config=self.lora_config,
|
||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -257,6 +273,10 @@ class LLMEngine:
|
|||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||||
|
if self.lora_config:
|
||||||
|
self.lora_config.verify_with_model_config(self.model_config)
|
||||||
|
self.lora_config.verify_with_scheduler_config(
|
||||||
|
self.scheduler_config)
|
||||||
|
|
||||||
def _init_cache(self) -> None:
|
def _init_cache(self) -> None:
|
||||||
"""Profiles the memory usage and initializes the KV cache.
|
"""Profiles the memory usage and initializes the KV cache.
|
||||||
@ -332,6 +352,20 @@ class LLMEngine:
|
|||||||
log_stats=not engine_args.disable_log_stats)
|
log_stats=not engine_args.disable_log_stats)
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
def encode_request(
|
||||||
|
self,
|
||||||
|
request_id: str, # pylint: disable=unused-argument
|
||||||
|
prompt: Optional[str],
|
||||||
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
):
|
||||||
|
if prompt_token_ids is None:
|
||||||
|
assert prompt is not None
|
||||||
|
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
lora_request=lora_request)
|
||||||
|
return prompt_token_ids
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -339,6 +373,7 @@ class LLMEngine:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix_pos: Optional[int] = None,
|
prefix_pos: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a request to the engine's request pool.
|
"""Add a request to the engine's request pool.
|
||||||
@ -386,24 +421,31 @@ class LLMEngine:
|
|||||||
>>> # continue the request processing
|
>>> # continue the request processing
|
||||||
>>> ...
|
>>> ...
|
||||||
"""
|
"""
|
||||||
|
if lora_request is not None and not self.lora_config:
|
||||||
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||||
|
"not enabled!")
|
||||||
if arrival_time is None:
|
if arrival_time is None:
|
||||||
arrival_time = time.monotonic()
|
arrival_time = time.monotonic()
|
||||||
if prompt_token_ids is None:
|
prompt_token_ids = self.encode_request(
|
||||||
assert prompt is not None
|
request_id=request_id,
|
||||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
prompt=prompt,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
lora_request=lora_request)
|
||||||
|
|
||||||
# Create the sequences.
|
# Create the sequences.
|
||||||
block_size = self.cache_config.block_size
|
block_size = self.cache_config.block_size
|
||||||
seq_id = next(self.seq_counter)
|
seq_id = next(self.seq_counter)
|
||||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||||
|
lora_request)
|
||||||
|
|
||||||
# Check whether the input specifies prefix
|
# Check whether the input specifies prefix
|
||||||
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
|
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
|
||||||
prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None
|
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
|
||||||
|
if lora_request else 0) if prefix_pos is not None else None
|
||||||
|
|
||||||
# Create the sequence group.
|
# Create the sequence group.
|
||||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||||
arrival_time, prefix)
|
arrival_time, lora_request, prefix)
|
||||||
|
|
||||||
# Add the sequence group to the scheduler.
|
# Add the sequence group to the scheduler.
|
||||||
self.scheduler.add_seq_group(seq_group)
|
self.scheduler.add_seq_group(seq_group)
|
||||||
@ -453,11 +495,13 @@ class LLMEngine:
|
|||||||
|
|
||||||
current_worst_score = (current_worst_seq.get_beam_search_score(
|
current_worst_score = (current_worst_seq.get_beam_search_score(
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
eos_token_id=self.tokenizer.eos_token_id))
|
eos_token_id=self.get_tokenizer_for_seq(
|
||||||
|
current_worst_seq).eos_token_id))
|
||||||
if early_stopping is False:
|
if early_stopping is False:
|
||||||
highest_attainable_score = (best_running_seq.get_beam_search_score(
|
highest_attainable_score = (best_running_seq.get_beam_search_score(
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
eos_token_id=self.tokenizer.eos_token_id))
|
eos_token_id=self.get_tokenizer_for_seq(
|
||||||
|
best_running_seq).eos_token_id))
|
||||||
else:
|
else:
|
||||||
assert early_stopping == "never"
|
assert early_stopping == "never"
|
||||||
if length_penalty > 0.0:
|
if length_penalty > 0.0:
|
||||||
@ -471,7 +515,8 @@ class LLMEngine:
|
|||||||
highest_attainable_score = (
|
highest_attainable_score = (
|
||||||
best_running_seq.get_beam_search_score(
|
best_running_seq.get_beam_search_score(
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
eos_token_id=self.tokenizer.eos_token_id,
|
eos_token_id=self.get_tokenizer_for_seq(
|
||||||
|
best_running_seq).eos_token_id,
|
||||||
seq_len=max_possible_length))
|
seq_len=max_possible_length))
|
||||||
else:
|
else:
|
||||||
# Otherwise, beam search will prefer shorter sequences. The
|
# Otherwise, beam search will prefer shorter sequences. The
|
||||||
@ -480,7 +525,8 @@ class LLMEngine:
|
|||||||
highest_attainable_score = (
|
highest_attainable_score = (
|
||||||
best_running_seq.get_beam_search_score(
|
best_running_seq.get_beam_search_score(
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
eos_token_id=self.tokenizer.eos_token_id))
|
eos_token_id=self.get_tokenizer_for_seq(
|
||||||
|
best_running_seq).eos_token_id))
|
||||||
return current_worst_score >= highest_attainable_score
|
return current_worst_score >= highest_attainable_score
|
||||||
|
|
||||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||||
@ -571,7 +617,7 @@ class LLMEngine:
|
|||||||
# Sort the finished sequences by their scores.
|
# Sort the finished sequences by their scores.
|
||||||
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
eos_token_id=self.tokenizer.eos_token_id),
|
eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
|
||||||
reverse=True)
|
reverse=True)
|
||||||
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
||||||
if is_new:
|
if is_new:
|
||||||
@ -599,7 +645,7 @@ class LLMEngine:
|
|||||||
# Sort the running sequences by their scores.
|
# Sort the running sequences by their scores.
|
||||||
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
eos_token_id=self.tokenizer.eos_token_id),
|
eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
|
||||||
reverse=True)
|
reverse=True)
|
||||||
|
|
||||||
# Check if we can stop the beam search.
|
# Check if we can stop the beam search.
|
||||||
@ -837,7 +883,7 @@ class LLMEngine:
|
|||||||
"""Decodes the new token for a sequence."""
|
"""Decodes the new token for a sequence."""
|
||||||
(new_tokens, new_output_text, prefix_offset,
|
(new_tokens, new_output_text, prefix_offset,
|
||||||
read_offset) = detokenize_incrementally(
|
read_offset) = detokenize_incrementally(
|
||||||
self.tokenizer,
|
self.get_tokenizer_for_seq(seq),
|
||||||
all_input_ids=seq.get_token_ids(),
|
all_input_ids=seq.get_token_ids(),
|
||||||
prev_tokens=seq.tokens,
|
prev_tokens=seq.tokens,
|
||||||
prefix_offset=seq.prefix_offset,
|
prefix_offset=seq.prefix_offset,
|
||||||
@ -879,11 +925,28 @@ class LLMEngine:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Check if the sequence has generated the EOS token.
|
# Check if the sequence has generated the EOS token.
|
||||||
if ((not sampling_params.ignore_eos)
|
if ((not sampling_params.ignore_eos) and seq.get_last_token_id()
|
||||||
and seq.get_last_token_id() == self.tokenizer.eos_token_id):
|
== self.get_tokenizer_for_seq(seq).eos_token_id):
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
||||||
|
return self._run_workers(
|
||||||
|
"add_lora",
|
||||||
|
lora_request=lora_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
assert lora_id > 0, "lora_id must be greater than 0."
|
||||||
|
return self._run_workers(
|
||||||
|
"remove_lora",
|
||||||
|
lora_id=lora_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_loras(self) -> List[int]:
|
||||||
|
return self._run_workers("list_loras")
|
||||||
|
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
|
@ -3,6 +3,7 @@ from typing import List, Optional, Union
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
@ -122,6 +123,7 @@ class LLM:
|
|||||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
prefix_pos: Optional[Union[int, List[int]]] = None,
|
prefix_pos: Optional[Union[int, List[int]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
"""Generates the completions for the input prompts.
|
"""Generates the completions for the input prompts.
|
||||||
|
|
||||||
@ -141,6 +143,7 @@ class LLM:
|
|||||||
This is an experimental feature, and may be replaced with
|
This is an experimental feature, and may be replaced with
|
||||||
automatic prefix caching in the future.
|
automatic prefix caching in the future.
|
||||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of `RequestOutput` objects containing the generated
|
A list of `RequestOutput` objects containing the generated
|
||||||
@ -168,7 +171,11 @@ class LLM:
|
|||||||
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
|
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
|
||||||
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
||||||
i]
|
i]
|
||||||
self._add_request(prompt, sampling_params, token_ids, prefix_pos_i)
|
self._add_request(prompt,
|
||||||
|
sampling_params,
|
||||||
|
token_ids,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prefix_pos=prefix_pos_i)
|
||||||
return self._run_engine(use_tqdm)
|
return self._run_engine(use_tqdm)
|
||||||
|
|
||||||
def _add_request(
|
def _add_request(
|
||||||
@ -176,6 +183,7 @@ class LLM:
|
|||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
prompt_token_ids: Optional[List[int]],
|
prompt_token_ids: Optional[List[int]],
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix_pos: Optional[int] = None,
|
prefix_pos: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
request_id = str(next(self.request_counter))
|
request_id = str(next(self.request_counter))
|
||||||
@ -183,6 +191,7 @@ class LLM:
|
|||||||
prompt,
|
prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
prompt_token_ids,
|
prompt_token_ids,
|
||||||
|
lora_request=lora_request,
|
||||||
prefix_pos=prefix_pos)
|
prefix_pos=prefix_pos)
|
||||||
|
|
||||||
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||||
|
0
vllm/lora/__init__.py
Normal file
0
vllm/lora/__init__.py
Normal file
975
vllm/lora/layers.py
Normal file
975
vllm/lora/layers.py
Normal file
@ -0,0 +1,975 @@
|
|||||||
|
# pylint: disable=unused-argument
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
|
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
|
tensor_model_parallel_all_gather,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
tensor_model_parallel_gather,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
MergedColumnParallelLinear)
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_lora(
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: torch.Tensor,
|
||||||
|
lora_b_stacked: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Applies lora to each input.
|
||||||
|
|
||||||
|
This method applies all loras to each input. It uses the
|
||||||
|
indices vector to determine which lora yields the
|
||||||
|
correct output. An index of -1 means no lora should be
|
||||||
|
applied. This method adds the final lora results to the
|
||||||
|
output.
|
||||||
|
|
||||||
|
Input shapes:
|
||||||
|
x: (batch_size, hidden_dim)
|
||||||
|
lora_a_stacked: (num_loras, lora_rank, hidden_dim)
|
||||||
|
lora_b_stacked: (num_loras, output_dim, lora_rank)
|
||||||
|
indices: (batch_size)
|
||||||
|
output: (batch_size, output_dim)
|
||||||
|
"""
|
||||||
|
org_output = output
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
output = output.view(-1, output.shape[-1])
|
||||||
|
indices = indices.view(-1)
|
||||||
|
add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
|
||||||
|
return output.view_as(org_output)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_lora_packed_nslice(
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
|
lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
|
indices: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
output_slices: Tuple[int, ...],
|
||||||
|
):
|
||||||
|
"""Applies lora to each input.
|
||||||
|
|
||||||
|
This method applies all loras to each input. It uses the
|
||||||
|
indices vector to determine which lora yields the
|
||||||
|
correct output. An index of -1 means no lora should be
|
||||||
|
applied. This method adds the final lora results to the
|
||||||
|
output.
|
||||||
|
|
||||||
|
This method is used for layers that are composed of multiple sublayers
|
||||||
|
(slices) packed together.
|
||||||
|
|
||||||
|
Input shapes:
|
||||||
|
x: (batch_size, hidden_dim)
|
||||||
|
lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
|
||||||
|
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
|
||||||
|
indices: (batch_size)
|
||||||
|
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
||||||
|
output_slices: n-1 element tuple of (slice_size...), where n is number of slices
|
||||||
|
"""
|
||||||
|
org_output = output
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
output = output.view(-1, output.shape[-1])
|
||||||
|
indices = indices.view(-1)
|
||||||
|
offset_left = 0
|
||||||
|
for slice_idx in range(len(output_slices)):
|
||||||
|
add_lora_slice(output, x, lora_a_stacked[slice_idx],
|
||||||
|
lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
|
||||||
|
output_slices[slice_idx])
|
||||||
|
offset_left += output_slices[slice_idx]
|
||||||
|
return output.view_as(org_output)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAMapping:
|
||||||
|
# Per every token in input_ids:
|
||||||
|
index_mapping: Tuple[int, ...]
|
||||||
|
# Per sampled token:
|
||||||
|
prompt_mapping: Tuple[int, ...]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.index_mapping = tuple(self.index_mapping)
|
||||||
|
self.prompt_mapping = tuple(self.prompt_mapping)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLayerWithLoRA(nn.Module):
|
||||||
|
|
||||||
|
def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
|
||||||
|
model_config: PretrainedConfig) -> None:
|
||||||
|
"""Initializes lora matrices."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def reset_lora(self, index: int):
|
||||||
|
"""Resets the lora weights at index back to 0."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def set_lora(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
lora_a: torch.Tensor,
|
||||||
|
lora_b: torch.Tensor,
|
||||||
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
"""Overwrites lora tensors at index."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def set_mapping(
|
||||||
|
self,
|
||||||
|
base_indices: torch.Tensor,
|
||||||
|
sampler_indices: torch.Tensor,
|
||||||
|
sampler_indices_padded: torch.Tensor,
|
||||||
|
embeddings_indices: torch.Tensor,
|
||||||
|
indices_len: List[int],
|
||||||
|
):
|
||||||
|
"""Sets the mapping indices."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||||
|
|
||||||
|
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.base_layer = base_layer
|
||||||
|
|
||||||
|
def create_lora_weights(
|
||||||
|
self,
|
||||||
|
max_loras: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||||
|
|
||||||
|
lora_vocab_start_idx = self.base_layer.org_vocab_size
|
||||||
|
weights_idx = None
|
||||||
|
if self.base_layer.vocab_end_index > lora_vocab_start_idx:
|
||||||
|
# We can start adding lora weights
|
||||||
|
weights_idx = max(
|
||||||
|
lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
|
||||||
|
self.embeddings_slice = (self.base_layer.vocab_start_index -
|
||||||
|
self.base_layer.org_vocab_size +
|
||||||
|
weights_idx,
|
||||||
|
self.base_layer.vocab_end_index -
|
||||||
|
self.base_layer.org_vocab_size)
|
||||||
|
self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
|
||||||
|
self.embeddings_weights.fill_(0)
|
||||||
|
else:
|
||||||
|
self.embeddings_slice = None
|
||||||
|
self.embeddings_weights = None
|
||||||
|
|
||||||
|
self.embeddings_tensors = torch.zeros(
|
||||||
|
(
|
||||||
|
max_loras,
|
||||||
|
lora_config.lora_extra_vocab_size,
|
||||||
|
self.base_layer.embedding_dim,
|
||||||
|
),
|
||||||
|
dtype=self.base_layer.weight.dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
)
|
||||||
|
self.lora_a_stacked = torch.zeros(
|
||||||
|
(
|
||||||
|
max_loras,
|
||||||
|
self.base_layer.org_vocab_size +
|
||||||
|
lora_config.lora_extra_vocab_size,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
),
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
)
|
||||||
|
self.lora_b_stacked = torch.zeros(
|
||||||
|
(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.base_layer.embedding_dim,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
),
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
)
|
||||||
|
self.lora_a_stacked_2d = self.lora_a_stacked.view(
|
||||||
|
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
|
||||||
|
self.lora_a_stacked.shape[2],
|
||||||
|
)
|
||||||
|
self.indices: Optional[torch.Tensor] = None
|
||||||
|
self.indices_len: Optional[List[int]] = None
|
||||||
|
self.embeddings_indices = None
|
||||||
|
|
||||||
|
def reset_lora(self, index: int):
|
||||||
|
self.lora_a_stacked[index] = 0
|
||||||
|
self.lora_b_stacked[index] = 0
|
||||||
|
self.embeddings_tensors[index] = 0
|
||||||
|
|
||||||
|
def set_lora(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
lora_a: torch.Tensor,
|
||||||
|
lora_b: torch.Tensor,
|
||||||
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
self.reset_lora(index)
|
||||||
|
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
||||||
|
lora_a, non_blocking=True)
|
||||||
|
self.lora_b_stacked[index,
|
||||||
|
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||||
|
lora_b.T, non_blocking=True)
|
||||||
|
if embeddings_tensor is not None:
|
||||||
|
self.embeddings_tensors[
|
||||||
|
index, :embeddings_tensor.shape[0], :embeddings_tensor.
|
||||||
|
shape[1]].copy_(embeddings_tensor, non_blocking=True)
|
||||||
|
if self.embeddings_slice is not None:
|
||||||
|
# TODO(yard1): Optimize this copy, we don't need to copy
|
||||||
|
# everything, just the modified part
|
||||||
|
embeddings = self.embeddings_tensors.view(
|
||||||
|
self.embeddings_tensors.shape[0] *
|
||||||
|
self.embeddings_tensors.shape[1],
|
||||||
|
self.embeddings_tensors.shape[2]
|
||||||
|
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
|
||||||
|
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
|
||||||
|
|
||||||
|
def set_mapping(
|
||||||
|
self,
|
||||||
|
base_indices: torch.Tensor,
|
||||||
|
sampler_indices: torch.Tensor,
|
||||||
|
sampler_indices_padded: torch.Tensor,
|
||||||
|
embeddings_indices: torch.Tensor,
|
||||||
|
indices_len: List[int],
|
||||||
|
):
|
||||||
|
self.indices = base_indices
|
||||||
|
self.embeddings_indices = embeddings_indices
|
||||||
|
self.indices_len = indices_len
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
|
||||||
|
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
|
||||||
|
full_lora_a_embeddings = F.embedding(
|
||||||
|
x + indices,
|
||||||
|
self.lora_a_stacked_2d,
|
||||||
|
)
|
||||||
|
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
|
||||||
|
full_output = self.base_layer.forward(
|
||||||
|
x.add_(indices * added_tokens_mask))
|
||||||
|
|
||||||
|
full_output_org = full_output
|
||||||
|
if full_output.ndim == 3:
|
||||||
|
full_output = full_output.view(
|
||||||
|
full_output.shape[0] * full_output.shape[1], -1)
|
||||||
|
if full_lora_a_embeddings.ndim == 3:
|
||||||
|
full_lora_a_embeddings = full_lora_a_embeddings.view(
|
||||||
|
full_lora_a_embeddings.shape[0] *
|
||||||
|
full_lora_a_embeddings.shape[1], -1)
|
||||||
|
bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
|
||||||
|
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||||
|
return full_output.view_as(full_output_org)
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||||
|
|
||||||
|
def __init__(self, base_layer: ColumnParallelLinear) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.base_layer = base_layer
|
||||||
|
|
||||||
|
def create_lora_weights(
|
||||||
|
self,
|
||||||
|
max_loras: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||||
|
self.lora_a_stacked = torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
self.base_layer.weight.shape[1],
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
)
|
||||||
|
self.lora_b_stacked = torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.base_layer.weight.shape[0],
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.indices: Optional[torch.Tensor] = None
|
||||||
|
self.indices_len: Optional[List[int]] = None
|
||||||
|
self.output_dim = self.lora_b_stacked.shape[1]
|
||||||
|
|
||||||
|
def reset_lora(self, index: int):
|
||||||
|
self.lora_a_stacked[index] = 0
|
||||||
|
self.lora_b_stacked[index] = 0
|
||||||
|
|
||||||
|
def set_lora(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
lora_a: torch.Tensor,
|
||||||
|
lora_b: torch.Tensor,
|
||||||
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
self.reset_lora(index)
|
||||||
|
|
||||||
|
self.lora_a_stacked[index,
|
||||||
|
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||||
|
lora_a.T, non_blocking=True)
|
||||||
|
self.lora_b_stacked[index,
|
||||||
|
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||||
|
lora_b.T, non_blocking=True)
|
||||||
|
|
||||||
|
def set_mapping(
|
||||||
|
self,
|
||||||
|
base_indices: torch.Tensor,
|
||||||
|
sampler_indices: torch.Tensor,
|
||||||
|
sampler_indices_padded: torch.Tensor,
|
||||||
|
embeddings_indices: torch.Tensor,
|
||||||
|
indices_len: List[int],
|
||||||
|
):
|
||||||
|
self.indices = base_indices
|
||||||
|
self.indices_len = indices_len
|
||||||
|
|
||||||
|
def apply_weights(self, x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
output = self.base_layer.linear_method.apply_weights(
|
||||||
|
self.base_layer.linear_weights, x, bias)
|
||||||
|
_apply_lora(
|
||||||
|
x,
|
||||||
|
self.lora_a_stacked,
|
||||||
|
self.lora_b_stacked,
|
||||||
|
self.indices[:self.indices_len[0]],
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(self, input_):
|
||||||
|
"""Forward of ColumnParallelLinear
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_: Tensor whose last dimension is `input_size`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- output
|
||||||
|
- bias
|
||||||
|
"""
|
||||||
|
bias = (self.base_layer.bias
|
||||||
|
if not self.base_layer.skip_bias_add else None)
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
output_parallel = self.apply_weights(input_, bias)
|
||||||
|
if self.base_layer.gather_output:
|
||||||
|
# All-gather across the partitions.
|
||||||
|
output = tensor_model_parallel_all_gather(output_parallel)
|
||||||
|
else:
|
||||||
|
output = output_parallel
|
||||||
|
output_bias = (self.base_layer.bias
|
||||||
|
if self.base_layer.skip_bias_add else None)
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
@property
|
||||||
|
def linear_weights(self):
|
||||||
|
return self.base_layer.linear_weights
|
||||||
|
|
||||||
|
|
||||||
|
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||||
|
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
|
||||||
|
packed together (eg. gate_proj + up_proj -> gate_up_proj).
|
||||||
|
|
||||||
|
This means we have 2 LoRAs, each applied to one half of the layer.
|
||||||
|
|
||||||
|
Both slices must have the same size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
|
||||||
|
super().__init__(base_layer)
|
||||||
|
|
||||||
|
def create_lora_weights(
|
||||||
|
self,
|
||||||
|
max_loras: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||||
|
n_slices = 2
|
||||||
|
if not (len(self.base_layer.output_sizes) == n_slices
|
||||||
|
and self.base_layer.output_sizes[0]
|
||||||
|
== self.base_layer.output_sizes[1]):
|
||||||
|
raise ValueError(
|
||||||
|
"LoRAColumnParallelLinear2Slice requires 2 slices with "
|
||||||
|
"the same size.")
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
self.lora_a_stacked = tuple(
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
self.base_layer.weight.shape[1],
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
) for _ in range(n_slices))
|
||||||
|
self.lora_b_stacked = tuple(
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.base_layer.weight.shape[0] // 2,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
) for _ in range(n_slices))
|
||||||
|
|
||||||
|
self.indices: Optional[torch.Tensor] = None
|
||||||
|
self.output_dim = self.lora_b_stacked[0].shape[2]
|
||||||
|
|
||||||
|
def reset_lora(self, index: int):
|
||||||
|
self.lora_a_stacked[0][index] = 0
|
||||||
|
self.lora_a_stacked[1][index] = 0
|
||||||
|
self.lora_b_stacked[0][index] = 0
|
||||||
|
self.lora_b_stacked[1][index] = 0
|
||||||
|
|
||||||
|
def set_lora(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
lora_a: torch.Tensor,
|
||||||
|
lora_b: torch.Tensor,
|
||||||
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
self.reset_lora(index)
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
shard_size = self.output_dim
|
||||||
|
start_idx = tensor_model_parallel_rank * shard_size
|
||||||
|
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||||
|
lora_b = lora_b[0][:,
|
||||||
|
start_idx:end_idx], lora_b[1][:,
|
||||||
|
start_idx:end_idx]
|
||||||
|
|
||||||
|
if lora_a[0] is not None:
|
||||||
|
self.lora_a_stacked[0][
|
||||||
|
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
|
||||||
|
lora_a[0].T, non_blocking=True)
|
||||||
|
self.lora_b_stacked[0][
|
||||||
|
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
||||||
|
lora_b[0].T, non_blocking=True)
|
||||||
|
if lora_a[1] is not None:
|
||||||
|
self.lora_a_stacked[1][
|
||||||
|
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
||||||
|
lora_a[1].T, non_blocking=True)
|
||||||
|
self.lora_b_stacked[1][
|
||||||
|
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
||||||
|
lora_b[1].T, non_blocking=True)
|
||||||
|
|
||||||
|
def apply_weights(self, x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
output = self.base_layer.linear_method.apply_weights(
|
||||||
|
self.base_layer.linear_weights, x, bias)
|
||||||
|
_apply_lora_packed_nslice(
|
||||||
|
x,
|
||||||
|
self.lora_a_stacked,
|
||||||
|
self.lora_b_stacked,
|
||||||
|
self.indices[:self.indices_len[0]],
|
||||||
|
output,
|
||||||
|
(self.output_dim, self.output_dim),
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||||
|
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
|
||||||
|
packed together in qkv proj fashion
|
||||||
|
(q_proj + k_proj + v_proj -> qkv_proj).
|
||||||
|
|
||||||
|
This means we have 3 LoRAs, each applied to one slice of the layer.
|
||||||
|
|
||||||
|
Q slice may have different shape than K and V slices (which both have
|
||||||
|
the same shape).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
||||||
|
super().__init__(base_layer)
|
||||||
|
|
||||||
|
def create_lora_weights(
|
||||||
|
self,
|
||||||
|
max_loras: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
self.q_proj_shard_size = (self.base_layer.num_heads *
|
||||||
|
self.base_layer.head_size)
|
||||||
|
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
||||||
|
self.base_layer.head_size)
|
||||||
|
self.q_shard_id = tp_rank
|
||||||
|
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
|
||||||
|
|
||||||
|
# q, k, v
|
||||||
|
self.lora_a_stacked = (
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
self.base_layer.weight.shape[1],
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
self.base_layer.weight.shape[1],
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
self.base_layer.weight.shape[1],
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.lora_b_stacked = (
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.q_proj_shard_size,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.kv_proj_shard_size,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.kv_proj_shard_size,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
|
||||||
|
self.kv_proj_shard_size)
|
||||||
|
self.packed_indices: Optional[torch.Tensor] = None
|
||||||
|
self.standard_indices: Optional[torch.Tensor] = None
|
||||||
|
self.indices_len: Optional[List[int]] = None
|
||||||
|
|
||||||
|
def reset_lora(self, index: int):
|
||||||
|
self.lora_a_stacked[0][index] = 0
|
||||||
|
self.lora_b_stacked[0][index] = 0
|
||||||
|
self.lora_a_stacked[1][index] = 0
|
||||||
|
self.lora_b_stacked[1][index] = 0
|
||||||
|
self.lora_a_stacked[2][index] = 0
|
||||||
|
self.lora_b_stacked[2][index] = 0
|
||||||
|
|
||||||
|
def set_lora(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
lora_a: torch.Tensor,
|
||||||
|
lora_b: torch.Tensor,
|
||||||
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
self.reset_lora(index)
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
if lora_b[0] is not None:
|
||||||
|
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
|
||||||
|
self.q_shard_id:self.q_proj_shard_size *
|
||||||
|
(self.q_shard_id + 1)]
|
||||||
|
self.lora_b_stacked[0][
|
||||||
|
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
|
||||||
|
lora_b_q.T, non_blocking=True)
|
||||||
|
if lora_b[1] is not None:
|
||||||
|
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
|
||||||
|
self.kv_shard_id:self.kv_proj_shard_size *
|
||||||
|
(self.kv_shard_id + 1)]
|
||||||
|
self.lora_b_stacked[1][
|
||||||
|
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
|
||||||
|
lora_b_k.T, non_blocking=True)
|
||||||
|
if lora_b[2] is not None:
|
||||||
|
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
|
||||||
|
self.kv_shard_id:self.kv_proj_shard_size *
|
||||||
|
(self.kv_shard_id + 1)]
|
||||||
|
self.lora_b_stacked[2][
|
||||||
|
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
|
||||||
|
lora_b_v.T, non_blocking=True)
|
||||||
|
else:
|
||||||
|
if lora_b[0] is not None:
|
||||||
|
self.lora_b_stacked[0][
|
||||||
|
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
||||||
|
lora_b[0].T, non_blocking=True)
|
||||||
|
if lora_b[1] is not None:
|
||||||
|
self.lora_b_stacked[1][
|
||||||
|
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
||||||
|
lora_b[1].T, non_blocking=True)
|
||||||
|
if lora_b[2] is not None:
|
||||||
|
self.lora_b_stacked[2][
|
||||||
|
index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
|
||||||
|
lora_b[2].T, non_blocking=True)
|
||||||
|
|
||||||
|
if lora_a[0] is not None:
|
||||||
|
self.lora_a_stacked[0][
|
||||||
|
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
|
||||||
|
lora_a[0].T, non_blocking=True)
|
||||||
|
if lora_a[1] is not None:
|
||||||
|
self.lora_a_stacked[1][
|
||||||
|
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
||||||
|
lora_a[1].T, non_blocking=True)
|
||||||
|
if lora_a[2] is not None:
|
||||||
|
self.lora_a_stacked[2][
|
||||||
|
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
|
||||||
|
lora_a[2].T, non_blocking=True)
|
||||||
|
|
||||||
|
def apply_weights(self, x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
output = self.base_layer.linear_method.apply_weights(
|
||||||
|
self.base_layer.linear_weights, x, bias)
|
||||||
|
_apply_lora_packed_nslice(
|
||||||
|
x,
|
||||||
|
self.lora_a_stacked,
|
||||||
|
self.lora_b_stacked,
|
||||||
|
self.indices[:self.indices_len[0]],
|
||||||
|
output,
|
||||||
|
self.output_slices,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||||
|
|
||||||
|
def __init__(self, base_layer: RowParallelLinear) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.base_layer = base_layer
|
||||||
|
|
||||||
|
def create_lora_weights(
|
||||||
|
self,
|
||||||
|
max_loras: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||||
|
self.lora_a_stacked = torch.zeros(
|
||||||
|
(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
self.base_layer.weight.shape[1],
|
||||||
|
),
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
)
|
||||||
|
self.lora_b_stacked = torch.zeros(
|
||||||
|
(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.base_layer.weight.shape[0],
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
),
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.base_layer.weight.device,
|
||||||
|
)
|
||||||
|
self.indices: Optional[torch.Tensor] = None
|
||||||
|
self.indices_len: Optional[List[int]] = None
|
||||||
|
|
||||||
|
def reset_lora(self, index: int):
|
||||||
|
self.lora_a_stacked[index] = 0
|
||||||
|
self.lora_b_stacked[index] = 0
|
||||||
|
|
||||||
|
def set_lora(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
lora_a: torch.Tensor,
|
||||||
|
lora_b: torch.Tensor,
|
||||||
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
self.reset_lora(index)
|
||||||
|
if self.base_layer.tp_size > 1:
|
||||||
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
shard_size = self.base_layer.weight.shape[1]
|
||||||
|
start_idx = tensor_model_parallel_rank * shard_size
|
||||||
|
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||||
|
lora_a = lora_a[start_idx:end_idx, :]
|
||||||
|
|
||||||
|
self.lora_a_stacked[index,
|
||||||
|
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||||
|
lora_a.T, non_blocking=True)
|
||||||
|
self.lora_b_stacked[index,
|
||||||
|
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||||
|
lora_b.T, non_blocking=True)
|
||||||
|
|
||||||
|
def set_mapping(
|
||||||
|
self,
|
||||||
|
base_indices: torch.Tensor,
|
||||||
|
sampler_indices: torch.Tensor,
|
||||||
|
sampler_indices_padded: torch.Tensor,
|
||||||
|
embeddings_indices: torch.Tensor,
|
||||||
|
indices_len: List[int],
|
||||||
|
):
|
||||||
|
self.indices = base_indices
|
||||||
|
self.indices_len = indices_len
|
||||||
|
|
||||||
|
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
output = self.base_layer.linear_method.apply_weights(
|
||||||
|
self.base_layer.linear_weights, x)
|
||||||
|
_apply_lora(
|
||||||
|
x,
|
||||||
|
self.lora_a_stacked,
|
||||||
|
self.lora_b_stacked,
|
||||||
|
self.indices[:self.indices_len[0]],
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(self, input_):
|
||||||
|
"""Forward of RowParallelLinear
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_: tensor whose last dimension is `input_size`. If
|
||||||
|
`input_is_parallel` is set, then the last dimension
|
||||||
|
is `input_size // tp_size`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- output
|
||||||
|
- bias
|
||||||
|
"""
|
||||||
|
# Set up backprop all-reduce.
|
||||||
|
if self.base_layer.input_is_parallel:
|
||||||
|
input_parallel = input_
|
||||||
|
else:
|
||||||
|
# TODO: simplify code below
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
splitted_input = split_tensor_along_last_dim(
|
||||||
|
input_, num_partitions=self.base_layer.tp_size)
|
||||||
|
input_parallel = splitted_input[tp_rank].contiguous()
|
||||||
|
|
||||||
|
# Matrix multiply.
|
||||||
|
output_parallel = self.apply_weights(input_parallel)
|
||||||
|
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||||
|
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
|
else:
|
||||||
|
output_ = output_parallel
|
||||||
|
|
||||||
|
if not self.base_layer.skip_bias_add:
|
||||||
|
output = (output_ + self.base_layer.bias
|
||||||
|
if self.base_layer.bias is not None else output_)
|
||||||
|
output_bias = None
|
||||||
|
else:
|
||||||
|
output = output_
|
||||||
|
output_bias = self.base_layer.bias
|
||||||
|
return output, output_bias
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self):
|
||||||
|
return self.base_layer.weight
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerWithLoRA(BaseLayerWithLoRA):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_layer: Sampler,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.base_layer = base_layer
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return self.base_layer.vocab_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def org_vocab_size(self):
|
||||||
|
return self.base_layer.org_vocab_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def include_gpu_probs_tensor(self):
|
||||||
|
return self.base_layer.include_gpu_probs_tensor
|
||||||
|
|
||||||
|
def create_lora_weights(
|
||||||
|
self,
|
||||||
|
max_loras: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
model_config: Optional[PretrainedConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
|
||||||
|
if 32000 < self.base_layer.vocab_size > 33024:
|
||||||
|
raise ValueError(
|
||||||
|
"When using LoRA, vocab size must be 32000 >= vocab_size <= 33024"
|
||||||
|
)
|
||||||
|
self.lora_a_stacked = torch.zeros(
|
||||||
|
(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
self.hidden_size,
|
||||||
|
),
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.lora_b_stacked = torch.zeros(
|
||||||
|
(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
# Pad for kernel compatibility
|
||||||
|
math.ceil(self.base_layer.vocab_size /
|
||||||
|
lora_config.lora_vocab_padding_size) *
|
||||||
|
lora_config.lora_vocab_padding_size,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
),
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.embeddings_tensors = torch.full(
|
||||||
|
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
|
||||||
|
fill_value=float("-inf"),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.indices = None
|
||||||
|
self.indices_padded = None
|
||||||
|
self.indices_len = None
|
||||||
|
|
||||||
|
def reset_lora(self, index: int):
|
||||||
|
self.lora_a_stacked[index] = 0
|
||||||
|
self.lora_b_stacked[index] = 0
|
||||||
|
self.embeddings_tensors[index] = float("-inf")
|
||||||
|
|
||||||
|
def set_lora(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
lora_a: torch.Tensor,
|
||||||
|
lora_b: torch.Tensor,
|
||||||
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
self.reset_lora(index)
|
||||||
|
self.lora_a_stacked[index,
|
||||||
|
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||||
|
lora_a.T, non_blocking=True)
|
||||||
|
self.lora_b_stacked[index,
|
||||||
|
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||||
|
lora_b.T, non_blocking=True)
|
||||||
|
if embeddings_tensor is not None:
|
||||||
|
self.embeddings_tensors[
|
||||||
|
index, :embeddings_tensor.shape[0], :embeddings_tensor.
|
||||||
|
shape[1], ] = embeddings_tensor
|
||||||
|
|
||||||
|
def set_mapping(
|
||||||
|
self,
|
||||||
|
base_indices: torch.Tensor,
|
||||||
|
sampler_indices: torch.Tensor,
|
||||||
|
sampler_indices_padded: torch.Tensor,
|
||||||
|
embeddings_indices: torch.Tensor,
|
||||||
|
indices_len: List[int],
|
||||||
|
):
|
||||||
|
self.indices = sampler_indices
|
||||||
|
self.indices_padded = sampler_indices_padded
|
||||||
|
self.indices_len = indices_len
|
||||||
|
|
||||||
|
def _get_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
embedding: torch.Tensor,
|
||||||
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Get the logits for the next tokens.
|
||||||
|
logits = torch.matmul(hidden_states, embedding.t())
|
||||||
|
if embedding_bias is not None:
|
||||||
|
logits += embedding_bias
|
||||||
|
logits = tensor_model_parallel_gather(logits)
|
||||||
|
if logits is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
lora_logits = torch.empty(
|
||||||
|
self.embeddings_tensors.shape[0] + 1,
|
||||||
|
self.embeddings_tensors.shape[1],
|
||||||
|
hidden_states.shape[0],
|
||||||
|
dtype=self.embeddings_tensors.dtype,
|
||||||
|
device=self.embeddings_tensors.device,
|
||||||
|
)
|
||||||
|
torch.matmul(self.embeddings_tensors,
|
||||||
|
hidden_states.T,
|
||||||
|
out=lora_logits[:-1])
|
||||||
|
lora_logits[-1] = float("-inf")
|
||||||
|
lora_logits = lora_logits.mT
|
||||||
|
lora_logits = (lora_logits.reshape(
|
||||||
|
lora_logits.shape[0] * lora_logits.shape[1],
|
||||||
|
lora_logits.shape[2],
|
||||||
|
).index_select(0,
|
||||||
|
self.indices_padded[:self.indices_len[2]]).nan_to_num_(
|
||||||
|
nan=float("-inf"),
|
||||||
|
posinf=float("inf"),
|
||||||
|
neginf=float("-inf")))
|
||||||
|
logits[:,
|
||||||
|
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
|
||||||
|
lora_logits.shape[1]] = lora_logits
|
||||||
|
|
||||||
|
_apply_lora(
|
||||||
|
hidden_states,
|
||||||
|
self.lora_a_stacked,
|
||||||
|
self.lora_b_stacked,
|
||||||
|
self.indices[:self.indices_len[1]],
|
||||||
|
logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove paddings in vocab (if any).
|
||||||
|
logits = logits[:, :self.base_layer.vocab_size]
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return type(self.base_layer).forward(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def from_layer(
|
||||||
|
layer: nn.Module,
|
||||||
|
max_loras: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
|
||||||
|
supported_layer_types = {
|
||||||
|
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
||||||
|
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
||||||
|
QKVParallelLinear: QKVParallelLinearWithLora,
|
||||||
|
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
||||||
|
RowParallelLinear: RowParallelLinearWithLoRA,
|
||||||
|
}
|
||||||
|
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||||
|
if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
|
||||||
|
ret = lora_layer_type(layer)
|
||||||
|
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||||
|
return ret
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def from_layer_sampler(
|
||||||
|
layer: Sampler,
|
||||||
|
lm_head: ParallelLMHead,
|
||||||
|
max_loras: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
model_config: Optional[PretrainedConfig] = None,
|
||||||
|
) -> SamplerWithLoRA:
|
||||||
|
ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
|
||||||
|
lm_head.weight.device)
|
||||||
|
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||||
|
return ret
|
160
vllm/lora/lora.py
Normal file
160
vllm/lora/lora.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.utils import in_wsl
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALayerWeights:
|
||||||
|
"""LoRA weights for a layer composed of two low rank matrixes."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module_name: str,
|
||||||
|
rank: int,
|
||||||
|
lora_alpha: int,
|
||||||
|
lora_a: torch.Tensor,
|
||||||
|
lora_b: torch.Tensor,
|
||||||
|
embeddings_tensor: Optional[torch.Tensor] = None,
|
||||||
|
scaling: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
self.module_name = module_name
|
||||||
|
self.rank = rank
|
||||||
|
self.lora_alpha = lora_alpha
|
||||||
|
self.lora_a = lora_a
|
||||||
|
self.lora_b = lora_b
|
||||||
|
self.embeddings_tensor = embeddings_tensor
|
||||||
|
|
||||||
|
if scaling is None:
|
||||||
|
self.scaling = self.lora_alpha / self.rank
|
||||||
|
else:
|
||||||
|
self.scaling = scaling
|
||||||
|
|
||||||
|
def optimize(self) -> "LoRALayerWeights":
|
||||||
|
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||||
|
if self.scaling == 1:
|
||||||
|
return
|
||||||
|
self.lora_b *= self.scaling
|
||||||
|
self.scaling = 1
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_dim(self) -> int:
|
||||||
|
return self.lora_a.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dim(self) -> int:
|
||||||
|
return self.lora_b.shape[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_packed(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_vocab_size(self) -> int:
|
||||||
|
return self.embeddings_tensor.shape[
|
||||||
|
0] if self.embeddings_tensor is not None else 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_dummy_lora_weights(
|
||||||
|
cls,
|
||||||
|
module_name: str,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
rank: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
|
||||||
|
pin_memory = str(device) == "cpu" and not in_wsl()
|
||||||
|
lora_a = torch.zeros([input_dim, rank],
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
lora_b = torch.zeros([rank, output_dim],
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
embeddings_tensor = torch.rand(
|
||||||
|
10,
|
||||||
|
embeddings_tensor_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
pin_memory=pin_memory) if embeddings_tensor_dim else None
|
||||||
|
return cls(
|
||||||
|
module_name,
|
||||||
|
rank=rank,
|
||||||
|
lora_alpha=1,
|
||||||
|
lora_a=lora_a,
|
||||||
|
lora_b=lora_b,
|
||||||
|
embeddings_tensor=embeddings_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PackedLoRALayerWeights(LoRALayerWeights):
|
||||||
|
"""LoRA used for packed layers (eg. qkv_proj)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module_name: str,
|
||||||
|
rank: int,
|
||||||
|
lora_alphas: List[int],
|
||||||
|
lora_a: List[torch.Tensor],
|
||||||
|
lora_b: List[torch.Tensor],
|
||||||
|
scaling: Optional[List[float]] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
module_name=module_name,
|
||||||
|
rank=rank,
|
||||||
|
lora_alpha=0,
|
||||||
|
lora_a=lora_a,
|
||||||
|
lora_b=lora_b,
|
||||||
|
scaling=scaling,
|
||||||
|
embeddings_tensor=None,
|
||||||
|
)
|
||||||
|
self.lora_alphas = lora_alphas
|
||||||
|
if scaling is None:
|
||||||
|
self.scaling = [
|
||||||
|
lora_alpha / self.rank for lora_alpha in self.lora_alphas
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
|
||||||
|
"""Pack a list of LoRAs into a single LoRA.
|
||||||
|
|
||||||
|
If LoRA is None, it signifies that the submodule does not have a LoRA.
|
||||||
|
"""
|
||||||
|
first_lora = next(lora for lora in loras if lora is not None)
|
||||||
|
for lora in loras:
|
||||||
|
if lora is None:
|
||||||
|
continue
|
||||||
|
lora.optimize()
|
||||||
|
rank = first_lora.rank
|
||||||
|
module_name = first_lora.module_name
|
||||||
|
obj = cls(
|
||||||
|
module_name,
|
||||||
|
rank,
|
||||||
|
[lora.lora_alpha if lora is not None else None for lora in loras],
|
||||||
|
[lora.lora_a if lora is not None else None for lora in loras],
|
||||||
|
[lora.lora_b if lora is not None else None for lora in loras],
|
||||||
|
scaling=[1 if lora is not None else None for lora in loras])
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def optimize(self) -> "PackedLoRALayerWeights":
|
||||||
|
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||||
|
for i in range(len(self.lora_b)):
|
||||||
|
if self.scaling[i] == 1 or self.lora_b[i] is None:
|
||||||
|
continue
|
||||||
|
self.lora_b[i] *= self.scaling[i]
|
||||||
|
self.scaling[i] = 1
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_dim(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_dim(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_packed(self) -> bool:
|
||||||
|
return True
|
654
vllm/lora/models.py
Normal file
654
vllm/lora/models.py
Normal file
@ -0,0 +1,654 @@
|
|||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type,
|
||||||
|
Union)
|
||||||
|
|
||||||
|
import safetensors.torch
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
|
from vllm.utils import LRUCache, in_wsl
|
||||||
|
|
||||||
|
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler
|
||||||
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
|
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# TODO: The mappings below should be moved to individual model classes.
|
||||||
|
|
||||||
|
PACKED_MODULES_CFG = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
TARGET_MODULES_QKV = [
|
||||||
|
"qkv_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_up_proj",
|
||||||
|
"down_proj",
|
||||||
|
"embed_tokens",
|
||||||
|
"lm_head",
|
||||||
|
]
|
||||||
|
|
||||||
|
EMBEDDING_MODULES = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings",
|
||||||
|
}
|
||||||
|
|
||||||
|
EMBEDDING_PADDING_MODULES = ["lm_head"]
|
||||||
|
|
||||||
|
_GLOBAL_LORA_ID = 0
|
||||||
|
|
||||||
|
|
||||||
|
def convert_mapping(
|
||||||
|
mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
|
||||||
|
max_loras: int, vocab_size: int, extra_vocab_size: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
|
||||||
|
"""Converts LoRAMapping to index tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
|
||||||
|
lora_index_to_id: List mapping LoRA ids to LoRA indices.
|
||||||
|
max_loras: Maximum number of LoRAs.
|
||||||
|
vocab_size: Model vocab size.
|
||||||
|
extra_vocab_size: Extra vocab size each LoRA can have.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of tensors:
|
||||||
|
base_indices: Tensor of shape [batch_size] mapping batch rows to
|
||||||
|
LoRA indices.
|
||||||
|
sampler_indices: Tensor of shape [batch_size] mapping requests to
|
||||||
|
LoRA indices for sampler. For generation, this will be the
|
||||||
|
same as base_indicies. For prefill, this will map requests
|
||||||
|
to LoRA indices.
|
||||||
|
sampler_indices_padded: Tensor of shape [batch_size] mapping
|
||||||
|
requests to LoRA indices for sampler with padding.
|
||||||
|
Same as sampler_indicies, but -1 is replaced with
|
||||||
|
max_loras.
|
||||||
|
embeddings_indices: Tensor of shape [2, batch_size] mapping
|
||||||
|
requests to embedding indices. First row is for embeddings
|
||||||
|
added by the LoRAs, second row is for the LoRA.lora_a
|
||||||
|
embeddings.
|
||||||
|
indices_len: List of lengths of the above tensors.
|
||||||
|
"""
|
||||||
|
indices = list(mapping.index_mapping).copy()
|
||||||
|
embedding_indices = indices.copy()
|
||||||
|
lora_indices = indices.copy()
|
||||||
|
prompt_mapping = [
|
||||||
|
lora_index_to_id.index(x) if x > 0 else -1
|
||||||
|
for x in mapping.prompt_mapping
|
||||||
|
]
|
||||||
|
lora_idx = None
|
||||||
|
for i in range(len(indices)):
|
||||||
|
# TODO index can be slow. optimize
|
||||||
|
lora_idx = (lora_index_to_id.index(indices[i])
|
||||||
|
if indices[i] > 0 else -1)
|
||||||
|
embedding_indices[i] = lora_idx if indices[i] > 0 else 0
|
||||||
|
indices[i] = i
|
||||||
|
lora_indices[i] = lora_idx
|
||||||
|
|
||||||
|
indices = torch.tensor([indices, lora_indices, embedding_indices],
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cuda")
|
||||||
|
prompt_mapping = torch.tensor(prompt_mapping,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.long)
|
||||||
|
embeddings_indices = torch.stack([
|
||||||
|
indices[2] * extra_vocab_size,
|
||||||
|
indices[2] * (vocab_size + extra_vocab_size)
|
||||||
|
])
|
||||||
|
embeddings_indices[embeddings_indices == -1] = max_loras - 1
|
||||||
|
base_indices = indices[1]
|
||||||
|
sampler_indices = prompt_mapping
|
||||||
|
sampler_indices_padded = sampler_indices.clone()
|
||||||
|
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
||||||
|
sampler_indices_padded = (
|
||||||
|
torch.arange(
|
||||||
|
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
|
||||||
|
(sampler_indices_padded * len(sampler_indices_padded)))
|
||||||
|
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
|
||||||
|
sampler_indices_padded.shape[-1],
|
||||||
|
embeddings_indices.shape[-1])
|
||||||
|
|
||||||
|
return (base_indices, sampler_indices, sampler_indices_padded,
|
||||||
|
embeddings_indices, indices_len)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_id():
|
||||||
|
global _GLOBAL_LORA_ID
|
||||||
|
_GLOBAL_LORA_ID += 1
|
||||||
|
return _GLOBAL_LORA_ID
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAModel:
|
||||||
|
"""A LoRA fine-tuned model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lora_model_id: int,
|
||||||
|
rank: int,
|
||||||
|
loras: Dict[str, LoRALayerWeights],
|
||||||
|
) -> None:
|
||||||
|
self.id = lora_model_id
|
||||||
|
assert (lora_model_id >
|
||||||
|
0), f"a valid lora id should be greater than 0, got {self.id}"
|
||||||
|
self.rank = rank
|
||||||
|
self.loras: Dict[str, LoRALayerWeights] = loras
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_vocab_size(self) -> int:
|
||||||
|
return max(lora.extra_vocab_size
|
||||||
|
for lora in self.loras.values()) if self.loras else 0
|
||||||
|
|
||||||
|
def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
|
||||||
|
"""Get LoRA for a given module by name"""
|
||||||
|
return self.loras.get(module_name, None)
|
||||||
|
|
||||||
|
# (yard1): TODO see if we can derive target_embedding_padding automatically
|
||||||
|
@classmethod
|
||||||
|
def from_lora_tensors(
|
||||||
|
cls,
|
||||||
|
lora_model_id: int,
|
||||||
|
rank: int,
|
||||||
|
lora_alpha: int,
|
||||||
|
tensors: Dict[str, torch.Tensor],
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
target_embedding_padding: Optional[int] = None,
|
||||||
|
) -> "LoRAModel":
|
||||||
|
"""Create a LoRAModel from a dictionary of tensors."""
|
||||||
|
pin_memory = str(device) == "cpu" and not in_wsl()
|
||||||
|
loras: Dict[str, LoRALayerWeights] = {}
|
||||||
|
for tensor_name, tensor in tensors.items():
|
||||||
|
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
|
||||||
|
if module_name not in loras:
|
||||||
|
lora_embeddings_tensor = None
|
||||||
|
if embeddings:
|
||||||
|
embeddings_module = next(
|
||||||
|
(k for k in EMBEDDING_MODULES if k in module_name),
|
||||||
|
None)
|
||||||
|
if embeddings_module:
|
||||||
|
lora_embeddings_tensor = embeddings[
|
||||||
|
EMBEDDING_MODULES[embeddings_module]].to(
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
if pin_memory:
|
||||||
|
lora_embeddings_tensor = (
|
||||||
|
lora_embeddings_tensor.pin_memory())
|
||||||
|
loras[module_name] = LoRALayerWeights(module_name, rank,
|
||||||
|
lora_alpha, None, None,
|
||||||
|
lora_embeddings_tensor)
|
||||||
|
if is_lora_a:
|
||||||
|
loras[module_name].lora_a = tensor.to(device=device,
|
||||||
|
dtype=dtype).t()
|
||||||
|
if pin_memory:
|
||||||
|
loras[module_name].lora_a = loras[
|
||||||
|
module_name].lora_a.pin_memory()
|
||||||
|
else:
|
||||||
|
loras[module_name].lora_b = tensor.to(device=device,
|
||||||
|
dtype=dtype).t()
|
||||||
|
if any(name in module_name
|
||||||
|
for name in EMBEDDING_PADDING_MODULES
|
||||||
|
) and target_embedding_padding is not None:
|
||||||
|
lora_b = loras[module_name].lora_b
|
||||||
|
assert target_embedding_padding >= lora_b.shape[1]
|
||||||
|
addition = target_embedding_padding - lora_b.shape[1]
|
||||||
|
loras[module_name].lora_b = torch.nn.functional.pad(
|
||||||
|
lora_b, (0, addition))
|
||||||
|
if pin_memory:
|
||||||
|
loras[module_name].lora_b = loras[
|
||||||
|
module_name].lora_b.pin_memory()
|
||||||
|
|
||||||
|
for lora in loras.values():
|
||||||
|
lora.optimize()
|
||||||
|
return cls(lora_model_id, rank, loras)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_local_checkpoint(
|
||||||
|
cls,
|
||||||
|
lora_dir: str,
|
||||||
|
lora_model_id: Optional[int] = None,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
target_embedding_padding: Optional[int] = None) -> "LoRAModel":
|
||||||
|
"""Create a LoRAModel from a local checkpoint."""
|
||||||
|
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
|
||||||
|
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
||||||
|
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
|
||||||
|
new_embeddings_tensor_path = os.path.join(
|
||||||
|
lora_dir, "new_embeddings.safetensors")
|
||||||
|
new_embeddings_bin_file_path = os.path.join(lora_dir,
|
||||||
|
"new_embeddings.bin")
|
||||||
|
if os.path.isfile(lora_tensor_path):
|
||||||
|
tensors = safetensors.torch.load_file(lora_tensor_path)
|
||||||
|
elif os.path.isfile(lora_bin_file_path):
|
||||||
|
tensors = torch.load(lora_bin_file_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
||||||
|
|
||||||
|
embeddings = None
|
||||||
|
if os.path.isfile(new_embeddings_tensor_path):
|
||||||
|
embeddings = safetensors.torch.load_file(
|
||||||
|
new_embeddings_tensor_path)
|
||||||
|
elif os.path.isfile(new_embeddings_bin_file_path):
|
||||||
|
embeddings = torch.load(new_embeddings_bin_file_path)
|
||||||
|
|
||||||
|
with open(lora_config_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
rank = config["r"]
|
||||||
|
lora_alpha = config["lora_alpha"]
|
||||||
|
return cls.from_lora_tensors(
|
||||||
|
lora_model_id=get_lora_id()
|
||||||
|
if lora_model_id is None else lora_model_id,
|
||||||
|
rank=rank,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
tensors=tensors,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
embeddings=embeddings,
|
||||||
|
target_embedding_padding=target_embedding_padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAModelManager:
|
||||||
|
"""A manager that manages multiple LoRA-fine-tuned models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
max_num_seqs: int,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
vocab_size: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
||||||
|
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
|
||||||
|
):
|
||||||
|
"""Create a LoRAModelManager and adapter for a given model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: the model to be adapted.
|
||||||
|
max_num_seqs: the maximum number of sequences model can run in a
|
||||||
|
single batch.
|
||||||
|
max_num_batched_tokens: the maximum number of tokens model can run
|
||||||
|
in a single batch.
|
||||||
|
vocab_size: the vocab size of the model.
|
||||||
|
lora_config: the LoRA configuration.
|
||||||
|
lora_target_modules: the target modules patterns to be adapted.
|
||||||
|
Support both single module name and a list of module names.
|
||||||
|
packed_modules_mapping: the mapping for packed modules. vLLM
|
||||||
|
packs some modules into one module, e.g., qkv_proj
|
||||||
|
is packed of q_proj, k_proj, and v_proj. These modules
|
||||||
|
have a single layer in the original model, but they are split
|
||||||
|
into multiple layers in the adapted model.
|
||||||
|
"""
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.max_num_seqs = max_num_seqs
|
||||||
|
assert self.capacity >= self.lora_slots
|
||||||
|
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||||
|
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.base_indices = torch.empty(self.max_num_batched_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cuda")
|
||||||
|
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cuda")
|
||||||
|
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cuda")
|
||||||
|
self.embeddings_indices = torch.empty(2,
|
||||||
|
self.max_num_batched_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cuda")
|
||||||
|
self.offsets = []
|
||||||
|
# 4 is the number of indicies tensors defined above
|
||||||
|
# base_indices, sampler_indices, sampler_indices_padded,
|
||||||
|
# embeddings_indices
|
||||||
|
self.indices_len = [None] * 4
|
||||||
|
|
||||||
|
self.model: nn.Module = model
|
||||||
|
self.lora_target_modules: List[str] = ([
|
||||||
|
lora_target_modules
|
||||||
|
] if isinstance(lora_target_modules, str) else lora_target_modules)
|
||||||
|
self.lora_target_modules = copy.deepcopy(lora_target_modules)
|
||||||
|
self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping)
|
||||||
|
self.packed_modules: Dict[str, List[str]] = {}
|
||||||
|
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
|
||||||
|
self._registered_loras: Dict[int, LoRAModel] = {}
|
||||||
|
# Dict instead of a Set for compatibility with LRUCache.
|
||||||
|
self._active_loras: Dict[int, None] = {}
|
||||||
|
self._last_mapping = None
|
||||||
|
self._create_lora_modules()
|
||||||
|
self.model.lora_manager = self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capacity(self) -> int:
|
||||||
|
return self.lora_config.max_cpu_loras
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lora_slots(self) -> int:
|
||||||
|
return self.lora_config.max_loras
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._registered_loras)
|
||||||
|
|
||||||
|
def activate_lora(
|
||||||
|
self,
|
||||||
|
lora_id: int,
|
||||||
|
) -> bool:
|
||||||
|
"""Move LoRA into a GPU buffer to be used in the forward pass."""
|
||||||
|
if lora_id in self._active_loras:
|
||||||
|
return False
|
||||||
|
first_free_slot = next(
|
||||||
|
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
|
||||||
|
if lora_id is None), None)
|
||||||
|
if first_free_slot is None:
|
||||||
|
raise ValueError("No free lora slots")
|
||||||
|
index, _ = first_free_slot
|
||||||
|
self._active_loras[lora_id] = None
|
||||||
|
lora_model = self._registered_loras[lora_id]
|
||||||
|
logger.debug(
|
||||||
|
f"Activating LoRA. int id: {lora_model.id}, slot index: {index}")
|
||||||
|
self.lora_index_to_id[index] = lora_model.id
|
||||||
|
for module_name, module in self.modules.items():
|
||||||
|
module_lora = lora_model.get_lora(module_name)
|
||||||
|
if module_lora:
|
||||||
|
module_lora.optimize()
|
||||||
|
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
|
||||||
|
module_lora.embeddings_tensor)
|
||||||
|
else:
|
||||||
|
module.reset_lora(index)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _deactivate_lora(self, lora_id: int):
|
||||||
|
try:
|
||||||
|
index = self.lora_index_to_id.index(lora_id)
|
||||||
|
self.lora_index_to_id[index] = None
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def deactivate_lora(self, lora_id: int) -> bool:
|
||||||
|
"""Remove a LoRA from a GPU buffer."""
|
||||||
|
if lora_id in self._active_loras:
|
||||||
|
self._deactivate_lora(lora_id)
|
||||||
|
self._active_loras.pop(lora_id)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _add_lora(self, lora: LoRAModel) -> bool:
|
||||||
|
self._create_merged_loras_inplace(lora)
|
||||||
|
self._registered_loras[lora.id] = lora
|
||||||
|
|
||||||
|
def add_lora(self, lora: LoRAModel) -> bool:
|
||||||
|
"""Add a LoRAModel to the manager CPU cache."""
|
||||||
|
if lora.id not in self._registered_loras:
|
||||||
|
if len(self._registered_loras) >= self.capacity:
|
||||||
|
raise RuntimeError("No free LoRA slots.")
|
||||||
|
self._add_lora(lora)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
"""Remove a LoRAModel from the manager CPU cache."""
|
||||||
|
# TODO: should we check active lora?
|
||||||
|
self.deactivate_lora(lora_id)
|
||||||
|
return bool(self._registered_loras.pop(lora_id, None))
|
||||||
|
|
||||||
|
# TODO see if this can be vectorized
|
||||||
|
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
|
||||||
|
(base_indices, sampler_indices, sampler_indices_padded,
|
||||||
|
embeddings_indices,
|
||||||
|
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
|
||||||
|
self.lora_slots + 1, self.vocab_size,
|
||||||
|
self.lora_config.lora_extra_vocab_size)
|
||||||
|
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
|
||||||
|
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
||||||
|
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
|
||||||
|
sampler_indices_padded)
|
||||||
|
self.embeddings_indices[:embeddings_indices.
|
||||||
|
shape[0], :embeddings_indices.shape[1]].copy_(
|
||||||
|
embeddings_indices)
|
||||||
|
# Maintain the reference
|
||||||
|
self.indices_len[:] = indices_len
|
||||||
|
|
||||||
|
def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
|
||||||
|
if self._last_mapping != lora_mapping:
|
||||||
|
self._set_lora_mapping(lora_mapping)
|
||||||
|
self._last_mapping = lora_mapping
|
||||||
|
|
||||||
|
def list_loras(self) -> Dict[int, LoRAModel]:
|
||||||
|
"""List all registered LoRAModels."""
|
||||||
|
return dict(self._registered_loras)
|
||||||
|
|
||||||
|
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
|
||||||
|
return self._registered_loras.get(lora_id, None)
|
||||||
|
|
||||||
|
def remove_all_loras(self) -> bool:
|
||||||
|
"""Remove all LoRAModels from the manager."""
|
||||||
|
self._registered_loras.clear()
|
||||||
|
self.lora_index_to_id = [None] * self.lora_slots
|
||||||
|
self._active_loras.clear()
|
||||||
|
|
||||||
|
def _create_lora_modules(self):
|
||||||
|
for module_name, module in self.model.named_modules():
|
||||||
|
if not self._match_target_modules(module_name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_module = replace_submodule(
|
||||||
|
self.model, module_name,
|
||||||
|
from_layer(module, self.lora_slots, self.lora_config,
|
||||||
|
self.model.config))
|
||||||
|
# (yard1): TODO make this more robust
|
||||||
|
if "lm_head" in module_name:
|
||||||
|
sampler_module = self.model.get_submodule("sampler")
|
||||||
|
new_module = replace_submodule(
|
||||||
|
self.model, "sampler",
|
||||||
|
from_layer_sampler(sampler_module, module, self.lora_slots,
|
||||||
|
self.lora_config, self.model.config))
|
||||||
|
self.register_module(module_name, new_module)
|
||||||
|
self._register_packed_modules(module_name)
|
||||||
|
new_module.set_mapping(self.base_indices, self.sampler_indices,
|
||||||
|
self.sampler_indices_padded,
|
||||||
|
self.embeddings_indices, self.indices_len)
|
||||||
|
|
||||||
|
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||||
|
assert isinstance(module, BaseLayerWithLoRA)
|
||||||
|
self.modules[module_name] = module
|
||||||
|
|
||||||
|
def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel:
|
||||||
|
"""Create zero-initialized LoRAModel for warmup."""
|
||||||
|
model = LoRAModel(lora_id, rank, {})
|
||||||
|
for module_name, module in self.model.named_modules():
|
||||||
|
if not self._match_target_modules(module_name) or not isinstance(
|
||||||
|
module, BaseLayerWithLoRA):
|
||||||
|
continue
|
||||||
|
parts = module_name.split(".")
|
||||||
|
if module_name not in self.packed_modules:
|
||||||
|
if parts[-1] in EMBEDDING_MODULES:
|
||||||
|
input_dim = (module.base_layer.org_vocab_size +
|
||||||
|
self.lora_config.lora_extra_vocab_size if
|
||||||
|
hasattr(module.base_layer, "org_vocab_size")
|
||||||
|
else module.base_layer.weight.shape[1])
|
||||||
|
output_dim = module.base_layer.embedding_dim if hasattr(
|
||||||
|
module.base_layer,
|
||||||
|
"embedding_dim") else module.base_layer.weight.shape[0]
|
||||||
|
embeddings_tensor_dim = (module.base_layer.embedding_dim if
|
||||||
|
hasattr(module.base_layer,
|
||||||
|
"embedding_dim") else
|
||||||
|
module.base_layer.weight.shape[1])
|
||||||
|
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||||
|
module_name,
|
||||||
|
input_dim,
|
||||||
|
output_dim,
|
||||||
|
rank,
|
||||||
|
module.lora_a_stacked.dtype,
|
||||||
|
"cpu",
|
||||||
|
embeddings_tensor_dim=embeddings_tensor_dim)
|
||||||
|
else:
|
||||||
|
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||||
|
module_name,
|
||||||
|
module.lora_a_stacked.shape[-1],
|
||||||
|
module.lora_b_stacked.shape[-2],
|
||||||
|
rank,
|
||||||
|
module.lora_a_stacked.dtype,
|
||||||
|
"cpu",
|
||||||
|
)
|
||||||
|
lora.optimize()
|
||||||
|
else:
|
||||||
|
parts = module_name.split(".")
|
||||||
|
replacements = self.packed_modules_mapping[parts[-1]]
|
||||||
|
subloras = []
|
||||||
|
for i, r in enumerate(replacements):
|
||||||
|
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||||
|
module_name + "." + r,
|
||||||
|
module.lora_a_stacked[i].shape[-1],
|
||||||
|
module.lora_b_stacked[i].shape[-2],
|
||||||
|
rank,
|
||||||
|
module.lora_a_stacked[i].dtype,
|
||||||
|
"cpu",
|
||||||
|
)
|
||||||
|
lora.optimize()
|
||||||
|
subloras.append(lora)
|
||||||
|
lora = PackedLoRALayerWeights.pack(subloras)
|
||||||
|
model.loras[module_name] = lora
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _match_target_modules(self, module_name: str):
|
||||||
|
return any(
|
||||||
|
re.match(
|
||||||
|
r".*\.{target_module}$".format(target_module=target_module),
|
||||||
|
module_name) or target_module == module_name
|
||||||
|
for target_module in self.lora_target_modules)
|
||||||
|
|
||||||
|
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||||
|
parts = module_full_name.split(".")
|
||||||
|
module_name = parts[-1]
|
||||||
|
replacements = self.packed_modules_mapping.get(module_name)
|
||||||
|
if not replacements:
|
||||||
|
return
|
||||||
|
prefix = ".".join(parts[:-1])
|
||||||
|
self.packed_modules[module_full_name] = [
|
||||||
|
prefix + "." + r if prefix else r for r in replacements
|
||||||
|
]
|
||||||
|
|
||||||
|
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
||||||
|
for module_name, new_module_names in self.packed_modules.items():
|
||||||
|
replacement_loras = []
|
||||||
|
has_replacement = False
|
||||||
|
for r in new_module_names:
|
||||||
|
lora = lora_model.get_lora(r)
|
||||||
|
replacement_loras.append(lora)
|
||||||
|
if lora:
|
||||||
|
has_replacement = True
|
||||||
|
if not has_replacement:
|
||||||
|
continue
|
||||||
|
for i in range(len(replacement_loras)):
|
||||||
|
if replacement_loras[i]:
|
||||||
|
continue
|
||||||
|
replacement_loras[i] = None
|
||||||
|
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||||
|
replacement_loras)
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALRUCache(LRUCache):
|
||||||
|
|
||||||
|
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
|
||||||
|
None]):
|
||||||
|
super().__init__(capacity)
|
||||||
|
self.deactivate_lora_fn = deactivate_lora_fn
|
||||||
|
|
||||||
|
def _on_remove(self, key: Hashable, value: Any):
|
||||||
|
logger.debug(f"Removing LoRA. int id: {key}")
|
||||||
|
self.deactivate_lora_fn(key)
|
||||||
|
return super()._on_remove(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||||
|
"""A model manager that manages multiple LoRAs with LRU cache."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
max_num_seqs: int,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
vocab_size: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
||||||
|
packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG,
|
||||||
|
):
|
||||||
|
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||||
|
vocab_size, lora_config, lora_target_modules,
|
||||||
|
packed_modules_mapping)
|
||||||
|
self._registered_loras: LoRALRUCache = LoRALRUCache(
|
||||||
|
self.capacity, self.deactivate_lora)
|
||||||
|
self._active_loras: LoRALRUCache = LoRALRUCache(
|
||||||
|
self.lora_slots, self._deactivate_lora)
|
||||||
|
|
||||||
|
def list_loras(self) -> Dict[int, LoRAModel]:
|
||||||
|
"""List all registered LoRAModels."""
|
||||||
|
return dict(self._registered_loras.cache)
|
||||||
|
|
||||||
|
def add_lora(self, lora: LoRAModel) -> bool:
|
||||||
|
"""Add a LoRAModel to the manager."""
|
||||||
|
if lora.id not in self._registered_loras:
|
||||||
|
self._add_lora(lora)
|
||||||
|
was_added = True
|
||||||
|
else:
|
||||||
|
# We always touch to update the LRU cache order
|
||||||
|
self._registered_loras.touch(lora.id)
|
||||||
|
was_added = False
|
||||||
|
return was_added
|
||||||
|
|
||||||
|
def activate_lora(
|
||||||
|
self,
|
||||||
|
lora_id: int,
|
||||||
|
) -> bool:
|
||||||
|
if lora_id not in self._active_loras and len(
|
||||||
|
self._active_loras) >= self.lora_slots:
|
||||||
|
self._active_loras.remove_oldest()
|
||||||
|
result = super().activate_lora(lora_id)
|
||||||
|
# We always touch to update the LRU cache order
|
||||||
|
self._active_loras.touch(lora_id)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def remove_oldest_lora(self) -> bool:
|
||||||
|
if len(self._registered_loras) > 0:
|
||||||
|
self._registered_loras.remove_oldest()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def create_lora_manager(
|
||||||
|
model: nn.Module,
|
||||||
|
max_num_seqs: int,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
vocab_size: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
||||||
|
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
|
||||||
|
**kwargs) -> LoRAModelManager:
|
||||||
|
"""Create a LoRA adapter for a given model."""
|
||||||
|
if not getattr(model, "supports_lora", False):
|
||||||
|
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
|
||||||
|
lora_manager = lora_manager_cls(
|
||||||
|
model=model,
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
lora_config=lora_config,
|
||||||
|
lora_target_modules=target_modules,
|
||||||
|
**kwargs)
|
||||||
|
return lora_manager
|
173
vllm/lora/punica.py
Normal file
173
vllm/lora/punica.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
# Based on code from https://github.com/punica-ai/punica
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import_exc = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import vllm._punica_C as punica_kernels
|
||||||
|
except ImportError as e:
|
||||||
|
import_exc = e
|
||||||
|
|
||||||
|
if import_exc is None:
|
||||||
|
|
||||||
|
def bgmv(
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
indicies: torch.LongTensor,
|
||||||
|
layer_idx: int,
|
||||||
|
scale: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Semantics:
|
||||||
|
y[i] += (
|
||||||
|
x[i].unsqueeze(0)
|
||||||
|
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
|
* scale
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||||
|
x: Shape: `[B, H1]`. Input vectors.
|
||||||
|
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
|
||||||
|
matrices.
|
||||||
|
indicies: Shape: `[B]`. Indices of the weight matrices.
|
||||||
|
layer_idx: Layer index of the weight matrices.
|
||||||
|
scale: Scaling factor.
|
||||||
|
"""
|
||||||
|
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
|
||||||
|
|
||||||
|
def add_lora(y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
wa_t_all: torch.Tensor,
|
||||||
|
wb_t_all: torch.Tensor,
|
||||||
|
indicies: torch.LongTensor,
|
||||||
|
layer_idx: int,
|
||||||
|
scale: float,
|
||||||
|
*,
|
||||||
|
buffer: Optional[torch.Tensor] = None):
|
||||||
|
"""
|
||||||
|
Semantics:
|
||||||
|
y[i] += (
|
||||||
|
x[i].unsqueeze(0)
|
||||||
|
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
|
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
|
* scale
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||||
|
x: Shape: `[B, H1]`. Input vectors.
|
||||||
|
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
||||||
|
LoRA A matrices.
|
||||||
|
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
||||||
|
LoRA B matrices.
|
||||||
|
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||||
|
layer_idx: Layer index of LoRA weights.
|
||||||
|
scale: Scaling factor.
|
||||||
|
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
|
||||||
|
"""
|
||||||
|
r = wb_t_all.size(-1)
|
||||||
|
if buffer is None:
|
||||||
|
# We set the buffer to be float32 by default to avoid
|
||||||
|
# numerical innacuracies that would otherwise happen
|
||||||
|
# due to downcasting.
|
||||||
|
buffer = torch.zeros((x.size(0), r),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device)
|
||||||
|
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx,
|
||||||
|
1.0)
|
||||||
|
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
|
||||||
|
scale)
|
||||||
|
|
||||||
|
def add_lora_slice(y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
wa_t_all: torch.Tensor,
|
||||||
|
wb_t_all: torch.Tensor,
|
||||||
|
indicies: torch.LongTensor,
|
||||||
|
layer_idx: int,
|
||||||
|
scale: float,
|
||||||
|
y_offset: int,
|
||||||
|
y_slice_size: int,
|
||||||
|
*,
|
||||||
|
buffer: Optional[torch.Tensor] = None):
|
||||||
|
"""
|
||||||
|
Same as `add_lora` but you can operate on slices of y.
|
||||||
|
Pass whole y, define y_offset and y_slice_size.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
y[i] += (
|
||||||
|
x[i].unsqueeze(0)
|
||||||
|
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
|
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
|
* scale
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||||
|
x: Shape: `[B, H1]`. Input vectors.
|
||||||
|
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
||||||
|
LoRA A matrices.
|
||||||
|
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
||||||
|
LoRA B matrices.
|
||||||
|
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||||
|
layer_idx: Layer index of LoRA weights.
|
||||||
|
scale: Scaling factor.
|
||||||
|
y_offset: Offset to apply to the starting column of y.
|
||||||
|
y_slice_size: Size of the y column slice.
|
||||||
|
"""
|
||||||
|
r = wb_t_all.size(-1)
|
||||||
|
if buffer is None:
|
||||||
|
# We set the buffer to be float32 by default to avoid
|
||||||
|
# numerical inaccuracies that would otherwise happen
|
||||||
|
# due to downcasting.
|
||||||
|
buffer = torch.zeros((x.size(0), r),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device)
|
||||||
|
punica_kernels.dispatch_bgmv_low_level(
|
||||||
|
buffer,
|
||||||
|
x,
|
||||||
|
wa_t_all,
|
||||||
|
indicies,
|
||||||
|
layer_idx,
|
||||||
|
1.0,
|
||||||
|
x.size(1),
|
||||||
|
buffer.size(1),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
punica_kernels.dispatch_bgmv_low_level(
|
||||||
|
y,
|
||||||
|
buffer,
|
||||||
|
wb_t_all,
|
||||||
|
indicies,
|
||||||
|
layer_idx,
|
||||||
|
scale,
|
||||||
|
buffer.size(1),
|
||||||
|
y_slice_size,
|
||||||
|
y_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def _raise_exc(
|
||||||
|
*args, # pylint: disable=unused-argument
|
||||||
|
**kwargs # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if torch.cuda.get_device_capability() < (8, 0):
|
||||||
|
raise ImportError(
|
||||||
|
"LoRA kernels require compute capability>=8.0") from import_exc
|
||||||
|
else:
|
||||||
|
raise import_exc
|
||||||
|
|
||||||
|
bgmv = _raise_exc
|
||||||
|
add_lora = _raise_exc
|
||||||
|
add_lora_slice = _raise_exc
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"bgmv",
|
||||||
|
"add_lora",
|
||||||
|
"add_lora_slice",
|
||||||
|
]
|
32
vllm/lora/request.py
Normal file
32
vllm/lora/request.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRARequest:
|
||||||
|
"""
|
||||||
|
Request for a LoRA adapter.
|
||||||
|
|
||||||
|
Note that this class should be be used internally. For online
|
||||||
|
serving, it is recommended to not allow users to use this class but
|
||||||
|
instead provide another layer of abstraction to prevent users from
|
||||||
|
accessing unauthorized LoRA adapters.
|
||||||
|
|
||||||
|
lora_int_id must be globally unique for a given adapter.
|
||||||
|
This is currently not enforced in vLLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
lora_name: str
|
||||||
|
lora_int_id: int
|
||||||
|
lora_local_path: str
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.lora_int_id < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"lora_int_id must be > 0, got {self.lora_int_id}")
|
||||||
|
|
||||||
|
def __eq__(self, value: object) -> bool:
|
||||||
|
return isinstance(
|
||||||
|
value, LoRARequest) and self.lora_int_id == value.lora_int_id
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return self.lora_int_id
|
39
vllm/lora/utils.py
Normal file
39
vllm/lora/utils.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_submodule(model: nn.Module, module_name: str,
|
||||||
|
new_module: nn.Module) -> nn.Module:
|
||||||
|
"""Replace a submodule in a model with a new module."""
|
||||||
|
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
||||||
|
target_name = module_name.split(".")[-1]
|
||||||
|
setattr(parent, target_name, new_module)
|
||||||
|
return new_module
|
||||||
|
|
||||||
|
|
||||||
|
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
||||||
|
"""Parse the name of lora weights.
|
||||||
|
|
||||||
|
args:
|
||||||
|
name: the name of the fine-tuned LoRA, e.g.
|
||||||
|
base_model.model.dense1.weight
|
||||||
|
return:
|
||||||
|
Tuple(module_name, is_lora_a):
|
||||||
|
module_name: the name of the module, e.g. model.dense1,
|
||||||
|
is_lora_a whether the tensor is lora_a or lora_b.
|
||||||
|
"""
|
||||||
|
parts = name.split(".")
|
||||||
|
assert parts[0] == "base_model"
|
||||||
|
assert parts[1] == "model"
|
||||||
|
if parts[-1] == "weight":
|
||||||
|
assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
|
||||||
|
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
|
||||||
|
|
||||||
|
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
||||||
|
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
||||||
|
|
||||||
|
raise ValueError(f"{name} is unsupported format")
|
237
vllm/lora/worker_manager.py
Normal file
237
vllm/lora/worker_manager.py
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod, abstractproperty
|
||||||
|
from typing import Any, List, Optional, Set, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager,
|
||||||
|
LRUCacheLoRAModelManager, create_lora_manager)
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.lora.layers import LoRAMapping
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerLoRAManager(ABC):
|
||||||
|
"""Abstract class for managing LoRA models on the worker side."""
|
||||||
|
|
||||||
|
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
|
||||||
|
vocab_size: int, lora_config: LoRAConfig,
|
||||||
|
device: torch.device):
|
||||||
|
self.max_num_seqs = max_num_seqs
|
||||||
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.device = device
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
@abstractproperty
|
||||||
|
def is_enabled(self) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_lora_manager(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
||||||
|
) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_active_loras(self, lora_requests: List[LoRARequest],
|
||||||
|
lora_mapping: LoRAMapping) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_all_loras(self) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_loras(self) -> Set[int]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerLoRAManager(WorkerLoRAManager):
|
||||||
|
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
||||||
|
|
||||||
|
Every request, the requested LoRAs will be loaded (unless they are already
|
||||||
|
loaded), and every other LoRA will be unloaded."""
|
||||||
|
|
||||||
|
_lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_num_seqs: int,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
vocab_size: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
device: torch.device,
|
||||||
|
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
||||||
|
):
|
||||||
|
self._lora_manager: Optional[LoRAModelManager] = None
|
||||||
|
self._lora_model_cls = lora_model_cls
|
||||||
|
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
|
||||||
|
lora_config, device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_enabled(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def create_lora_manager(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
||||||
|
) -> Any:
|
||||||
|
lora_manager = create_lora_manager(
|
||||||
|
model,
|
||||||
|
max_num_seqs=self.max_num_seqs,
|
||||||
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
|
target_modules=target_modules,
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
lora_manager_cls=self._lora_manager_cls,
|
||||||
|
)
|
||||||
|
self._lora_manager: LoRAModelManager = lora_manager
|
||||||
|
return lora_manager.model
|
||||||
|
|
||||||
|
def set_active_loras(self, lora_requests: List[LoRARequest],
|
||||||
|
lora_mapping: LoRAMapping) -> None:
|
||||||
|
self._apply_loras(lora_requests)
|
||||||
|
self._lora_manager.set_lora_mapping(lora_mapping)
|
||||||
|
|
||||||
|
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
|
||||||
|
loras_that_exist = self.list_loras()
|
||||||
|
loras_map = {
|
||||||
|
lora_request.lora_int_id: lora_request
|
||||||
|
for lora_request in lora_requests if lora_request
|
||||||
|
}
|
||||||
|
if len(loras_map) > self._lora_manager.lora_slots:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Number of requested LoRAs ({len(loras_map)}) is greater "
|
||||||
|
"than the number of GPU LoRA slots "
|
||||||
|
f"({self._lora_manager.lora_slots}).")
|
||||||
|
|
||||||
|
new_loras = set(loras_map)
|
||||||
|
loras_to_add = new_loras - loras_that_exist
|
||||||
|
loras_to_remove = loras_that_exist - new_loras
|
||||||
|
|
||||||
|
for lora_id in loras_to_remove:
|
||||||
|
self.remove_lora(lora_id)
|
||||||
|
|
||||||
|
for lora_id in loras_to_add:
|
||||||
|
self.add_lora(loras_map[lora_id])
|
||||||
|
|
||||||
|
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
|
||||||
|
try:
|
||||||
|
lora = self._lora_model_cls.from_local_checkpoint(
|
||||||
|
lora_request.lora_local_path,
|
||||||
|
lora_model_id=lora_request.lora_int_id,
|
||||||
|
device="cpu",
|
||||||
|
dtype=self.lora_config.lora_dtype,
|
||||||
|
target_embedding_padding=self.vocab_size +
|
||||||
|
self.lora_config.lora_extra_vocab_size,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Loading lora {lora_request.lora_local_path} failed") from e
|
||||||
|
if lora.rank > self.lora_config.max_lora_rank:
|
||||||
|
raise ValueError(
|
||||||
|
f"LoRA rank {lora.rank} is greater than max_lora_rank "
|
||||||
|
f"{self.lora_config.max_lora_rank}.")
|
||||||
|
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"LoRA added vocab size {lora.extra_vocab_size} is greater than "
|
||||||
|
f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}."
|
||||||
|
)
|
||||||
|
return lora
|
||||||
|
|
||||||
|
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
||||||
|
if lora_request.lora_int_id in self.list_loras():
|
||||||
|
return False
|
||||||
|
return self._lora_manager.add_lora(
|
||||||
|
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
|
||||||
|
rank))
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
if lora_request.lora_int_id in self.list_loras():
|
||||||
|
return False
|
||||||
|
lora = self._load_lora(lora_request)
|
||||||
|
loaded = self._lora_manager.add_lora(lora)
|
||||||
|
self._lora_manager.activate_lora(lora.id)
|
||||||
|
return loaded
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
return self._lora_manager.remove_lora(lora_id)
|
||||||
|
|
||||||
|
def remove_all_loras(self) -> bool:
|
||||||
|
self._lora_manager.remove_all_loras()
|
||||||
|
|
||||||
|
def list_loras(self) -> Set[int]:
|
||||||
|
return set(self._lora_manager.list_loras())
|
||||||
|
|
||||||
|
|
||||||
|
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||||
|
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
||||||
|
|
||||||
|
Uses an LRU Cache. Every request, the requested LoRAs will be loaded
|
||||||
|
(unless they are already loaded) and least recently used LoRAs will
|
||||||
|
be unloaded if the cache is above capacity."""
|
||||||
|
|
||||||
|
_lora_manager_cls: Type[
|
||||||
|
LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
|
||||||
|
|
||||||
|
def create_lora_manager(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
target_modules: Union[str, List[str]] = TARGET_MODULES_QKV,
|
||||||
|
) -> Any:
|
||||||
|
lora_manager = create_lora_manager(
|
||||||
|
model,
|
||||||
|
target_modules=target_modules,
|
||||||
|
lora_manager_cls=self._lora_manager_cls,
|
||||||
|
max_num_seqs=self.max_num_seqs,
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
|
)
|
||||||
|
self._lora_manager: LRUCacheLoRAModelManager = lora_manager
|
||||||
|
return lora_manager.model
|
||||||
|
|
||||||
|
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
|
||||||
|
loras_map = {
|
||||||
|
lora_request.lora_int_id: lora_request
|
||||||
|
for lora_request in lora_requests if lora_request
|
||||||
|
}
|
||||||
|
if len(loras_map) > self._lora_manager.lora_slots:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Number of requested LoRAs ({len(loras_map)}) is greater "
|
||||||
|
"than the number of GPU LoRA slots "
|
||||||
|
f"({self._lora_manager.lora_slots}).")
|
||||||
|
for lora in loras_map.values():
|
||||||
|
self.add_lora(lora)
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
if lora_request.lora_int_id not in self.list_loras():
|
||||||
|
# Remove before we load the new lora to save memory
|
||||||
|
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
|
||||||
|
self._lora_manager.remove_oldest_lora()
|
||||||
|
lora = self._load_lora(lora_request)
|
||||||
|
loaded = self._lora_manager.add_lora(lora)
|
||||||
|
else:
|
||||||
|
# If the lora is already loaded, just touch it to
|
||||||
|
# update its position in the caches
|
||||||
|
loaded = self._lora_manager.get_lora(lora_request.lora_int_id)
|
||||||
|
self._lora_manager.activate_lora(lora_request.lora_int_id)
|
||||||
|
return loaded
|
@ -27,9 +27,25 @@ class Sampler(nn.Module):
|
|||||||
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vocab_size: int) -> None:
|
def __init__(self,
|
||||||
|
vocab_size: int,
|
||||||
|
org_vocab_size: Optional[int] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
# original vocabulary size (without LoRA).
|
||||||
|
self.org_vocab_size = org_vocab_size or vocab_size
|
||||||
|
|
||||||
|
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||||
|
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
# Get the logits for the next tokens.
|
||||||
|
logits = torch.matmul(hidden_states, embedding.t())
|
||||||
|
if embedding_bias is not None:
|
||||||
|
logits += embedding_bias
|
||||||
|
logits = tensor_model_parallel_gather(logits)
|
||||||
|
# Remove paddings in vocab (if any).
|
||||||
|
if logits is not None:
|
||||||
|
logits = logits[:, :self.org_vocab_size]
|
||||||
|
return logits
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -42,8 +58,7 @@ class Sampler(nn.Module):
|
|||||||
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
|
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
|
||||||
|
|
||||||
# Get the logits for the next tokens.
|
# Get the logits for the next tokens.
|
||||||
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
||||||
self.vocab_size)
|
|
||||||
|
|
||||||
# Only perform sampling in the driver worker.
|
# Only perform sampling in the driver worker.
|
||||||
# Note: `_get_logits` is still distributed across TP workers because
|
# Note: `_get_logits` is still distributed across TP workers because
|
||||||
@ -98,20 +113,6 @@ class Sampler(nn.Module):
|
|||||||
prompt_logprobs, sample_logprobs)
|
prompt_logprobs, sample_logprobs)
|
||||||
|
|
||||||
|
|
||||||
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
|
|
||||||
embedding_bias: Optional[torch.Tensor],
|
|
||||||
vocab_size: int) -> Optional[torch.Tensor]:
|
|
||||||
# Get the logits for the next tokens.
|
|
||||||
logits = torch.matmul(hidden_states, embedding.t())
|
|
||||||
if embedding_bias is not None:
|
|
||||||
logits += embedding_bias
|
|
||||||
logits = tensor_model_parallel_gather(logits)
|
|
||||||
# Remove paddings in vocab (if any).
|
|
||||||
if logits is not None:
|
|
||||||
logits = logits[:, :vocab_size]
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def _prune_hidden_states(
|
def _prune_hidden_states(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
|
@ -13,8 +13,11 @@ from vllm.model_executor.parallel_utils.communication_op import (
|
|||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||||
|
|
||||||
def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
|
|
||||||
|
def pad_vocab_size(vocab_size: int,
|
||||||
|
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||||
"""Pad the vocab size to the given value."""
|
"""Pad the vocab size to the given value."""
|
||||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||||
|
|
||||||
@ -43,17 +46,23 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
num_embeddings: vocabulary size.
|
num_embeddings: vocabulary size.
|
||||||
embedding_dim: size of hidden state.
|
embedding_dim: size of hidden state.
|
||||||
params_dtype: type of the parameters.
|
params_dtype: type of the parameters.
|
||||||
|
org_num_embeddings: original vocabulary size (without LoRA).
|
||||||
|
padding_size: padding size for the vocabulary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_embeddings: int,
|
num_embeddings: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
params_dtype: Optional[torch.dtype] = None):
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
org_num_embeddings: Optional[int] = None,
|
||||||
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Keep the input dimensions.
|
# Keep the input dimensions.
|
||||||
self.num_embeddings = num_embeddings
|
self.num_embeddings = num_embeddings
|
||||||
self.num_embeddings_padded = pad_vocab_size(num_embeddings)
|
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||||
|
self.num_embeddings_padded = pad_vocab_size(num_embeddings,
|
||||||
|
padding_size)
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
@ -77,7 +86,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
parallel_dim = param.parallel_dim
|
parallel_dim = param.parallel_dim
|
||||||
assert loaded_weight.shape[parallel_dim] == self.num_embeddings
|
assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
|
||||||
loaded_weight = loaded_weight[self.vocab_start_index:self.
|
loaded_weight = loaded_weight[self.vocab_start_index:self.
|
||||||
vocab_end_index]
|
vocab_end_index]
|
||||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||||
@ -114,14 +123,19 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|||||||
embedding_dim: size of hidden state.
|
embedding_dim: size of hidden state.
|
||||||
bias: whether to use bias.
|
bias: whether to use bias.
|
||||||
params_dtype: type of the parameters.
|
params_dtype: type of the parameters.
|
||||||
|
org_num_embeddings: original vocabulary size (without LoRA).
|
||||||
|
padding_size: padding size for the vocabulary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_embeddings: int,
|
num_embeddings: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
params_dtype: Optional[torch.dtype] = None):
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
super().__init__(num_embeddings, embedding_dim, params_dtype)
|
org_num_embeddings: Optional[int] = None,
|
||||||
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
|
||||||
|
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||||
|
org_num_embeddings, padding_size)
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = Parameter(
|
self.bias = Parameter(
|
||||||
torch.empty(self.num_embeddings_per_partition,
|
torch.empty(self.num_embeddings_per_partition,
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
"""Utilities for selecting and loading models."""
|
"""Utilities for selecting and loading models."""
|
||||||
import contextlib
|
import contextlib
|
||||||
from typing import Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig, LoRAConfig
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||||
initialize_dummy_weights)
|
initialize_dummy_weights)
|
||||||
@ -32,7 +32,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
|||||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
def get_model(model_config: ModelConfig,
|
||||||
|
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
|
||||||
model_class = _get_model_architecture(model_config.hf_config)
|
model_class = _get_model_architecture(model_config.hf_config)
|
||||||
|
|
||||||
# Get the (maybe quantized) linear method.
|
# Get the (maybe quantized) linear method.
|
||||||
@ -62,7 +63,17 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
|||||||
# Create a model instance.
|
# Create a model instance.
|
||||||
# The weights will be initialized as empty tensors.
|
# The weights will be initialized as empty tensors.
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
model = model_class(model_config.hf_config, linear_method)
|
if getattr(model_class, "supports_lora", False):
|
||||||
|
model = model_class(model_config.hf_config, linear_method,
|
||||||
|
lora_config)
|
||||||
|
elif lora_config:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {model_class.__name__} does not support LoRA, "
|
||||||
|
"but LoRA is enabled. Support for this model may "
|
||||||
|
"be added in the future. If this is important to you, "
|
||||||
|
"please open an issue on github.")
|
||||||
|
else:
|
||||||
|
model = model_class(model_config.hf_config, linear_method)
|
||||||
if model_config.load_format == "dummy":
|
if model_config.load_format == "dummy":
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
|
@ -38,13 +38,14 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -225,14 +226,19 @@ class LlamaModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
self.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
LlamaDecoderLayer(config, linear_method)
|
LlamaDecoderLayer(config, linear_method)
|
||||||
@ -263,18 +269,31 @@ class LlamaModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LlamaForCausalLM(nn.Module):
|
class LlamaForCausalLM(nn.Module):
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = LlamaModel(config, linear_method)
|
self.model = LlamaModel(config, linear_method, lora_config=lora_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
unpadded_vocab_size = config.vocab_size
|
||||||
self.sampler = Sampler(config.vocab_size)
|
if lora_config:
|
||||||
|
unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
# We need bigger padding if using lora for kernel
|
||||||
|
# compatibility
|
||||||
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
|
)
|
||||||
|
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -38,13 +38,14 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
|
|||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding, ParallelLMHead)
|
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.config import LoRAConfig
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -220,15 +221,20 @@ class MistralModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
config: MistralConfig,
|
config: MistralConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
self.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
MistralDecoderLayer(config, linear_method)
|
MistralDecoderLayer(config, linear_method)
|
||||||
@ -259,18 +265,33 @@ class MistralModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MistralForCausalLM(nn.Module):
|
class MistralForCausalLM(nn.Module):
|
||||||
|
supports_lora = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MistralConfig,
|
config: MistralConfig,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.linear_method = linear_method
|
self.linear_method = linear_method
|
||||||
self.model = MistralModel(config, linear_method)
|
self.model = MistralModel(config,
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
linear_method,
|
||||||
self.sampler = Sampler(config.vocab_size)
|
lora_config=lora_config)
|
||||||
|
unpadded_vocab_size = config.vocab_size
|
||||||
|
if lora_config:
|
||||||
|
unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
# We need bigger padding if using lora for kernel
|
||||||
|
# compatibility
|
||||||
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
|
)
|
||||||
|
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -195,10 +195,14 @@ def get_pipeline_model_parallel_prev_rank():
|
|||||||
|
|
||||||
|
|
||||||
def destroy_model_parallel():
|
def destroy_model_parallel():
|
||||||
"""Set the groups to none."""
|
"""Set the groups to none and destroy them."""
|
||||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||||
|
if _TENSOR_MODEL_PARALLEL_GROUP:
|
||||||
|
torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
|
||||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||||
|
if _PIPELINE_MODEL_PARALLEL_GROUP:
|
||||||
|
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
|
||||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||||
global _PIPELINE_GLOBAL_RANKS
|
global _PIPELINE_GLOBAL_RANKS
|
||||||
_PIPELINE_GLOBAL_RANKS = None
|
_PIPELINE_GLOBAL_RANKS = None
|
||||||
|
@ -2,6 +2,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
|
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
|
||||||
SequenceStatus)
|
SequenceStatus)
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|
||||||
class CompletionOutput:
|
class CompletionOutput:
|
||||||
@ -16,6 +17,7 @@ class CompletionOutput:
|
|||||||
logprobs: The log probabilities of the top probability words at each
|
logprobs: The log probabilities of the top probability words at each
|
||||||
position if the logprobs are requested.
|
position if the logprobs are requested.
|
||||||
finish_reason: The reason why the sequence is finished.
|
finish_reason: The reason why the sequence is finished.
|
||||||
|
lora_request: The LoRA request that was used to generate the output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -26,6 +28,7 @@ class CompletionOutput:
|
|||||||
cumulative_logprob: float,
|
cumulative_logprob: float,
|
||||||
logprobs: Optional[SampleLogprobs],
|
logprobs: Optional[SampleLogprobs],
|
||||||
finish_reason: Optional[str] = None,
|
finish_reason: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.index = index
|
self.index = index
|
||||||
self.text = text
|
self.text = text
|
||||||
@ -33,6 +36,7 @@ class CompletionOutput:
|
|||||||
self.cumulative_logprob = cumulative_logprob
|
self.cumulative_logprob = cumulative_logprob
|
||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
self.finish_reason = finish_reason
|
self.finish_reason = finish_reason
|
||||||
|
self.lora_request = lora_request
|
||||||
|
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
return self.finish_reason is not None
|
return self.finish_reason is not None
|
||||||
@ -56,6 +60,7 @@ class RequestOutput:
|
|||||||
prompt_logprobs: The log probabilities to return per prompt token.
|
prompt_logprobs: The log probabilities to return per prompt token.
|
||||||
outputs: The output sequences of the request.
|
outputs: The output sequences of the request.
|
||||||
finished: Whether the whole request is finished.
|
finished: Whether the whole request is finished.
|
||||||
|
lora_request: The LoRA request that was used to generate the output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -66,6 +71,7 @@ class RequestOutput:
|
|||||||
prompt_logprobs: Optional[PromptLogprobs],
|
prompt_logprobs: Optional[PromptLogprobs],
|
||||||
outputs: List[CompletionOutput],
|
outputs: List[CompletionOutput],
|
||||||
finished: bool,
|
finished: bool,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
@ -73,6 +79,7 @@ class RequestOutput:
|
|||||||
self.prompt_logprobs = prompt_logprobs
|
self.prompt_logprobs = prompt_logprobs
|
||||||
self.outputs = outputs
|
self.outputs = outputs
|
||||||
self.finished = finished
|
self.finished = finished
|
||||||
|
self.lora_request = lora_request
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||||
@ -108,8 +115,13 @@ class RequestOutput:
|
|||||||
prompt_token_ids = seq_group.prompt_token_ids
|
prompt_token_ids = seq_group.prompt_token_ids
|
||||||
prompt_logprobs = seq_group.prompt_logprobs
|
prompt_logprobs = seq_group.prompt_logprobs
|
||||||
finished = seq_group.is_finished()
|
finished = seq_group.is_finished()
|
||||||
return cls(seq_group.request_id, prompt, prompt_token_ids,
|
return cls(seq_group.request_id,
|
||||||
prompt_logprobs, outputs, finished)
|
prompt,
|
||||||
|
prompt_token_ids,
|
||||||
|
prompt_logprobs,
|
||||||
|
outputs,
|
||||||
|
finished,
|
||||||
|
lora_request=seq_group.lora_request)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"RequestOutput(request_id={self.request_id}, "
|
return (f"RequestOutput(request_id={self.request_id}, "
|
||||||
@ -117,4 +129,5 @@ class RequestOutput:
|
|||||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||||
f"outputs={self.outputs}, "
|
f"outputs={self.outputs}, "
|
||||||
f"finished={self.finished})")
|
f"finished={self.finished}, "
|
||||||
|
f"lora_request={self.lora_request})")
|
||||||
|
@ -74,13 +74,14 @@ class PrefixPool:
|
|||||||
new_length = len(token_ids) // self.block_size * self.block_size
|
new_length = len(token_ids) // self.block_size * self.block_size
|
||||||
return tuple(token_ids[:new_length])
|
return tuple(token_ids[:new_length])
|
||||||
|
|
||||||
def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]:
|
def add_or_get_prefix(self, token_ids: Sequence[int],
|
||||||
|
lora_int_id: int) -> Optional[Prefix]:
|
||||||
token_ids = self._truncate_token_ids(token_ids)
|
token_ids = self._truncate_token_ids(token_ids)
|
||||||
if len(token_ids) == 0:
|
if len(token_ids) == 0:
|
||||||
# Prefix is empty.
|
# Prefix is empty.
|
||||||
return None
|
return None
|
||||||
prefix = Prefix(token_ids, self.block_size)
|
prefix = Prefix(token_ids, self.block_size)
|
||||||
prefix_hash = hash(prefix)
|
prefix_hash = hash((prefix, lora_int_id))
|
||||||
if prefix_hash not in self.prefixes:
|
if prefix_hash not in self.prefixes:
|
||||||
self.prefixes[prefix_hash] = prefix
|
self.prefixes[prefix_hash] = prefix
|
||||||
return self.prefixes[prefix_hash]
|
return self.prefixes[prefix_hash]
|
||||||
|
@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Union
|
|||||||
from vllm.block import LogicalTokenBlock
|
from vllm.block import LogicalTokenBlock
|
||||||
from vllm.prefix import Prefix
|
from vllm.prefix import Prefix
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
PromptLogprobs = List[Optional[Dict[int, float]]]
|
PromptLogprobs = List[Optional[Dict[int, float]]]
|
||||||
SampleLogprobs = List[Dict[int, float]]
|
SampleLogprobs = List[Dict[int, float]]
|
||||||
@ -106,6 +107,7 @@ class Sequence:
|
|||||||
prompt_token_ids: The token IDs of the prompt.
|
prompt_token_ids: The token IDs of the prompt.
|
||||||
block_size: The block size of the sequence. Should be the same as the
|
block_size: The block size of the sequence. Should be the same as the
|
||||||
block size used by the block manager and cache engine.
|
block size used by the block manager and cache engine.
|
||||||
|
lora_request: LoRA request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -114,10 +116,12 @@ class Sequence:
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_id = seq_id
|
self.seq_id = seq_id
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
self.lora_request = lora_request
|
||||||
|
|
||||||
self.data = SequenceData(prompt_token_ids)
|
self.data = SequenceData(prompt_token_ids)
|
||||||
self.output_logprobs: SampleLogprobs = []
|
self.output_logprobs: SampleLogprobs = []
|
||||||
@ -134,6 +138,10 @@ class Sequence:
|
|||||||
# Input + output tokens
|
# Input + output tokens
|
||||||
self.tokens: Optional[List[str]] = None
|
self.tokens: Optional[List[str]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lora_int_id(self) -> int:
|
||||||
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||||
|
|
||||||
def _append_logical_block(self) -> None:
|
def _append_logical_block(self) -> None:
|
||||||
block = LogicalTokenBlock(
|
block = LogicalTokenBlock(
|
||||||
block_number=len(self.logical_token_blocks),
|
block_number=len(self.logical_token_blocks),
|
||||||
@ -229,6 +237,7 @@ class SequenceGroup:
|
|||||||
seqs: The list of sequences.
|
seqs: The list of sequences.
|
||||||
sampling_params: The sampling parameters used to generate the outputs.
|
sampling_params: The sampling parameters used to generate the outputs.
|
||||||
arrival_time: The arrival time of the request.
|
arrival_time: The arrival time of the request.
|
||||||
|
lora_request: LoRA request.
|
||||||
prefix: The prefix of the prompt of the sequence group.
|
prefix: The prefix of the prompt of the sequence group.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -238,12 +247,14 @@ class SequenceGroup:
|
|||||||
seqs: List[Sequence],
|
seqs: List[Sequence],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix: Optional[Prefix] = None,
|
prefix: Optional[Prefix] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.arrival_time = arrival_time
|
self.arrival_time = arrival_time
|
||||||
|
self.lora_request = lora_request
|
||||||
self.prefix: Optional[Prefix] = prefix
|
self.prefix: Optional[Prefix] = prefix
|
||||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||||
|
|
||||||
@ -259,6 +270,10 @@ class SequenceGroup:
|
|||||||
# We use the prompt of an arbitrary sequence.
|
# We use the prompt of an arbitrary sequence.
|
||||||
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lora_int_id(self) -> int:
|
||||||
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||||
|
|
||||||
def get_max_num_running_seqs(self) -> int:
|
def get_max_num_running_seqs(self) -> int:
|
||||||
"""The maximum number of sequences running in parallel in the remaining
|
"""The maximum number of sequences running in parallel in the remaining
|
||||||
lifetime of the request."""
|
lifetime of the request."""
|
||||||
@ -338,6 +353,7 @@ class SequenceGroupMetadata:
|
|||||||
sampling_params: The sampling parameters used to generate the outputs.
|
sampling_params: The sampling parameters used to generate the outputs.
|
||||||
block_tables: The block tables. (Seq id -> list of physical block
|
block_tables: The block tables. (Seq id -> list of physical block
|
||||||
numbers)
|
numbers)
|
||||||
|
lora_request: LoRA request.
|
||||||
prefix: The prefix of the prompt of the sequence group.
|
prefix: The prefix of the prompt of the sequence group.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -348,6 +364,7 @@ class SequenceGroupMetadata:
|
|||||||
seq_data: Dict[int, SequenceData],
|
seq_data: Dict[int, SequenceData],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
block_tables: Dict[int, List[int]],
|
block_tables: Dict[int, List[int]],
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix: Optional[Prefix] = None,
|
prefix: Optional[Prefix] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
@ -355,8 +372,13 @@ class SequenceGroupMetadata:
|
|||||||
self.seq_data = seq_data
|
self.seq_data = seq_data
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
|
self.lora_request = lora_request
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lora_int_id(self) -> int:
|
||||||
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||||
|
|
||||||
|
|
||||||
class SequenceOutput:
|
class SequenceOutput:
|
||||||
"""The model output associated with a sequence.
|
"""The model output associated with a sequence.
|
||||||
|
@ -4,6 +4,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
|||||||
PreTrainedTokenizerFast)
|
PreTrainedTokenizerFast)
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.utils import make_async, LRUCache
|
||||||
from vllm.transformers_utils.tokenizers import *
|
from vllm.transformers_utils.tokenizers import *
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -65,6 +67,84 @@ def get_tokenizer(
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_tokenizer(lora_request: LoRARequest, *args,
|
||||||
|
**kwargs) -> Optional[PreTrainedTokenizer]:
|
||||||
|
if lora_request is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
|
||||||
|
**kwargs)
|
||||||
|
except OSError as e:
|
||||||
|
# No tokenizer was found in the LoRA folder,
|
||||||
|
# use base model tokenizer
|
||||||
|
logger.warning(
|
||||||
|
f"No tokenizer found in {lora_request.lora_local_path}, "
|
||||||
|
"using base model tokenizer instead. "
|
||||||
|
f"(Exception: {str(e)})")
|
||||||
|
tokenizer = None
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
get_lora_tokenizer_async = make_async(get_lora_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerGroup:
|
||||||
|
"""A group of tokenizers that can be used for LoRA adapters."""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
|
||||||
|
max_input_length: Optional[int], **tokenizer_config):
|
||||||
|
self.tokenizer_id = tokenizer_id
|
||||||
|
self.tokenizer_config = tokenizer_config
|
||||||
|
self.enable_lora = enable_lora
|
||||||
|
self.max_input_length = max_input_length
|
||||||
|
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
||||||
|
if enable_lora:
|
||||||
|
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
|
||||||
|
else:
|
||||||
|
self.lora_tokenizers = None
|
||||||
|
|
||||||
|
def encode(self,
|
||||||
|
prompt: str,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||||
|
tokenizer = self.get_lora_tokenizer(lora_request)
|
||||||
|
return tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
async def encode_async(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||||
|
tokenizer = await self.get_lora_tokenizer_async(lora_request)
|
||||||
|
return tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
def get_lora_tokenizer(
|
||||||
|
self,
|
||||||
|
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
|
||||||
|
if not lora_request or not self.enable_lora:
|
||||||
|
return self.tokenizer
|
||||||
|
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||||
|
tokenizer = (get_lora_tokenizer(
|
||||||
|
lora_request, **self.tokenizer_config) or self.tokenizer)
|
||||||
|
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
||||||
|
return tokenizer
|
||||||
|
else:
|
||||||
|
return self.lora_tokenizers.get(lora_request.lora_int_id)
|
||||||
|
|
||||||
|
async def get_lora_tokenizer_async(
|
||||||
|
self,
|
||||||
|
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
|
||||||
|
if not lora_request or not self.enable_lora:
|
||||||
|
return self.tokenizer
|
||||||
|
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||||
|
tokenizer = (await get_lora_tokenizer_async(
|
||||||
|
lora_request, **self.tokenizer_config) or self.tokenizer)
|
||||||
|
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
||||||
|
return tokenizer
|
||||||
|
else:
|
||||||
|
return self.lora_tokenizers.get(lora_request.lora_int_id)
|
||||||
|
|
||||||
|
|
||||||
def _convert_tokens_to_string_with_added_encoders(
|
def _convert_tokens_to_string_with_added_encoders(
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
output_tokens: List[str],
|
output_tokens: List[str],
|
||||||
|
@ -7,6 +7,17 @@ from typing import List
|
|||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
import asyncio
|
||||||
|
from functools import partial
|
||||||
|
from typing import (
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Hashable, Optional
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Device(enum.Enum):
|
class Device(enum.Enum):
|
||||||
@ -28,6 +39,69 @@ class Counter:
|
|||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
class LRUCache:
|
||||||
|
|
||||||
|
def __init__(self, capacity: int):
|
||||||
|
self.cache = OrderedDict()
|
||||||
|
self.capacity = capacity
|
||||||
|
|
||||||
|
def __contains__(self, key: Hashable) -> bool:
|
||||||
|
return key in self.cache
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.cache)
|
||||||
|
|
||||||
|
def __getitem__(self, key: Hashable) -> Any:
|
||||||
|
return self.get(key)
|
||||||
|
|
||||||
|
def __setitem__(self, key: Hashable, value: Any) -> None:
|
||||||
|
self.put(key, value)
|
||||||
|
|
||||||
|
def __delitem__(self, key: Hashable) -> None:
|
||||||
|
self.pop(key)
|
||||||
|
|
||||||
|
def touch(self, key: Hashable) -> None:
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
|
||||||
|
def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
|
||||||
|
if key in self.cache:
|
||||||
|
value = self.cache[key]
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
else:
|
||||||
|
value = default_value
|
||||||
|
return value
|
||||||
|
|
||||||
|
def put(self, key: Hashable, value: Any) -> None:
|
||||||
|
self.cache[key] = value
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
self._remove_old_if_needed()
|
||||||
|
|
||||||
|
def _on_remove(self, key: Hashable, value: Any):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def remove_oldest(self):
|
||||||
|
if not self.cache:
|
||||||
|
return
|
||||||
|
key, value = self.cache.popitem(last=False)
|
||||||
|
self._on_remove(key, value)
|
||||||
|
|
||||||
|
def _remove_old_if_needed(self) -> None:
|
||||||
|
while len(self.cache) > self.capacity:
|
||||||
|
self.remove_oldest()
|
||||||
|
|
||||||
|
def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
|
||||||
|
run_on_remove = key in self.cache
|
||||||
|
value = self.cache.pop(key, default_value)
|
||||||
|
if run_on_remove:
|
||||||
|
self._on_remove(key, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
while len(self.cache) > 0:
|
||||||
|
self.remove_oldest()
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
|
||||||
def is_hip() -> bool:
|
def is_hip() -> bool:
|
||||||
return torch.version.hip is not None
|
return torch.version.hip is not None
|
||||||
|
|
||||||
@ -59,6 +133,22 @@ def in_wsl() -> bool:
|
|||||||
return "microsoft" in " ".join(uname()).lower()
|
return "microsoft" in " ".join(uname()).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
|
||||||
|
"""Take a blocking function, and run it on in an executor thread.
|
||||||
|
|
||||||
|
This function prevents the blocking function from blocking the
|
||||||
|
asyncio event loop.
|
||||||
|
The code in this function needs to be thread safe.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _async_wrapper(*args, **kwargs) -> asyncio.Future:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
p_func = partial(func, *args, **kwargs)
|
||||||
|
return loop.run_in_executor(executor=None, func=p_func)
|
||||||
|
|
||||||
|
return _async_wrapper
|
||||||
|
|
||||||
|
|
||||||
def get_ip() -> str:
|
def get_ip() -> str:
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
||||||
|
@ -1,23 +1,27 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Set, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
|
from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
broadcast_tensor_dict)
|
broadcast_tensor_dict)
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||||
|
from vllm.lora.layers import LoRAMapping
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.utils import in_wsl
|
from vllm.utils import in_wsl
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
_PAD_SLOT_ID = -1
|
_PAD_SLOT_ID = -1
|
||||||
|
LORA_WARMUP_RANK = 8
|
||||||
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
||||||
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
|
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
|
||||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||||
@ -30,19 +34,23 @@ class ModelRunner:
|
|||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
):
|
):
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
|
self.lora_config = lora_config
|
||||||
self.is_driver_worker = is_driver_worker
|
self.is_driver_worker = is_driver_worker
|
||||||
|
|
||||||
# model_config can be None in tests/samplers/test_sampler.py.
|
# model_config can be None in tests/samplers/test_sampler.py.
|
||||||
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
|
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
|
||||||
self.sliding_window = (model_config.get_sliding_window()
|
self.sliding_window = (model_config.get_sliding_window()
|
||||||
if model_config is not None else None)
|
if model_config is not None else None)
|
||||||
|
self.device = torch.device(torch.cuda.current_device())
|
||||||
self.model = None
|
self.model = None
|
||||||
self.block_size = None # Set after initial profiling.
|
self.block_size = None # Set after initial profiling.
|
||||||
|
self.lora_manager = None
|
||||||
|
|
||||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||||
self.graph_memory_pool = None # Set during graph capture.
|
self.graph_memory_pool = None # Set during graph capture.
|
||||||
@ -61,7 +69,17 @@ class ModelRunner:
|
|||||||
self.in_wsl = in_wsl()
|
self.in_wsl = in_wsl()
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.model = get_model(self.model_config)
|
self.model = get_model(self.model_config, self.lora_config)
|
||||||
|
|
||||||
|
vocab_size = self.model.config.vocab_size
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
|
self.scheduler_config.max_num_seqs,
|
||||||
|
self.scheduler_config.max_num_batched_tokens +
|
||||||
|
self.scheduler_config.max_paddings, vocab_size,
|
||||||
|
self.lora_config, self.device)
|
||||||
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||||
|
|
||||||
def set_block_size(self, block_size: int) -> None:
|
def set_block_size(self, block_size: int) -> None:
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
@ -74,12 +92,15 @@ class ModelRunner:
|
|||||||
def _prepare_prompt(
|
def _prepare_prompt(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int],
|
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
|
||||||
List[int]]:
|
List[int], List[int], Set[LoRARequest]]:
|
||||||
assert len(seq_group_metadata_list) > 0
|
assert len(seq_group_metadata_list) > 0
|
||||||
input_tokens: List[List[int]] = []
|
input_tokens: List[List[int]] = []
|
||||||
input_positions: List[List[int]] = []
|
input_positions: List[List[int]] = []
|
||||||
slot_mapping: List[List[int]] = []
|
slot_mapping: List[List[int]] = []
|
||||||
|
lora_index_mapping: List[int] = []
|
||||||
|
lora_prompt_mapping: List[int] = []
|
||||||
|
lora_requests: Set[LoRARequest] = set()
|
||||||
|
|
||||||
prompt_lens: List[int] = []
|
prompt_lens: List[int] = []
|
||||||
context_lens: List[int] = []
|
context_lens: List[int] = []
|
||||||
@ -113,6 +134,17 @@ class ModelRunner:
|
|||||||
input_positions.append(
|
input_positions.append(
|
||||||
list(range(prefix_len, prefix_len + len(prompt_tokens))))
|
list(range(prefix_len, prefix_len + len(prompt_tokens))))
|
||||||
|
|
||||||
|
lora_id = seq_group_metadata.lora_int_id
|
||||||
|
|
||||||
|
if lora_id > 0:
|
||||||
|
lora_requests.add(seq_group_metadata.lora_request)
|
||||||
|
|
||||||
|
lora_index_mapping.append([lora_id] * prompt_len)
|
||||||
|
lora_prompt_mapping.extend(
|
||||||
|
[lora_id] *
|
||||||
|
(prompt_len
|
||||||
|
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
||||||
|
|
||||||
if seq_group_metadata.block_tables is None:
|
if seq_group_metadata.block_tables is None:
|
||||||
# During memory profiling, the block tables are not initialized
|
# During memory profiling, the block tables are not initialized
|
||||||
# yet. In this case, we just use a dummy slot mapping.
|
# yet. In this case, we just use a dummy slot mapping.
|
||||||
@ -156,6 +188,10 @@ class ModelRunner:
|
|||||||
max_prompt_len,
|
max_prompt_len,
|
||||||
pad=_PAD_SLOT_ID,
|
pad=_PAD_SLOT_ID,
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
|
lora_index_mapping = [
|
||||||
|
_pad_to_max(mapping, max_prompt_len, pad=0)
|
||||||
|
for mapping in lora_index_mapping
|
||||||
|
]
|
||||||
context_lens_tensor = torch.tensor(context_lens,
|
context_lens_tensor = torch.tensor(context_lens,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device='cuda')
|
device='cuda')
|
||||||
@ -188,23 +224,33 @@ class ModelRunner:
|
|||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
)
|
)
|
||||||
return (input_tokens, input_positions, input_metadata, prompt_lens,
|
return (input_tokens, input_positions, input_metadata, prompt_lens,
|
||||||
subquery_lens)
|
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
||||||
|
lora_requests)
|
||||||
|
|
||||||
def _prepare_decode(
|
def _prepare_decode(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
|
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
|
||||||
|
Set[LoRARequest]]:
|
||||||
assert len(seq_group_metadata_list) > 0
|
assert len(seq_group_metadata_list) > 0
|
||||||
input_tokens: List[List[int]] = []
|
input_tokens: List[List[int]] = []
|
||||||
input_positions: List[List[int]] = []
|
input_positions: List[List[int]] = []
|
||||||
slot_mapping: List[List[int]] = []
|
slot_mapping: List[List[int]] = []
|
||||||
context_lens: List[int] = []
|
context_lens: List[int] = []
|
||||||
block_tables: List[List[int]] = []
|
block_tables: List[List[int]] = []
|
||||||
|
lora_index_mapping: List[int] = []
|
||||||
|
lora_prompt_mapping: List[int] = []
|
||||||
|
lora_requests: Set[LoRARequest] = set()
|
||||||
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
assert not seq_group_metadata.is_prompt
|
assert not seq_group_metadata.is_prompt
|
||||||
|
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
lora_id = seq_group_metadata.lora_int_id
|
||||||
|
|
||||||
|
if lora_id > 0:
|
||||||
|
lora_requests.add(seq_group_metadata.lora_request)
|
||||||
|
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
generation_token = seq_data.get_last_token_id()
|
generation_token = seq_data.get_last_token_id()
|
||||||
@ -223,6 +269,8 @@ class ModelRunner:
|
|||||||
block_offset = position % self.block_size
|
block_offset = position % self.block_size
|
||||||
slot = block_number * self.block_size + block_offset
|
slot = block_number * self.block_size + block_offset
|
||||||
slot_mapping.append([slot])
|
slot_mapping.append([slot])
|
||||||
|
lora_index_mapping.append([lora_id])
|
||||||
|
lora_prompt_mapping.append(lora_id)
|
||||||
|
|
||||||
if self.sliding_window is not None:
|
if self.sliding_window is not None:
|
||||||
sliding_window_blocks = (self.sliding_window //
|
sliding_window_blocks = (self.sliding_window //
|
||||||
@ -287,6 +335,10 @@ class ModelRunner:
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
lora_index_mapping = [
|
||||||
|
_pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
|
||||||
|
]
|
||||||
|
|
||||||
input_metadata = InputMetadata(
|
input_metadata = InputMetadata(
|
||||||
is_prompt=False,
|
is_prompt=False,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
@ -298,7 +350,7 @@ class ModelRunner:
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
use_cuda_graph=use_captured_graph,
|
use_cuda_graph=use_captured_graph,
|
||||||
)
|
)
|
||||||
return input_tokens, input_positions, input_metadata
|
return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests
|
||||||
|
|
||||||
def _prepare_sample(
|
def _prepare_sample(
|
||||||
self,
|
self,
|
||||||
@ -375,7 +427,8 @@ class ModelRunner:
|
|||||||
def prepare_input_tensors(
|
def prepare_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]:
|
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
|
||||||
|
Set[int], LoRAMapping]:
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
# NOTE: We assume that all sequences in the group are all prompts or
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
# all decodes.
|
# all decodes.
|
||||||
@ -383,16 +436,29 @@ class ModelRunner:
|
|||||||
# Prepare input tensors.
|
# Prepare input tensors.
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
(input_tokens, input_positions, input_metadata, prompt_lens,
|
(input_tokens, input_positions, input_metadata, prompt_lens,
|
||||||
subquery_lens) = self._prepare_prompt(seq_group_metadata_list)
|
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
||||||
|
lora_requests) = self._prepare_prompt(seq_group_metadata_list)
|
||||||
else:
|
else:
|
||||||
(input_tokens, input_positions, input_metadata
|
(input_tokens, input_positions, input_metadata,
|
||||||
) = self._prepare_decode(seq_group_metadata_list)
|
lora_index_mapping, lora_prompt_mapping,
|
||||||
subquery_lens = None
|
lora_requests) = self._prepare_decode(seq_group_metadata_list)
|
||||||
prompt_lens = []
|
prompt_lens = []
|
||||||
|
subquery_lens = None
|
||||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||||
prompt_lens,
|
prompt_lens,
|
||||||
subquery_lens)
|
subquery_lens)
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
flat_lora_index_mapping = [
|
||||||
|
item for sublist in lora_index_mapping for item in sublist
|
||||||
|
]
|
||||||
|
lora_mapping = LoRAMapping(
|
||||||
|
flat_lora_index_mapping,
|
||||||
|
lora_prompt_mapping,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lora_mapping = None
|
||||||
|
|
||||||
# Broadcast the metadata.
|
# Broadcast the metadata.
|
||||||
metadata_dict = {
|
metadata_dict = {
|
||||||
"input_tokens": input_tokens,
|
"input_tokens": input_tokens,
|
||||||
@ -408,12 +474,16 @@ class ModelRunner:
|
|||||||
"use_cuda_graph": input_metadata.use_cuda_graph,
|
"use_cuda_graph": input_metadata.use_cuda_graph,
|
||||||
"selected_token_indices":
|
"selected_token_indices":
|
||||||
sampling_metadata.selected_token_indices,
|
sampling_metadata.selected_token_indices,
|
||||||
|
"lora_requests": lora_requests,
|
||||||
|
"lora_mapping": lora_mapping,
|
||||||
}
|
}
|
||||||
broadcast_tensor_dict(metadata_dict, src=0)
|
broadcast_tensor_dict(metadata_dict, src=0)
|
||||||
else:
|
else:
|
||||||
metadata_dict = broadcast_tensor_dict(src=0)
|
metadata_dict = broadcast_tensor_dict(src=0)
|
||||||
input_tokens = metadata_dict["input_tokens"]
|
input_tokens = metadata_dict["input_tokens"]
|
||||||
input_positions = metadata_dict["input_positions"]
|
input_positions = metadata_dict["input_positions"]
|
||||||
|
lora_mapping = metadata_dict["lora_mapping"]
|
||||||
|
lora_requests = metadata_dict["lora_requests"]
|
||||||
input_metadata = InputMetadata(
|
input_metadata = InputMetadata(
|
||||||
is_prompt=metadata_dict["is_prompt"],
|
is_prompt=metadata_dict["is_prompt"],
|
||||||
slot_mapping=metadata_dict["slot_mapping"],
|
slot_mapping=metadata_dict["slot_mapping"],
|
||||||
@ -434,7 +504,7 @@ class ModelRunner:
|
|||||||
perform_sampling=False,
|
perform_sampling=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return input_tokens, input_positions, input_metadata, sampling_metadata
|
return input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
@ -442,8 +512,12 @@ class ModelRunner:
|
|||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
input_tokens, input_positions, input_metadata, sampling_metadata = (
|
input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping = (
|
||||||
self.prepare_input_tensors(seq_group_metadata_list))
|
self.prepare_input_tensors(seq_group_metadata_list))
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
self.set_active_loras(lora_requests, lora_mapping)
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
if input_metadata.use_cuda_graph:
|
if input_metadata.use_cuda_graph:
|
||||||
graph_batch_size = input_tokens.shape[0]
|
graph_batch_size = input_tokens.shape[0]
|
||||||
@ -472,6 +546,28 @@ class ModelRunner:
|
|||||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||||
|
|
||||||
|
# This represents the maximum number of different requests
|
||||||
|
# that will have unique loras, an therefore the max amount of memory
|
||||||
|
# consumption create dummy lora request copies from the lora request
|
||||||
|
# passed in, which contains a lora from the lora warmup path.
|
||||||
|
dummy_lora_requests = []
|
||||||
|
dummy_lora_requests_per_seq = []
|
||||||
|
if self.lora_config:
|
||||||
|
for idx in range(self.lora_config.max_loras):
|
||||||
|
lora_id = idx + 1
|
||||||
|
dummy_lora_request = LoRARequest(
|
||||||
|
lora_name=f"warmup_{lora_id}",
|
||||||
|
lora_int_id=lora_id,
|
||||||
|
lora_local_path="/not/a/real/path",
|
||||||
|
)
|
||||||
|
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
||||||
|
rank=LORA_WARMUP_RANK)
|
||||||
|
dummy_lora_requests.append(dummy_lora_request)
|
||||||
|
dummy_lora_requests_per_seq = [
|
||||||
|
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||||||
|
for idx in range(max_num_seqs)
|
||||||
|
]
|
||||||
|
|
||||||
# Profile memory usage with max_num_sequences sequences and the total
|
# Profile memory usage with max_num_sequences sequences and the total
|
||||||
# number of tokens equal to max_num_batched_tokens.
|
# number of tokens equal to max_num_batched_tokens.
|
||||||
seqs: List[SequenceGroupMetadata] = []
|
seqs: List[SequenceGroupMetadata] = []
|
||||||
@ -485,6 +581,8 @@ class ModelRunner:
|
|||||||
seq_data={group_id: seq_data},
|
seq_data={group_id: seq_data},
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
|
lora_request=dummy_lora_requests_per_seq[group_id]
|
||||||
|
if dummy_lora_requests_per_seq else None,
|
||||||
)
|
)
|
||||||
seqs.append(seq)
|
seqs.append(seq)
|
||||||
|
|
||||||
@ -495,6 +593,32 @@ class ModelRunner:
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def remove_all_loras(self) -> bool:
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.remove_all_loras()
|
||||||
|
|
||||||
|
def set_active_loras(self, lora_requests: List[LoRARequest],
|
||||||
|
lora_mapping: LoRAMapping) -> None:
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
self.lora_manager.set_active_loras(lora_requests, lora_mapping)
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.add_lora(lora_request)
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.remove_lora(lora_id)
|
||||||
|
|
||||||
|
def list_loras(self) -> Set[int]:
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.list_loras()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def capture_model(self, kv_caches: List[KVCache]) -> None:
|
def capture_model(self, kv_caches: List[KVCache]) -> None:
|
||||||
assert not self.model_config.enforce_eager
|
assert not self.model_config.enforce_eager
|
||||||
@ -541,6 +665,13 @@ class ModelRunner:
|
|||||||
use_cuda_graph=True,
|
use_cuda_graph=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
lora_mapping = LoRAMapping(
|
||||||
|
[0] * batch_size,
|
||||||
|
[0] * batch_size,
|
||||||
|
)
|
||||||
|
self.set_active_loras(set(), lora_mapping)
|
||||||
|
|
||||||
graph_runner = CUDAGraphRunner(self.model)
|
graph_runner = CUDAGraphRunner(self.model)
|
||||||
graph_runner.capture(
|
graph_runner.capture(
|
||||||
input_tokens[:batch_size],
|
input_tokens[:batch_size],
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
"""A GPU worker class."""
|
"""A GPU worker class."""
|
||||||
|
import gc
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Tuple, Set, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig, LoRAConfig)
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
broadcast_tensor_dict)
|
broadcast_tensor_dict)
|
||||||
@ -15,6 +16,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.worker.model_runner import ModelRunner
|
from vllm.worker.model_runner import ModelRunner
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
@ -33,6 +35,7 @@ class Worker:
|
|||||||
local_rank: int,
|
local_rank: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
@ -41,12 +44,16 @@ class Worker:
|
|||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.distributed_init_method = distributed_init_method
|
self.distributed_init_method = distributed_init_method
|
||||||
|
self.lora_config = lora_config
|
||||||
self.is_driver_worker = is_driver_worker
|
self.is_driver_worker = is_driver_worker
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
assert self.rank == 0, "The driver worker must have rank 0."
|
assert self.rank == 0, "The driver worker must have rank 0."
|
||||||
|
|
||||||
self.model_runner = ModelRunner(model_config, parallel_config,
|
self.model_runner = ModelRunner(model_config,
|
||||||
scheduler_config, is_driver_worker)
|
parallel_config,
|
||||||
|
scheduler_config,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
is_driver_worker=is_driver_worker)
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# self.init_cache_engine().
|
# self.init_cache_engine().
|
||||||
self.cache_config = None
|
self.cache_config = None
|
||||||
@ -117,6 +124,9 @@ class Worker:
|
|||||||
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
|
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
|
||||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||||
|
if self.model_runner.lora_manager:
|
||||||
|
self.model_runner.remove_all_loras()
|
||||||
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return num_gpu_blocks, num_cpu_blocks
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
@ -199,6 +209,15 @@ class Worker:
|
|||||||
self.gpu_cache)
|
self.gpu_cache)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
return self.model_runner.add_lora(lora_request)
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
return self.model_runner.remove_lora(lora_id)
|
||||||
|
|
||||||
|
def list_loras(self) -> Set[int]:
|
||||||
|
return self.model_runner.list_loras()
|
||||||
|
|
||||||
|
|
||||||
def _init_distributed_environment(
|
def _init_distributed_environment(
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user