[CORE] Adding support for insertion of soft-tuned prompts (#4645)

Co-authored-by: Swapnil Parekh <swapnilp@ibm.com>
Co-authored-by: Joe G <joseph.granados@h2o.ai>
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Swapnil Parekh 2024-07-09 16:26:36 -04:00 committed by GitHub
parent a0550cbc80
commit 4d6ada947c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
48 changed files with 1952 additions and 519 deletions

View File

@ -111,6 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/prompt_adapter --config-file pyproject.toml
mypy tests --config-file pyproject.toml

View File

@ -92,11 +92,10 @@ def batched_generate(
for input in inputs:
prompt, sampling_param, lora_req = input
# Add requests to the engine and run the engine
llm._validate_and_add_requests(
prompt,
llm._validate_and_add_requests(prompt,
sampling_param,
lora_request=lora_req,
)
prompt_adapter_request=None)
outputs = llm._run_engine(use_tqdm=True)
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]

View File

@ -127,37 +127,37 @@ def test_lora_model_manager(dist_init, dummy_model):
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_lora(model_lora1)
assert manager.activate_lora(1)
assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1)
assert manager.lora_index_to_id[0] == 1
assert not manager.add_lora(model_lora1)
assert not manager.activate_lora(1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)
assert not manager.add_adapter(model_lora1)
assert not manager.activate_adapter(1)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert not manager.add_lora(model_lora2)
assert not manager.activate_lora(2)
assert manager.add_lora(model_lora3)
assert not manager.add_adapter(model_lora2)
assert not manager.activate_adapter(2)
assert manager.add_adapter(model_lora3)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
with pytest.raises(ValueError):
assert manager.activate_lora(3)
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert manager.remove_lora(model_lora2.id)
assert manager.remove_adapter(model_lora2.id)
assert manager.lora_index_to_id[1] is None
assert not manager.remove_lora(model_lora2.id)
assert manager.remove_lora(model_lora1.id)
assert not manager.remove_lora(model_lora1.id)
assert manager.add_lora(model_lora1)
assert not manager.remove_adapter(model_lora2.id)
assert manager.remove_adapter(model_lora1.id)
assert not manager.remove_adapter(model_lora1.id)
assert manager.add_adapter(model_lora1)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] is None
assert manager.add_lora(model_lora2)
assert manager.activate_lora(3)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] is None
assert manager.activate_lora(2)
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
@ -173,70 +173,70 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
model, 2, 2, 2,
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_lora(model_lora1)
assert manager.activate_lora(1)
assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1)
assert manager.lora_index_to_id[0] == 1
assert not manager.add_lora(model_lora1)
assert not manager.activate_lora(1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)
assert not manager.add_adapter(model_lora1)
assert not manager.activate_adapter(1)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert not manager.add_lora(model_lora2)
assert not manager.activate_lora(2)
assert manager.add_lora(model_lora3)
assert not manager.add_adapter(model_lora2)
assert not manager.activate_adapter(2)
assert manager.add_adapter(model_lora3)
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert manager.activate_lora(3)
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
assert manager.remove_lora(model_lora2.id)
assert manager.remove_adapter(model_lora2.id)
assert manager.lora_index_to_id[1] is None
assert not manager.remove_lora(model_lora2.id)
assert manager.remove_lora(model_lora1.id)
assert not manager.remove_lora(model_lora1.id)
assert manager.add_lora(model_lora1)
assert manager.activate_lora(1)
assert not manager.remove_adapter(model_lora2.id)
assert manager.remove_adapter(model_lora1.id)
assert not manager.remove_adapter(model_lora1.id)
assert manager.add_adapter(model_lora1)
assert manager.activate_adapter(1)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
assert manager.add_lora(model_lora2)
assert manager.deactivate_lora(3)
assert manager.add_adapter(model_lora2)
assert manager.deactivate_adapter(3)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 1
assert manager.activate_lora(2)
assert manager.activate_adapter(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.activate_lora(3)
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.pin_lora(2)
assert manager.pin_adapter(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.activate_lora(1)
assert manager.activate_adapter(1)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.deactivate_lora(2)
assert manager.deactivate_adapter(2)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 1
assert manager.activate_lora(3)
assert manager.activate_adapter(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
assert manager.pin_lora(3)
assert manager.pin_lora(1)
assert manager.pin_adapter(3)
assert manager.pin_adapter(1)
with pytest.raises(RuntimeError):
assert manager.pin_lora(2)
assert manager.pin_adapter(2)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
with pytest.raises(RuntimeError):
assert manager.activate_lora(2)
assert manager.activate_adapter(2)
assert manager.deactivate_lora(3)
assert manager.pin_lora(2)
assert manager.deactivate_adapter(3)
assert manager.pin_adapter(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.remove_lora(3)
assert manager.remove_adapter(3)
with pytest.raises(ValueError):
assert manager.pin_lora(3)
assert manager.pin_adapter(3)
def test_lru_lora_model_manager(dist_init, dummy_model):
@ -256,168 +256,169 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
assert all(x is None for x in manager.lora_index_to_id)
# Add up to capacity
assert manager.add_lora(model_lora1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(1)
assert manager.activate_lora(2)
assert manager.add_adapter(model_lora1)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(1)
assert manager.activate_adapter(2)
assert set(manager.list_loras()) == {1, 2}
assert set(manager.list_adapters()) == {1, 2}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
# Add over capacity
assert manager.add_lora(model_lora3)
assert manager.add_lora(model_lora4)
assert manager.activate_lora(3)
assert manager.activate_lora(4)
assert manager.add_adapter(model_lora3)
assert manager.add_adapter(model_lora4)
assert manager.activate_adapter(3)
assert manager.activate_adapter(4)
assert set(manager.list_loras()) == {3, 4}
assert set(manager.list_adapters()) == {3, 4}
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 4
# Add 3 again to move it to the top and then add 2
# should return false since it's in already
assert not manager.add_lora(model_lora3)
assert not manager.activate_lora(3)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)
assert not manager.add_adapter(model_lora3)
assert not manager.activate_adapter(3)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(2)
assert set(manager.list_loras()) == {3, 2}
assert set(manager.list_adapters()) == {3, 2}
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 2
# Remove manually
assert manager.remove_lora(3)
assert not manager.remove_lora(3)
assert manager.remove_adapter(3)
assert not manager.remove_adapter(3)
assert set(manager.list_loras()) == {2}
assert set(manager.list_adapters()) == {2}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 2
assert manager.add_lora(model_lora3)
assert manager.activate_lora(3)
assert manager.add_lora(model_lora4)
assert manager.activate_lora(4)
assert manager.add_adapter(model_lora3)
assert manager.activate_adapter(3)
assert manager.add_adapter(model_lora4)
assert manager.activate_adapter(4)
assert set(manager.list_loras()) == {3, 4}
assert set(manager.list_adapters()) == {3, 4}
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 4
assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == {4}
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {4}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 4
assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == set()
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == set()
assert all(x is None for x in manager.lora_index_to_id)
assert not manager.remove_oldest_lora()
assert set(manager.list_loras()) == set()
assert not manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == set()
assert all(x is None for x in manager.lora_index_to_id)
# pinning
assert manager.add_lora(model_lora3)
assert manager.activate_lora(3)
assert manager.add_lora(model_lora4)
assert manager.activate_lora(4)
assert set(manager.list_loras()) == {3, 4}
assert manager.add_adapter(model_lora3)
assert manager.activate_adapter(3)
assert manager.add_adapter(model_lora4)
assert manager.activate_adapter(4)
assert set(manager.list_adapters()) == {3, 4}
with pytest.raises(ValueError):
assert manager.pin_lora(1)
assert manager.pin_lora(3)
assert manager.pin_adapter(1)
assert manager.pin_adapter(3)
# Remove manually
assert manager.remove_lora(3)
assert not manager.remove_lora(3)
assert manager.remove_adapter(3)
assert not manager.remove_adapter(3)
assert set(manager.list_loras()) == {4}
assert set(manager.list_adapters()) == {4}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 4
assert manager.add_lora(model_lora1)
assert manager.pin_lora(1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)
assert manager.add_adapter(model_lora1)
assert manager.pin_adapter(1)
assert manager.add_adapter(model_lora2)
assert manager.activate_adapter(2)
assert set(manager.list_loras()) == {1, 2}
assert set(manager.list_adapters()) == {1, 2}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2
assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == {1}
assert manager.remove_oldest_adapter()
assert set(manager.list_adapters()) == {1}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] is None
with pytest.raises(RuntimeError):
assert manager.remove_oldest_lora()
assert manager.remove_oldest_adapter()
assert set(manager.list_loras()) == {1}
assert set(manager.list_adapters()) == {1}
def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_lora_manager = LRUCacheWorkerLoRAManager(
worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
worker_adapter_manager.create_lora_manager(
llama_2_7b_model_extra_embeddings)
mapping = LoRAMapping([], [])
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager.list_adapters() == {1, 2}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("3", 3, sql_lora_files),
LoRARequest("4", 4, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2, 3, 4}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files),
LoRARequest("5", 5, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("6", 6, sql_lora_files),
LoRARequest("7", 7, sql_lora_files),
LoRARequest("8", 8, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 6, 7, 8}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6
assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
# Over capacity
with pytest.raises(RuntimeError):
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("10", 10, sql_lora_files),
LoRARequest("11", 11, sql_lora_files),
LoRARequest("12", 12, sql_lora_files),
@ -426,68 +427,69 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
], mapping)
def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
worker_lora_manager = WorkerLoRAManager(
worker_adapter_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
worker_adapter_manager.create_lora_manager(
llama_2_7b_model_extra_embeddings)
mapping = LoRAMapping([], [])
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager.list_adapters() == {1, 2}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("3", 3, sql_lora_files),
LoRARequest("4", 4, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 3, 4}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4
assert worker_adapter_manager.list_adapters() == {1, 3, 4}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files),
LoRARequest("5", 5, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1, 2, 5}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
assert worker_adapter_manager.list_adapters() == {1, 2, 5}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {1}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None
assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None
assert worker_adapter_manager.list_adapters() == {1}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("6", 6, sql_lora_files),
LoRARequest("7", 7, sql_lora_files),
LoRARequest("8", 8, sql_lora_files)
], mapping)
assert worker_lora_manager.list_loras() == {6, 7, 8}
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7
assert worker_adapter_manager.list_adapters() == {6, 7, 8}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
# Over capacity
with pytest.raises(RuntimeError):
worker_lora_manager.set_active_loras([
worker_adapter_manager.set_active_adapters([
LoRARequest("10", 10, sql_lora_files),
LoRARequest("11", 11, sql_lora_files),
LoRARequest("12", 12, sql_lora_files),
@ -525,8 +527,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
assert isinstance(model.get_submodule("gate_up_proj"),
MergedColumnParallelLinearWithLoRA)
assert manager.add_lora(model_lora)
assert manager.add_lora(model_lora1)
assert manager.add_adapter(model_lora)
assert manager.add_adapter(model_lora1)
packed_lora = model_lora.get_lora("gate_up_proj")
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

View File

@ -0,0 +1,45 @@
import pytest
import vllm
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
def do_sample(llm, pa_name: str, pa_id: int):
prompts = [
"Tweet text : @nationalgridus I have no water and the bill is \
current and paid. Can you do something about this? Label : ",
"Tweet text : @nationalgridus Looks good thanks! Label : "
]
sampling_params = vllm.SamplingParams(temperature=0.0,
max_tokens=3,
stop_token_ids=[3])
outputs = llm.generate(prompts,
sampling_params,
prompt_adapter_request=PromptAdapterRequest(
pa_name, pa_id, PA_PATH, 8) if pa_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_twitter_prompt_adapter(enforce_eager: bool):
llm = vllm.LLM(MODEL_PATH,
enforce_eager=enforce_eager,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
expected_output = ['complaint', 'no complaint']
assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output

View File

@ -0,0 +1,53 @@
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "bigscience/bloomz-560m"
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
pa_path2 = 'swapnilbp/angry_tweet_ptune'
def do_sample(engine):
prompts = [
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3), None),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("complain", 3, pa_path, 8)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_multi_prompt_adapters():
engine_args = EngineArgs(model=MODEL_PATH,
max_prompt_adapters=3,
enable_prompt_adapter=True,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
expected_output = {
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
}
assert do_sample(engine) == expected_output

View File

@ -0,0 +1,61 @@
from huggingface_hub import snapshot_download
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
def do_sample(engine):
prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
# first prompt with a prompt adapter and second without adapter
prompts = [
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]),
PromptAdapterRequest("hate_speech", 1, pa_path,
8), LoRARequest("sql_test", 1, lora_path)),
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]), None,
LoRARequest("sql_test", 1, lora_path)),
]
request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request,
lora_request=lora_request)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results
def test_lora_prompt_adapter():
engine_args = EngineArgs(model=MODEL_PATH,
enable_prompt_adapter=True,
enable_lora=True,
max_num_seqs=60,
max_prompt_adapter_token=8)
engine = LLMEngine.from_engine_args(engine_args)
result = do_sample(engine)
expected_output = {
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
}
assert result == expected_output

