[Lora] Support long context lora (#4787)

Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through.

It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors.

Follow up of https://github.com/vllm-project/vllm/pull/3095/files
This commit is contained in:
SangBin Cho 2024-05-18 16:05:23 +09:00 committed by GitHub
parent c0724fc915
commit 2e9a2227ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 998 additions and 71 deletions

View File

@ -119,9 +119,23 @@ steps:
- label: LoRA Test %N
#mirror_hardwares: [amd]
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
parallelism: 4
- label: LoRA Long Context (Distributed)
#mirror_hardwares: [amd]
num_gpus: 4
# This test runs llama 13B, so it is required to run on 4 GPUs.
commands:
# Temporarily run this way because we cannot clean up GPU mem usage
# for multi GPU tests.
# TODO(sang): Fix it.
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
- pytest -v -s lora/test_long_context.py::test_self_consistency
- pytest -v -s lora/test_long_context.py::test_quality
- pytest -v -s lora/test_long_context.py::test_max_len
- label: Tensorizer Test
#mirror_hardwares: [amd]
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader

View File

@ -112,7 +112,7 @@ mypy vllm/model_executor --config-file pyproject.toml
CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**'
'--skip' '*docs/source/_build/**,./tests/lora/data'
)
# check spelling of specified files
@ -133,10 +133,9 @@ spell_check_changed() {
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
codespell "${CODESPELL_EXCLUDES[@]}"
codespell "${CODESPELL_EXCLUDES[@]}"
fi
}

View File

@ -60,7 +60,7 @@ exclude = [
[tool.codespell]
ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts,./benchmarks/sonnet.txt"
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data"
[tool.isort]
use_parentheses = true

View File

@ -21,6 +21,17 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
LONG_LORA_INFOS = [{
"lora_id": 1,
"context_length": "16k",
}, {
"lora_id": 2,
"context_length": "16k",
}, {
"lora_id": 3,
"context_length": "32k",
}]
def cleanup():
destroy_model_parallel()
@ -154,6 +165,45 @@ def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
@pytest.fixture(scope="session")
def long_context_lora_files_16k_1():
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
@pytest.fixture(scope="session")
def long_context_lora_files_16k_2():
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2")
@pytest.fixture(scope="session")
def long_context_lora_files_32k():
return snapshot_download(repo_id="SangBinCho/long_context_32k_testing")
# SANG-TODO Download long lora files.
@pytest.fixture(scope="session")
def long_context_infos(long_context_lora_files_16k_1,
long_context_lora_files_16k_2,
long_context_lora_files_32k):
cleanup()
infos = {}
for lora_checkpoint_info in LONG_LORA_INFOS:
lora_id = lora_checkpoint_info["lora_id"]
if lora_id == 1:
lora = long_context_lora_files_16k_1
elif lora_id == 2:
lora = long_context_lora_files_16k_2
elif lora_id == 3:
lora = long_context_lora_files_32k
else:
raise AssertionError("Unknown lora id")
infos[lora_id] = {
"context_length": lora_checkpoint_info["context_length"],
"lora": lora,
}
return infos
@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()

View File

File diff suppressed because one or more lines are too long

View File

@ -15,6 +15,7 @@ from vllm.lora.fully_sharded_layers import (
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LogitsProcessorWithLoRA, LoRAMapping,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
@ -22,13 +23,14 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights,
convert_mapping)
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
PackedLoRALayerWeights, convert_mapping)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.utils import set_random_seed
@ -771,3 +773,97 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 8])
@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0),
(6.0, 1.0)])
@pytest.mark.parametrize("max_position", [11, 4096, 32768])
@pytest.mark.parametrize("is_neox_style", [True, False])
@pytest.mark.parametrize("rotary_dim", [None, 32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
def test_rotary_embedding_long_context(dist_init, num_loras, device,
scaling_factors, max_position,
is_neox_style, rotary_dim, head_size,
seq_len) -> None:
dtype = torch.float16
seed = 0
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
long_lora_scaling_factors=scaling_factors,
lora_dtype=dtype)
if rotary_dim is None:
rotary_dim = head_size
base = 10000
batch_size = 5 * num_loras
num_heads = 7
# Verify lora is equivalent to linear scaling rotary embedding.
rope = get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
)
lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
lora_rope.create_lora_weights(max_loras, lora_config)
linear_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, {
"type": "linear",
"factor": scaling_factors
})
linear_rope = linear_rope.to(dtype=dtype)
id_to_index = get_random_id_to_index(num_loras, max_loras)
_, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=batch_size,
input_size=(1, max_position),
input_range=(0, lora_config.lora_extra_vocab_size),
input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
long_lora_context = LongContextLoRAContext(list(scaling_factors),
rotary_dim)
next_expected_offset = 0
# Make sure the offset is correct.
scaling_factor_to_offset = lora_rope.scaling_factor_to_offset
for scaling_factor, offset in scaling_factor_to_offset.items():
assert offset == next_expected_offset
next_expected_offset += scaling_factor * max_position
for i in range(len(scaling_factors)):
long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
scaling_factors[i], 0)
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
long_lora_context=long_lora_context,
)
lora_rope.set_mapping(*mapping_info)
positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query)
ref_q, ref_k = linear_rope(positions, query, key)
actual_q, actual_k = lora_rope(positions, query, key)
torch.allclose(ref_q, actual_q)
torch.allclose(ref_k, actual_k)

