[V1] LoRA Support (#10957)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-02-06 23:02:51 +05:30 committed by GitHub
parent 8108ac841d
commit 467a96a541
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 453 additions and 56 deletions

View File

@ -306,3 +306,20 @@ def llama_2_7b_engine_extra_embeddings():
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)
@pytest.fixture(params=[True, False])
def run_with_both_engines_lora(request, monkeypatch):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")
if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
monkeypatch.setenv('VLLM_USE_V1', '1')
else:
monkeypatch.setenv('VLLM_USE_V1', '0')
yield

View File

@ -42,6 +42,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
def test_baichuan_lora(baichuan_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,

View File

@ -2,6 +2,8 @@
from typing import List
import pytest
import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest
@ -47,6 +49,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.mark.skip_v1
@fork_new_process_for_each_test
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
@ -66,6 +77,7 @@ def test_chatglm3_lora(chatglm3_lora_files):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4(chatglm3_lora_files):
@ -87,6 +99,7 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):

View File

@ -33,6 +33,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):

View File

@ -2,6 +2,7 @@
from typing import List
import pytest
import ray
import vllm
@ -73,6 +74,14 @@ def generate_and_test(llm, sql_lora_files):
print("removing lora")
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):
@ -85,6 +94,9 @@ def test_llama_lora(sql_lora_files):
generate_and_test(llm, sql_lora_files)
# Skipping for v1 as v1 doesn't have a good way to expose the num_gpu_blocks
# used by the engine yet.
@pytest.mark.skip_v1
@fork_new_process_for_each_test
def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and

View File

@ -30,6 +30,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
# Skipping for V1 for now as we are hitting,
# "Head size 80 is not supported by FlashAttention." error.
@pytest.mark.skip_v1
@pytest.mark.parametrize("lora_bias", [True])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):

View File

