[V1] LoRA Support (#10957)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
8108ac841d
commit
467a96a541
@ -306,3 +306,20 @@ def llama_2_7b_engine_extra_embeddings():
|
||||
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
|
||||
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def run_with_both_engines_lora(request, monkeypatch):
|
||||
# Automatically runs tests twice, once with V1 and once without
|
||||
use_v1 = request.param
|
||||
# Tests decorated with `@skip_v1` are only run without v1
|
||||
skip_v1 = request.node.get_closest_marker("skip_v1")
|
||||
|
||||
if use_v1:
|
||||
if skip_v1:
|
||||
pytest.skip("Skipping test on vllm V1")
|
||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||
else:
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
yield
|
||||
|
@ -42,6 +42,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines_lora):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
def test_baichuan_lora(baichuan_lora_files):
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from tests.utils import fork_new_process_for_each_test
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -47,6 +49,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines_lora):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip_v1
|
||||
@fork_new_process_for_each_test
|
||||
def test_chatglm3_lora(chatglm3_lora_files):
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
@ -66,6 +77,7 @@ def test_chatglm3_lora(chatglm3_lora_files):
|
||||
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
|
||||
|
||||
|
||||
@pytest.mark.skip_v1
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
@fork_new_process_for_each_test
|
||||
def test_chatglm3_lora_tp4(chatglm3_lora_files):
|
||||
@ -87,6 +99,7 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
|
||||
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
|
||||
|
||||
|
||||
@pytest.mark.skip_v1
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
@fork_new_process_for_each_test
|
||||
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
|
||||
|
@ -33,6 +33,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines_lora):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.xfail(current_platform.is_rocm(),
|
||||
reason="There can be output mismatch on ROCm")
|
||||
def test_gemma_lora(gemma_lora_files):
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
|
||||
import vllm
|
||||
@ -73,6 +74,14 @@ def generate_and_test(llm, sql_lora_files):
|
||||
print("removing lora")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines_lora):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_llama_lora(sql_lora_files):
|
||||
|
||||
@ -85,6 +94,9 @@ def test_llama_lora(sql_lora_files):
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
|
||||
# Skipping for v1 as v1 doesn't have a good way to expose the num_gpu_blocks
|
||||
# used by the engine yet.
|
||||
@pytest.mark.skip_v1
|
||||
@fork_new_process_for_each_test
|
||||
def test_llama_lora_warmup(sql_lora_files):
|
||||
"""Test that the LLM initialization works with a warmup LORA path and
|
||||
|
@ -30,6 +30,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines_lora):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
# Skipping for V1 for now as we are hitting,
|
||||
# "Head size 80 is not supported by FlashAttention." error.
|
||||
@pytest.mark.skip_v1
|
||||
@pytest.mark.parametrize("lora_bias", [True])
|
||||
@pytest.mark.parametrize("fully_sharded", [True, False])
|
||||
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
@ -48,6 +50,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines_lora):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
# Skipping for V1 for now as we are hitting,
|
||||
# "Head size 80 is not supported by FlashAttention." error.
|
||||
@pytest.mark.skip_v1
|
||||
def test_phi2_lora(phi2_lora_files):
|
||||
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
|
||||
# Otherwise, the lora-test will fail due to CUDA OOM.
|
||||
|
@ -70,6 +70,14 @@ def do_sample(llm: vllm.LLM,
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines_lora):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
|
||||
|
@ -163,7 +163,7 @@ def test_generate_block_hash_extra_keys():
|
||||
|
||||
# Test with no overlap
|
||||
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 6, 10, 0)
|
||||
assert extra_keys == ()
|
||||
assert extra_keys is None
|
||||
assert next_mm_idx == 1
|
||||
|
||||
# Test with multiple extra keys
|
||||
|
@ -16,8 +16,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_gather)
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.utils import divide
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -1043,7 +1042,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
logits = lm_head.linear_method.apply(lm_head, hidden_states)
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
|
||||
# Gather logits for TP
|
||||
logits = self.base_layer._gather_logits(logits)
|
||||
|
||||
if logits is None:
|
||||
return None
|
||||
|
||||
|
@ -51,7 +51,6 @@ class LogitsProcessor(nn.Module):
|
||||
# Soft cap the logits. Used in Gemma 2.
|
||||
self.soft_cap = soft_cap
|
||||
# Whether to use gather or all-gather to gather the logits.
|
||||
|
||||
parallel_config = get_current_vllm_config().parallel_config
|
||||
self.use_all_gather = current_platform.is_tpu() \
|
||||
or envs.VLLM_USE_V1 \
|
||||
@ -88,6 +87,20 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
return logits
|
||||
|
||||
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""gather/all-gather the logits tensor across model parallel group."""
|
||||
if self.use_all_gather:
|
||||
# Gather is not supported for some devices such as TPUs.
|
||||
# Use all-gather instead.
|
||||
# NOTE(woosuk): Here, the outputs of every device should not be None
|
||||
# because XLA requires strict SPMD among all devices. Every device
|
||||
# should execute the same operations after gathering the logits.
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
else:
|
||||
# None may be returned for rank > 0
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
return logits
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -99,16 +112,9 @@ class LogitsProcessor(nn.Module):
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
|
||||
if self.use_all_gather:
|
||||
# Gather is not supported for some devices such as TPUs.
|
||||
# Use all-gather instead.
|
||||
# NOTE(woosuk): Here, the outputs of every device should not be None
|
||||
# because XLA requires strict SPMD among all devices. Every device
|
||||
# should execute the same operations after gathering the logits.
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
else:
|
||||
# None may be returned for rank > 0
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
# Gather logits for TP
|
||||
logits = self._gather_logits(logits)
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[..., :self.org_vocab_size]
|
||||
|
@ -170,14 +170,28 @@ class FreeKVCacheBlockQueue:
|
||||
return ret
|
||||
|
||||
|
||||
def generate_block_hash_extra_keys(
|
||||
request: Request, start_token_idx: int, end_token_idx: int,
|
||||
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
|
||||
"""Generate extra keys for the block hash. The extra keys can come from
|
||||
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
|
||||
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
|
||||
indicate a mm input contained in the block and its starting offset in
|
||||
the block tokens.
|
||||
def need_extra_keys(request: Request) -> bool:
|
||||
"""Check whether the blocks allocated to this request need extra hash keys.
|
||||
|
||||
Args:
|
||||
request (Request): The request.
|
||||
|
||||
Returns:
|
||||
bool: Whether blocks allocated to this request need extra hash keys.
|
||||
"""
|
||||
|
||||
# Multimodal requests need to include the MM hash.
|
||||
# LoRA requests need to include the LoRA ID.
|
||||
return bool(request.mm_positions) or (request.lora_request is not None)
|
||||
|
||||
|
||||
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
|
||||
end_token_idx: int,
|
||||
start_mm_idx: int) -> Tuple[List[Any], int]:
|
||||
"""Generate extra keys related to MultiModal request for block hash
|
||||
computation. For multi-modal inputs, the extra keys are
|
||||
(mm_hash, start_offset) that indicate a mm input contained in the
|
||||
block and its starting offset in the block tokens.
|
||||
|
||||
Args:
|
||||
request: The request object.
|
||||
@ -188,10 +202,11 @@ def generate_block_hash_extra_keys(
|
||||
Returns:
|
||||
A tuple of extra keys and the next multi-modal index.
|
||||
"""
|
||||
extra_keys: List[Any] = []
|
||||
|
||||
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
|
||||
if not mm_positions:
|
||||
return None, start_mm_idx
|
||||
return extra_keys, start_mm_idx
|
||||
|
||||
if mm_positions and len(mm_positions) != len(mm_hashes):
|
||||
raise ValueError(
|
||||
@ -204,14 +219,13 @@ def generate_block_hash_extra_keys(
|
||||
# range. This usually happens in the late prefill phase and decoding phase.
|
||||
if mm_positions[-1]["offset"] + mm_positions[-1][
|
||||
"length"] < start_token_idx:
|
||||
return None, start_mm_idx
|
||||
return extra_keys, start_mm_idx
|
||||
|
||||
# Support start_mm_idx == -1 to indicate the last mm input.
|
||||
if start_mm_idx < 0:
|
||||
assert -start_mm_idx <= len(mm_positions)
|
||||
start_mm_idx = len(mm_positions) + start_mm_idx
|
||||
|
||||
extra_keys = []
|
||||
curr_mm_idx = start_mm_idx
|
||||
while mm_positions and curr_mm_idx < len(mm_positions):
|
||||
assert mm_hashes[curr_mm_idx] is not None
|
||||
@ -237,7 +251,50 @@ def generate_block_hash_extra_keys(
|
||||
else:
|
||||
# This block has not reached the current mm input.
|
||||
break
|
||||
return tuple(extra_keys), curr_mm_idx
|
||||
return extra_keys, curr_mm_idx
|
||||
|
||||
|
||||
def _gen_lora_extra_hash_keys(request: Request) -> List[int]:
|
||||
"""Generate extra keys related to LoRA for block hash computation.
|
||||
|
||||
Args:
|
||||
request: The request object.
|
||||
|
||||
Returns:
|
||||
Return LoRA id of the request if it is a LoRA request. Return empty
|
||||
list otherwise.
|
||||
"""
|
||||
if not request.lora_request:
|
||||
return []
|
||||
return [request.lora_request.lora_int_id]
|
||||
|
||||
|
||||
def generate_block_hash_extra_keys(
|
||||
request: Request, start_token_idx: int, end_token_idx: int,
|
||||
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
|
||||
"""Generate extra keys for the block hash. The extra keys can come from
|
||||
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
|
||||
|
||||
Args:
|
||||
request: The request object.
|
||||
start_token_idx: The start token index of the block.
|
||||
end_token_idx: The end token index of the block.
|
||||
start_mm_idx: The start multi-modal index of the block.
|
||||
|
||||
Returns:
|
||||
A tuple of extra keys and the next multi-modal index.
|
||||
"""
|
||||
mm_extra_keys: List[Any]
|
||||
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
|
||||
request, start_token_idx, end_token_idx, start_mm_idx)
|
||||
lora_extra_keys: List[int] = _gen_lora_extra_hash_keys(request)
|
||||
|
||||
extra_keys: List[Any] = lora_extra_keys + mm_extra_keys
|
||||
|
||||
if not extra_keys:
|
||||
return None, new_start_mm_idx
|
||||
|
||||
return tuple(extra_keys), new_start_mm_idx
|
||||
|
||||
|
||||
def hash_block_tokens(
|
||||
@ -249,9 +306,6 @@ def hash_block_tokens(
|
||||
prefix caching. We use LRU cache for this function to avoid recomputing
|
||||
hash values for the same block contents.
|
||||
|
||||
TODO: Support arbitrary metadata so that we could support more
|
||||
features such as LoRA adapter.
|
||||
|
||||
Args:
|
||||
parent_block_hash: The hash of the parent block. None
|
||||
if this is the first block.
|
||||
@ -291,14 +345,9 @@ def hash_request_tokens(block_size: int,
|
||||
The list of computed hash values.
|
||||
"""
|
||||
token_ids = request.all_token_ids
|
||||
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
|
||||
if mm_positions and len(mm_positions) != len(mm_hashes):
|
||||
raise ValueError(
|
||||
"The number of multi-modal positions and hashes must match.")
|
||||
|
||||
# TODO: Extend this to support other features such as LoRA.
|
||||
need_extra_keys = bool(mm_positions)
|
||||
extra_keys = None
|
||||
req_need_extra_keys = need_extra_keys(request)
|
||||
req_extra_keys = None
|
||||
curr_mm_idx = 0
|
||||
|
||||
ret = []
|
||||
@ -310,13 +359,13 @@ def hash_request_tokens(block_size: int,
|
||||
if len(block_token_ids) < block_size:
|
||||
break
|
||||
|
||||
# Add extra keys if the block is a multi-modal block.
|
||||
if need_extra_keys:
|
||||
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
|
||||
if req_need_extra_keys:
|
||||
# MM and LoRA requests need extra keys for block-hash computation.
|
||||
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
|
||||
request, start, end, curr_mm_idx)
|
||||
|
||||
block_hash = hash_block_tokens(parent_block_hash_value,
|
||||
block_token_ids, extra_keys)
|
||||
block_token_ids, req_extra_keys)
|
||||
ret.append(block_hash)
|
||||
parent_block_hash_value = block_hash.hash_value
|
||||
return ret
|
||||
|
@ -7,6 +7,7 @@ from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||
compute_encoder_budget)
|
||||
@ -35,8 +36,6 @@ class Scheduler:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
# TODO: Support LoRA.
|
||||
assert lora_config is None, "V1 does not support LoRA yet."
|
||||
|
||||
# Scheduling constraints.
|
||||
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
||||
@ -180,6 +179,14 @@ class Scheduler:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_budget = new_encoder_budget
|
||||
|
||||
# Record the LoRAs in scheduled_running_reqs
|
||||
requested_loras: Set[int] = set()
|
||||
if self.lora_config:
|
||||
requested_loras = set(
|
||||
req.lora_request.lora_int_id for req in scheduled_running_reqs
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0)
|
||||
assert len(requested_loras) <= self.lora_config.max_loras
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
while self.waiting and token_budget > 0:
|
||||
@ -187,6 +194,23 @@ class Scheduler:
|
||||
break
|
||||
|
||||
request = self.waiting[0]
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if self.lora_config and request.lora_request:
|
||||
req_lora_id = request.lora_request.lora_int_id
|
||||
if len(requested_loras) == self.lora_config.max_loras and (
|
||||
req_lora_id not in requested_loras):
|
||||
# Cannot schedule.
|
||||
# TODO (varun): This means all the other requests in
|
||||
# the WAITING queue will be blocked by this request,
|
||||
# even if,
|
||||
# 1. these other requests do not use LoRA, or,
|
||||
# 2. these other requests use the already requested
|
||||
# LoRAs.
|
||||
# This is too conservative and could be optimized.
|
||||
break
|
||||
|
||||
# Get already-cached tokens.
|
||||
computed_blocks, num_computed_tokens = \
|
||||
self.kv_cache_manager.get_computed_blocks(request)
|
||||
@ -234,6 +258,8 @@ class Scheduler:
|
||||
raise RuntimeError(
|
||||
f"Invalid request status: {request.status}")
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
requested_loras.add(request.lora_request.lora_int_id)
|
||||
req_to_new_block_ids[request.request_id] = [
|
||||
b.block_id for b in computed_blocks + new_blocks
|
||||
]
|
||||
@ -568,6 +594,7 @@ class NewRequestData:
|
||||
sampling_params: SamplingParams
|
||||
block_ids: List[int]
|
||||
num_computed_tokens: int
|
||||
lora_request: Optional[LoRARequest]
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
@ -586,6 +613,7 @@ class NewRequestData:
|
||||
sampling_params=request.sampling_params,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
lora_request=request.lora_request,
|
||||
)
|
||||
|
||||
|
||||
|
@ -3,11 +3,12 @@
|
||||
# Datastructures defining an input batch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@ -35,6 +36,8 @@ class CachedRequestState:
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return len(self.prompt_token_ids) + len(self.output_token_ids)
|
||||
@ -161,6 +164,12 @@ class InputBatch:
|
||||
]
|
||||
self.prompt_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||
dtype=np.int32)
|
||||
self.lora_id_to_request_ids: Dict[int, Set[str]] = {}
|
||||
self.lora_id_to_lora_request: Dict[int, LoRARequest] = {}
|
||||
|
||||
# req_index -> generator
|
||||
# NOTE(woosuk): The indices of the requests that do not have their own
|
||||
# generator should not be included in the dictionary.
|
||||
@ -235,6 +244,19 @@ class InputBatch:
|
||||
if sampling_params.prompt_logprobs:
|
||||
self.prompt_logprob_reqs.add(req_id)
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
if lora_id not in self.lora_id_to_request_ids:
|
||||
self.lora_id_to_request_ids[lora_id] = set()
|
||||
|
||||
self.request_lora_mapping[req_index] = lora_id
|
||||
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
||||
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
||||
else:
|
||||
# No LoRA
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
def remove_request(self, req_id: str) -> Optional[int]:
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
@ -251,6 +273,16 @@ class InputBatch:
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.prompt_logprob_reqs.discard(req_id)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
self.lora_id_to_request_ids[lora_id].discard(req_id)
|
||||
if len(self.lora_id_to_request_ids[lora_id]) == 0:
|
||||
self.lora_id_to_request_ids.pop(lora_id)
|
||||
self.lora_id_to_lora_request.pop(lora_id)
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
return req_index
|
||||
|
||||
def clear(self) -> None:
|
||||
@ -266,6 +298,9 @@ class InputBatch:
|
||||
self.generators.clear()
|
||||
self.num_logprobs.clear()
|
||||
self.prompt_logprob_reqs.clear()
|
||||
self.request_lora_mapping.fill(0)
|
||||
self.lora_id_to_lora_request.clear()
|
||||
self.lora_id_to_request_ids.clear()
|
||||
|
||||
def condense(self, empty_req_indices: List[int]) -> None:
|
||||
if self.num_reqs == 0:
|
||||
@ -318,6 +353,9 @@ class InputBatch:
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
@ -401,6 +439,29 @@ class InputBatch:
|
||||
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
||||
non_blocking=True)
|
||||
|
||||
def make_lora_inputs(
|
||||
self, num_scheduled_tokens: np.ndarray
|
||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]:
|
||||
"""
|
||||
Given the num_scheduled_tokens for each request in the batch, return
|
||||
datastructures used to activate the current LoRAs.
|
||||
Returns:
|
||||
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
||||
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
||||
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
||||
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
||||
3. lora_requests: Set of relevant LoRA requests.
|
||||
"""
|
||||
|
||||
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
||||
prompt_lora_mapping = tuple(req_lora_mapping)
|
||||
token_lora_mapping = tuple(
|
||||
req_lora_mapping.repeat(num_scheduled_tokens))
|
||||
active_lora_requests: Set[LoRARequest] = set(
|
||||
self.lora_id_to_lora_request.values())
|
||||
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
@ -33,6 +33,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler import SchedulerOutput
|
||||
@ -40,7 +41,7 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUModelRunner:
|
||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -279,6 +280,7 @@ class GPUModelRunner:
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
@ -372,15 +374,16 @@ class GPUModelRunner:
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
# TODO: The Python loop can be slow. Optimize.
|
||||
num_scheduled_tokens = []
|
||||
num_scheduled_tokens_list: List[int] = []
|
||||
max_num_scheduled_tokens = 0
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
assert req_id is not None
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_scheduled_tokens.append(num_tokens)
|
||||
num_scheduled_tokens_list.append(num_tokens)
|
||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||
num_tokens)
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
|
||||
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
assert max_num_scheduled_tokens > 0
|
||||
|
||||
# Get request indices.
|
||||
@ -565,6 +568,11 @@ class GPUModelRunner:
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
)
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
|
||||
# requests. While we should not sample any token from these partial
|
||||
# requests, we do so for simplicity. We will ignore the sampled
|
||||
@ -867,6 +875,12 @@ class GPUModelRunner:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model,
|
||||
self.model_config,
|
||||
self.scheduler_config,
|
||||
self.lora_config,
|
||||
self.device)
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
@ -1005,14 +1019,32 @@ class GPUModelRunner:
|
||||
# Cache the dummy encoder outputs.
|
||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
logits = logits[:self.max_num_tokens]
|
||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, logits
|
||||
self.encoder_cache.clear()
|
||||
# For profile, have maximum num_reqs and that collectively have
|
||||
# maximum num_tokens.
|
||||
num_reqs = self.scheduler_config.max_num_seqs
|
||||
num_tokens = self.max_num_tokens
|
||||
min_tokens_per_req: int = num_tokens // num_reqs
|
||||
|
||||
num_scheduled_tokens_list: List[int] = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
|
||||
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||
|
||||
with self.maybe_profile_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.max_num_tokens,
|
||||
dummy_kv_caches)
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, logits
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
def capture_model(self) -> None:
|
||||
|
129
vllm/v1/worker/lora_model_runner_mixin.py
Normal file
129
vllm/v1/worker/lora_model_runner_mixin.py
Normal file
@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Define LoRA functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a mixin for GPUModelRunner
|
||||
class LoRAModelRunnerMixin:
|
||||
|
||||
LORA_WARMUP_RANK = 8
|
||||
|
||||
def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: LoRAConfig, device: str) -> nn.Module:
|
||||
|
||||
assert supports_lora(
|
||||
model), f"{model.__class__.__name__} does not support LoRA yet."
|
||||
|
||||
if supports_multimodal(model):
|
||||
logger.warning("Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model.")
|
||||
|
||||
# It's necessary to distinguish between the max_position_embeddings
|
||||
# of VLMs and LLMs.
|
||||
if hasattr(model.config, "max_position_embeddings"):
|
||||
max_pos_embeddings = model.config.max_position_embeddings
|
||||
else:
|
||||
max_pos_embeddings = (
|
||||
model.config.text_config.max_position_embeddings)
|
||||
|
||||
# Add LoRA Manager to the Model Runner
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
scheduler_config.max_num_seqs,
|
||||
scheduler_config.max_num_batched_tokens,
|
||||
model_config.get_vocab_size(),
|
||||
lora_config,
|
||||
device,
|
||||
model.embedding_modules,
|
||||
model.embedding_padding_modules,
|
||||
max_position_embeddings=max_pos_embeddings,
|
||||
)
|
||||
return self.lora_manager.create_lora_manager(model)
|
||||
|
||||
def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...],
|
||||
token_lora_mapping: Tuple[int, ...],
|
||||
lora_requests: Set[LoRARequest]) -> None:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
|
||||
# We dont make any distinction between prefills and decodes in the
|
||||
# scheduler. To that effect, set is_prefill to True so we use the
|
||||
# sgmv punica kernels always.
|
||||
lora_mapping = LoRAMapping(token_lora_mapping,
|
||||
prompt_lora_mapping,
|
||||
is_prefill=True)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
def set_active_loras(self, input_batch: InputBatch,
|
||||
num_scheduled_tokens: np.ndarray) -> None:
|
||||
|
||||
prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs
|
||||
token_lora_mapping: Tuple[int,
|
||||
...] # of size np.sum(num_scheduled_tokens)
|
||||
lora_requests: Set[LoRARequest]
|
||||
prompt_lora_mapping, token_lora_mapping, lora_requests = \
|
||||
input_batch.make_lora_inputs(num_scheduled_tokens)
|
||||
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
||||
lora_requests)
|
||||
|
||||
@contextmanager
|
||||
def maybe_profile_with_lora(self, lora_config: LoRAConfig,
|
||||
num_scheduled_tokens: np.ndarray):
|
||||
if lora_config is None:
|
||||
yield
|
||||
else:
|
||||
# __enter__ code
|
||||
assert self.lora_manager is not None, "LoRA is not enabled"
|
||||
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
num_loras = lora_config.max_loras
|
||||
|
||||
# Make prompt lora mapping
|
||||
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
||||
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) %
|
||||
num_loras) + 1
|
||||
|
||||
# Make token lora mapping
|
||||
token_lora_mapping = np.repeat(prompt_lora_mapping,
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Make dummy lora requests
|
||||
lora_requests: Set[LoRARequest] = {
|
||||
LoRARequest(lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path")
|
||||
for lora_id in range(1, num_loras + 1)
|
||||
}
|
||||
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
||||
# load from disk.
|
||||
for lr in lora_requests:
|
||||
self.lora_manager.add_dummy_lora(
|
||||
lr, rank=self.LORA_WARMUP_RANK)
|
||||
|
||||
self._set_active_loras(tuple(prompt_lora_mapping),
|
||||
tuple(token_lora_mapping),
|
||||
lora_requests)
|
||||
|
||||
yield
|
||||
|
||||
# __exit__ code
|
||||
self.lora_manager.remove_all_adapters()
|
Loading…
x
Reference in New Issue
Block a user