View File

@ -0,0 +1,292 @@
import ast
from typing import List, Optional, Tuple
import numpy as np
import pytest
import vllm
from vllm import SamplingParams
from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.rotary_embedding import (
LinearScalingRotaryEmbedding)
from .data.long_context_test_data import prompts_and_responses
context_len_to_scaling_factor = {
"16k": 4,
"32k": 8,
}
# We use the same sampling params for all requests
sampling_params = SamplingParams(
temperature=0,
max_tokens=100,
)
def _create_lora_request(lora_id, long_context_infos):
context_len = long_context_infos[lora_id]["context_length"]
scaling_factor = context_len_to_scaling_factor[context_len]
return LoRARequest(context_len, lora_id,
long_context_infos[lora_id]["lora"],
4096 * scaling_factor)
def evaluate_json_response(model_response, golden_response):
"""Evaluates the model response against the golden response.
Returns a score between 0 and 1, where 1 is a perfect match and 0 is no
match. The score quantifies how well the model is able to extract the
golden JSON from the long context.
"""
try:
model_response = ast.literal_eval(model_response)
except Exception as e:
raise ValueError(
f"Model response is not a valid JSON. Expected {golden_response}, "
f"got {model_response}") from e
# Normally, we would flatten the dictionary and compare the values, but in
# this case, we know that the dictionary is only 2 levels deep
positive_values = 0
total_values = 0
# We look at all the attributes of the person that we are extracting a
# biography of and copmare them to the golden response
for person_attribute, person_attribute_value in golden_response.items():
if person_attribute in model_response:
if isinstance(person_attribute_value, dict):
for (sub_attribute,
sub_attribute_value) in person_attribute_value.items():
total_values += 1
if sub_attribute in model_response[
person_attribute] and model_response[
person_attribute][
sub_attribute] == sub_attribute_value:
positive_values += 1
else:
total_values += 1
if model_response[person_attribute] == person_attribute_value:
positive_values += 1
else:
# We count a missing sub-dict as a single missed value.
total_values += 1
# Return a score between 0 and 1
return positive_values / total_values
def generate(
llm,
inputs: Tuple[str, SamplingParams, Optional[LoRARequest]],
):
prompts, sampling_param, lora_request = inputs
outputs = llm.generate(prompts, sampling_param, lora_request=lora_request)
return outputs[0].outputs[0].text.strip()
def batched_generate(
llm,
inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]],
):
for input in inputs:
prompt, sampling_param, lora_req = input
requests_data = llm._validate_and_prepare_requests(
prompt,
sampling_param,
lora_request=lora_req,
)
# Add requests to the engine and run the engine
for request_data in requests_data:
llm._add_request(**request_data)
outputs = llm._run_engine(use_tqdm=True)
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
@pytest.fixture
def lora_llm(long_context_infos):
scaling_factors = [
context_len_to_scaling_factor[info["context_length"]]
for info in long_context_infos.values()
]
llm = vllm.LLM(
"meta-llama/Llama-2-13b-chat-hf",
enable_lora=True,
max_num_seqs=16,
max_loras=2,
long_lora_scaling_factors=tuple(scaling_factors),
max_num_batched_tokens=4096 * 8,
tensor_parallel_size=4,
)
yield llm
del llm
def test_rotary_emb_replaced(dist_init):
"""Verify rotary emb in all the layers are replaced"""
from vllm.engine.arg_utils import EngineArgs
from vllm.worker.model_runner import ModelRunner
engine_args = EngineArgs("meta-llama/Llama-2-7b-hf",
long_lora_scaling_factors=(4.0, ),
enable_lora=True)
engine_config = engine_args.create_engine_config()
model_runner = ModelRunner(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
is_driver_worker=True,
)
model_runner.load_model()
rotary_emb_count = 0
for module_name, module in model_runner.model.named_modules(
remove_duplicate=False):
if "rotary_emb" in module_name:
if "base_layer" not in module_name:
rotary_emb_count += 1
assert isinstance(module, LinearScalingRotaryEmbeddingWithLora)
else:
assert isinstance(module, LinearScalingRotaryEmbedding)
# Llama 2 has 32 layers.
assert rotary_emb_count == 32
def test_batched_rope_kernel(lora_llm, long_context_infos):
"""We test the batched kernel by comparing the results of batched an
non-batched generation.
"""
# Create non batched results first to compare against batched results
non_batched_results = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
lora_prompt = (prompts_and_responses[context_len][0]["prompt"],
sampling_params,
_create_lora_request(lora_id, long_context_infos))
lora_output = generate(lora_llm, lora_prompt)
non_batched_results.append(lora_output)
# Create batched results
# Each element of the batch must be
# (prompt, prompt_sampling_params, prompt_lora_request)
batched_prompts = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
_create_lora_request(lora_id, long_context_infos))
])
batched_results = batched_generate(lora_llm, batched_prompts)
# Results should be the same
for non_batched, batched in zip(non_batched_results, batched_results):
assert non_batched == batched, (
"Non batched and batched results should be the "
f"same:\n{batched}\n{non_batched}")
def test_self_consistency(lora_llm, long_context_infos):
"""We test consistency of the batched kernel by permuting batched
inputs and comparing the results to the non-permuted batched results.
"""
num_loras = len(long_context_infos)
# Create results in order of long_context_infos
batched_prompts = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
_create_lora_request(lora_id, long_context_infos))
])
batched_results = batched_generate(lora_llm, batched_prompts)
permutation = np.random.default_rng(seed=42).permutation(num_loras)
# Create results in random order of permutation
batched_prompts = []
for i in permutation:
lora_id, info = list(long_context_infos.items())[i]
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
_create_lora_request(lora_id, long_context_infos))
])
permutated_batched_results = batched_generate(lora_llm, batched_prompts)
# Results should be the same
for i in range(num_loras):
assert batched_results[i] == permutated_batched_results[
permutation[i]], (
f"Results should be the same:\n{batched_results[i]}"
f"\n{permutated_batched_results[permutation[i]]}")
def test_quality(lora_llm, long_context_infos):
"""We test the quality of the answers given by the LoRA model by
comparing the generated text to the merged model's outputs.
This is effectively a mini-benchmark over four prompts.
If this test fails, this indicates that the quality of the LoRA model
is suboptimal compared to the merged model. For example, if the model
does not output valid dictionaries, this test will fail.
If needed for testing, the merged versions of the models are available
as part of the `conftest`.
The test is expected to run for about 1 minute on a p4de.24xlarge
instance.
"""
scores = []
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
for prompt_and_response in prompts_and_responses[context_len]:
lora_prompt = (prompt_and_response["prompt"], sampling_params,
_create_lora_request(lora_id, long_context_infos))
response = generate(lora_llm, lora_prompt)
golden_answer = prompt_and_response["golden_answer"]
score = evaluate_json_response(response, golden_answer)
scores.append(score)
assert score > 0.3, ("Quality of the answer is not good enough. "
f"Expected {golden_answer}, got {response}")
assert np.mean(scores) > 0.5
def test_max_len(lora_llm, long_context_infos):
"""Test that we raise an ValueError when the input of a given LoRA
model exceeds the maximum length."""
# Since each LoRA model has a different maximum length, we need to
# test each one separately
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
lora_request = _create_lora_request(lora_id, long_context_infos)
# Good prompt should be fine
good_prompt = prompts_and_responses[context_len][0]["prompt"]
generate(lora_llm, (good_prompt, sampling_params, lora_request))
# Bad prompt should raise an error
bad_prompt = good_prompt * 2
with pytest.raises(ValueError):
generate(lora_llm, (bad_prompt, sampling_params, lora_request))
# Also test batched
batched_prompts = []
for lora_id_with_bad_inputs in long_context_infos:
for lora_id, info in long_context_infos.items():
context_len = info["context_length"]
batched_prompts.extend([
(prompts_and_responses[context_len][0]["prompt"] *
(2 if lora_id == lora_id_with_bad_inputs else 1),
sampling_params,
_create_lora_request(lora_id, long_context_infos))
])
# Turn good prompt into bad prompt inside of batched prompts
with pytest.raises(ValueError):
batched_generate(lora_llm, batched_prompts)

