diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 0acdaeac..7026f705 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -84,12 +84,14 @@ def v1(run_with_both_engines_lora): @create_new_process_for_each_test() def test_llama_lora(sql_lora_files): - llm = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=1, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + # also test odd max_num_seqs + max_num_seqs=13, + max_loras=4, + tensor_parallel_size=1, + enable_chunked_prefill=True) generate_and_test(llm, sql_lora_files) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index eb6f5b1b..be9cbe24 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final import torch +import vllm.envs as envs from vllm.lora.layers import LoRAMapping from vllm.triton_utils import HAS_TRITON @@ -42,8 +43,15 @@ class PunicaWrapperGPU(PunicaWrapperBase): self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras, max_num_batched_tokens, device=device) + + # When cudagraph capture size is greater than max_num_seqs (max_batches, + # here), V0 captures the graph as if max_num_seqs is set to + # the capture size. + # V1 doesn't have this problem and always respects max_num_seqs. + max_num_prompts = (max_batches + if envs.VLLM_USE_V1 else max_num_batched_tokens) self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras, - max_batches, + max_num_prompts, device=device) def update_metadata(