[Lora] Support long context lora (#4787)
Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through. It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors. Follow up of https://github.com/vllm-project/vllm/pull/3095/files
This commit is contained in:
parent
c0724fc915
commit
2e9a2227ec
@ -119,9 +119,23 @@ steps:
|
|||||||
|
|
||||||
- label: LoRA Test %N
|
- label: LoRA Test %N
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
|
- label: LoRA Long Context (Distributed)
|
||||||
|
#mirror_hardwares: [amd]
|
||||||
|
num_gpus: 4
|
||||||
|
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||||
|
commands:
|
||||||
|
# Temporarily run this way because we cannot clean up GPU mem usage
|
||||||
|
# for multi GPU tests.
|
||||||
|
# TODO(sang): Fix it.
|
||||||
|
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
|
||||||
|
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
|
||||||
|
- pytest -v -s lora/test_long_context.py::test_self_consistency
|
||||||
|
- pytest -v -s lora/test_long_context.py::test_quality
|
||||||
|
- pytest -v -s lora/test_long_context.py::test_max_len
|
||||||
|
|
||||||
- label: Tensorizer Test
|
- label: Tensorizer Test
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
|
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
|
||||||
|
@ -112,7 +112,7 @@ mypy vllm/model_executor --config-file pyproject.toml
|
|||||||
|
|
||||||
|
|
||||||
CODESPELL_EXCLUDES=(
|
CODESPELL_EXCLUDES=(
|
||||||
'--skip' '*docs/source/_build/**'
|
'--skip' '*docs/source/_build/**,./tests/lora/data'
|
||||||
)
|
)
|
||||||
|
|
||||||
# check spelling of specified files
|
# check spelling of specified files
|
||||||
@ -133,7 +133,6 @@ spell_check_changed() {
|
|||||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
||||||
# exist on both branches.
|
# exist on both branches.
|
||||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||||
|
|
||||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
||||||
codespell "${CODESPELL_EXCLUDES[@]}"
|
codespell "${CODESPELL_EXCLUDES[@]}"
|
||||||
|
@ -60,7 +60,7 @@ exclude = [
|
|||||||
|
|
||||||
[tool.codespell]
|
[tool.codespell]
|
||||||
ignore-words-list = "dout, te, indicies"
|
ignore-words-list = "dout, te, indicies"
|
||||||
skip = "./tests/prompts,./benchmarks/sonnet.txt"
|
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data"
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
use_parentheses = true
|
use_parentheses = true
|
||||||
|
@ -21,6 +21,17 @@ from vllm.model_executor.layers.sampler import Sampler
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
|
||||||
|
LONG_LORA_INFOS = [{
|
||||||
|
"lora_id": 1,
|
||||||
|
"context_length": "16k",
|
||||||
|
}, {
|
||||||
|
"lora_id": 2,
|
||||||
|
"context_length": "16k",
|
||||||
|
}, {
|
||||||
|
"lora_id": 3,
|
||||||
|
"context_length": "32k",
|
||||||
|
}]
|
||||||
|
|
||||||
|
|
||||||
def cleanup():
|
def cleanup():
|
||||||
destroy_model_parallel()
|
destroy_model_parallel()
|
||||||
@ -154,6 +165,45 @@ def tinyllama_lora_files():
|
|||||||
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def long_context_lora_files_16k_1():
|
||||||
|
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def long_context_lora_files_16k_2():
|
||||||
|
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def long_context_lora_files_32k():
|
||||||
|
return snapshot_download(repo_id="SangBinCho/long_context_32k_testing")
|
||||||
|
|
||||||
|
|
||||||
|
# SANG-TODO Download long lora files.
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def long_context_infos(long_context_lora_files_16k_1,
|
||||||
|
long_context_lora_files_16k_2,
|
||||||
|
long_context_lora_files_32k):
|
||||||
|
cleanup()
|
||||||
|
infos = {}
|
||||||
|
for lora_checkpoint_info in LONG_LORA_INFOS:
|
||||||
|
lora_id = lora_checkpoint_info["lora_id"]
|
||||||
|
if lora_id == 1:
|
||||||
|
lora = long_context_lora_files_16k_1
|
||||||
|
elif lora_id == 2:
|
||||||
|
lora = long_context_lora_files_16k_2
|
||||||
|
elif lora_id == 3:
|
||||||
|
lora = long_context_lora_files_32k
|
||||||
|
else:
|
||||||
|
raise AssertionError("Unknown lora id")
|
||||||
|
infos[lora_id] = {
|
||||||
|
"context_length": lora_checkpoint_info["context_length"],
|
||||||
|
"lora": lora,
|
||||||
|
}
|
||||||
|
return infos
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||||
cleanup()
|
cleanup()
|
||||||
|
0
tests/lora/data/__init__.py
Normal file
0
tests/lora/data/__init__.py
Normal file
97
tests/lora/data/long_context_test_data.py
Normal file
97
tests/lora/data/long_context_test_data.py
Normal file
File diff suppressed because one or more lines are too long
@ -15,6 +15,7 @@ from vllm.lora.fully_sharded_layers import (
|
|||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||||
|
LinearScalingRotaryEmbeddingWithLora,
|
||||||
LogitsProcessorWithLoRA, LoRAMapping,
|
LogitsProcessorWithLoRA, LoRAMapping,
|
||||||
MergedColumnParallelLinearWithLoRA,
|
MergedColumnParallelLinearWithLoRA,
|
||||||
MergedQKVParallelLinearWithLora,
|
MergedQKVParallelLinearWithLora,
|
||||||
@ -22,13 +23,14 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
|||||||
RowParallelLinearWithLoRA,
|
RowParallelLinearWithLoRA,
|
||||||
VocabParallelEmbeddingWithLoRA)
|
VocabParallelEmbeddingWithLoRA)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights,
|
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
|
||||||
convert_mapping)
|
PackedLoRALayerWeights, convert_mapping)
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
@ -771,3 +773,97 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||||||
expected_result,
|
expected_result,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.parametrize("num_loras", [1, 8])
|
||||||
|
@pytest.mark.parametrize("device", ["cuda"])
|
||||||
|
@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0),
|
||||||
|
(6.0, 1.0)])
|
||||||
|
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
|
||||||
|
@pytest.mark.parametrize("is_neox_style", [True, False])
|
||||||
|
@pytest.mark.parametrize("rotary_dim", [None, 32])
|
||||||
|
@pytest.mark.parametrize("head_size", [32, 108])
|
||||||
|
@pytest.mark.parametrize("seq_len", [11, 1024])
|
||||||
|
def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||||
|
scaling_factors, max_position,
|
||||||
|
is_neox_style, rotary_dim, head_size,
|
||||||
|
seq_len) -> None:
|
||||||
|
dtype = torch.float16
|
||||||
|
seed = 0
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
max_loras = 8
|
||||||
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
|
max_lora_rank=8,
|
||||||
|
long_lora_scaling_factors=scaling_factors,
|
||||||
|
lora_dtype=dtype)
|
||||||
|
|
||||||
|
if rotary_dim is None:
|
||||||
|
rotary_dim = head_size
|
||||||
|
base = 10000
|
||||||
|
batch_size = 5 * num_loras
|
||||||
|
num_heads = 7
|
||||||
|
|
||||||
|
# Verify lora is equivalent to linear scaling rotary embedding.
|
||||||
|
rope = get_rope(
|
||||||
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
|
max_position,
|
||||||
|
base,
|
||||||
|
is_neox_style,
|
||||||
|
)
|
||||||
|
lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
|
||||||
|
lora_rope.create_lora_weights(max_loras, lora_config)
|
||||||
|
linear_rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||||
|
is_neox_style, {
|
||||||
|
"type": "linear",
|
||||||
|
"factor": scaling_factors
|
||||||
|
})
|
||||||
|
linear_rope = linear_rope.to(dtype=dtype)
|
||||||
|
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||||
|
_, index_mapping, prompt_mapping = create_random_inputs(
|
||||||
|
active_lora_ids=[0],
|
||||||
|
num_inputs=batch_size,
|
||||||
|
input_size=(1, max_position),
|
||||||
|
input_range=(0, lora_config.lora_extra_vocab_size),
|
||||||
|
input_type=torch.float16,
|
||||||
|
)
|
||||||
|
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||||
|
long_lora_context = LongContextLoRAContext(list(scaling_factors),
|
||||||
|
rotary_dim)
|
||||||
|
|
||||||
|
next_expected_offset = 0
|
||||||
|
# Make sure the offset is correct.
|
||||||
|
scaling_factor_to_offset = lora_rope.scaling_factor_to_offset
|
||||||
|
for scaling_factor, offset in scaling_factor_to_offset.items():
|
||||||
|
assert offset == next_expected_offset
|
||||||
|
next_expected_offset += scaling_factor * max_position
|
||||||
|
|
||||||
|
for i in range(len(scaling_factors)):
|
||||||
|
long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
|
||||||
|
scaling_factors[i], 0)
|
||||||
|
mapping_info = convert_mapping(
|
||||||
|
lora_mapping,
|
||||||
|
id_to_index,
|
||||||
|
max_loras,
|
||||||
|
512,
|
||||||
|
lora_config.lora_extra_vocab_size,
|
||||||
|
long_lora_context=long_lora_context,
|
||||||
|
)
|
||||||
|
lora_rope.set_mapping(*mapping_info)
|
||||||
|
|
||||||
|
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||||
|
query = torch.randn(batch_size,
|
||||||
|
seq_len,
|
||||||
|
num_heads * head_size,
|
||||||
|
dtype=dtype)
|
||||||
|
key = torch.randn_like(query)
|
||||||
|
ref_q, ref_k = linear_rope(positions, query, key)
|
||||||
|
actual_q, actual_k = lora_rope(positions, query, key)
|
||||||
|
|
||||||
|
torch.allclose(ref_q, actual_q)
|
||||||
|
torch.allclose(ref_k, actual_k)
|
||||||
|
292
tests/lora/test_long_context.py
Normal file
292
tests/lora/test_long_context.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
import ast
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
|
LinearScalingRotaryEmbedding)
|
||||||
|
|
||||||
|
from .data.long_context_test_data import prompts_and_responses
|
||||||
|
|
||||||
|
context_len_to_scaling_factor = {
|
||||||
|
"16k": 4,
|
||||||
|
"32k": 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
# We use the same sampling params for all requests
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_lora_request(lora_id, long_context_infos):
|
||||||
|
context_len = long_context_infos[lora_id]["context_length"]
|
||||||
|
scaling_factor = context_len_to_scaling_factor[context_len]
|
||||||
|
return LoRARequest(context_len, lora_id,
|
||||||
|
long_context_infos[lora_id]["lora"],
|
||||||
|
4096 * scaling_factor)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_json_response(model_response, golden_response):
|
||||||
|
"""Evaluates the model response against the golden response.
|
||||||
|
|
||||||
|
Returns a score between 0 and 1, where 1 is a perfect match and 0 is no
|
||||||
|
match. The score quantifies how well the model is able to extract the
|
||||||
|
golden JSON from the long context.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_response = ast.literal_eval(model_response)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model response is not a valid JSON. Expected {golden_response}, "
|
||||||
|
f"got {model_response}") from e
|
||||||
|
|
||||||
|
# Normally, we would flatten the dictionary and compare the values, but in
|
||||||
|
# this case, we know that the dictionary is only 2 levels deep
|
||||||
|
positive_values = 0
|
||||||
|
total_values = 0
|
||||||
|
# We look at all the attributes of the person that we are extracting a
|
||||||
|
# biography of and copmare them to the golden response
|
||||||
|
for person_attribute, person_attribute_value in golden_response.items():
|
||||||
|
if person_attribute in model_response:
|
||||||
|
if isinstance(person_attribute_value, dict):
|
||||||
|
for (sub_attribute,
|
||||||
|
sub_attribute_value) in person_attribute_value.items():
|
||||||
|
total_values += 1
|
||||||
|
if sub_attribute in model_response[
|
||||||
|
person_attribute] and model_response[
|
||||||
|
person_attribute][
|
||||||
|
sub_attribute] == sub_attribute_value:
|
||||||
|
positive_values += 1
|
||||||
|
else:
|
||||||
|
total_values += 1
|
||||||
|
if model_response[person_attribute] == person_attribute_value:
|
||||||
|
positive_values += 1
|
||||||
|
else:
|
||||||
|
# We count a missing sub-dict as a single missed value.
|
||||||
|
total_values += 1
|
||||||
|
|
||||||
|
# Return a score between 0 and 1
|
||||||
|
return positive_values / total_values
|
||||||
|
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
llm,
|
||||||
|
inputs: Tuple[str, SamplingParams, Optional[LoRARequest]],
|
||||||
|
):
|
||||||
|
prompts, sampling_param, lora_request = inputs
|
||||||
|
outputs = llm.generate(prompts, sampling_param, lora_request=lora_request)
|
||||||
|
return outputs[0].outputs[0].text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def batched_generate(
|
||||||
|
llm,
|
||||||
|
inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]],
|
||||||
|
):
|
||||||
|
for input in inputs:
|
||||||
|
prompt, sampling_param, lora_req = input
|
||||||
|
requests_data = llm._validate_and_prepare_requests(
|
||||||
|
prompt,
|
||||||
|
sampling_param,
|
||||||
|
lora_request=lora_req,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add requests to the engine and run the engine
|
||||||
|
for request_data in requests_data:
|
||||||
|
llm._add_request(**request_data)
|
||||||
|
outputs = llm._run_engine(use_tqdm=True)
|
||||||
|
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def lora_llm(long_context_infos):
|
||||||
|
scaling_factors = [
|
||||||
|
context_len_to_scaling_factor[info["context_length"]]
|
||||||
|
for info in long_context_infos.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
llm = vllm.LLM(
|
||||||
|
"meta-llama/Llama-2-13b-chat-hf",
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_loras=2,
|
||||||
|
long_lora_scaling_factors=tuple(scaling_factors),
|
||||||
|
max_num_batched_tokens=4096 * 8,
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
)
|
||||||
|
yield llm
|
||||||
|
del llm
|
||||||
|
|
||||||
|
|
||||||
|
def test_rotary_emb_replaced(dist_init):
|
||||||
|
"""Verify rotary emb in all the layers are replaced"""
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.worker.model_runner import ModelRunner
|
||||||
|
engine_args = EngineArgs("meta-llama/Llama-2-7b-hf",
|
||||||
|
long_lora_scaling_factors=(4.0, ),
|
||||||
|
enable_lora=True)
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
model_runner = ModelRunner(
|
||||||
|
model_config=engine_config.model_config,
|
||||||
|
parallel_config=engine_config.parallel_config,
|
||||||
|
scheduler_config=engine_config.scheduler_config,
|
||||||
|
device_config=engine_config.device_config,
|
||||||
|
cache_config=engine_config.cache_config,
|
||||||
|
load_config=engine_config.load_config,
|
||||||
|
lora_config=engine_config.lora_config,
|
||||||
|
is_driver_worker=True,
|
||||||
|
)
|
||||||
|
model_runner.load_model()
|
||||||
|
rotary_emb_count = 0
|
||||||
|
for module_name, module in model_runner.model.named_modules(
|
||||||
|
remove_duplicate=False):
|
||||||
|
if "rotary_emb" in module_name:
|
||||||
|
if "base_layer" not in module_name:
|
||||||
|
rotary_emb_count += 1
|
||||||
|
assert isinstance(module, LinearScalingRotaryEmbeddingWithLora)
|
||||||
|
else:
|
||||||
|
assert isinstance(module, LinearScalingRotaryEmbedding)
|
||||||
|
# Llama 2 has 32 layers.
|
||||||
|
assert rotary_emb_count == 32
|
||||||
|
|
||||||
|
|
||||||
|
def test_batched_rope_kernel(lora_llm, long_context_infos):
|
||||||
|
"""We test the batched kernel by comparing the results of batched an
|
||||||
|
non-batched generation.
|
||||||
|
"""
|
||||||
|
# Create non batched results first to compare against batched results
|
||||||
|
non_batched_results = []
|
||||||
|
|
||||||
|
for lora_id, info in long_context_infos.items():
|
||||||
|
context_len = info["context_length"]
|
||||||
|
lora_prompt = (prompts_and_responses[context_len][0]["prompt"],
|
||||||
|
sampling_params,
|
||||||
|
_create_lora_request(lora_id, long_context_infos))
|
||||||
|
lora_output = generate(lora_llm, lora_prompt)
|
||||||
|
non_batched_results.append(lora_output)
|
||||||
|
|
||||||
|
# Create batched results
|
||||||
|
# Each element of the batch must be
|
||||||
|
# (prompt, prompt_sampling_params, prompt_lora_request)
|
||||||
|
batched_prompts = []
|
||||||
|
for lora_id, info in long_context_infos.items():
|
||||||
|
context_len = info["context_length"]
|
||||||
|
batched_prompts.extend([
|
||||||
|
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
|
||||||
|
_create_lora_request(lora_id, long_context_infos))
|
||||||
|
])
|
||||||
|
batched_results = batched_generate(lora_llm, batched_prompts)
|
||||||
|
|
||||||
|
# Results should be the same
|
||||||
|
for non_batched, batched in zip(non_batched_results, batched_results):
|
||||||
|
assert non_batched == batched, (
|
||||||
|
"Non batched and batched results should be the "
|
||||||
|
f"same:\n{batched}\n{non_batched}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_self_consistency(lora_llm, long_context_infos):
|
||||||
|
"""We test consistency of the batched kernel by permuting batched
|
||||||
|
inputs and comparing the results to the non-permuted batched results.
|
||||||
|
"""
|
||||||
|
num_loras = len(long_context_infos)
|
||||||
|
|
||||||
|
# Create results in order of long_context_infos
|
||||||
|
batched_prompts = []
|
||||||
|
for lora_id, info in long_context_infos.items():
|
||||||
|
context_len = info["context_length"]
|
||||||
|
batched_prompts.extend([
|
||||||
|
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
|
||||||
|
_create_lora_request(lora_id, long_context_infos))
|
||||||
|
])
|
||||||
|
|
||||||
|
batched_results = batched_generate(lora_llm, batched_prompts)
|
||||||
|
|
||||||
|
permutation = np.random.default_rng(seed=42).permutation(num_loras)
|
||||||
|
|
||||||
|
# Create results in random order of permutation
|
||||||
|
batched_prompts = []
|
||||||
|
for i in permutation:
|
||||||
|
lora_id, info = list(long_context_infos.items())[i]
|
||||||
|
context_len = info["context_length"]
|
||||||
|
batched_prompts.extend([
|
||||||
|
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
|
||||||
|
_create_lora_request(lora_id, long_context_infos))
|
||||||
|
])
|
||||||
|
|
||||||
|
permutated_batched_results = batched_generate(lora_llm, batched_prompts)
|
||||||
|
|
||||||
|
# Results should be the same
|
||||||
|
for i in range(num_loras):
|
||||||
|
assert batched_results[i] == permutated_batched_results[
|
||||||
|
permutation[i]], (
|
||||||
|
f"Results should be the same:\n{batched_results[i]}"
|
||||||
|
f"\n{permutated_batched_results[permutation[i]]}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_quality(lora_llm, long_context_infos):
|
||||||
|
"""We test the quality of the answers given by the LoRA model by
|
||||||
|
comparing the generated text to the merged model's outputs.
|
||||||
|
|
||||||
|
This is effectively a mini-benchmark over four prompts.
|
||||||
|
If this test fails, this indicates that the quality of the LoRA model
|
||||||
|
is suboptimal compared to the merged model. For example, if the model
|
||||||
|
does not output valid dictionaries, this test will fail.
|
||||||
|
|
||||||
|
If needed for testing, the merged versions of the models are available
|
||||||
|
as part of the `conftest`.
|
||||||
|
|
||||||
|
The test is expected to run for about 1 minute on a p4de.24xlarge
|
||||||
|
instance.
|
||||||
|
"""
|
||||||
|
scores = []
|
||||||
|
for lora_id, info in long_context_infos.items():
|
||||||
|
context_len = info["context_length"]
|
||||||
|
for prompt_and_response in prompts_and_responses[context_len]:
|
||||||
|
lora_prompt = (prompt_and_response["prompt"], sampling_params,
|
||||||
|
_create_lora_request(lora_id, long_context_infos))
|
||||||
|
response = generate(lora_llm, lora_prompt)
|
||||||
|
golden_answer = prompt_and_response["golden_answer"]
|
||||||
|
score = evaluate_json_response(response, golden_answer)
|
||||||
|
scores.append(score)
|
||||||
|
assert score > 0.3, ("Quality of the answer is not good enough. "
|
||||||
|
f"Expected {golden_answer}, got {response}")
|
||||||
|
assert np.mean(scores) > 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_len(lora_llm, long_context_infos):
|
||||||
|
"""Test that we raise an ValueError when the input of a given LoRA
|
||||||
|
model exceeds the maximum length."""
|
||||||
|
# Since each LoRA model has a different maximum length, we need to
|
||||||
|
# test each one separately
|
||||||
|
for lora_id, info in long_context_infos.items():
|
||||||
|
context_len = info["context_length"]
|
||||||
|
lora_request = _create_lora_request(lora_id, long_context_infos)
|
||||||
|
# Good prompt should be fine
|
||||||
|
good_prompt = prompts_and_responses[context_len][0]["prompt"]
|
||||||
|
generate(lora_llm, (good_prompt, sampling_params, lora_request))
|
||||||
|
# Bad prompt should raise an error
|
||||||
|
bad_prompt = good_prompt * 2
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
generate(lora_llm, (bad_prompt, sampling_params, lora_request))
|
||||||
|
|
||||||
|
# Also test batched
|
||||||
|
batched_prompts = []
|
||||||
|
for lora_id_with_bad_inputs in long_context_infos:
|
||||||
|
for lora_id, info in long_context_infos.items():
|
||||||
|
context_len = info["context_length"]
|
||||||
|
batched_prompts.extend([
|
||||||
|
(prompts_and_responses[context_len][0]["prompt"] *
|
||||||
|
(2 if lora_id == lora_id_with_bad_inputs else 1),
|
||||||
|
sampling_params,
|
||||||
|
_create_lora_request(lora_id, long_context_infos))
|
||||||
|
])
|
||||||
|
# Turn good prompt into bad prompt inside of batched prompts
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
batched_generate(lora_llm, batched_prompts)
|
@ -1,7 +1,7 @@
|
|||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
|
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@ -968,6 +968,7 @@ class LoRAConfig:
|
|||||||
lora_extra_vocab_size: int = 256
|
lora_extra_vocab_size: int = 256
|
||||||
# This is a constant.
|
# This is a constant.
|
||||||
lora_vocab_padding_size: ClassVar[int] = 256
|
lora_vocab_padding_size: ClassVar[int] = 256
|
||||||
|
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
|
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
|
||||||
|
@ -264,13 +264,6 @@ class Scheduler:
|
|||||||
# LoRAs. This should be improved in the future.
|
# LoRAs. This should be improved in the future.
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
|
||||||
if self.scheduler_config.chunked_prefill_enabled:
|
|
||||||
self.prompt_limit = self.scheduler_config.max_model_len
|
|
||||||
else:
|
|
||||||
self.prompt_limit = min(
|
|
||||||
self.scheduler_config.max_model_len,
|
|
||||||
self.scheduler_config.max_num_batched_tokens)
|
|
||||||
|
|
||||||
version = "v1"
|
version = "v1"
|
||||||
if self.scheduler_config.use_v2_block_manager:
|
if self.scheduler_config.use_v2_block_manager:
|
||||||
version = "v2"
|
version = "v2"
|
||||||
@ -596,6 +589,21 @@ class Scheduler:
|
|||||||
infeasible_seq_groups=infeasible_seq_groups,
|
infeasible_seq_groups=infeasible_seq_groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
|
||||||
|
if self.scheduler_config.chunked_prefill_enabled:
|
||||||
|
prompt_limit = self.scheduler_config.max_model_len
|
||||||
|
else:
|
||||||
|
prompt_limit = min(self.scheduler_config.max_model_len,
|
||||||
|
self.scheduler_config.max_num_batched_tokens)
|
||||||
|
|
||||||
|
# Model is fine tuned with long context. Return the fine tuned max_len.
|
||||||
|
if (seq_group.lora_request
|
||||||
|
and seq_group.lora_request.long_lora_max_len):
|
||||||
|
assert prompt_limit <= seq_group.lora_request.long_lora_max_len
|
||||||
|
return seq_group.lora_request.long_lora_max_len
|
||||||
|
else:
|
||||||
|
return prompt_limit
|
||||||
|
|
||||||
def _schedule_prefills(
|
def _schedule_prefills(
|
||||||
self,
|
self,
|
||||||
waiting_queue: deque,
|
waiting_queue: deque,
|
||||||
@ -650,11 +658,11 @@ class Scheduler:
|
|||||||
num_prompt_tokens = waiting_seqs[0].get_len()
|
num_prompt_tokens = waiting_seqs[0].get_len()
|
||||||
assert num_new_tokens == num_prompt_tokens
|
assert num_new_tokens == num_prompt_tokens
|
||||||
|
|
||||||
if num_new_tokens > self.prompt_limit:
|
prompt_limit = self._get_prompt_limit(seq_group)
|
||||||
|
if num_new_tokens > prompt_limit:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Input prompt (%d tokens) is too long"
|
"Input prompt (%d tokens) is too long"
|
||||||
" and exceeds limit of %d", num_new_tokens,
|
" and exceeds limit of %d", num_new_tokens, prompt_limit)
|
||||||
self.prompt_limit)
|
|
||||||
for seq in waiting_seqs:
|
for seq in waiting_seqs:
|
||||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||||
ignored_seq_groups.append(seq_group)
|
ignored_seq_groups.append(seq_group)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||||
@ -63,6 +63,7 @@ class EngineArgs:
|
|||||||
max_lora_rank: int = 16
|
max_lora_rank: int = 16
|
||||||
fully_sharded_loras: bool = False
|
fully_sharded_loras: bool = False
|
||||||
lora_extra_vocab_size: int = 256
|
lora_extra_vocab_size: int = 256
|
||||||
|
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
||||||
lora_dtype = 'auto'
|
lora_dtype = 'auto'
|
||||||
max_cpu_loras: Optional[int] = None
|
max_cpu_loras: Optional[int] = None
|
||||||
device: str = 'auto'
|
device: str = 'auto'
|
||||||
@ -397,6 +398,17 @@ class EngineArgs:
|
|||||||
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
||||||
help=('Data type for LoRA. If auto, will default to '
|
help=('Data type for LoRA. If auto, will default to '
|
||||||
'base model dtype.'))
|
'base model dtype.'))
|
||||||
|
parser.add_argument(
|
||||||
|
'--long-lora-scaling-factors',
|
||||||
|
type=nullable_str,
|
||||||
|
default=EngineArgs.long_lora_scaling_factors,
|
||||||
|
help=('Specify multiple scaling factors (which can '
|
||||||
|
'be different from base model scaling factor '
|
||||||
|
'- see eg. Long LoRA) to allow for multiple '
|
||||||
|
'LoRA adapters trained with those scaling '
|
||||||
|
'factors to be used at the same time. If not '
|
||||||
|
'specified, only adapters trained with the '
|
||||||
|
'base model scaling factor are allowed.'))
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--max-cpu-loras',
|
'--max-cpu-loras',
|
||||||
type=int,
|
type=int,
|
||||||
@ -593,6 +605,7 @@ class EngineArgs:
|
|||||||
max_loras=self.max_loras,
|
max_loras=self.max_loras,
|
||||||
fully_sharded_loras=self.fully_sharded_loras,
|
fully_sharded_loras=self.fully_sharded_loras,
|
||||||
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||||
|
long_lora_scaling_factors=self.long_lora_scaling_factors,
|
||||||
lora_dtype=self.lora_dtype,
|
lora_dtype=self.lora_dtype,
|
||||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||||
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||||
|
@ -131,10 +131,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||||
seq, sampling_params)
|
seq, sampling_params)
|
||||||
|
|
||||||
|
# TODO(sang): Support lora.
|
||||||
self.stop_checker.maybe_stop_sequence(
|
self.stop_checker.maybe_stop_sequence(
|
||||||
seq,
|
seq,
|
||||||
new_char_count=new_char_count,
|
new_char_count=new_char_count,
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
if seq.is_finished():
|
if seq.is_finished():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -118,8 +118,12 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
seq, seq_group.sampling_params)
|
seq, seq_group.sampling_params)
|
||||||
else:
|
else:
|
||||||
new_char_count = 0
|
new_char_count = 0
|
||||||
self.stop_checker.maybe_stop_sequence(seq, new_char_count,
|
self.stop_checker.maybe_stop_sequence(
|
||||||
seq_group.sampling_params)
|
seq,
|
||||||
|
new_char_count,
|
||||||
|
seq_group.sampling_params,
|
||||||
|
lora_req=seq_group.lora_request,
|
||||||
|
)
|
||||||
|
|
||||||
# Non-beam search case
|
# Non-beam search case
|
||||||
if not seq_group.sampling_params.use_beam_search:
|
if not seq_group.sampling_params.use_beam_search:
|
||||||
|
@ -2,6 +2,7 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import Sequence, SequenceStatus
|
from vllm.sequence import Sequence, SequenceStatus
|
||||||
|
|
||||||
@ -16,11 +17,23 @@ class StopChecker:
|
|||||||
def __init__(self, max_model_len: int,
|
def __init__(self, max_model_len: int,
|
||||||
get_tokenizer_for_seq: Callable[[Sequence],
|
get_tokenizer_for_seq: Callable[[Sequence],
|
||||||
PreTrainedTokenizer]):
|
PreTrainedTokenizer]):
|
||||||
self.max_model_len = max_model_len
|
# Do not use it directly, but use `self._get_max_model_len`.
|
||||||
|
self._max_model_len = max_model_len
|
||||||
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
||||||
|
|
||||||
def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
|
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
|
||||||
sampling_params: SamplingParams) -> None:
|
if lora_req and lora_req.long_lora_max_len:
|
||||||
|
return lora_req.long_lora_max_len
|
||||||
|
else:
|
||||||
|
return self._max_model_len
|
||||||
|
|
||||||
|
def maybe_stop_sequence(
|
||||||
|
self,
|
||||||
|
seq: Sequence,
|
||||||
|
new_char_count: int,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
lora_req: Optional[LoRARequest] = None,
|
||||||
|
) -> None:
|
||||||
"""Stop the finished sequences.
|
"""Stop the finished sequences.
|
||||||
|
|
||||||
new_char_count is the number of chars added to the
|
new_char_count is the number of chars added to the
|
||||||
@ -59,7 +72,7 @@ class StopChecker:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Check if the sequence has reached max_model_len.
|
# Check if the sequence has reached max_model_len.
|
||||||
if seq.get_len() > self.max_model_len:
|
if seq.get_len() > self._get_max_model_len(lora_req):
|
||||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -22,6 +22,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
|
LinearScalingRotaryEmbedding, RotaryEmbedding)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
|
|
||||||
@ -185,6 +187,7 @@ class BaseLayerWithLoRA(nn.Module):
|
|||||||
sampler_indices: torch.Tensor,
|
sampler_indices: torch.Tensor,
|
||||||
sampler_indices_padded: torch.Tensor,
|
sampler_indices_padded: torch.Tensor,
|
||||||
embeddings_indices: torch.Tensor,
|
embeddings_indices: torch.Tensor,
|
||||||
|
long_lora_indices: torch.Tensor,
|
||||||
indices_len: List[int],
|
indices_len: List[int],
|
||||||
):
|
):
|
||||||
"""Sets the mapping indices."""
|
"""Sets the mapping indices."""
|
||||||
@ -306,6 +309,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||||||
sampler_indices: torch.Tensor,
|
sampler_indices: torch.Tensor,
|
||||||
sampler_indices_padded: torch.Tensor,
|
sampler_indices_padded: torch.Tensor,
|
||||||
embeddings_indices: torch.Tensor,
|
embeddings_indices: torch.Tensor,
|
||||||
|
long_lora_indices: torch.Tensor,
|
||||||
indices_len: List[int],
|
indices_len: List[int],
|
||||||
):
|
):
|
||||||
self.indices = base_indices
|
self.indices = base_indices
|
||||||
@ -431,6 +435,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
sampler_indices: torch.Tensor,
|
sampler_indices: torch.Tensor,
|
||||||
sampler_indices_padded: torch.Tensor,
|
sampler_indices_padded: torch.Tensor,
|
||||||
embeddings_indices: torch.Tensor,
|
embeddings_indices: torch.Tensor,
|
||||||
|
long_lora_indices: torch.Tensor,
|
||||||
indices_len: List[int],
|
indices_len: List[int],
|
||||||
):
|
):
|
||||||
self.indices = base_indices
|
self.indices = base_indices
|
||||||
@ -951,6 +956,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
sampler_indices: torch.Tensor,
|
sampler_indices: torch.Tensor,
|
||||||
sampler_indices_padded: torch.Tensor,
|
sampler_indices_padded: torch.Tensor,
|
||||||
embeddings_indices: torch.Tensor,
|
embeddings_indices: torch.Tensor,
|
||||||
|
long_lora_indices: torch.Tensor,
|
||||||
indices_len: List[int],
|
indices_len: List[int],
|
||||||
):
|
):
|
||||||
self.indices = base_indices
|
self.indices = base_indices
|
||||||
@ -1127,6 +1133,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
sampler_indices: torch.Tensor,
|
sampler_indices: torch.Tensor,
|
||||||
sampler_indices_padded: torch.Tensor,
|
sampler_indices_padded: torch.Tensor,
|
||||||
embeddings_indices: torch.Tensor,
|
embeddings_indices: torch.Tensor,
|
||||||
|
long_lora_indices: torch.Tensor,
|
||||||
indices_len: List[int],
|
indices_len: List[int],
|
||||||
):
|
):
|
||||||
self.indices = sampler_indices
|
self.indices = sampler_indices
|
||||||
@ -1193,3 +1200,101 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
model_config: Optional[PretrainedConfig]) -> bool:
|
model_config: Optional[PretrainedConfig]) -> bool:
|
||||||
# Special handling for the LogitsProcessor.
|
# Special handling for the LogitsProcessor.
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
|
||||||
|
"""Implements RoPE-scaled embeddings with linear scaling for
|
||||||
|
multiple LoRA adapters with a specialized kernel.
|
||||||
|
|
||||||
|
Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
|
||||||
|
which can handle multi lora adapters in a specialied kernel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_layer: RotaryEmbedding) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.base_layer = base_layer
|
||||||
|
# Lazily initialized
|
||||||
|
self.long_lora_indices: torch.Tensor
|
||||||
|
self.indices_len: List[int]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scaling_factors(self):
|
||||||
|
return self.base_layer.scaling_factors
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rotary_dim(self):
|
||||||
|
return self.base_layer.rotary_dim
|
||||||
|
|
||||||
|
def create_lora_weights(
|
||||||
|
self,
|
||||||
|
max_loras: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
model_config: Optional[PretrainedConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
scaling_factors = list(
|
||||||
|
lora_config.long_lora_scaling_factors
|
||||||
|
) if lora_config.long_lora_scaling_factors else []
|
||||||
|
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
|
||||||
|
self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
|
||||||
|
scaling_factors = sorted(
|
||||||
|
list(set([base_scaling_factor] + scaling_factors)))
|
||||||
|
self.base_layer = LinearScalingRotaryEmbedding(
|
||||||
|
self.base_layer.head_size,
|
||||||
|
self.base_layer.rotary_dim,
|
||||||
|
self.base_layer.max_position_embeddings,
|
||||||
|
self.base_layer.base,
|
||||||
|
self.base_layer.is_neox_style,
|
||||||
|
scaling_factors,
|
||||||
|
self.base_layer.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset_lora(self, index: int):
|
||||||
|
...
|
||||||
|
|
||||||
|
def set_lora(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
lora_a: torch.Tensor,
|
||||||
|
lora_b: torch.Tensor,
|
||||||
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
def set_mapping(
|
||||||
|
self,
|
||||||
|
base_indices: torch.Tensor,
|
||||||
|
sampler_indices: torch.Tensor,
|
||||||
|
sampler_indices_padded: torch.Tensor,
|
||||||
|
embeddings_indices: torch.Tensor,
|
||||||
|
long_lora_indices: torch.Tensor,
|
||||||
|
indices_len: List[int],
|
||||||
|
):
|
||||||
|
self.long_lora_indices = long_lora_indices
|
||||||
|
self.indices_len = indices_len
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return self.base_layer(
|
||||||
|
positions,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
offsets=self.long_lora_indices[:self.indices_len[4]])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scaling_factor_to_offset(self) -> Dict[float, int]:
|
||||||
|
return self.base_layer.scaling_factor_to_offset
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_replace_layer(cls, source_layer: nn.Module,
|
||||||
|
lora_config: LoRAConfig, packed_modules_list: List,
|
||||||
|
model_config: Optional[PretrainedConfig]) -> bool:
|
||||||
|
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||||
|
return type(source_layer) is LinearScalingRotaryEmbedding or type(
|
||||||
|
source_layer) is RotaryEmbedding
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return self.base_layer.extra_repr()
|
||||||
|
@ -3,7 +3,8 @@ import json
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Type
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
@ -11,7 +12,9 @@ from torch import nn
|
|||||||
|
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import LoRAConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
|
from vllm.lora.layers import (BaseLayerWithLoRA,
|
||||||
|
LinearScalingRotaryEmbeddingWithLora,
|
||||||
|
LoRAMapping)
|
||||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||||
parse_fine_tuned_lora_name, replace_submodule)
|
parse_fine_tuned_lora_name, replace_submodule)
|
||||||
@ -22,10 +25,27 @@ logger = init_logger(__name__)
|
|||||||
_GLOBAL_LORA_ID = 0
|
_GLOBAL_LORA_ID = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LongContextLoRAContext:
|
||||||
|
"""Context for lora adapters that support long context."""
|
||||||
|
# The scaling factors to support long context lora fine tuned models.
|
||||||
|
scaling_factors: List[float]
|
||||||
|
# dimension to apply rotary embedding.
|
||||||
|
rot_dim: int
|
||||||
|
# offsets to the sin_cos_cache for each lora_id loaded.
|
||||||
|
# This value is dynamically modified.
|
||||||
|
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
def convert_mapping(
|
def convert_mapping(
|
||||||
mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
|
mapping: LoRAMapping,
|
||||||
max_loras: int, vocab_size: int, extra_vocab_size: int
|
lora_index_to_id: List[Optional[int]],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
|
max_loras: int,
|
||||||
|
vocab_size: int,
|
||||||
|
extra_vocab_size: int,
|
||||||
|
long_lora_context: Optional[LongContextLoRAContext] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||||
|
Optional[torch.Tensor], List[int]]:
|
||||||
"""Converts LoRAMapping to index tensors.
|
"""Converts LoRAMapping to index tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -34,6 +54,7 @@ def convert_mapping(
|
|||||||
max_loras: Maximum number of LoRAs.
|
max_loras: Maximum number of LoRAs.
|
||||||
vocab_size: Model vocab size.
|
vocab_size: Model vocab size.
|
||||||
extra_vocab_size: Extra vocab size each LoRA can have.
|
extra_vocab_size: Extra vocab size each LoRA can have.
|
||||||
|
long_lora_context: Passed if there are long context lora in a batch.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of tensors:
|
A tuple of tensors:
|
||||||
@ -51,11 +72,23 @@ def convert_mapping(
|
|||||||
requests to embedding indices. First row is for embeddings
|
requests to embedding indices. First row is for embeddings
|
||||||
added by the LoRAs, second row is for the LoRA.lora_a
|
added by the LoRAs, second row is for the LoRA.lora_a
|
||||||
embeddings.
|
embeddings.
|
||||||
|
long_lora_indices: Tensor of shape [batch_size] mapping
|
||||||
|
requests to RoPE offsets and rot dims for long LoRAs.
|
||||||
|
None if long context lora doesn't exist.
|
||||||
indices_len: List of lengths of the above tensors.
|
indices_len: List of lengths of the above tensors.
|
||||||
|
Used to index into each tensor. It contains length for
|
||||||
|
(base_indices, sampler_indices, sampler_indices_padded,
|
||||||
|
embeddings_indices, long_lora_indices). If long_lora doesn't
|
||||||
|
exist, it only contains first 4 entries.
|
||||||
"""
|
"""
|
||||||
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
|
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
|
||||||
embedding_indices = index_mapping_indices.copy()
|
embedding_indices = index_mapping_indices.copy()
|
||||||
lora_indices = index_mapping_indices.copy()
|
lora_indices = index_mapping_indices.copy()
|
||||||
|
long_lora_offsets: Optional[torch.Tensor] = None
|
||||||
|
if long_lora_context:
|
||||||
|
long_lora_offsets = torch.zeros(len(index_mapping_indices),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.long)
|
||||||
prompt_mapping: List[int] = [
|
prompt_mapping: List[int] = [
|
||||||
lora_index_to_id.index(x) if x > 0 else -1
|
lora_index_to_id.index(x) if x > 0 else -1
|
||||||
for x in mapping.prompt_mapping
|
for x in mapping.prompt_mapping
|
||||||
@ -66,13 +99,22 @@ def convert_mapping(
|
|||||||
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
|
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
|
||||||
if index_mapping_indices[i] > 0 else -1)
|
if index_mapping_indices[i] > 0 else -1)
|
||||||
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
||||||
index_mapping_indices[i] = i
|
|
||||||
lora_indices[i] = lora_idx
|
lora_indices[i] = lora_idx
|
||||||
|
if long_lora_context:
|
||||||
|
assert long_lora_offsets is not None
|
||||||
|
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
|
||||||
|
index_mapping_indices[i], 0)
|
||||||
|
long_lora_offsets[i] = lora_offset
|
||||||
|
# SANG-TODO
|
||||||
|
# index_mapping_indices[i] = i
|
||||||
|
|
||||||
indices = torch.tensor(
|
indices_list: List[Union[List[int], torch.Tensor]] = [
|
||||||
[index_mapping_indices, lora_indices, embedding_indices],
|
index_mapping_indices, lora_indices, embedding_indices
|
||||||
dtype=torch.long,
|
]
|
||||||
device="cuda")
|
if long_lora_context:
|
||||||
|
assert long_lora_offsets is not None
|
||||||
|
indices_list.append(long_lora_offsets)
|
||||||
|
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
|
||||||
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=torch.long)
|
dtype=torch.long)
|
||||||
@ -89,13 +131,21 @@ def convert_mapping(
|
|||||||
torch.arange(
|
torch.arange(
|
||||||
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
|
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
|
||||||
(sampler_indices_padded * len(sampler_indices_padded)))
|
(sampler_indices_padded * len(sampler_indices_padded)))
|
||||||
|
long_lora_indices = None
|
||||||
|
long_lora_indices_len: Optional[int] = None
|
||||||
|
if long_lora_context:
|
||||||
|
long_lora_indices = indices[3]
|
||||||
|
long_lora_indices_len = long_lora_indices.shape[-1]
|
||||||
|
# Contain length of indices tensors. Used to index into each tensor.
|
||||||
indices_len = [
|
indices_len = [
|
||||||
base_indices.shape[-1], sampler_indices.shape[-1],
|
base_indices.shape[-1], sampler_indices.shape[-1],
|
||||||
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
|
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
|
||||||
]
|
]
|
||||||
|
if long_lora_indices_len is not None:
|
||||||
|
indices_len.append(long_lora_indices_len)
|
||||||
|
|
||||||
return (base_indices, sampler_indices, sampler_indices_padded,
|
return (base_indices, sampler_indices, sampler_indices_padded,
|
||||||
embeddings_indices, indices_len)
|
embeddings_indices, long_lora_indices, indices_len)
|
||||||
|
|
||||||
|
|
||||||
def get_lora_id():
|
def get_lora_id():
|
||||||
@ -112,8 +162,20 @@ class LoRAModel:
|
|||||||
lora_model_id: int,
|
lora_model_id: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
loras: Dict[str, LoRALayerWeights],
|
loras: Dict[str, LoRALayerWeights],
|
||||||
|
scaling_factor: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lora_model_id: The integer id for the lora model.
|
||||||
|
rank: lora rank.
|
||||||
|
loras: module name -> weights for lora-replaced layers.
|
||||||
|
scaling_factor: Scaling factor to support long context lora model.
|
||||||
|
None if the lora is not tuned for long context support.
|
||||||
|
"""
|
||||||
self.id = lora_model_id
|
self.id = lora_model_id
|
||||||
|
# Scaling factor for long context lora model. None if it is not
|
||||||
|
# fine tuned for the long context.
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
assert (lora_model_id >
|
assert (lora_model_id >
|
||||||
0), f"a valid lora id should be greater than 0, got {self.id}"
|
0), f"a valid lora id should be greater than 0, got {self.id}"
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
@ -150,6 +212,7 @@ class LoRAModel:
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
target_embedding_padding: Optional[int] = None,
|
target_embedding_padding: Optional[int] = None,
|
||||||
|
scaling_factor: Optional[float] = None,
|
||||||
embedding_modules: Optional[Dict[str, str]] = None,
|
embedding_modules: Optional[Dict[str, str]] = None,
|
||||||
embedding_padding_modules: Optional[List[str]] = None,
|
embedding_padding_modules: Optional[List[str]] = None,
|
||||||
) -> "LoRAModel":
|
) -> "LoRAModel":
|
||||||
@ -199,13 +262,15 @@ class LoRAModel:
|
|||||||
|
|
||||||
for lora in loras.values():
|
for lora in loras.values():
|
||||||
lora.optimize()
|
lora.optimize()
|
||||||
return cls(lora_model_id, rank, loras)
|
return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local_checkpoint(
|
def from_local_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
lora_dir: str,
|
lora_dir: str,
|
||||||
expected_lora_modules: List[str],
|
expected_lora_modules: List[str],
|
||||||
|
*,
|
||||||
|
max_position_embeddings: Optional[int] = None,
|
||||||
lora_model_id: Optional[int] = None,
|
lora_model_id: Optional[int] = None,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
@ -213,7 +278,23 @@ class LoRAModel:
|
|||||||
embedding_modules: Optional[Dict[str, str]] = None,
|
embedding_modules: Optional[Dict[str, str]] = None,
|
||||||
embedding_padding_modules: Optional[List[str]] = None,
|
embedding_padding_modules: Optional[List[str]] = None,
|
||||||
) -> "LoRAModel":
|
) -> "LoRAModel":
|
||||||
"""Create a LoRAModel from a local checkpoint."""
|
"""Create a LoRAModel from a local checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora_dir: The local path that has lora data.
|
||||||
|
expected_lora_modules: Name of modules that are expected to be
|
||||||
|
replaced by lora.
|
||||||
|
max_position_embeddings: Max position embedding length. Used to
|
||||||
|
scaling the largest context length. If None, the lora model's
|
||||||
|
context length is not scaled.
|
||||||
|
lora_model_id: Lora model id. If not given, automatically set by
|
||||||
|
a global counter.
|
||||||
|
device: Device where the lora model is loaded.
|
||||||
|
dtype: dtype of the lora model weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded LoRA Model.
|
||||||
|
"""
|
||||||
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
|
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
|
||||||
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
||||||
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
|
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
|
||||||
@ -253,6 +334,14 @@ class LoRAModel:
|
|||||||
|
|
||||||
rank = config["r"]
|
rank = config["r"]
|
||||||
lora_alpha = config["lora_alpha"]
|
lora_alpha = config["lora_alpha"]
|
||||||
|
context_length = config.get("context_length", None)
|
||||||
|
scaling_factor = None
|
||||||
|
if context_length:
|
||||||
|
if max_position_embeddings is None:
|
||||||
|
max_position_embeddings = context_length
|
||||||
|
scaling_factor = float(
|
||||||
|
math.ceil(context_length / max_position_embeddings))
|
||||||
|
|
||||||
return cls.from_lora_tensors(
|
return cls.from_lora_tensors(
|
||||||
lora_model_id=get_lora_id()
|
lora_model_id=get_lora_id()
|
||||||
if lora_model_id is None else lora_model_id,
|
if lora_model_id is None else lora_model_id,
|
||||||
@ -263,6 +352,7 @@ class LoRAModel:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
embeddings=embeddings,
|
embeddings=embeddings,
|
||||||
target_embedding_padding=target_embedding_padding,
|
target_embedding_padding=target_embedding_padding,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
embedding_modules=embedding_modules,
|
embedding_modules=embedding_modules,
|
||||||
embedding_padding_modules=embedding_padding_modules,
|
embedding_padding_modules=embedding_padding_modules,
|
||||||
)
|
)
|
||||||
@ -296,6 +386,7 @@ class LoRAModelManager:
|
|||||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||||
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
|
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
self.long_lora_context: Optional[LongContextLoRAContext] = None
|
||||||
self.base_indices = torch.empty(self.max_num_batched_tokens,
|
self.base_indices = torch.empty(self.max_num_batched_tokens,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
@ -309,6 +400,12 @@ class LoRAModelManager:
|
|||||||
self.max_num_batched_tokens,
|
self.max_num_batched_tokens,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
|
self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cuda")
|
||||||
|
# Scaling factor -> offset to the sin_cos_cache to it.
|
||||||
|
# Used for long context lora.
|
||||||
|
self.scaling_factor_to_offset: Dict[float, int] = {}
|
||||||
# 4 is the number of indicies tensors defined above
|
# 4 is the number of indicies tensors defined above
|
||||||
# base_indices, sampler_indices, sampler_indices_padded,
|
# base_indices, sampler_indices, sampler_indices_padded,
|
||||||
# embeddings_indices
|
# embeddings_indices
|
||||||
@ -318,6 +415,10 @@ class LoRAModelManager:
|
|||||||
if hasattr(self.model, "supported_lora_modules"):
|
if hasattr(self.model, "supported_lora_modules"):
|
||||||
self.supported_lora_modules = copy.deepcopy(
|
self.supported_lora_modules = copy.deepcopy(
|
||||||
self.model.supported_lora_modules)
|
self.model.supported_lora_modules)
|
||||||
|
if lora_config.long_lora_scaling_factors:
|
||||||
|
# We need to replace rotary emb layer to do batch computation
|
||||||
|
# for long lora.
|
||||||
|
self.supported_lora_modules.append("rotary_emb")
|
||||||
self.packed_modules_mapping = copy.deepcopy(
|
self.packed_modules_mapping = copy.deepcopy(
|
||||||
self.model.packed_modules_mapping)
|
self.model.packed_modules_mapping)
|
||||||
self.packed_modules: Dict[str, List[str]] = {}
|
self.packed_modules: Dict[str, List[str]] = {}
|
||||||
@ -383,12 +484,32 @@ class LoRAModelManager:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _set_long_lora_context(self, lora: LoRAModel):
|
||||||
|
if self.long_lora_context is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if lora.scaling_factor is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (lora.scaling_factor not in self.scaling_factor_to_offset):
|
||||||
|
raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
|
||||||
|
" has not been initialized.")
|
||||||
|
|
||||||
|
offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
|
||||||
|
if offsets:
|
||||||
|
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
|
||||||
|
|
||||||
def _add_lora(self, lora: LoRAModel):
|
def _add_lora(self, lora: LoRAModel):
|
||||||
self._create_merged_loras_inplace(lora)
|
self._create_merged_loras_inplace(lora)
|
||||||
self._registered_loras[lora.id] = lora
|
self._registered_loras[lora.id] = lora
|
||||||
|
self._set_long_lora_context(lora)
|
||||||
|
|
||||||
def add_lora(self, lora: LoRAModel) -> bool:
|
def add_lora(self, lora: LoRAModel) -> bool:
|
||||||
"""Add a LoRAModel to the manager CPU cache."""
|
"""Add a LoRAModel to the manager CPU cache."""
|
||||||
|
logger.debug(
|
||||||
|
"Adding lora. Model id: %d, "
|
||||||
|
"int id: %d, "
|
||||||
|
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
|
||||||
if lora.id not in self._registered_loras:
|
if lora.id not in self._registered_loras:
|
||||||
if len(self._registered_loras) >= self.capacity:
|
if len(self._registered_loras) >= self.capacity:
|
||||||
raise RuntimeError("No free LoRA slots.")
|
raise RuntimeError("No free LoRA slots.")
|
||||||
@ -400,15 +521,18 @@ class LoRAModelManager:
|
|||||||
"""Remove a LoRAModel from the manager CPU cache."""
|
"""Remove a LoRAModel from the manager CPU cache."""
|
||||||
# TODO: should we check active lora?
|
# TODO: should we check active lora?
|
||||||
self.deactivate_lora(lora_id)
|
self.deactivate_lora(lora_id)
|
||||||
|
if self.long_lora_context:
|
||||||
|
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
|
||||||
return bool(self._registered_loras.pop(lora_id, None))
|
return bool(self._registered_loras.pop(lora_id, None))
|
||||||
|
|
||||||
# TODO see if this can be vectorized
|
# TODO see if this can be vectorized
|
||||||
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
|
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
|
||||||
(base_indices, sampler_indices, sampler_indices_padded,
|
(base_indices, sampler_indices, sampler_indices_padded,
|
||||||
embeddings_indices,
|
embeddings_indices, long_lora_offsets_tensor,
|
||||||
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
|
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
|
||||||
self.lora_slots + 1, self.vocab_size,
|
self.lora_slots + 1, self.vocab_size,
|
||||||
self.lora_config.lora_extra_vocab_size)
|
self.lora_config.lora_extra_vocab_size,
|
||||||
|
self.long_lora_context)
|
||||||
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
|
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
|
||||||
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
||||||
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
|
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
|
||||||
@ -416,6 +540,11 @@ class LoRAModelManager:
|
|||||||
self.embeddings_indices[:embeddings_indices.
|
self.embeddings_indices[:embeddings_indices.
|
||||||
shape[0], :embeddings_indices.shape[1]].copy_(
|
shape[0], :embeddings_indices.shape[1]].copy_(
|
||||||
embeddings_indices)
|
embeddings_indices)
|
||||||
|
if long_lora_offsets_tensor is not None:
|
||||||
|
self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
|
||||||
|
long_lora_offsets_tensor)
|
||||||
|
else:
|
||||||
|
self.long_lora_indices.zero_()
|
||||||
# Maintain the reference
|
# Maintain the reference
|
||||||
self.indices_len[:] = indices_len
|
self.indices_len[:] = indices_len
|
||||||
|
|
||||||
@ -438,7 +567,8 @@ class LoRAModelManager:
|
|||||||
self._active_loras.clear()
|
self._active_loras.clear()
|
||||||
|
|
||||||
def _create_lora_modules(self):
|
def _create_lora_modules(self):
|
||||||
for module_name, module in self.model.named_modules():
|
for module_name, module in self.model.named_modules(
|
||||||
|
remove_duplicate=False):
|
||||||
if not self._match_target_modules(module_name):
|
if not self._match_target_modules(module_name):
|
||||||
continue
|
continue
|
||||||
parts = module_name.split(".")[-1]
|
parts = module_name.split(".")[-1]
|
||||||
@ -447,6 +577,13 @@ class LoRAModelManager:
|
|||||||
self.model, module_name,
|
self.model, module_name,
|
||||||
from_layer(module, self.lora_slots, self.lora_config,
|
from_layer(module, self.lora_slots, self.lora_config,
|
||||||
packed_moduled_lst, self.model.config))
|
packed_moduled_lst, self.model.config))
|
||||||
|
# LinearScalingRotaryEmbeddingWithLora is used to handle
|
||||||
|
# long context lora. Register relevant metadata.
|
||||||
|
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
|
||||||
|
self.long_lora_context = LongContextLoRAContext(
|
||||||
|
new_module.scaling_factors, new_module.rotary_dim)
|
||||||
|
self.scaling_factor_to_offset = \
|
||||||
|
new_module.scaling_factor_to_offset
|
||||||
# (yard1): TODO make this more robust
|
# (yard1): TODO make this more robust
|
||||||
if "lm_head" in module_name:
|
if "lm_head" in module_name:
|
||||||
logits_processor_module = self.model.get_submodule(
|
logits_processor_module = self.model.get_submodule(
|
||||||
@ -461,7 +598,8 @@ class LoRAModelManager:
|
|||||||
self._register_packed_modules(module_name)
|
self._register_packed_modules(module_name)
|
||||||
new_module.set_mapping(self.base_indices, self.sampler_indices,
|
new_module.set_mapping(self.base_indices, self.sampler_indices,
|
||||||
self.sampler_indices_padded,
|
self.sampler_indices_padded,
|
||||||
self.embeddings_indices, self.indices_len)
|
self.embeddings_indices,
|
||||||
|
self.long_lora_indices, self.indices_len)
|
||||||
|
|
||||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||||
assert isinstance(module, BaseLayerWithLoRA)
|
assert isinstance(module, BaseLayerWithLoRA)
|
||||||
@ -471,12 +609,14 @@ class LoRAModelManager:
|
|||||||
self,
|
self,
|
||||||
lora_id: int,
|
lora_id: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
|
scaling_factor: Optional[float],
|
||||||
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
|
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
|
||||||
"""Create zero-initialized LoRAModel for warmup."""
|
"""Create zero-initialized LoRAModel for warmup."""
|
||||||
model = LoRAModel(lora_id, rank, {})
|
model = LoRAModel(lora_id, rank, {}, scaling_factor)
|
||||||
for module_name, module in self.model.named_modules():
|
for module_name, module in self.model.named_modules():
|
||||||
if not self._match_target_modules(module_name) or not isinstance(
|
if not self._match_target_modules(module_name) or not isinstance(
|
||||||
module, BaseLayerWithLoRA):
|
module, BaseLayerWithLoRA) or isinstance(
|
||||||
|
module, LinearScalingRotaryEmbeddingWithLora):
|
||||||
continue
|
continue
|
||||||
parts = module_name.split(".")
|
parts = module_name.split(".")
|
||||||
if module_name not in self.packed_modules:
|
if module_name not in self.packed_modules:
|
||||||
@ -606,6 +746,10 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
|||||||
|
|
||||||
def add_lora(self, lora: LoRAModel) -> bool:
|
def add_lora(self, lora: LoRAModel) -> bool:
|
||||||
"""Add a LoRAModel to the manager."""
|
"""Add a LoRAModel to the manager."""
|
||||||
|
logger.debug(
|
||||||
|
"Adding lora. Model id: %d, "
|
||||||
|
"int id: %d, "
|
||||||
|
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
|
||||||
if lora.id not in self._registered_loras:
|
if lora.id not in self._registered_loras:
|
||||||
self._add_lora(lora)
|
self._add_lora(lora)
|
||||||
was_added = True
|
was_added = True
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -18,6 +19,7 @@ class LoRARequest:
|
|||||||
lora_name: str
|
lora_name: str
|
||||||
lora_int_id: int
|
lora_int_id: int
|
||||||
lora_local_path: str
|
lora_local_path: str
|
||||||
|
long_lora_max_len: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.lora_int_id < 1:
|
if self.lora_int_id < 1:
|
||||||
|
@ -13,6 +13,7 @@ from vllm.lora.fully_sharded_layers import (
|
|||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||||
|
LinearScalingRotaryEmbeddingWithLora,
|
||||||
LogitsProcessorWithLoRA,
|
LogitsProcessorWithLoRA,
|
||||||
MergedColumnParallelLinearWithLoRA,
|
MergedColumnParallelLinearWithLoRA,
|
||||||
MergedQKVParallelLinearWithLora,
|
MergedQKVParallelLinearWithLora,
|
||||||
@ -26,12 +27,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
|
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
|
||||||
VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA,
|
VocabParallelEmbeddingWithLoRA,
|
||||||
MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora,
|
ColumnParallelLinearWithLoRA,
|
||||||
MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA,
|
MergedColumnParallelLinearWithLoRA,
|
||||||
LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA,
|
QKVParallelLinearWithLora,
|
||||||
|
MergedQKVParallelLinearWithLora,
|
||||||
|
RowParallelLinearWithLoRA,
|
||||||
|
LogitsProcessorWithLoRA,
|
||||||
|
ColumnParallelLinearWithShardedLoRA,
|
||||||
MergedColumnParallelLinearWithShardedLoRA,
|
MergedColumnParallelLinearWithShardedLoRA,
|
||||||
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA
|
MergedQKVParallelLinearWithShardedLora,
|
||||||
|
RowParallelLinearWithShardedLoRA,
|
||||||
|
LinearScalingRotaryEmbeddingWithLora,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod, abstractproperty
|
from abc import ABC, abstractmethod, abstractproperty
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Literal, Set, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -17,11 +17,16 @@ logger = init_logger(__name__)
|
|||||||
class AbstractWorkerLoRAManager(ABC):
|
class AbstractWorkerLoRAManager(ABC):
|
||||||
"""Abstract class for managing LoRA models on the worker side."""
|
"""Abstract class for managing LoRA models on the worker side."""
|
||||||
|
|
||||||
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
|
def __init__(self,
|
||||||
vocab_size: int, lora_config: LoRAConfig,
|
max_num_seqs: int,
|
||||||
device: torch.device):
|
max_num_batched_tokens: int,
|
||||||
|
vocab_size: int,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
device: torch.device,
|
||||||
|
max_position_embeddings: Optional[int] = None):
|
||||||
self.max_num_seqs = max_num_seqs
|
self.max_num_seqs = max_num_seqs
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.device = device
|
self.device = device
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
@ -92,14 +97,21 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
|||||||
embedding_modules: Dict[str, str],
|
embedding_modules: Dict[str, str],
|
||||||
embedding_padding_modules: List[str],
|
embedding_padding_modules: List[str],
|
||||||
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
||||||
|
max_position_embeddings: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self._lora_model_cls = lora_model_cls
|
self._lora_model_cls = lora_model_cls
|
||||||
self.embedding_modules = embedding_modules
|
self.embedding_modules = embedding_modules
|
||||||
self.embedding_padding_modules = embedding_padding_modules
|
self.embedding_padding_modules = embedding_padding_modules
|
||||||
# Lazily initialized by create_lora_manager.
|
# Lazily initialized by create_lora_manager.
|
||||||
self._lora_manager: LoRAModelManager
|
self._lora_manager: LoRAModelManager
|
||||||
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
|
super().__init__(
|
||||||
lora_config, device)
|
max_num_seqs,
|
||||||
|
max_num_batched_tokens,
|
||||||
|
vocab_size,
|
||||||
|
lora_config,
|
||||||
|
device,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_enabled(self) -> bool:
|
def is_enabled(self) -> bool:
|
||||||
@ -162,6 +174,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
|||||||
lora = self._lora_model_cls.from_local_checkpoint(
|
lora = self._lora_model_cls.from_local_checkpoint(
|
||||||
lora_request.lora_local_path,
|
lora_request.lora_local_path,
|
||||||
expected_lora_modules,
|
expected_lora_modules,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
lora_model_id=lora_request.lora_int_id,
|
lora_model_id=lora_request.lora_int_id,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=self.lora_config.lora_dtype,
|
dtype=self.lora_config.lora_dtype,
|
||||||
@ -191,7 +204,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
|||||||
lora_request.lora_int_id)
|
lora_request.lora_int_id)
|
||||||
else:
|
else:
|
||||||
dummy_lora = self._lora_manager.create_dummy_lora(
|
dummy_lora = self._lora_manager.create_dummy_lora(
|
||||||
lora_request.lora_int_id, rank, self.embedding_modules)
|
lora_request.lora_int_id, rank, 1, self.embedding_modules)
|
||||||
if self._cached_dummy_lora is None:
|
if self._cached_dummy_lora is None:
|
||||||
self._cached_dummy_lora = dummy_lora
|
self._cached_dummy_lora = dummy_lora
|
||||||
return self._lora_manager.add_lora(dummy_lora)
|
return self._lora_manager.add_lora(dummy_lora)
|
||||||
|
@ -61,6 +61,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
self.is_neox_style = is_neox_style
|
self.is_neox_style = is_neox_style
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
cache = self._compute_cos_sin_cache()
|
cache = self._compute_cos_sin_cache()
|
||||||
cache = cache.to(dtype)
|
cache = cache.to(dtype)
|
||||||
@ -168,6 +169,29 @@ class RotaryEmbedding(nn.Module):
|
|||||||
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
"""RotaryEmbedding extended with linear scaling.
|
"""RotaryEmbedding extended with linear scaling.
|
||||||
|
|
||||||
|
It supports multiple scaling factors. Since multiple LoRA adapters may have
|
||||||
|
different scaling factors, we need multiple cos/sin caches. In this way,
|
||||||
|
instead of running rotary embedding kernel per lora, we can run multiple
|
||||||
|
lora in a batched way.
|
||||||
|
|
||||||
|
In addition to that, we also keep the cos/sin cache for the scaling factor
|
||||||
|
of 1 (default) at all times.
|
||||||
|
|
||||||
|
Exemplary for two scaling factors x=1, y and z with embeddings
|
||||||
|
[[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
|
||||||
|
[[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
|
||||||
|
[[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
|
||||||
|
|
||||||
|
we construct the cos/sin cache as follows:
|
||||||
|
[[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
|
||||||
|
...
|
||||||
|
[xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
|
||||||
|
|
||||||
|
We then use offsets to index into the cos/sin cache for
|
||||||
|
the respective scaling factors.
|
||||||
|
|
||||||
|
The offset to cache can be accessed via `scaling_factor_to_offset` API.
|
||||||
|
|
||||||
Credits to the Reddit user /u/kaiokendev
|
Credits to the Reddit user /u/kaiokendev
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -183,13 +207,18 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(scaling_factors, float):
|
if isinstance(scaling_factors, float):
|
||||||
scaling_factors = [scaling_factors]
|
scaling_factors = [scaling_factors]
|
||||||
self.scaling_factors = scaling_factors
|
self.scaling_factors: List[float] = scaling_factors # noqa
|
||||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
is_neox_style, dtype)
|
is_neox_style, dtype)
|
||||||
|
# Lazy initialized.
|
||||||
|
self._scaling_factor_to_offset: Dict[float, int]
|
||||||
|
|
||||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
inv_freq = self._compute_inv_freq(self.base)
|
inv_freq = self._compute_inv_freq(self.base)
|
||||||
cache_list = []
|
cache_list: List[torch.Tensor] = []
|
||||||
|
# offsets to the next cache in a tensor.
|
||||||
|
# Each offset corresponds to the same index in scaling_factors.
|
||||||
|
offsets: List[int] = []
|
||||||
for scaling_factor in self.scaling_factors:
|
for scaling_factor in self.scaling_factors:
|
||||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||||
# maximum length before applying the rope scaling.
|
# maximum length before applying the rope scaling.
|
||||||
@ -203,9 +232,25 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
if not cache_list:
|
||||||
|
offset = 0
|
||||||
|
else:
|
||||||
|
last_offset = offsets[-1]
|
||||||
|
next_max_len = cache_list[-1].shape[0]
|
||||||
|
offset = last_offset + next_max_len
|
||||||
|
offsets.append(offset)
|
||||||
cache_list.append(cache)
|
cache_list.append(cache)
|
||||||
|
self._scaling_factor_to_offset = {
|
||||||
|
float(scaling_factor): offsets[i]
|
||||||
|
for i, scaling_factor in enumerate(self.scaling_factors)
|
||||||
|
}
|
||||||
|
assert len(self.scaling_factors) == len(offsets)
|
||||||
return torch.cat(cache_list, dim=0)
|
return torch.cat(cache_list, dim=0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scaling_factor_to_offset(self) -> Dict[float, int]:
|
||||||
|
return self._scaling_factor_to_offset
|
||||||
|
|
||||||
|
|
||||||
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
||||||
|
@ -348,6 +348,8 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config: ChatGLMConfig = config
|
self.config: ChatGLMConfig = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
||||||
|
8192)
|
||||||
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
||||||
self.lm_head_weight = self.transformer.output_layer.weight
|
self.lm_head_weight = self.transformer.output_layer.weight
|
||||||
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
||||||
|
@ -321,12 +321,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# LoRA specific attributes
|
# LoRA specific attributes
|
||||||
supported_lora_modules = [
|
supported_lora_modules = [
|
||||||
"qkv_proj",
|
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
|
||||||
"o_proj",
|
"lm_head"
|
||||||
"gate_up_proj",
|
|
||||||
"down_proj",
|
|
||||||
"embed_tokens",
|
|
||||||
"lm_head",
|
|
||||||
]
|
]
|
||||||
embedding_modules = {
|
embedding_modules = {
|
||||||
"embed_tokens": "input_embeddings",
|
"embed_tokens": "input_embeddings",
|
||||||
|
@ -46,6 +46,8 @@ class ChatGLMConfig(PretrainedConfig):
|
|||||||
self.kv_channels = kv_channels
|
self.kv_channels = kv_channels
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
|
# It is to be compatible with long lora.
|
||||||
|
self.max_position_embeddings = seq_length
|
||||||
self.hidden_dropout = hidden_dropout
|
self.hidden_dropout = hidden_dropout
|
||||||
self.attention_dropout = attention_dropout
|
self.attention_dropout = attention_dropout
|
||||||
self.layernorm_epsilon = layernorm_epsilon
|
self.layernorm_epsilon = layernorm_epsilon
|
||||||
|
@ -34,12 +34,26 @@ class TokenizerGroup(BaseTokenizerGroup):
|
|||||||
"""Get the maximum input length for the LoRA request."""
|
"""Get the maximum input length for the LoRA request."""
|
||||||
return self.max_input_length
|
return self.max_input_length
|
||||||
|
|
||||||
|
def _raise_if_input_too_long(self,
|
||||||
|
encoded_tokens: List[str],
|
||||||
|
lora_request: Optional[LoRARequest] = None):
|
||||||
|
input_length = len(encoded_tokens)
|
||||||
|
if lora_request:
|
||||||
|
max_input_length = (lora_request.long_lora_max_len
|
||||||
|
or self.max_input_length)
|
||||||
|
else:
|
||||||
|
max_input_length = self.max_input_length
|
||||||
|
if max_input_length is not None and input_length > max_input_length:
|
||||||
|
raise ValueError("Input too long.", input_length, max_input_length)
|
||||||
|
|
||||||
def encode(self,
|
def encode(self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
request_id: Optional[str] = None,
|
request_id: Optional[str] = None,
|
||||||
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||||
tokenizer = self.get_lora_tokenizer(lora_request)
|
tokenizer = self.get_lora_tokenizer(lora_request)
|
||||||
return tokenizer.encode(prompt)
|
ret = tokenizer.encode(prompt)
|
||||||
|
self._raise_if_input_too_long(ret, lora_request)
|
||||||
|
return ret
|
||||||
|
|
||||||
async def encode_async(
|
async def encode_async(
|
||||||
self,
|
self,
|
||||||
@ -47,7 +61,9 @@ class TokenizerGroup(BaseTokenizerGroup):
|
|||||||
request_id: Optional[str] = None,
|
request_id: Optional[str] = None,
|
||||||
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||||
tokenizer = await self.get_lora_tokenizer_async(lora_request)
|
tokenizer = await self.get_lora_tokenizer_async(lora_request)
|
||||||
return tokenizer.encode(prompt)
|
ret = tokenizer.encode(prompt)
|
||||||
|
self._raise_if_input_too_long(ret, lora_request)
|
||||||
|
return ret
|
||||||
|
|
||||||
def get_lora_tokenizer(
|
def get_lora_tokenizer(
|
||||||
self,
|
self,
|
||||||
|
@ -156,9 +156,15 @@ class ModelRunner:
|
|||||||
), "Model does not have embedding_padding_modules"
|
), "Model does not have embedding_padding_modules"
|
||||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
self.scheduler_config.max_num_seqs,
|
self.scheduler_config.max_num_seqs,
|
||||||
self.scheduler_config.max_num_batched_tokens, self.vocab_size,
|
self.scheduler_config.max_num_batched_tokens,
|
||||||
self.lora_config, self.device, self.model.embedding_modules,
|
self.vocab_size,
|
||||||
self.model.embedding_padding_modules)
|
self.lora_config,
|
||||||
|
self.device,
|
||||||
|
self.model.embedding_modules,
|
||||||
|
self.model.embedding_padding_modules,
|
||||||
|
max_position_embeddings=self.model.config.
|
||||||
|
max_position_embeddings,
|
||||||
|
)
|
||||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||||
|
|
||||||
if self.kv_cache_dtype == "fp8" and is_hip():
|
if self.kv_cache_dtype == "fp8" and is_hip():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user