View File

@ -1,7 +1,7 @@
import enum
import json
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
import torch
from transformers import PretrainedConfig
@ -968,6 +968,7 @@ class LoRAConfig:
lora_extra_vocab_size: int = 256
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
def __post_init__(self):
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h

View File

@ -264,13 +264,6 @@ class Scheduler:
# LoRAs. This should be improved in the future.
self.lora_config = lora_config
if self.scheduler_config.chunked_prefill_enabled:
self.prompt_limit = self.scheduler_config.max_model_len
else:
self.prompt_limit = min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
version = "v1"
if self.scheduler_config.use_v2_block_manager:
version = "v2"
@ -596,6 +589,21 @@ class Scheduler:
infeasible_seq_groups=infeasible_seq_groups,
)
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
if self.scheduler_config.chunked_prefill_enabled:
prompt_limit = self.scheduler_config.max_model_len
else:
prompt_limit = min(self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
# Model is fine tuned with long context. Return the fine tuned max_len.
if (seq_group.lora_request
and seq_group.lora_request.long_lora_max_len):
assert prompt_limit <= seq_group.lora_request.long_lora_max_len
return seq_group.lora_request.long_lora_max_len
else:
return prompt_limit
def _schedule_prefills(
self,
waiting_queue: deque,
@ -650,11 +658,11 @@ class Scheduler:
num_prompt_tokens = waiting_seqs[0].get_len()
assert num_new_tokens == num_prompt_tokens
if num_new_tokens > self.prompt_limit:
prompt_limit = self._get_prompt_limit(seq_group)
if num_new_tokens > prompt_limit:
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds limit of %d", num_new_tokens,
self.prompt_limit)
" and exceeds limit of %d", num_new_tokens, prompt_limit)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)