View File

@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.utils import set_random_seed
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
from vllm.usage.usage_lib import UsageContext
@ -92,6 +93,7 @@ class AsyncLLM:
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalDataDict] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> List[RequestOutput]:
if prompts is None:

View File

@ -23,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
is_driver_worker=True,
)
return model_runner

View File

View File

@ -0,0 +1,14 @@
from dataclasses import dataclass
from typing import Tuple
@dataclass
class AdapterMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)

View File

@ -0,0 +1,104 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
from torch import nn
from vllm.logger import init_logger
from vllm.utils import LRUCache
logger = init_logger(__name__)
class AdapterModel(ABC):
def __init__(self, model_id=None):
self.id = model_id
@abstractmethod
def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
# Common initialization code
# Load weights or embeddings from local checkpoint
raise NotImplementedError("Subclasses must implement this method.")
T = TypeVar('T')
class AdapterLRUCache(LRUCache[T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn
def _on_remove(self, key: Hashable, value: T):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
class AdapterModelManager(ABC):
def __init__(
self,
model: nn.Module,
):
"""Create a AdapterModelManager and adapter for a given model.
Args:
model: the model to be adapted.
"""
self.model: nn.Module = model
self._registered_adapters: Dict[int, Any] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_adapters: Dict[int, None] = {}
self.adapter_type = 'Adapter'
self._last_mapping = None
def __len__(self) -> int:
return len(self._registered_adapters)
@property
@abstractmethod
def adapter_slots(self):
...
@property
@abstractmethod
def capacity(self):
...
@abstractmethod
def activate_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def add_adapter(self, adapter: Any) -> bool:
...
@abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None:
...
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def remove_all_adapters(self):
...
@abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]:
...
@abstractmethod
def list_adapters(self) -> Dict[int, Any]:
...
@abstractmethod
def pin_adapter(self, adapter_id: int) -> bool:
...

View File

@ -0,0 +1,25 @@
from abc import abstractmethod
from dataclasses import dataclass
@dataclass
class AdapterRequest:
"""
Base class for adapter requests.
"""
@property
@abstractmethod
def adapter_id(self):
...
def __post_init__(self):
if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}")
def __eq__(self, value: object) -> bool:
return isinstance(
value, self.__class__) and self.adapter_id == value.adapter_id
def __hash__(self) -> int:
return hash(self.adapter_id)

View File

@ -0,0 +1,90 @@
from typing import Any, Callable, Dict, Optional, Set
## model functions
def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None],
deactivate_func: Callable) -> bool:
if adapter_id in active_adapters:
deactivate_func(adapter_id)
active_adapters.pop(adapter_id)
return True
return False
def add_adapter(adapter: Any, registered_adapters: Dict[int, Any],
capacity: int, add_func: Callable) -> bool:
if adapter.id not in registered_adapters:
if len(registered_adapters) >= capacity:
raise RuntimeError('No free adapter slots.')
add_func(adapter)
registered_adapters[adapter.id] = adapter
return True
return False
def set_adapter_mapping(mapping: Any, last_mapping: Any,
set_mapping_func: Callable) -> Any:
if last_mapping != mapping:
set_mapping_func(mapping)
return mapping
return last_mapping
def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any],
deactivate_func: Callable) -> bool:
deactivate_func(adapter_id)
return bool(registered_adapters.pop(adapter_id, None))
def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]:
return dict(registered_adapters)
def get_adapter(adapter_id: int,
registered_adapters: Dict[int, Any]) -> Optional[Any]:
return registered_adapters.get(adapter_id, None)
## worker functions
def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any],
apply_adapters_func,
set_adapter_mapping_func) -> None:
apply_adapters_func(requests)
set_adapter_mapping_func(mapping)
def add_adapter_worker(adapter_request: Any, list_adapters_func,
load_adapter_func, add_adapter_func,
activate_adapter_func) -> bool:
if adapter_request.adapter_id in list_adapters_func():
return False
loaded_adapter = load_adapter_func(adapter_request)
loaded = add_adapter_func(loaded_adapter)
activate_adapter_func(loaded_adapter.id)
return loaded
def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func,
adapter_slots: int, remove_adapter_func,
add_adapter_func) -> None:
models_that_exist = list_adapters_func()
models_map = {
adapter_request.adapter_id: adapter_request
for adapter_request in adapter_requests if adapter_request
}
if len(models_map) > adapter_slots:
raise RuntimeError(
f"Number of requested models ({len(models_map)}) is greater "
f"than the number of GPU model slots "
f"({adapter_slots}).")
new_models = set(models_map)
models_to_add = new_models - models_that_exist
models_to_remove = models_that_exist - new_models
for adapter_id in models_to_remove:
remove_adapter_func(adapter_id)
for adapter_id in models_to_add:
add_adapter_func(models_map[adapter_id])
def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]:
return set(adapter_manager_list_adapters_func())

View File

@ -0,0 +1,36 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Set
import torch
class AbstractWorkerManager(ABC):
def __init__(self, device: torch.device):
self.device = device
@property
@abstractmethod
def is_enabled(self) -> bool:
...
@abstractmethod
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
...
@abstractmethod
def add_adapter(self, adapter_request: Any) -> bool:
...
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
@abstractmethod
def remove_all_adapters(self):
...
@abstractmethod
def list_adapters(self) -> Set[int]:
...

View File