@ -2,6 +2,8 @@
from typing import List
import pytest
import vllm
from vllm.lora.request import LoRARequest
@ -48,6 +50,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
# Skipping for V1 for now as we are hitting,
# "Head size 80 is not supported by FlashAttention." error.
@pytest.mark.skip_v1
def test_phi2_lora(phi2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.

View File

@ -70,6 +70,14 @@ def do_sample(llm: vllm.LLM,
return generated_texts
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,

View File

@ -163,7 +163,7 @@ def test_generate_block_hash_extra_keys():
# Test with no overlap
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 6, 10, 0)
assert extra_keys == ()
assert extra_keys is None
assert next_mm_idx == 1
# Test with multiple extra keys

View File

@ -16,8 +16,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
tensor_model_parallel_all_reduce)
from vllm.distributed.utils import divide
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -1043,7 +1042,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
logits = lm_head.linear_method.apply(lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Gather logits for TP
logits = self.base_layer._gather_logits(logits)
if logits is None:
return None

View File

@ -51,7 +51,6 @@ class LogitsProcessor(nn.Module):
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
parallel_config = get_current_vllm_config().parallel_config
self.use_all_gather = current_platform.is_tpu() \
or envs.VLLM_USE_V1 \
@ -88,6 +87,20 @@ class LogitsProcessor(nn.Module):
return logits
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
"""gather/all-gather the logits tensor across model parallel group."""
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
return logits
def _get_logits(
self,
hidden_states: torch.Tensor,
@ -99,16 +112,9 @@ class LogitsProcessor(nn.Module):
hidden_states,
bias=embedding_bias)
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
# Gather logits for TP
logits = self._gather_logits(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]

View File

@ -170,14 +170,28 @@ class FreeKVCacheBlockQueue:
return ret
def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int,
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
indicate a mm input contained in the block and its starting offset in
the block tokens.
def need_extra_keys(request: Request) -> bool:
"""Check whether the blocks allocated to this request need extra hash keys.
Args:
request (Request): The request.
Returns:
bool: Whether blocks allocated to this request need extra hash keys.
"""
# Multimodal requests need to include the MM hash.
# LoRA requests need to include the LoRA ID.
return bool(request.mm_positions) or (request.lora_request is not None)
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
end_token_idx: int,
start_mm_idx: int) -> Tuple[List[Any], int]:
"""Generate extra keys related to MultiModal request for block hash
computation. For multi-modal inputs, the extra keys are
(mm_hash, start_offset) that indicate a mm input contained in the
block and its starting offset in the block tokens.
Args:
request: The request object.
@ -188,10 +202,11 @@ def generate_block_hash_extra_keys(
Returns:
A tuple of extra keys and the next multi-modal index.
"""
extra_keys: List[Any] = []
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if not mm_positions:
return None, start_mm_idx
return extra_keys, start_mm_idx
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
@ -204,14 +219,13 @@ def generate_block_hash_extra_keys(
# range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1]["offset"] + mm_positions[-1][
"length"] < start_token_idx:
return None, start_mm_idx
return extra_keys, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if start_mm_idx < 0:
assert -start_mm_idx <= len(mm_positions)
start_mm_idx = len(mm_positions) + start_mm_idx
extra_keys = []
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
@ -237,7 +251,50 @@ def generate_block_hash_extra_keys(
else:
# This block has not reached the current mm input.
break
return tuple(extra_keys), curr_mm_idx
return extra_keys, curr_mm_idx
def _gen_lora_extra_hash_keys(request: Request) -> List[int]:
"""Generate extra keys related to LoRA for block hash computation.
Args:
request: The request object.
Returns:
Return LoRA id of the request if it is a LoRA request. Return empty
list otherwise.
"""
if not request.lora_request:
return []
return [request.lora_request.lora_int_id]
def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int,
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_extra_keys: List[Any]
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
request, start_token_idx, end_token_idx, start_mm_idx)
lora_extra_keys: List[int] = _gen_lora_extra_hash_keys(request)
extra_keys: List[Any] = lora_extra_keys + mm_extra_keys
if not extra_keys:
return None, new_start_mm_idx
return tuple(extra_keys), new_start_mm_idx
def hash_block_tokens(
@ -249,9 +306,6 @@ def hash_block_tokens(
prefix caching. We use LRU cache for this function to avoid recomputing
hash values for the same block contents.
TODO: Support arbitrary metadata so that we could support more
features such as LoRA adapter.
Args:
parent_block_hash: The hash of the parent block. None
if this is the first block.
@ -291,14 +345,9 @@ def hash_request_tokens(block_size: int,
The list of computed hash values.
"""
token_ids = request.all_token_ids
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match.")
# TODO: Extend this to support other features such as LoRA.
need_extra_keys = bool(mm_positions)
extra_keys = None
req_need_extra_keys = need_extra_keys(request)
req_extra_keys = None
curr_mm_idx = 0
ret = []
@ -310,13 +359,13 @@ def hash_request_tokens(block_size: int,
if len(block_token_ids) < block_size:
break
# Add extra keys if the block is a multi-modal block.
if need_extra_keys:
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
if req_need_extra_keys:
# MM and LoRA requests need extra keys for block-hash computation.
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(parent_block_hash_value,
block_token_ids, extra_keys)
block_token_ids, req_extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret

View File

@ -7,6 +7,7 @@ from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
@ -35,8 +36,6 @@ class Scheduler:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
# TODO: Support LoRA.
assert lora_config is None, "V1 does not support LoRA yet."
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
@ -180,6 +179,14 @@ class Scheduler:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Record the LoRAs in scheduled_running_reqs
requested_loras: Set[int] = set()
if self.lora_config:
requested_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(requested_loras) <= self.lora_config.max_loras
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting and token_budget > 0:
@ -187,6 +194,23 @@ class Scheduler:
break
request = self.waiting[0]
# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request:
req_lora_id = request.lora_request.lora_int_id
if len(requested_loras) == self.lora_config.max_loras and (
req_lora_id not in requested_loras):
# Cannot schedule.
# TODO (varun): This means all the other requests in
# the WAITING queue will be blocked by this request,
# even if,
# 1. these other requests do not use LoRA, or,
# 2. these other requests use the already requested
# LoRAs.
# This is too conservative and could be optimized.
break
# Get already-cached tokens.
computed_blocks, num_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(request)
@ -234,6 +258,8 @@ class Scheduler:
raise RuntimeError(
f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request:
requested_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in computed_blocks + new_blocks
]
@ -568,6 +594,7 @@ class NewRequestData:
sampling_params: SamplingParams
block_ids: List[int]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
@classmethod
def from_request(
@ -586,6 +613,7 @@ class NewRequestData:
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
lora_request=request.lora_request,
)

View File

@ -3,11 +3,12 @@
# Datastructures defining an input batch
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.metadata import SamplingMetadata
@ -35,6 +36,8 @@ class CachedRequestState:
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)
@ -161,6 +164,12 @@ class InputBatch:
]
self.prompt_token_ids: Optional[torch.Tensor] = None
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
self.lora_id_to_request_ids: Dict[int, Set[str]] = {}
self.lora_id_to_lora_request: Dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
@ -235,6 +244,19 @@ class InputBatch:
if sampling_params.prompt_logprobs:
self.prompt_logprob_reqs.add(req_id)
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0
def remove_request(self, req_id: str) -> Optional[int]:
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
@ -251,6 +273,16 @@ class InputBatch:
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id)
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
self.lora_id_to_request_ids[lora_id].discard(req_id)
if len(self.lora_id_to_request_ids[lora_id]) == 0:
self.lora_id_to_request_ids.pop(lora_id)
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0
return req_index
def clear(self) -> None:
@ -266,6 +298,9 @@ class InputBatch:
self.generators.clear()
self.num_logprobs.clear()
self.prompt_logprob_reqs.clear()
self.request_lora_mapping.fill(0)
self.lora_id_to_lora_request.clear()
self.lora_id_to_request_ids.clear()
def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0:
@ -318,6 +353,9 @@ class InputBatch:
if generator is not None:
self.generators[empty_index] = generator
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
# Decrement last_req_index since it is now empty.
last_req_index -= 1
@ -401,6 +439,29 @@ class InputBatch:
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(
req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: Set[LoRARequest] = set(
self.lora_id_to_lora_request.values())
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)

View File

@ -33,6 +33,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
@ -40,7 +41,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
class GPUModelRunner:
class GPUModelRunner(LoRAModelRunnerMixin):
def __init__(
self,
@ -279,6 +280,7 @@ class GPUModelRunner:
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
lora_request=new_req_data.lora_request,
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@ -372,15 +374,16 @@ class GPUModelRunner:
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = []
num_scheduled_tokens_list: List[int] = []
max_num_scheduled_tokens = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens)
num_scheduled_tokens_list.append(num_tokens)
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
dtype=np.int32)
assert max_num_scheduled_tokens > 0
# Get request indices.
@ -565,6 +568,11 @@ class GPUModelRunner:
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
)
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
# requests. While we should not sample any token from these partial
# requests, we do so for simplicity. We will ignore the sampled
@ -867,6 +875,12 @@ class GPUModelRunner:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
self.model = self.load_lora_model(self.model,
self.model_config,
self.scheduler_config,
self.lora_config,
self.device)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
@ -1005,14 +1019,32 @@ class GPUModelRunner:
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler.
torch.cuda.synchronize()
del hidden_states, logits
self.encoder_cache.clear()
# For profile, have maximum num_reqs and that collectively have
# maximum num_tokens.
num_reqs = self.scheduler_config.max_num_seqs
num_tokens = self.max_num_tokens
min_tokens_per_req: int = num_tokens // num_reqs
num_scheduled_tokens_list: List[int] = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
dtype=np.int32)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
with self.maybe_profile_with_lora(self.lora_config,
num_scheduled_tokens):
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens,
dummy_kv_caches)
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
# TODO(woosuk): Consider the memory usage of the sampler.
torch.cuda.synchronize()
del hidden_states, logits
self.encoder_cache.clear()
gc.collect()
def capture_model(self) -> None:

View File

@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
"""
Define LoRA functionality mixin for model runners.
"""
from contextlib import contextmanager
from typing import Set, Tuple
import numpy as np
import torch.nn as nn
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.v1.worker.gpu_input_batch import InputBatch
logger = init_logger(__name__)
# Defined as a mixin for GPUModelRunner
class LoRAModelRunnerMixin:
LORA_WARMUP_RANK = 8
def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
scheduler_config: SchedulerConfig,
lora_config: LoRAConfig, device: str) -> nn.Module:
assert supports_lora(
model), f"{model.__class__.__name__} does not support LoRA yet."
if supports_multimodal(model):
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# It's necessary to distinguish between the max_position_embeddings
# of VLMs and LLMs.
if hasattr(model.config, "max_position_embeddings"):
max_pos_embeddings = model.config.max_position_embeddings
else:
max_pos_embeddings = (
model.config.text_config.max_position_embeddings)
# Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager(
scheduler_config.max_num_seqs,
scheduler_config.max_num_batched_tokens,
model_config.get_vocab_size(),
lora_config,
device,
model.embedding_modules,
model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
)
return self.lora_manager.create_lora_manager(model)
def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...],
token_lora_mapping: Tuple[int, ...],
lora_requests: Set[LoRARequest]) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
# We dont make any distinction between prefills and decodes in the
# scheduler. To that effect, set is_prefill to True so we use the
# sgmv punica kernels always.
lora_mapping = LoRAMapping(token_lora_mapping,
prompt_lora_mapping,
is_prefill=True)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def set_active_loras(self, input_batch: InputBatch,
num_scheduled_tokens: np.ndarray) -> None:
prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs
token_lora_mapping: Tuple[int,
...] # of size np.sum(num_scheduled_tokens)
lora_requests: Set[LoRARequest]
prompt_lora_mapping, token_lora_mapping, lora_requests = \
input_batch.make_lora_inputs(num_scheduled_tokens)
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
lora_requests)
@contextmanager
def maybe_profile_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.ndarray):
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_reqs = len(num_scheduled_tokens)
num_loras = lora_config.max_loras
# Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) %
num_loras) + 1
# Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping,
num_scheduled_tokens)
# Make dummy lora requests
lora_requests: Set[LoRARequest] = {
LoRARequest(lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path")
for lora_id in range(1, num_loras + 1)
}
with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
# load from disk.
for lr in lora_requests:
self.lora_manager.add_dummy_lora(
lr, rank=self.LORA_WARMUP_RANK)
self._set_active_loras(tuple(prompt_lora_mapping),
tuple(token_lora_mapping),
lora_requests)
yield
# __exit__ code
self.lora_manager.remove_all_adapters()