View File

@ -1,7 +1,7 @@
import argparse
import dataclasses
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
@ -63,6 +63,7 @@ class EngineArgs:
max_lora_rank: int = 16
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
@ -397,6 +398,17 @@ class EngineArgs:
choices=['auto', 'float16', 'bfloat16', 'float32'],
help=('Data type for LoRA. If auto, will default to '
'base model dtype.'))
parser.add_argument(
'--long-lora-scaling-factors',
type=nullable_str,
default=EngineArgs.long_lora_scaling_factors,
help=('Specify multiple scaling factors (which can '
'be different from base model scaling factor '
'- see eg. Long LoRA) to allow for multiple '
'LoRA adapters trained with those scaling '
'factors to be used at the same time. If not '
'specified, only adapters trained with the '
'base model scaling factor are allowed.'))
parser.add_argument(
'--max-cpu-loras',
type=int,
@ -593,6 +605,7 @@ class EngineArgs:
max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
long_lora_scaling_factors=self.long_lora_scaling_factors,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None

View File

@ -131,10 +131,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
# TODO(sang): Support lora.
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count=new_char_count,
sampling_params=sampling_params)
sampling_params=sampling_params,
)
if seq.is_finished():
break

View File

@ -118,8 +118,12 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
seq, seq_group.sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(seq, new_char_count,
seq_group.sampling_params)
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count,
seq_group.sampling_params,
lora_req=seq_group.lora_request,
)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:

View File

@ -2,6 +2,7 @@ from typing import Callable, Optional
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
@ -16,11 +17,23 @@ class StopChecker:
def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence],
PreTrainedTokenizer]):
self.max_model_len = max_model_len
# Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> None:
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
if lora_req and lora_req.long_lora_max_len:
return lora_req.long_lora_max_len
else:
return self._max_model_len
def maybe_stop_sequence(
self,
seq: Sequence,
new_char_count: int,
sampling_params: SamplingParams,
lora_req: Optional[LoRARequest] = None,
) -> None:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
@ -59,7 +72,7 @@ class StopChecker:
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.max_model_len:
if seq.get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

View File

@ -1,7 +1,7 @@
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -22,6 +22,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import (
LinearScalingRotaryEmbedding, RotaryEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@ -185,6 +187,7 @@ class BaseLayerWithLoRA(nn.Module):
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
"""Sets the mapping indices."""
@ -306,6 +309,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
@ -431,6 +435,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
@ -951,6 +956,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = base_indices
@ -1127,6 +1133,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.indices = sampler_indices
@ -1193,3 +1200,101 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config: Optional[PretrainedConfig]) -> bool:
# Special handling for the LogitsProcessor.
return False
class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
"""Implements RoPE-scaled embeddings with linear scaling for
multiple LoRA adapters with a specialized kernel.
Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
which can handle multi lora adapters in a specialied kernel.
"""
def __init__(self, base_layer: RotaryEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
# Lazily initialized
self.long_lora_indices: torch.Tensor
self.indices_len: List[int]
@property
def scaling_factors(self):
return self.base_layer.scaling_factors
@property
def rotary_dim(self):
return self.base_layer.rotary_dim
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
scaling_factors = list(
lora_config.long_lora_scaling_factors
) if lora_config.long_lora_scaling_factors else []
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
scaling_factors = sorted(
list(set([base_scaling_factor] + scaling_factors)))
self.base_layer = LinearScalingRotaryEmbedding(
self.base_layer.head_size,
self.base_layer.rotary_dim,
self.base_layer.max_position_embeddings,
self.base_layer.base,
self.base_layer.is_neox_style,
scaling_factors,
self.base_layer.dtype,
)
def reset_lora(self, index: int):
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
...
def set_mapping(
self,
base_indices: torch.Tensor,
sampler_indices: torch.Tensor,
sampler_indices_padded: torch.Tensor,
embeddings_indices: torch.Tensor,
long_lora_indices: torch.Tensor,
indices_len: List[int],
):
self.long_lora_indices = long_lora_indices
self.indices_len = indices_len
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
return self.base_layer(
positions,
query,
key,
offsets=self.long_lora_indices[:self.indices_len[4]])
@property
def scaling_factor_to_offset(self) -> Dict[float, int]:
return self.base_layer.scaling_factor_to_offset
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
return type(source_layer) is LinearScalingRotaryEmbedding or type(
source_layer) is RotaryEmbedding
def extra_repr(self) -> str:
return self.base_layer.extra_repr()

View File

@ -3,7 +3,8 @@ import json
import math
import os
import re
from typing import Callable, Dict, List, Optional, Tuple, Type
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import safetensors.torch
import torch
@ -11,7 +12,9 @@ from torch import nn
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
@ -22,10 +25,27 @@ logger = init_logger(__name__)
_GLOBAL_LORA_ID = 0
@dataclass
class LongContextLoRAContext:
"""Context for lora adapters that support long context."""
# The scaling factors to support long context lora fine tuned models.
scaling_factors: List[float]
# dimension to apply rotary embedding.
rot_dim: int
# offsets to the sin_cos_cache for each lora_id loaded.
# This value is dynamically modified.
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
def convert_mapping(
mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
max_loras: int, vocab_size: int, extra_vocab_size: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
mapping: LoRAMapping,
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional[LongContextLoRAContext] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
@ -34,6 +54,7 @@ def convert_mapping(
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
@ -51,11 +72,23 @@ def convert_mapping(
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors.
Used to index into each tensor. It contains length for
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices). If long_lora doesn't
exist, it only contains first 4 entries.
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device="cuda",
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
@ -66,13 +99,22 @@ def convert_mapping(
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
index_mapping_indices[i] = i
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
# SANG-TODO
# index_mapping_indices[i] = i
indices = torch.tensor(
[index_mapping_indices, lora_indices, embedding_indices],
dtype=torch.long,
device="cuda")
indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices, lora_indices, embedding_indices
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device="cuda",
dtype=torch.long)
@ -89,13 +131,21 @@ def convert_mapping(
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len)
embeddings_indices, long_lora_indices, indices_len)
def get_lora_id():
@ -112,8 +162,20 @@ class LoRAModel:
lora_model_id: int,
rank: int,
loras: Dict[str, LoRALayerWeights],
scaling_factor: Optional[float] = None,
) -> None:
"""
Args:
lora_model_id: The integer id for the lora model.
rank: lora rank.
loras: module name -> weights for lora-replaced layers.
scaling_factor: Scaling factor to support long context lora model.
None if the lora is not tuned for long context support.
"""
self.id = lora_model_id
# Scaling factor for long context lora model. None if it is not
# fine tuned for the long context.
self.scaling_factor = scaling_factor
assert (lora_model_id >
0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank
@ -150,6 +212,7 @@ class LoRAModel:
dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None,
scaling_factor: Optional[float] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
@ -199,13 +262,15 @@ class LoRAModel:
for lora in loras.values():
lora.optimize()
return cls(lora_model_id, rank, loras)
return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: List[str],
*,
max_position_embeddings: Optional[int] = None,
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
@ -213,7 +278,23 @@ class LoRAModel:
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint."""
"""Create a LoRAModel from a local checkpoint.
Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
max_position_embeddings: Max position embedding length. Used to
scaling the largest context length. If None, the lora model's
context length is not scaled.
lora_model_id: Lora model id. If not given, automatically set by
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.
Returns:
Loaded LoRA Model.
"""
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
@ -253,6 +334,14 @@ class LoRAModel:
rank = config["r"]
lora_alpha = config["lora_alpha"]
context_length = config.get("context_length", None)
scaling_factor = None
if context_length:
if max_position_embeddings is None:
max_position_embeddings = context_length
scaling_factor = float(
math.ceil(context_length / max_position_embeddings))
return cls.from_lora_tensors(
lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id,
@ -263,6 +352,7 @@ class LoRAModel:
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
scaling_factor=scaling_factor,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
)
@ -296,6 +386,7 @@ class LoRAModelManager:
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.base_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
@ -309,6 +400,12 @@ class LoRAModelManager:
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
@ -318,6 +415,10 @@ class LoRAModelManager:
if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules)
if lora_config.long_lora_scaling_factors:
# We need to replace rotary emb layer to do batch computation
# for long lora.
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
self.packed_modules: Dict[str, List[str]] = {}
@ -383,12 +484,32 @@ class LoRAModelManager:
return True
return False
def _set_long_lora_context(self, lora: LoRAModel):
if self.long_lora_context is None:
return
if lora.scaling_factor is None:
return
if (lora.scaling_factor not in self.scaling_factor_to_offset):
raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
" has not been initialized.")
offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
if offsets:
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
def _add_lora(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora
self._set_long_lora_context(lora)
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager CPU cache."""
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_loras:
if len(self._registered_loras) >= self.capacity:
raise RuntimeError("No free LoRA slots.")
@ -400,15 +521,18 @@ class LoRAModelManager:
"""Remove a LoRAModel from the manager CPU cache."""
# TODO: should we check active lora?
self.deactivate_lora(lora_id)
if self.long_lora_context:
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
return bool(self._registered_loras.pop(lora_id, None))
# TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices,
embeddings_indices, long_lora_offsets_tensor,
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
self.lora_slots + 1, self.vocab_size,
self.lora_config.lora_extra_vocab_size)
self.lora_config.lora_extra_vocab_size,
self.long_lora_context)
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
@ -416,6 +540,11 @@ class LoRAModelManager:
self.embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
if long_lora_offsets_tensor is not None:
self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self.long_lora_indices.zero_()
# Maintain the reference
self.indices_len[:] = indices_len
@ -438,7 +567,8 @@ class LoRAModelManager:
self._active_loras.clear()
def _create_lora_modules(self):
for module_name, module in self.model.named_modules():
for module_name, module in self.model.named_modules(
remove_duplicate=False):
if not self._match_target_modules(module_name):
continue
parts = module_name.split(".")[-1]
@ -447,6 +577,13 @@ class LoRAModelManager:
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config))
# LinearScalingRotaryEmbeddingWithLora is used to handle
# long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
self.long_lora_context = LongContextLoRAContext(
new_module.scaling_factors, new_module.rotary_dim)
self.scaling_factor_to_offset = \
new_module.scaling_factor_to_offset
# (yard1): TODO make this more robust
if "lm_head" in module_name:
logits_processor_module = self.model.get_submodule(
@ -461,7 +598,8 @@ class LoRAModelManager:
self._register_packed_modules(module_name)
new_module.set_mapping(self.base_indices, self.sampler_indices,
self.sampler_indices_padded,
self.embeddings_indices, self.indices_len)
self.embeddings_indices,
self.long_lora_indices, self.indices_len)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA)
@ -471,12 +609,14 @@ class LoRAModelManager:
self,
lora_id: int,
rank: int,
scaling_factor: Optional[float],
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {})
model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name) or not isinstance(
module, BaseLayerWithLoRA):
module, BaseLayerWithLoRA) or isinstance(
module, LinearScalingRotaryEmbeddingWithLora):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
@ -606,6 +746,10 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager."""
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_loras:
self._add_lora(lora)
was_added = True

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional
@dataclass
@ -18,6 +19,7 @@ class LoRARequest:
lora_name: str
lora_int_id: int
lora_local_path: str
long_lora_max_len: Optional[int] = None
def __post_init__(self):
if self.lora_int_id < 1:

View File

@ -13,6 +13,7 @@ from vllm.lora.fully_sharded_layers import (
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
@ -26,12 +27,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger = init_logger(__name__)
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA,
ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA
MergedQKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA,
LinearScalingRotaryEmbeddingWithLora,
}

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod, abstractproperty
from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Set, Type, Union
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
import torch
@ -17,11 +17,16 @@ logger = init_logger(__name__)
class AbstractWorkerLoRAManager(ABC):
"""Abstract class for managing LoRA models on the worker side."""
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
vocab_size: int, lora_config: LoRAConfig,
device: torch.device):
def __init__(self,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
max_position_embeddings: Optional[int] = None):
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_position_embeddings = max_position_embeddings
self.vocab_size = vocab_size
self.device = device
self.lora_config = lora_config
@ -92,14 +97,21 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
embedding_modules: Dict[str, str],
embedding_padding_modules: List[str],
lora_model_cls: Type[LoRAModel] = LoRAModel,
max_position_embeddings: Optional[int] = None,
):
self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules
# Lazily initialized by create_lora_manager.
self._lora_manager: LoRAModelManager
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device)
super().__init__(
max_num_seqs,
max_num_batched_tokens,
vocab_size,
lora_config,
device,
max_position_embeddings=max_position_embeddings,
)
@property
def is_enabled(self) -> bool:
@ -162,6 +174,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
lora = self._lora_model_cls.from_local_checkpoint(
lora_request.lora_local_path,
expected_lora_modules,
max_position_embeddings=self.max_position_embeddings,
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
@ -191,7 +204,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
lora_request.lora_int_id)
else:
dummy_lora = self._lora_manager.create_dummy_lora(
lora_request.lora_int_id, rank, self.embedding_modules)
lora_request.lora_int_id, rank, 1, self.embedding_modules)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._lora_manager.add_lora(dummy_lora)

View File

@ -61,6 +61,7 @@ class RotaryEmbedding(nn.Module):
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
@ -168,6 +169,29 @@ class RotaryEmbedding(nn.Module):
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling.
It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per lora, we can run multiple
lora in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times.
Exemplary for two scaling factors x=1, y and z with embeddings
[[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
[[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
[[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
we construct the cos/sin cache as follows:
[[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
...
[xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
We then use offsets to index into the cos/sin cache for
the respective scaling factors.
The offset to cache can be accessed via `scaling_factor_to_offset` API.
Credits to the Reddit user /u/kaiokendev
"""
@ -183,13 +207,18 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors = scaling_factors
self.scaling_factors: List[float] = scaling_factors # noqa
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
# Lazy initialized.
self._scaling_factor_to_offset: Dict[float, int]
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
cache_list = []
cache_list: List[torch.Tensor] = []
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets: List[int] = []
for scaling_factor in self.scaling_factors:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
@ -203,9 +232,25 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
if not cache_list:
offset = 0
else:
last_offset = offsets[-1]
next_max_len = cache_list[-1].shape[0]
offset = last_offset + next_max_len
offsets.append(offset)
cache_list.append(cache)
self._scaling_factor_to_offset = {
float(scaling_factor): offsets[i]
for i, scaling_factor in enumerate(self.scaling_factors)
}
assert len(self.scaling_factors) == len(offsets)
return torch.cat(cache_list, dim=0)
@property
def scaling_factor_to_offset(self) -> Dict[float, int]:
return self._scaling_factor_to_offset
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.