@ -1285,6 +1285,39 @@ class LoRAConfig:
raise ValueError("LoRA is not supported with chunked prefill yet.")
@dataclass
class PromptAdapterConfig:
max_prompt_adapters: int
max_prompt_adapter_token: int
max_cpu_prompt_adapters: Optional[int] = None
prompt_adapter_dtype: Optional[torch.dtype] = None
def __post_init__(self):
library_name = 'peft'
try:
__import__(library_name)
except ImportError as e:
raise ImportError(
f"'{library_name}' is not installed for prompt adapter support."
f"Please install it using 'pip install {library_name}'."
) from e
if self.max_prompt_adapters < 1:
raise ValueError(f"max_prompt_adapters "
f"({self.max_prompt_adapters}) must be >= 1.")
if self.max_prompt_adapter_token == 0:
raise ValueError("max_prompt_adapter_token must be set.")
if self.max_cpu_prompt_adapters is None:
self.max_cpu_prompt_adapters = self.max_prompt_adapters
def verify_with_model_config(self, model_config: ModelConfig):
if self.prompt_adapter_dtype in (None, "auto"):
self.prompt_adapter_dtype = model_config.dtype
elif isinstance(self.prompt_adapter_dtype, str):
self.prompt_adapter_dtype = getattr(torch,
self.prompt_adapter_dtype)
@dataclass
class MultiModalConfig:
"""Configs the input data format and how models should run for
@ -1518,6 +1551,7 @@ class EngineConfig:
speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig]
prompt_adapter_config: Optional[PromptAdapterConfig]
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
@ -1529,6 +1563,9 @@ class EngineConfig:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs.

View File

@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import Policy, PolicyFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
@ -139,6 +140,8 @@ class SchedulerOutputs:
if self.num_loras > 0:
self._sort_by_lora_ids()
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
@ -157,6 +160,14 @@ class SchedulerOutputs:
if g.seq_group.lora_request is not None
}
@property
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
return {
g.seq_group.prompt_adapter_request
for g in self.scheduled_seq_groups
if g.seq_group.prompt_adapter_request is not None
}
@dataclass
class SchedulerRunningOutputs:
@ -1024,6 +1035,7 @@ class Scheduler:
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
seq_group_metadata_list.append(seq_group_metadata)

View File

@ -7,8 +7,8 @@ from typing import List, Optional, Tuple, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig)
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
@ -66,6 +66,9 @@ class EngineArgs:
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
enable_prompt_adapter: bool = False
max_prompt_adapters: int = 1
max_prompt_adapter_token: int = 0
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
@ -449,6 +452,17 @@ class EngineArgs:
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument('--enable-prompt-adapter',
action='store_true',
help='If True, enable handling of PromptAdapters.')
parser.add_argument('--max-prompt-adapters',
type=int,
default=EngineArgs.max_prompt_adapters,
help='Max number of PromptAdapters in a batch.')
parser.add_argument('--max-prompt-adapter-token',
type=int,
default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens')
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
@ -726,6 +740,11 @@ class EngineArgs:
model_loader_extra_config=self.model_loader_extra_config,
)
prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters,
max_prompt_adapter_token=self.max_prompt_adapter_token) \
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
@ -751,6 +770,7 @@ class EngineArgs:
load_config=load_config,
decoding_config=decoding_config,
observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
)

View File

@ -18,6 +18,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.usage.usage_lib import UsageContext
@ -264,6 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
@ -279,6 +281,12 @@ class _AsyncLLMEngine(LLMEngine):
else:
prompt_token_ids = inputs["prompt_token_ids"]
if prompt_adapter_request:
prompt_token_ids = [
0
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
prompt_token_ids
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
@ -293,6 +301,7 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@ -301,7 +310,10 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time()
processed_inputs = await self.process_model_inputs_async(
request_id=request_id, inputs=inputs, lora_request=lora_request)
request_id=request_id,
inputs=inputs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
self._add_processed_request(
request_id=request_id,
@ -309,6 +321,7 @@ class _AsyncLLMEngine(LLMEngine):
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
)
@ -627,6 +640,7 @@ class AsyncLLMEngine:
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream:
if self.log_requests:
if isinstance(inputs, str):
@ -669,7 +683,7 @@ class AsyncLLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
)
prompt_adapter_request=prompt_adapter_request)
return stream
@ -680,6 +694,7 @@ class AsyncLLMEngine:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
@ -695,6 +710,8 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
Yields:
The output `RequestOutput` objects from the LLMEngine
@ -749,6 +766,7 @@ class AsyncLLMEngine:
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
):
yield LLMEngine.validate_output(output, RequestOutput)
@ -837,6 +855,7 @@ class AsyncLLMEngine:
*,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
@ -849,6 +868,7 @@ class AsyncLLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
)
try:

View File

@ -8,7 +8,8 @@ from transformers import PreTrainedTokenizer
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, MultiModalConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
@ -27,6 +28,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
PoolerOutput, SamplerOutput, Sequence,
@ -93,6 +95,8 @@ class LLMEngine:
decoding.
executor_class: The model executor class for managing distributed
execution.
prompt_adapter_config (Optional): The configuration related to serving
prompt adapters.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection.
"""
@ -161,6 +165,7 @@ class LLMEngine:
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@ -222,6 +227,7 @@ class LLMEngine:
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats
@ -250,6 +256,7 @@ class LLMEngine:
multimodal_config=multimodal_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
)
if not self.model_config.embedding_mode:
@ -282,6 +289,8 @@ class LLMEngine:
# Feature flags
"enable_lora":
bool(lora_config),
"enable_prompt_adapter":
bool(prompt_adapter_config),
"enable_prefix_caching":
cache_config.enable_prefix_caching,
"enforce_eager":
@ -376,7 +385,6 @@ class LLMEngine:
engine_config = engine_args.create_engine_config()
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
@ -409,7 +417,6 @@ class LLMEngine:
else:
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
# Create the LLM engine.
engine = cls(
**engine_config.to_dict(),
@ -470,6 +477,9 @@ class LLMEngine:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def _get_eos_token_id(
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
@ -487,6 +497,7 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Dict[str, str]] = None,
) -> None:
# Create the sequences.
@ -495,7 +506,7 @@ class LLMEngine:
eos_token_id = self._get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request)
lora_request, prompt_adapter_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
@ -506,7 +517,7 @@ class LLMEngine:
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
)
prompt_adapter_request=prompt_adapter_request)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
@ -514,7 +525,7 @@ class LLMEngine:
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
prompt_adapter_request=prompt_adapter_request)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
@ -535,6 +546,7 @@ class LLMEngine:
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
@ -549,6 +561,11 @@ class LLMEngine:
else:
prompt_token_ids = inputs["prompt_token_ids"]
if prompt_adapter_request:
prompt_token_ids = \
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
+ prompt_token_ids
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
@ -563,6 +580,7 @@ class LLMEngine:
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
"""Add a request to the engine's request pool.
@ -612,9 +630,11 @@ class LLMEngine:
if arrival_time is None:
arrival_time = time.time()
processed_inputs = self.process_model_inputs(request_id=request_id,
processed_inputs = self.process_model_inputs(
request_id=request_id,
inputs=inputs,
lora_request=lora_request)
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
self._add_processed_request(
request_id=request_id,
@ -622,6 +642,7 @@ class LLMEngine:
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
)
@ -633,6 +654,7 @@ class LLMEngine:
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
@ -658,7 +680,7 @@ class LLMEngine:
sampling_params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
prompt_adapter_request=prompt_adapter_request)
return seq_group
@ -669,16 +691,19 @@ class LLMEngine:
pooling_params: PoolingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(request_id=request_id,
seq_group = SequenceGroup(
request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params)
pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
@ -1082,6 +1107,16 @@ class LLMEngine:
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_executor.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> List[int]:
return self.model_executor.list_prompt_adapters()
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()

View File

@ -13,6 +13,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext
@ -255,6 +256,7 @@ class LLM:
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
@ -271,6 +273,8 @@ class LLM:
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of `RequestOutput` objects containing the
@ -304,7 +308,7 @@ class LLM:
inputs=inputs,
params=sampling_params,
lora_request=lora_request,
)
prompt_adapter_request=prompt_adapter_request)
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
@ -397,6 +401,7 @@ class LLM:
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts.
@ -412,6 +417,8 @@ class LLM:
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
@ -445,6 +452,7 @@ class LLM:
inputs=inputs,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
@ -504,6 +512,7 @@ class LLM:
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
@ -526,19 +535,23 @@ class LLM:
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
)
prompt_adapter_request=prompt_adapter_request)
def _add_request(
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[Union[List[LoRARequest],
LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,
self.llm_engine.add_request(
request_id,
inputs,
params,
lora_request=lora_request)
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
def _run_engine(
self, *, use_tqdm: bool

View File

@ -116,7 +116,7 @@ async def detokenize(request: DetokenizeRequest):
@app.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
models = await openai_serving_completion.show_available_models()
return JSONResponse(content=models.model_dump())
@ -236,7 +236,8 @@ if __name__ == "__main__":
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, model_config, served_model_names, args.lora_modules)
engine, model_config, served_model_names, args.lora_modules,
args.prompt_adapters)
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
app.root_path = args.root_path

View File

@ -9,7 +9,8 @@ import json
import ssl
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.utils import FlexibleArgumentParser
@ -23,6 +24,16 @@ class LoRAParserAction(argparse.Action):
setattr(namespace, self.dest, lora_list)
class PromptAdapterParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
adapter_list = []
for item in values:
name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path))
setattr(namespace, self.dest, adapter_list)
def make_arg_parser():
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
@ -65,6 +76,14 @@ def make_arg_parser():
action=LoRAParserAction,
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.")
parser.add_argument(
"--prompt-adapters",
type=nullable_str,
default=None,
nargs='+',
action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.")
parser.add_argument("--chat-template",
type=nullable_str,
default=None,

View File

@ -258,7 +258,7 @@ class OpenAIServingChat(OpenAIServing):
prompt=prompt,
add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
_, lora_request = self._maybe_get_adapter(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend

View File

@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
TokenizeResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
OpenAIServing,
PromptAdapterPath)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
@ -67,11 +68,13 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]]):
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]]):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)
lora_modules=lora_modules,
prompt_adapters=prompt_adapters)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
@ -101,7 +104,12 @@ class OpenAIServingCompletion(OpenAIServing):
generators: List[AsyncIterator[RequestOutput]] = []
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
adapter_type, adapter_request = self._maybe_get_adapter(request)
lora_request, prompt_adapter_request = None, None
if adapter_type == 'LoRA':
lora_request, prompt_adapter_request = adapter_request, None
elif adapter_type == 'PromptAdapter':
lora_request, prompt_adapter_request = None, adapter_request
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
@ -147,6 +155,7 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params,
f"{request_id}-{i}",
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
)

View File

@ -16,12 +16,19 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ModelPermission, TokenizeRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import get_tokenizer
logger = init_logger(__name__)
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass
class LoRAModulePath:
name: str
@ -30,9 +37,14 @@ class LoRAModulePath:
class OpenAIServing:
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]]):
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
):
super().__init__()
self.engine = engine
@ -49,9 +61,8 @@ class OpenAIServing:
self.served_model_names = served_model_names
if lora_modules is None:
self.lora_requests = []
else:
if lora_modules is not None:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
@ -60,6 +71,20 @@ class OpenAIServing:
) for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with open(f"./{prompt_adapter.local_path}"
f"/adapter_config.json") as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
@ -75,7 +100,14 @@ class OpenAIServing:
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.served_model_names[0],
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
return ModelList(data=model_cards)
def create_error_response(
@ -109,20 +141,29 @@ class OpenAIServing:
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.prompt_adapter_requests
]:
return None
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(
def _maybe_get_adapter(
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[LoRARequest]:
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
PromptAdapterRequest]]]:
if request.model in self.served_model_names:
return None
return None, None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
return 'LoRA', lora
for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return 'PromptAdapter', prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")

View File

@ -7,6 +7,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
@ -48,6 +49,7 @@ class CPUExecutor(ExecutorBase):
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=self.cache_config.cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=True,
)
self.driver_worker.init_device()
@ -90,6 +92,19 @@ class CPUExecutor(ExecutorBase):
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
return self.driver_worker.list_prompt_adapters()
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
def check_health(self) -> None:
# CPUExecutor will always be healthy as long as
# it's running.

View File

@ -4,8 +4,10 @@ from typing import List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig)
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
@ -28,6 +30,7 @@ class ExecutorBase(ABC):
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
@ -38,6 +41,7 @@ class ExecutorBase(ABC):
self.device_config = device_config
self.multimodal_config = multimodal_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self._init_executor()
@ -95,6 +99,23 @@ class ExecutorBase(ABC):
def list_loras(self) -> Set[int]:
raise NotImplementedError
@abstractmethod
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
raise NotImplementedError
@abstractmethod
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError # type: ignore
@abstractmethod
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError
@abstractmethod
def check_health(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
@ -122,12 +143,14 @@ class ExecutorAsyncBase(ExecutorBase):
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
) -> None:
self.pp_locks: Optional[List[asyncio.Lock]] = None
super().__init__(model_config, cache_config, parallel_config,
scheduler_config, device_config, load_config,
lora_config, multimodal_config, speculative_config)
lora_config, multimodal_config, speculative_config,
prompt_adapter_config)
@abstractmethod
async def execute_model_async(

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
@ -45,6 +46,7 @@ class GPUExecutor(ExecutorBase):
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
speculative_config=self.speculative_config,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
@ -107,6 +109,25 @@ class GPUExecutor(ExecutorBase):
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
return self.driver_worker.list_prompt_adapters()
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.

View File

@ -8,7 +8,8 @@ from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig)
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.ray_utils import RayWorkerWrapper, ray
@ -44,6 +45,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
assert device_config.device_type == "xpu"
@ -58,6 +60,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
self.scheduler_config = scheduler_config
self.device_config = device_config
self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config
placement_group = self.parallel_config.placement_group

View File

@ -4,7 +4,8 @@ import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig)
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
@ -27,6 +28,7 @@ class XPUExecutor(GPUExecutor):
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
assert device_config.device_type == "xpu"
@ -43,6 +45,7 @@ class XPUExecutor(GPUExecutor):
self.scheduler_config = scheduler_config
self.device_config = device_config
self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config
self.speculative_config = None
# Instantiate the worker and load the model to GPU.

View File

@ -8,6 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.adapter_commons.layers import AdapterMapping
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@ -134,15 +135,8 @@ def _apply_lora_packed_nslice(
@dataclass
class LoRAMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
class LoRAMapping(AdapterMapping):
pass
class BaseLayerWithLoRA(nn.Module):

View File

@ -4,12 +4,17 @@ import math
import os
import re
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import safetensors.torch
import torch
from torch import nn
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
get_adapter, list_adapters,
remove_adapter, set_adapter_mapping)
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import (BaseLayerWithLoRA,
@ -19,7 +24,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.utils import LRUCache, is_pin_memory_available
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
@ -153,7 +158,7 @@ def get_lora_id():
return _GLOBAL_LORA_ID
class LoRAModel:
class LoRAModel(AdapterModel):
"""A LoRA fine-tuned model."""
def __init__(
@ -388,7 +393,7 @@ class LoRAModel:
)
class LoRAModelManager:
class LoRAModelManager(AdapterModelManager):
"""A manager that manages multiple LoRA-fine-tuned models."""
def __init__(
@ -440,8 +445,7 @@ class LoRAModelManager:
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len: List[Optional[int]] = [None] * 4
self.model = model
super().__init__(model)
if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules)
@ -453,11 +457,11 @@ class LoRAModelManager:
self.model.packed_modules_mapping)
self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {}
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self
self.adapter_type = 'LoRa'
@property
def capacity(self) -> int:
@ -467,15 +471,16 @@ class LoRAModelManager:
def lora_slots(self) -> int:
return self.lora_config.max_loras
def __len__(self) -> int:
return len(self._registered_loras)
@property
def adapter_slots(self) -> int:
return self.lora_slots
def activate_lora(
def activate_adapter(
self,
lora_id: int,
) -> bool:
"""Move LoRA into a GPU buffer to be used in the forward pass."""
if lora_id in self._active_loras:
if lora_id in self._active_adapters:
return False
first_free_slot = next(
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
@ -483,8 +488,8 @@ class LoRAModelManager:
if first_free_slot is None:
raise ValueError("No free lora slots")
index, _ = first_free_slot
self._active_loras[lora_id] = None
lora_model = self._registered_loras[lora_id]
self._active_adapters[lora_id] = None
lora_model = self._registered_adapters[lora_id]
logger.debug("Activating LoRA. int id: %d, slot index: %d",
lora_model.id, index)
self.lora_index_to_id[index] = lora_model.id
@ -498,21 +503,13 @@ class LoRAModelManager:
module.reset_lora(index)
return True
def _deactivate_lora(self, lora_id: int):
def _deactivate_adapter(self, lora_id: int):
try:
index = self.lora_index_to_id.index(lora_id)
self.lora_index_to_id[index] = None
except ValueError:
pass
def deactivate_lora(self, lora_id: int) -> bool:
"""Remove a LoRA from a GPU buffer."""
if lora_id in self._active_loras:
self._deactivate_lora(lora_id)
self._active_loras.pop(lora_id)
return True
return False
def _set_long_lora_context(self, lora: LoRAModel):
if self.long_lora_context is None:
return
@ -528,40 +525,19 @@ class LoRAModelManager:
if offsets:
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
def _add_lora(self, lora: LoRAModel):
def _add_adapter(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora
self._registered_adapters[lora.id] = lora
self._set_long_lora_context(lora)
def add_lora(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager CPU cache."""
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_loras:
if len(self._registered_loras) >= self.capacity:
raise RuntimeError("No free LoRA slots.")
self._add_lora(lora)
return True
return False
def remove_lora(self, lora_id: int) -> bool:
"""Remove a LoRAModel from the manager CPU cache."""
# TODO: should we check active lora?
self.deactivate_lora(lora_id)
if self.long_lora_context:
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
return bool(self._registered_loras.pop(lora_id, None))
def pin_lora(self, lora_id: int) -> bool:
def pin_adapter(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
raise NotImplementedError(
"Pinning is not supported in LoRAModelManager."
"Use LRUCacheLoRAModelManager for pinning") # type: ignore
# TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_offsets_tensor,
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
@ -583,23 +559,11 @@ class LoRAModelManager:
# Maintain the reference
self.indices_len[:] = indices_len
def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
if self._last_mapping != lora_mapping:
self._set_lora_mapping(lora_mapping)
self._last_mapping = lora_mapping
def list_loras(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_loras)
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None)
def remove_all_loras(self):
def remove_all_adapters(self):
"""Remove all LoRAModels from the manager."""
self._registered_loras.clear()
self._registered_adapters.clear()
self.lora_index_to_id = [None] * self.lora_slots
self._active_loras.clear()
self._active_adapters.clear()
def _create_lora_modules(self):
for module_name, module in self.model.named_modules(
@ -743,18 +707,39 @@ class LoRAModelManager:
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras)
def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
self._deactivate_adapter)
class LoRALRUCache(LRUCache[LoRAModel]):
def add_adapter(self, adapter: LoRAModel) -> bool:
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", adapter.id, adapter.id,
adapter.scaling_factor)
return add_adapter(adapter, self._registered_adapters, self.capacity,
self._add_adapter)
def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
self._set_adapter_mapping)
def remove_adapter(self, adapter_id: int) -> bool:
return remove_adapter(adapter_id, self._registered_adapters,
self.deactivate_adapter)
def list_adapters(self) -> Dict[int, Any]:
return list_adapters(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[Any]:
return get_adapter(adapter_id, self._registered_adapters)
class LoRALRUCache(AdapterLRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
bool]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: int, value: LoRAModel):
logger.debug("Removing LoRA. int id: %d", key)
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
super().__init__(capacity, deactivate_lora_fn)
class LRUCacheLoRAModelManager(LoRAModelManager):
@ -770,49 +755,49 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
):
super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config)
self._registered_loras: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_lora)
self._active_loras: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_lora)
self._registered_adapters: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_adapter)
self._active_adapters: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_adapter)
def list_loras(self) -> Dict[int, LoRAModel]:
def list_adapters(self) -> Dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_loras.cache)
return dict(self._registered_adapters.cache)
def add_lora(self, lora: LoRAModel) -> bool:
def add_adapter(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager."""
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_loras:
self._add_lora(lora)
if lora.id not in self._registered_adapters:
self._add_adapter(lora)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_loras.touch(lora.id)
self._registered_adapters.touch(lora.id)
was_added = False
return was_added
def activate_lora(
def activate_adapter(
self,
lora_id: int,
) -> bool:
if lora_id not in self._active_loras and len(
self._active_loras) >= self.lora_slots:
self._active_loras.remove_oldest()
result = super().activate_lora(lora_id)
if lora_id not in self._active_adapters and len(
self._active_adapters) >= self.lora_slots:
self._active_adapters.remove_oldest()
result = super().activate_adapter(lora_id)
# We always touch to update the LRU cache order
self._active_loras.touch(lora_id)
self._active_adapters.touch(lora_id)
return result
def remove_oldest_lora(self) -> bool:
if len(self._registered_loras) > 0:
self._registered_loras.remove_oldest()
def remove_oldest_adapter(self) -> bool:
if len(self._registered_adapters) > 0:
self._registered_adapters.remove_oldest()
return True
return False
def pin_lora(self, lora_id: int) -> bool:
def pin_adapter(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
self._pin_lora_in_cpu_cache(lora_id)
self._pin_lora_in_gpu_cache(lora_id)
@ -820,17 +805,17 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
def _pin_lora_in_cpu_cache(self, lora_id: int):
try:
self._registered_loras.pin(lora_id)
self._registered_adapters.pin(lora_id)
except ValueError as err:
raise ValueError("Pinning failed. "
f"LoRA {lora_id} is not registered.") from err
def _pin_lora_in_gpu_cache(self, lora_id: int):
if lora_id not in self._active_loras:
if lora_id not in self._active_adapters:
# move lora to gpu if not already active
self.activate_lora(lora_id)
self.activate_adapter(lora_id)
self._active_loras.pin(lora_id)
self._active_adapters.pin(lora_id)
def create_lora_manager(

View File

@ -1,13 +1,15 @@
from dataclasses import dataclass
from typing import Optional
from vllm.adapter_commons.request import AdapterRequest
@dataclass
class LoRARequest:
class LoRARequest(AdapterRequest):
"""
Request for a LoRA adapter.
Note that this class should be be used internally. For online
Note that this class should be used internally. For online
serving, it is recommended to not allow users to use this class but
instead provide another layer of abstraction to prevent users from
accessing unauthorized LoRA adapters.
@ -20,15 +22,16 @@ class LoRARequest:
lora_int_id: int
lora_local_path: str
long_lora_max_len: Optional[int] = None
__hash__ = AdapterRequest.__hash__
def __post_init__(self):
if self.lora_int_id < 1:
raise ValueError(
f"lora_int_id must be > 0, got {self.lora_int_id}")
def __eq__(self, value: object) -> bool:
return isinstance(
value, LoRARequest) and self.lora_int_id == value.lora_int_id
def __hash__(self) -> int:
@property
def adapter_id(self):
return self.lora_int_id
@property
def name(self):
return self.lora_name
@property
def local_path(self):
return self.lora_local_path

View File

@ -1,12 +1,15 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
import torch
from vllm.adapter_commons.utils import (add_adapter_worker,
apply_adapters_worker,
list_adapters_worker,
set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest
@ -14,79 +17,13 @@ from vllm.lora.request import LoRARequest
logger = init_logger(__name__)
class AbstractWorkerLoRAManager(ABC):
"""Abstract class for managing LoRA models on the worker side."""
def __init__(self,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
max_position_embeddings: Optional[int] = None):
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_position_embeddings = max_position_embeddings
self.vocab_size = vocab_size
self.device = device
self.lora_config = lora_config
# If False, do not cache. If None, cache is empty.
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False
@property
@abstractmethod
def is_enabled(self) -> bool:
...
@abstractmethod
def create_lora_manager(
self,
model: torch.nn.Module,
) -> Any:
...
@abstractmethod
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
...
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
...
@abstractmethod
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
...
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
...
@abstractmethod
def remove_all_loras(self):
...
@abstractmethod
def list_loras(self) -> Set[int]:
...
class WorkerLoRAManager(AbstractWorkerLoRAManager):
class WorkerLoRAManager(AbstractWorkerManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Every request, the requested LoRAs will be loaded (unless they are already
loaded), and every other LoRA will be unloaded."""
_lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager
_manager_cls: Type[LoRAModelManager] = LoRAModelManager
def __init__(
self,
@ -103,16 +40,23 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.lora_config = lora_config
self.max_position_embeddings = max_position_embeddings
super().__init__(device)
# Lazily initialized by create_lora_manager.
self._lora_manager: LoRAModelManager
super().__init__(
max_num_seqs,
max_num_batched_tokens,
vocab_size,
lora_config,
device,
max_position_embeddings=max_position_embeddings,
)
self._adapter_manager: LoRAModelManager
@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False
@property
def is_enabled(self) -> bool:
@ -128,41 +72,14 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
max_num_batched_tokens=self.max_num_batched_tokens,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls,
lora_manager_cls=self._manager_cls,
)
self._lora_manager = lora_manager
self._adapter_manager = lora_manager
return lora_manager.model
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
self._apply_loras(lora_requests)
self._lora_manager.set_lora_mapping(lora_mapping)
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_that_exist = self.list_loras()
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._lora_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).")
new_loras = set(loras_map)
loras_to_add = new_loras - loras_that_exist
loras_to_remove = loras_that_exist - new_loras
for lora_id in loras_to_remove:
self.remove_lora(lora_id)
for lora_id in loras_to_add:
self.add_lora(loras_map[lora_id])
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
try:
model = self._lora_manager.model
model = self._adapter_manager.model
supported_lora_modules = model.supported_lora_modules
packed_modules_mapping = model.packed_modules_mapping
expected_lora_modules: List[str] = []
@ -198,37 +115,45 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras():
if lora_request.lora_int_id in self.list_adapters():
return False
if isinstance(self._cached_dummy_lora, LoRAModel):
dummy_lora = self._cached_dummy_lora.clone(
lora_request.lora_int_id)
else:
dummy_lora = self._lora_manager.create_dummy_lora(
dummy_lora = self._adapter_manager.create_dummy_lora(
lora_request.lora_int_id, rank, 1, self.embedding_modules)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._lora_manager.add_lora(dummy_lora)
return self._adapter_manager.add_adapter(dummy_lora)
def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras():
return False
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
self._lora_manager.activate_lora(lora.id)
return loaded
def pin_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.pin_adapter(adapter_id)
def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id)
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
set_active_adapters_worker(requests, mapping, self._apply_adapters,
self._adapter_manager.set_adapter_mapping)
def pin_lora(self, lora_id: int) -> bool:
return self._lora_manager.pin_lora(lora_id)
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
apply_adapters_worker(adapter_requests, self.list_adapters,
self._adapter_manager.adapter_slots,
self.remove_adapter, self.add_adapter)
def remove_all_loras(self):
self._lora_manager.remove_all_loras()
def add_adapter(self, adapter_request: Any) -> bool:
return add_adapter_worker(adapter_request, self.list_adapters,
self._load_adapter,
self._adapter_manager.add_adapter,
self._adapter_manager.activate_adapter)
def list_loras(self) -> Set[int]:
return set(self._lora_manager.list_loras())
def remove_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.remove_adapter(adapter_id)
def remove_all_adapters(self):
self._adapter_manager.remove_all_adapters()
def list_adapters(self) -> Set[int]:
return list_adapters_worker(self._adapter_manager.list_adapters)
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
@ -238,8 +163,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
(unless they are already loaded) and least recently used LoRAs will
be unloaded if the cache is above capacity."""
_lora_manager_cls: Type[
LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
_manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
def create_lora_manager(
self,
@ -247,40 +171,41 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
) -> Any:
lora_manager = create_lora_manager(
model,
lora_manager_cls=self._lora_manager_cls,
lora_manager_cls=self._manager_cls,
max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
max_num_batched_tokens=self.max_num_batched_tokens,
)
self._lora_manager = lora_manager
self._adapter_manager = lora_manager
return lora_manager.model
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._lora_manager.lora_slots:
if len(loras_map) > self._adapter_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._lora_manager.lora_slots}).")
f"({self._adapter_manager.lora_slots}).")
for lora in loras_map.values():
self.add_lora(lora)
self.add_adapter(lora)
def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id not in self.list_loras():
def add_adapter(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id not in self.list_adapters():
# Remove before we load the new lora to save memory
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
self._lora_manager.remove_oldest_lora()
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
assert isinstance(self._adapter_manager,
LRUCacheLoRAModelManager)
self._adapter_manager.remove_oldest_adapter()
lora = self._load_adapter(lora_request)
loaded = self._adapter_manager.add_adapter(lora)
else:
# If the lora is already loaded, just touch it to
# update its position in the caches
loaded = self._lora_manager.get_lora(
loaded = self._adapter_manager.get_adapter(
lora_request.lora_int_id) is not None
self._lora_manager.activate_lora(lora_request.lora_int_id)
self._adapter_manager.activate_adapter(lora_request.lora_int_id)
return loaded

View File

View File

@ -0,0 +1,80 @@
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.adapter_commons.layers import AdapterMapping
from vllm.config import PromptAdapterConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@dataclass
class PromptAdapterMapping(AdapterMapping):
pass
class VocabParallelEmbeddingWithPromptAdapter(nn.Module):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.emb_layer = self.base_layer
if 'LoRA' in base_layer.__class__.__name__:
self.emb_layer = self.base_layer.base_layer
def create_prompt_adapter_weights(
self, prompt_adapter_config: PromptAdapterConfig):
self.embeddings_tensors = torch.zeros(
(
prompt_adapter_config.max_prompt_adapters,
prompt_adapter_config.max_prompt_adapter_token,
self.emb_layer.embedding_dim,
),
dtype=self.emb_layer.weight.dtype,
device=self.emb_layer.weight.device,
)
self.adapter_lengths = torch.zeros(
prompt_adapter_config.max_prompt_adapters,
dtype=torch.long,
device=self.emb_layer.weight.device)
self.indices_gpu: torch.Tensor
self.embedding_indices_gpu: torch.Tensor
def reset_prompt_adapter(self, index: int):
self.embeddings_tensors[index] = 0
def set_prompt_adapter(
self,
index: int,
adapter_model: Optional[torch.Tensor],
):
self.reset_prompt_adapter(index)
if adapter_model is not None:
length = adapter_model.shape[0]
self.embeddings_tensors[index, :length] = adapter_model
self.adapter_lengths[index] = length
def set_mapping(
self,
prompt_indices: torch.Tensor,
prompt_embedding_indices: torch.Tensor,
):
self.indices_gpu = prompt_indices.to(
device=self.emb_layer.weight.device)
self.embedding_indices_gpu = prompt_embedding_indices.to(
device=self.emb_layer.weight.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = self.base_layer(x)
if self.embedding_indices_gpu.ndim > 1:
valid_mask = self.indices_gpu != -1
gathered_embeddings = self.embeddings_tensors[
self.embedding_indices_gpu[:, 0],
self.embedding_indices_gpu[:, 1]]
# Update hidden states
hidden_states[valid_mask] = gathered_embeddings
return hidden_states

View File

@ -0,0 +1,355 @@
import logging
import math
from typing import Any, Callable, Dict, List, Optional, Type
import torch
from torch import nn
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
get_adapter, list_adapters,
remove_adapter, set_adapter_mapping)
from vllm.config import PromptAdapterConfig
from vllm.prompt_adapter.layers import (
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
from vllm.prompt_adapter.layers import PromptAdapterMapping
logger = logging.getLogger(__name__)
_GLOBAL_PROMPT_ADAPTER_ID = 0
def get_prompt_adapter_id():
global _GLOBAL_PROMPT_ADAPTER_ID
_GLOBAL_PROMPT_ADAPTER_ID += 1
return _GLOBAL_PROMPT_ADAPTER_ID
def convert_to_embedding_indices(indices):
embedding_indices = []
count = 0
for value in indices:
if value == -1:
count = 0
else:
embedding_indices.append([value, count])
count += 1
return torch.tensor(embedding_indices)
def convert_mapping(
mapping: PromptAdapterMapping,
prompt_adapter_index_to_id: List[Optional[int]],
) -> torch.Tensor:
"""Converts PromptAdapterMapping to index tensors.
Args:
mapping: PromptAdapterMapping mapping rows in a
batch to PromptAdapter ids.
prompt_adapter_index_to_id: List mapping PromptAdapter
ids to PromptAdapter indices.
Returns:
pa_indices: Tensor of shape [batch_size] mapping batch rows to
PromptAdapter indices.
"""
id_to_index = {
id_: idx
for idx, id_ in enumerate(prompt_adapter_index_to_id)
if id_ is not None
}
pa_indices = ([
id_to_index.get(id_, -1) if id_ > 0 else -1
for id_ in mapping.index_mapping
])
pa_embedding_mapping = convert_to_embedding_indices(pa_indices)
pa_indices = torch.tensor(pa_indices)
return pa_indices, pa_embedding_mapping
class PromptAdapterModel(AdapterModel):
def __init__(self,
prompt_adapter_id=None,
num_virtual_tokens=None,
prompt_embedding=None) -> None:
self.id = prompt_adapter_id
self.prompt_embedding = prompt_embedding
self.num_virtual_tokens = num_virtual_tokens
@classmethod
def from_local_checkpoint(
cls,
adapter_model_path: str,
prompt_adapter_id: int,
num_virtual_tokens: int,
config: PromptAdapterConfig,
device: str = "cuda",
) -> "PromptAdapterModel":
from peft.utils import load_peft_weights
if num_virtual_tokens > config.max_prompt_adapter_token:
raise ValueError(
f'num_virtual_tokens ({num_virtual_tokens}) should be <= '
f'max_prompt_adapter_token({config.max_prompt_adapter_token})')
adapters_weights = load_peft_weights(adapter_model_path, device)
prompt_embedding = adapters_weights["prompt_embeddings"].to(
config.prompt_adapter_dtype)
return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding)
class PromptAdapterModelManager(AdapterModelManager):
"""A manager that manages multiple Prompt Adapter models."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
):
"""Create a PromptAdapterModel and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
prompt_adapter_config: the PromptAdapter config,
"""
self.model: nn.Module = model
# Dict instead of a Set for compatibility with LRUCache.
self.prompt_adapter_index_to_id: List[
Optional[int]] = [None] * self.prompt_adapter_slots
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.prompt_adapter_config = prompt_adapter_config
self.model.prompt_adapter_manager = self
self.adapter_type = 'PromptAdapter'
self.base_indices = torch.tensor([-1])
self.base_embedding_indices = torch.tensor([])
self.modules: Dict[str, nn.Module] = {}
self._create_prompt_adapter_modules()
self._last_mapping: Optional[PromptAdapterMapping] = None
@property
def prompt_adapter_slots(self) -> int:
return self.prompt_adapter_config.max_prompt_adapters
@property
def adapter_slots(self) -> int:
return self.prompt_adapter_slots
@property
def capacity(self) -> int:
return self.prompt_adapter_config.max_cpu_prompt_adapters
def activate_adapter(
self,
prompt_adapter_id: int,
) -> bool:
"""Move PromptAdapter into a GPU buffer
to be used in the forward pass."""
if prompt_adapter_id in self._active_adapters:
return False
first_free_slot = next(
((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate(
self.prompt_adapter_index_to_id) if prompt_adapter_id is None),
None)
if first_free_slot is None:
raise ValueError("No free prompt_adapter slots")
index, _ = first_free_slot
self._active_adapters[prompt_adapter_id] = None
prompt_adapter_model = (self._registered_adapters[prompt_adapter_id])
logger.debug("Activating prompt_adapter. int id: %d, slot index: %d",
prompt_adapter_model.id, index)
self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id
for _, v in self.modules.items():
v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding)
return True
def _deactivate_adapter(self, prompt_adapter_id: int):
try:
index = self.prompt_adapter_index_to_id.index(prompt_adapter_id)
self.prompt_adapter_index_to_id[index] = None
for _, v in self.modules.items():
v.reset_prompt_adapter(index)
except ValueError:
pass
def _add_adapter(self, prompt_adapter: PromptAdapterModel):
self._registered_adapters[prompt_adapter.id] = prompt_adapter
def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
base_indices, base_embedding_indices = convert_mapping(
mapping, self.prompt_adapter_index_to_id)
for k, v in self.modules.items():
v.set_mapping(base_indices, base_embedding_indices)
def _create_prompt_adapter_modules(self):
for module_name, module in self.model.named_modules(
remove_duplicate=False):
if "VocabParallel" in module.__class__.__name__:
new_module = VocabParallelEmbeddingWithPromptAdapter(module)
new_module.create_prompt_adapter_weights(
self.prompt_adapter_config)
replaced_module = self.replace_submodule(
self.model, module_name, new_module)
self.register_module(module.__class__.__name__,
replaced_module)
replaced_module.set_mapping(self.base_indices,
self.base_embedding_indices)
break
def replace_submodule(self, model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
def register_module(self, module_name: str, module: nn.Module):
self.modules[module_name] = module
def pin_adapter(self, prompt_adapter_id: int) -> bool:
"""Pin a PromptAdapterModel in the manager cache."""
raise NotImplementedError(
"Pinning is not supported in PromptAdapterModelManager."
"Use LRUCachePromptAdapterModelManager for pinning"
) # type: ignore
def remove_all_adapters(self):
"""Remove all PromptAdapterModel from the manager."""
self._registered_adapters.clear()
self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots
self._active_adapters.clear()
def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
self._deactivate_adapter)
def add_adapter(self, adapter: PromptAdapterModel) -> bool:
return add_adapter(adapter, self._registered_adapters, self.capacity,
self._add_adapter)
def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
self._set_adapter_mapping)
def remove_adapter(self, adapter_id: int) -> bool:
return remove_adapter(adapter_id, self._registered_adapters,
self.deactivate_adapter)
def list_adapters(self) -> Dict[int, Any]:
return list_adapters(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[Any]:
return get_adapter(adapter_id, self._registered_adapters)
class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]):
def __init__(self, capacity: int,
deactivate_prompt_adapter_fn: Callable[[int], bool]):
super().__init__(capacity, deactivate_prompt_adapter_fn)
class LRUCachePromptAdapterModelManager(PromptAdapterModelManager):
"""A model manager that manages multiple prompt_adapters with LRU cache."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
):
self.prompt_adapter_config = prompt_adapter_config
super().__init__(model, max_num_seqs, max_num_batched_tokens,
prompt_adapter_config)
self._registered_adapters = PromptAdapterLRUCache(
self.capacity, self.deactivate_adapter)
self._active_adapters = PromptAdapterLRUCache(
self.prompt_adapter_slots, self._deactivate_adapter)
def list_adapters(self) -> Dict[int, PromptAdapterModel]:
"""List all registered PromptAdapterModel."""
return dict(self._registered_adapters.cache)
def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool:
"""Add a PromptAdapterModel to the manager."""
if prompt_adapter.id not in self._registered_adapters:
self._add_adapter(prompt_adapter)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_adapters.touch(prompt_adapter.id)
was_added = False
return was_added
def activate_adapter(
self,
prompt_adapter_id: int,
) -> bool:
if prompt_adapter_id not in self._active_adapters and len(
self._active_adapters) >= self.prompt_adapter_slots:
self._active_adapters.remove_oldest()
result = super().activate_adapter(prompt_adapter_id)
# We always touch to update the LRU cache order
self._active_adapters.touch(prompt_adapter_id)
return result
def remove_oldest_adapter(self) -> bool:
if len(self._registered_adapters) > 0:
self._registered_adapters.remove_oldest()
return True
return False
def pin_adapter(self, prompt_adapter_id: int) -> bool:
"""Pin a PromptAdapterModel in the manager cache."""
self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id)
self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id)
return True
def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int):
try:
self._registered_adapters.pin(prompt_adapter_id)
except ValueError as err:
raise ValueError(
"Pinning failed. "
f"Prompt Adapter {prompt_adapter_id} is not registered."
) from err
def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int):
if prompt_adapter_id not in self._active_adapters:
# move adapter to gpu if not already active
self.activate_adapter(prompt_adapter_id)
self._active_adapters.pin(prompt_adapter_id)
def create_prompt_adapter_manager(
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
prompt_adapter_config: PromptAdapterConfig,
prompt_adapter_manager_cls: Type[
PromptAdapterModelManager] = PromptAdapterModelManager,
**kwargs) -> PromptAdapterModelManager:
"""Create a PromptAdapterModel for a given model."""
prompt_adapter_manager = prompt_adapter_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
prompt_adapter_config=prompt_adapter_config,
**kwargs)
return prompt_adapter_manager

View File

@ -0,0 +1,30 @@
from dataclasses import dataclass
from vllm.adapter_commons.request import AdapterRequest
@dataclass
class PromptAdapterRequest(AdapterRequest):
"""
Request for a Prompt adapter.
"""
prompt_adapter_name: str
prompt_adapter_id: int
prompt_adapter_local_path: str
prompt_adapter_num_virtual_tokens: int
def __hash__(self):
return super().__hash__()
@property
def adapter_id(self):
return self.prompt_adapter_id
@property
def name(self):
return self.prompt_adapter_name
@property
def local_path(self):
return self.prompt_adapter_local_path

View File

@ -0,0 +1,176 @@
import logging
from typing import Any, Optional, Set, Type
import torch
from vllm.adapter_commons.utils import (add_adapter_worker,
apply_adapters_worker,
list_adapters_worker,
set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
from vllm.config import PromptAdapterConfig
from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager,
PromptAdapterModel,
PromptAdapterModelManager,
create_prompt_adapter_manager)
from vllm.prompt_adapter.request import PromptAdapterRequest
logger = logging.getLogger(__name__)
class WorkerPromptAdapterManager(AbstractWorkerManager):
"""WorkerPromptAdapterManager that manages
prompt_adapter models on the worker side.
Every request, the requested prompt_adapters will be
loaded (unless they are already loaded),
and every other prompt_adapter will be unloaded."""
_manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager
def __init__(
self,
max_num_seqs: int,
max_num_batched_tokens: int,
device: torch.device,
prompt_adapter_config: PromptAdapterConfig,
prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel
):
self._adapter_manager: PromptAdapterModelManager
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self._prompt_adapter_model_cls = prompt_adapter_model_cls
self.prompt_adapter_config = prompt_adapter_config
super().__init__(device)
@property
def is_enabled(self) -> bool:
return True
def create_prompt_adapter_manager(
self,
model: torch.nn.Module,
) -> Any:
prompt_adapter_manager = create_prompt_adapter_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
prompt_adapter_config=self.prompt_adapter_config,
prompt_adapter_manager_cls=self._manager_cls,
)
self._adapter_manager = prompt_adapter_manager
return prompt_adapter_manager.model
def _load_adapter(
self, prompt_adapter_request: PromptAdapterRequest
) -> PromptAdapterModel:
try:
prompt_adapter = (
self._prompt_adapter_model_cls.from_local_checkpoint(
prompt_adapter_request.prompt_adapter_local_path,
prompt_adapter_id=prompt_adapter_request.prompt_adapter_id,
num_virtual_tokens=prompt_adapter_request.
prompt_adapter_num_virtual_tokens,
config=self.prompt_adapter_config,
device=str(self.device),
))
except Exception as e:
raise RuntimeError(
f"Loading prompt_adapter "
f"{prompt_adapter_request.prompt_adapter_local_path}"
f" failed") from e
return prompt_adapter
def add_dummy_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return True
def pin_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.pin_adapter(adapter_id)
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
set_active_adapters_worker(requests, mapping, self._apply_adapters,
self._adapter_manager.set_adapter_mapping)
def add_adapter(self, adapter_request: Any) -> bool:
return add_adapter_worker(adapter_request, self.list_adapters,
self._load_adapter,
self._adapter_manager.add_adapter,
self._adapter_manager.activate_adapter)
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
apply_adapters_worker(adapter_requests, self.list_adapters,
self._adapter_manager.adapter_slots,
self.remove_adapter, self.add_adapter)
def remove_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.remove_adapter(adapter_id)
def remove_all_adapters(self):
self._adapter_manager.remove_all_adapters()
def list_adapters(self) -> Set[int]:
return list_adapters_worker(self._adapter_manager.list_adapters)
class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager):
"""WorkerPromptAdapterManager that manages
prompt_adapter models on the worker side.
Uses an LRU Cache. Every request, the requested
prompt_adapters will be loaded (unless they are already loaded)
and least recently used prompt_adapters will
be unloaded if the cache is above capacity."""
_prompt_adapter_manager_cls: Type[
LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager
def create_prompt_adapter_manager(
self,
model: torch.nn.Module,
) -> Any:
prompt_adapter_manager = create_prompt_adapter_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
prompt_adapter_config=self.prompt_adapter_config,
prompt_adapter_manager_cls=self._prompt_adapter_manager_cls)
self._adapter_manager: LRUCachePromptAdapterModelManager = (
prompt_adapter_manager)
return prompt_adapter_manager.model
def _apply_adapters(
self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None:
prompt_adapters_map = {
prompt_adapter_request.prompt_adapter_id: prompt_adapter_request
for prompt_adapter_request in prompt_adapter_requests
if prompt_adapter_request
}
if len(prompt_adapters_map
) > self._adapter_manager.prompt_adapter_slots:
raise RuntimeError(
f"Number of requested prompt_adapters "
f"({len(prompt_adapters_map)}) is greater "
"than the number of GPU prompt_adapter slots "
f"({self._adapter_manager.prompt_adapter_slots}).")
for prompt_adapter in prompt_adapters_map.values():
self.add_adapter(prompt_adapter)
def add_adapter(self,
prompt_adapter_request: PromptAdapterRequest) -> bool:
if prompt_adapter_request.prompt_adapter_id not in self.list_adapters(
):
# Remove before we load the new prompt_adapter to save memory
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
self._adapter_manager.remove_oldest_adapter()
prompt_adapter = self._load_adapter(prompt_adapter_request)
loaded = self._adapter_manager.add_adapter(prompt_adapter)
else:
# If the prompt_adapter is already loaded, just touch it to
# update its position in the caches
loaded = self._adapter_manager.get_adapter(
prompt_adapter_request.prompt_adapter_id) is not None
self._adapter_manager.activate_adapter(
prompt_adapter_request.prompt_adapter_id)
return loaded

View File

@ -10,6 +10,7 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
if TYPE_CHECKING:
@ -238,6 +239,8 @@ class Sequence:
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request.
"""
def __init__(
@ -247,12 +250,14 @@ class Sequence:
block_size: int,
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None:
self.seq_id = seq_id
self.inputs = inputs
self.block_size = block_size
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.data = SequenceData(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
@ -287,6 +292,11 @@ class Sequence:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
def get_output_text_to_return(self, buffer_length: int):
# We return the full output text if the sequence is finished.
truncate = buffer_length and not self.is_finished()
@ -414,6 +424,7 @@ class SequenceGroup:
encoder_seq: Optional, the single encoder sequence. Should be None
unless you are working with an encoder/decoder model.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request.
"""
def __init__(
@ -427,6 +438,7 @@ class SequenceGroup:
pooling_params: Optional[PoolingParams] = None,
encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
@ -441,6 +453,7 @@ class SequenceGroup:
self.state = SequenceGroupState()
self.embeddings = embeddings
self.pooling_params = pooling_params
self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers
@ -466,6 +479,16 @@ class SequenceGroup:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
@property
def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
if self.prompt_adapter_request else 0
def get_last_latency(self, now: float) -> Optional[float]:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
@ -624,6 +647,7 @@ class SequenceGroupMetadata:
(SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder
model.
prompt_adapter_request: Prompt Adapter request.
"""
def __init__(
@ -642,6 +666,7 @@ class SequenceGroupMetadata:
multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
@ -650,6 +675,7 @@ class SequenceGroupMetadata:
self.block_tables = block_tables
self.pooling_params = pooling_params
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
@ -674,6 +700,16 @@ class SequenceGroupMetadata:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
@property
def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
if self.prompt_adapter_request else 0
@property
def token_chunk_size(self) -> int:
"""Return the number of tokens to be processed (chunk size)."""

View File

@ -4,7 +4,7 @@ import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
@ -48,6 +48,7 @@ class TP1DraftModelRunner(ModelRunner):
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
multimodal_config: Optional[MultiModalConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
):
if return_hidden_states:
@ -66,6 +67,7 @@ class TP1DraftModelRunner(ModelRunner):
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
)
@ -136,6 +138,13 @@ class TP1DraftModelRunner(ModelRunner):
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
virtual_engine = model_input.virtual_engine
outputs: List[SamplerOutput] = []
for step in range(num_steps):

View File

@ -8,7 +8,7 @@ from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
@ -81,6 +81,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
*args,
**kwargs,
@ -94,6 +95,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self.cache_config = cache_config
self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker

View File

@ -7,7 +7,7 @@ import torch.distributed
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
@ -133,6 +133,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
@ -145,6 +146,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
@ -167,6 +169,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=kv_cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.

View File

@ -5,7 +5,7 @@ import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
@ -40,6 +40,7 @@ class EmbeddingModelRunner(
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
):
super().__init__(model_config,
@ -51,6 +52,7 @@ class EmbeddingModelRunner(
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config)
@torch.inference_mode()
@ -71,6 +73,13 @@ class EmbeddingModelRunner(
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata

View File

@ -25,7 +25,7 @@ except ImportError:
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY
@ -40,6 +40,10 @@ from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
@ -85,6 +89,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None
prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
@ -97,6 +103,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"prompt_adapter_mapping": self.prompt_adapter_mapping,
"prompt_adapter_requests": self.prompt_adapter_requests,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
@ -133,6 +141,8 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"prompt_adapter_mapping": self.prompt_adapter_mapping,
"prompt_adapter_requests": self.prompt_adapter_requests,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
@ -172,6 +182,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False,
):
@ -183,6 +194,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.lora_config = lora_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config
self.return_hidden_states = return_hidden_states
@ -232,6 +244,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.model: nn.Module # Set after load_model
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
self.flashinfer_decode_workspace_buffer = None
self.flashinfer_decode_wrapper = None
@ -240,16 +253,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(
model_config=self.model_config,
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
)
cache_config=self.cache_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
@ -274,6 +285,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
)
self.model = self.lora_manager.create_lora_manager(self.model)
if self.prompt_adapter_config:
self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens, self.device,
self.prompt_adapter_config)
self.model = (
self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model))
if self.kv_cache_dtype == "fp8" and is_hip():
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
@ -354,6 +374,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
prompt_adapter_index_mapping: List[int] = []
prompt_adapter_prompt_mapping: List[int] = []
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
seq_lens: List[int] = []
prefill_seq_lens: List[int] = []
@ -504,6 +527,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
input_tokens.extend(tokens)
input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
if is_prompt:
assert len(seq_ids) == 1
@ -534,6 +558,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
if prompt_adapter_id > 0 and is_prompt:
prompt_adapter_requests.add(
seq_group_metadata.prompt_adapter_request)
num_tokens = seq_group_metadata.\
prompt_adapter_num_virtual_tokens
pm = [prompt_adapter_id
] * num_tokens + [0] * (query_len - num_tokens)
prompt_adapter_index_mapping += pm
prompt_adapter_prompt_mapping.extend(
[prompt_adapter_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
else 1))
is_profile_run = _is_block_tables_empty(
seq_group_metadata.block_tables)
if is_profile_run:
@ -618,12 +657,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seq_lens.append(1)
block_tables.append([])
lora_index_mapping.append(0)
prompt_adapter_index_mapping.append(0)
if self.attn_backend.get_name() == "flashinfer":
last_paged_kv_indptr = paged_kv_indptr[-1]
paged_kv_indptr.append(last_paged_kv_indptr)
paged_kv_last_page_len.append(0)
batch_size = graph_batch_size
num_decode_tokens = batch_size
@ -759,6 +797,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
else:
lora_mapping = None
if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping(
prompt_adapter_index_mapping,
prompt_adapter_prompt_mapping,
)
else:
prompt_adapter_mapping = None
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
request_ids_to_seq_ids = {
@ -776,7 +822,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=finished_requests_ids)
finished_requests_ids=finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests,
)
@torch.inference_mode()
def profile_run(self) -> None:
@ -878,33 +927,67 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def remove_all_loras(self):
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.remove_all_loras()
self.lora_manager.remove_all_adapters()
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_loras(lora_requests, lora_mapping)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_lora(lora_request)
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_lora(lora_id)
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_lora(lora_id)
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_loras()
return self.lora_manager.list_adapters()
def remove_all_prompt_adapters(self):
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
self.prompt_adapter_manager.remove_all_adapters()
def set_active_prompt_adapters(
self, prompt_adapter_requests: Set[PromptAdapterRequest],
prompt_adapter_mapping: PromptAdapterMapping) -> None:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
self.prompt_adapter_manager.set_active_adapters(
prompt_adapter_requests, prompt_adapter_mapping)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.list_adapters()
@torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
@ -1063,6 +1146,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
)
self.set_active_loras(set(), lora_mapping)
if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping(
[-1] * batch_size,
[-1] * batch_size,
)
self.set_active_prompt_adapters(
set(), prompt_adapter_mapping)
graph_runner = CUDAGraphRunner(
self.model, self.attn_backend.get_name())
@ -1189,6 +1280,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
if self.attn_backend.get_name() == "flashinfer":
assert model_input.attn_metadata is not None
assert model_input.input_tokens is not None

View File

@ -8,7 +8,8 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig)
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
@ -16,6 +17,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
@ -45,6 +47,7 @@ class Worker(LocalOrDistributedWorkerBase):
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
) -> None:
@ -59,6 +62,7 @@ class Worker(LocalOrDistributedWorkerBase):
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.load_config = load_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
@ -92,6 +96,7 @@ class Worker(LocalOrDistributedWorkerBase):
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config,
**speculative_args,
)
@ -296,6 +301,19 @@ class Worker(LocalOrDistributedWorkerBase):
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_runner.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_runner.remove_lora(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_runner.pin_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
return self.model_runner.list_prompt_adapters()
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
@ -88,6 +88,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
*args,
**kwargs,
@ -98,6 +99,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self.lora_config = lora_config
self.load_config = load_config
self.cache_config = cache_config
self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker

View File

@ -10,7 +10,8 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig)
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
@ -47,6 +48,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
) -> None:
assert device_config.device_type == "xpu"
@ -63,6 +65,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."