View File

@ -348,6 +348,8 @@ class ChatGLMForCausalLM(nn.Module):
super().__init__()
self.config: ChatGLMConfig = config
self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size)

View File

@ -321,12 +321,8 @@ class LlamaForCausalLM(nn.Module):
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
embedding_modules = {
"embed_tokens": "input_embeddings",

View File

@ -46,6 +46,8 @@ class ChatGLMConfig(PretrainedConfig):
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
# It is to be compatible with long lora.
self.max_position_embeddings = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon

View File

@ -34,12 +34,26 @@ class TokenizerGroup(BaseTokenizerGroup):
"""Get the maximum input length for the LoRA request."""
return self.max_input_length
def _raise_if_input_too_long(self,
encoded_tokens: List[str],
lora_request: Optional[LoRARequest] = None):
input_length = len(encoded_tokens)
if lora_request:
max_input_length = (lora_request.long_lora_max_len
or self.max_input_length)
else:
max_input_length = self.max_input_length
if max_input_length is not None and input_length > max_input_length:
raise ValueError("Input too long.", input_length, max_input_length)
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
ret = tokenizer.encode(prompt)
self._raise_if_input_too_long(ret, lora_request)
return ret
async def encode_async(
self,
@ -47,7 +61,9 @@ class TokenizerGroup(BaseTokenizerGroup):
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
ret = tokenizer.encode(prompt)
self._raise_if_input_too_long(ret, lora_request)
return ret
def get_lora_tokenizer(
self,

View File

@ -156,9 +156,15 @@ class ModelRunner:
), "Model does not have embedding_padding_modules"
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens, self.vocab_size,
self.lora_config, self.device, self.model.embedding_modules,
self.model.embedding_padding_modules)
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=self.model.config.
max_position_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
if self.kv_cache_dtype == "fp8" and is_hip():