[CORE] Adding support for insertion of soft-tuned prompts (#4645)
Co-authored-by: Swapnil Parekh <swapnilp@ibm.com> Co-authored-by: Joe G <joseph.granados@h2o.ai> Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
a0550cbc80
commit
4d6ada947c
@ -111,6 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml
|
||||
mypy vllm/model_executor --config-file pyproject.toml
|
||||
mypy vllm/lora --config-file pyproject.toml
|
||||
mypy vllm/logging --config-file pyproject.toml
|
||||
mypy vllm/prompt_adapter --config-file pyproject.toml
|
||||
mypy tests --config-file pyproject.toml
|
||||
|
||||
|
||||
|
@ -92,11 +92,10 @@ def batched_generate(
|
||||
for input in inputs:
|
||||
prompt, sampling_param, lora_req = input
|
||||
# Add requests to the engine and run the engine
|
||||
llm._validate_and_add_requests(
|
||||
prompt,
|
||||
llm._validate_and_add_requests(prompt,
|
||||
sampling_param,
|
||||
lora_request=lora_req,
|
||||
)
|
||||
prompt_adapter_request=None)
|
||||
|
||||
outputs = llm._run_engine(use_tqdm=True)
|
||||
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
|
||||
|
@ -127,37 +127,37 @@ def test_lora_model_manager(dist_init, dummy_model):
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_lora(model_lora1)
|
||||
assert manager.activate_lora(1)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
assert manager.activate_adapter(1)
|
||||
assert manager.lora_index_to_id[0] == 1
|
||||
assert not manager.add_lora(model_lora1)
|
||||
assert not manager.activate_lora(1)
|
||||
assert manager.add_lora(model_lora2)
|
||||
assert manager.activate_lora(2)
|
||||
assert not manager.add_adapter(model_lora1)
|
||||
assert not manager.activate_adapter(1)
|
||||
assert manager.add_adapter(model_lora2)
|
||||
assert manager.activate_adapter(2)
|
||||
assert manager.lora_index_to_id[0] == 1
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
assert not manager.add_lora(model_lora2)
|
||||
assert not manager.activate_lora(2)
|
||||
assert manager.add_lora(model_lora3)
|
||||
assert not manager.add_adapter(model_lora2)
|
||||
assert not manager.activate_adapter(2)
|
||||
assert manager.add_adapter(model_lora3)
|
||||
assert manager.lora_index_to_id[0] == 1
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
with pytest.raises(ValueError):
|
||||
assert manager.activate_lora(3)
|
||||
assert manager.activate_adapter(3)
|
||||
assert manager.lora_index_to_id[0] == 1
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
assert manager.remove_lora(model_lora2.id)
|
||||
assert manager.remove_adapter(model_lora2.id)
|
||||
assert manager.lora_index_to_id[1] is None
|
||||
assert not manager.remove_lora(model_lora2.id)
|
||||
assert manager.remove_lora(model_lora1.id)
|
||||
assert not manager.remove_lora(model_lora1.id)
|
||||
assert manager.add_lora(model_lora1)
|
||||
assert not manager.remove_adapter(model_lora2.id)
|
||||
assert manager.remove_adapter(model_lora1.id)
|
||||
assert not manager.remove_adapter(model_lora1.id)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
assert manager.lora_index_to_id[0] is None
|
||||
assert manager.lora_index_to_id[1] is None
|
||||
assert manager.add_lora(model_lora2)
|
||||
assert manager.activate_lora(3)
|
||||
assert manager.add_adapter(model_lora2)
|
||||
assert manager.activate_adapter(3)
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] is None
|
||||
assert manager.activate_lora(2)
|
||||
assert manager.activate_adapter(2)
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
|
||||
@ -173,70 +173,70 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_lora(model_lora1)
|
||||
assert manager.activate_lora(1)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
assert manager.activate_adapter(1)
|
||||
assert manager.lora_index_to_id[0] == 1
|
||||
assert not manager.add_lora(model_lora1)
|
||||
assert not manager.activate_lora(1)
|
||||
assert manager.add_lora(model_lora2)
|
||||
assert manager.activate_lora(2)
|
||||
assert not manager.add_adapter(model_lora1)
|
||||
assert not manager.activate_adapter(1)
|
||||
assert manager.add_adapter(model_lora2)
|
||||
assert manager.activate_adapter(2)
|
||||
assert manager.lora_index_to_id[0] == 1
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
assert not manager.add_lora(model_lora2)
|
||||
assert not manager.activate_lora(2)
|
||||
assert manager.add_lora(model_lora3)
|
||||
assert not manager.add_adapter(model_lora2)
|
||||
assert not manager.activate_adapter(2)
|
||||
assert manager.add_adapter(model_lora3)
|
||||
assert manager.lora_index_to_id[0] == 1
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
assert manager.activate_lora(3)
|
||||
assert manager.activate_adapter(3)
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
assert manager.remove_lora(model_lora2.id)
|
||||
assert manager.remove_adapter(model_lora2.id)
|
||||
assert manager.lora_index_to_id[1] is None
|
||||
assert not manager.remove_lora(model_lora2.id)
|
||||
assert manager.remove_lora(model_lora1.id)
|
||||
assert not manager.remove_lora(model_lora1.id)
|
||||
assert manager.add_lora(model_lora1)
|
||||
assert manager.activate_lora(1)
|
||||
assert not manager.remove_adapter(model_lora2.id)
|
||||
assert manager.remove_adapter(model_lora1.id)
|
||||
assert not manager.remove_adapter(model_lora1.id)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
assert manager.activate_adapter(1)
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
assert manager.add_lora(model_lora2)
|
||||
assert manager.deactivate_lora(3)
|
||||
assert manager.add_adapter(model_lora2)
|
||||
assert manager.deactivate_adapter(3)
|
||||
assert manager.lora_index_to_id[0] is None
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
assert manager.activate_lora(2)
|
||||
assert manager.activate_adapter(2)
|
||||
assert manager.lora_index_to_id[0] == 2
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
assert manager.activate_lora(3)
|
||||
assert manager.activate_adapter(3)
|
||||
assert manager.lora_index_to_id[0] == 2
|
||||
assert manager.lora_index_to_id[1] == 3
|
||||
assert manager.pin_lora(2)
|
||||
assert manager.pin_adapter(2)
|
||||
assert manager.lora_index_to_id[0] == 2
|
||||
assert manager.lora_index_to_id[1] == 3
|
||||
assert manager.activate_lora(1)
|
||||
assert manager.activate_adapter(1)
|
||||
assert manager.lora_index_to_id[0] == 2
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
assert manager.deactivate_lora(2)
|
||||
assert manager.deactivate_adapter(2)
|
||||
assert manager.lora_index_to_id[0] is None
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
assert manager.activate_lora(3)
|
||||
assert manager.activate_adapter(3)
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
assert manager.pin_lora(3)
|
||||
assert manager.pin_lora(1)
|
||||
assert manager.pin_adapter(3)
|
||||
assert manager.pin_adapter(1)
|
||||
with pytest.raises(RuntimeError):
|
||||
assert manager.pin_lora(2)
|
||||
assert manager.pin_adapter(2)
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
with pytest.raises(RuntimeError):
|
||||
assert manager.activate_lora(2)
|
||||
assert manager.activate_adapter(2)
|
||||
|
||||
assert manager.deactivate_lora(3)
|
||||
assert manager.pin_lora(2)
|
||||
assert manager.deactivate_adapter(3)
|
||||
assert manager.pin_adapter(2)
|
||||
assert manager.lora_index_to_id[0] == 2
|
||||
assert manager.lora_index_to_id[1] == 1
|
||||
assert manager.remove_lora(3)
|
||||
assert manager.remove_adapter(3)
|
||||
with pytest.raises(ValueError):
|
||||
assert manager.pin_lora(3)
|
||||
assert manager.pin_adapter(3)
|
||||
|
||||
|
||||
def test_lru_lora_model_manager(dist_init, dummy_model):
|
||||
@ -256,168 +256,169 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
|
||||
# Add up to capacity
|
||||
assert manager.add_lora(model_lora1)
|
||||
assert manager.add_lora(model_lora2)
|
||||
assert manager.activate_lora(1)
|
||||
assert manager.activate_lora(2)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
assert manager.add_adapter(model_lora2)
|
||||
assert manager.activate_adapter(1)
|
||||
assert manager.activate_adapter(2)
|
||||
|
||||
assert set(manager.list_loras()) == {1, 2}
|
||||
assert set(manager.list_adapters()) == {1, 2}
|
||||
assert manager.lora_index_to_id[0] == 1
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
|
||||
# Add over capacity
|
||||
assert manager.add_lora(model_lora3)
|
||||
assert manager.add_lora(model_lora4)
|
||||
assert manager.activate_lora(3)
|
||||
assert manager.activate_lora(4)
|
||||
assert manager.add_adapter(model_lora3)
|
||||
assert manager.add_adapter(model_lora4)
|
||||
assert manager.activate_adapter(3)
|
||||
assert manager.activate_adapter(4)
|
||||
|
||||
assert set(manager.list_loras()) == {3, 4}
|
||||
assert set(manager.list_adapters()) == {3, 4}
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 4
|
||||
|
||||
# Add 3 again to move it to the top and then add 2
|
||||
# should return false since it's in already
|
||||
assert not manager.add_lora(model_lora3)
|
||||
assert not manager.activate_lora(3)
|
||||
assert manager.add_lora(model_lora2)
|
||||
assert manager.activate_lora(2)
|
||||
assert not manager.add_adapter(model_lora3)
|
||||
assert not manager.activate_adapter(3)
|
||||
assert manager.add_adapter(model_lora2)
|
||||
assert manager.activate_adapter(2)
|
||||
|
||||
assert set(manager.list_loras()) == {3, 2}
|
||||
assert set(manager.list_adapters()) == {3, 2}
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
|
||||
# Remove manually
|
||||
assert manager.remove_lora(3)
|
||||
assert not manager.remove_lora(3)
|
||||
assert manager.remove_adapter(3)
|
||||
assert not manager.remove_adapter(3)
|
||||
|
||||
assert set(manager.list_loras()) == {2}
|
||||
assert set(manager.list_adapters()) == {2}
|
||||
assert manager.lora_index_to_id[0] is None
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
|
||||
assert manager.add_lora(model_lora3)
|
||||
assert manager.activate_lora(3)
|
||||
assert manager.add_lora(model_lora4)
|
||||
assert manager.activate_lora(4)
|
||||
assert manager.add_adapter(model_lora3)
|
||||
assert manager.activate_adapter(3)
|
||||
assert manager.add_adapter(model_lora4)
|
||||
assert manager.activate_adapter(4)
|
||||
|
||||
assert set(manager.list_loras()) == {3, 4}
|
||||
assert set(manager.list_adapters()) == {3, 4}
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 4
|
||||
|
||||
assert manager.remove_oldest_lora()
|
||||
assert set(manager.list_loras()) == {4}
|
||||
assert manager.remove_oldest_adapter()
|
||||
assert set(manager.list_adapters()) == {4}
|
||||
assert manager.lora_index_to_id[0] is None
|
||||
assert manager.lora_index_to_id[1] == 4
|
||||
|
||||
assert manager.remove_oldest_lora()
|
||||
assert set(manager.list_loras()) == set()
|
||||
assert manager.remove_oldest_adapter()
|
||||
assert set(manager.list_adapters()) == set()
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
|
||||
assert not manager.remove_oldest_lora()
|
||||
assert set(manager.list_loras()) == set()
|
||||
assert not manager.remove_oldest_adapter()
|
||||
assert set(manager.list_adapters()) == set()
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
|
||||
# pinning
|
||||
assert manager.add_lora(model_lora3)
|
||||
assert manager.activate_lora(3)
|
||||
assert manager.add_lora(model_lora4)
|
||||
assert manager.activate_lora(4)
|
||||
assert set(manager.list_loras()) == {3, 4}
|
||||
assert manager.add_adapter(model_lora3)
|
||||
assert manager.activate_adapter(3)
|
||||
assert manager.add_adapter(model_lora4)
|
||||
assert manager.activate_adapter(4)
|
||||
assert set(manager.list_adapters()) == {3, 4}
|
||||
with pytest.raises(ValueError):
|
||||
assert manager.pin_lora(1)
|
||||
assert manager.pin_lora(3)
|
||||
assert manager.pin_adapter(1)
|
||||
assert manager.pin_adapter(3)
|
||||
# Remove manually
|
||||
assert manager.remove_lora(3)
|
||||
assert not manager.remove_lora(3)
|
||||
assert manager.remove_adapter(3)
|
||||
assert not manager.remove_adapter(3)
|
||||
|
||||
assert set(manager.list_loras()) == {4}
|
||||
assert set(manager.list_adapters()) == {4}
|
||||
assert manager.lora_index_to_id[0] is None
|
||||
assert manager.lora_index_to_id[1] == 4
|
||||
|
||||
assert manager.add_lora(model_lora1)
|
||||
assert manager.pin_lora(1)
|
||||
assert manager.add_lora(model_lora2)
|
||||
assert manager.activate_lora(2)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
assert manager.pin_adapter(1)
|
||||
assert manager.add_adapter(model_lora2)
|
||||
assert manager.activate_adapter(2)
|
||||
|
||||
assert set(manager.list_loras()) == {1, 2}
|
||||
assert set(manager.list_adapters()) == {1, 2}
|
||||
assert manager.lora_index_to_id[0] == 1
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
|
||||
assert manager.remove_oldest_lora()
|
||||
assert set(manager.list_loras()) == {1}
|
||||
assert manager.remove_oldest_adapter()
|
||||
assert set(manager.list_adapters()) == {1}
|
||||
assert manager.lora_index_to_id[0] == 1
|
||||
assert manager.lora_index_to_id[1] is None
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
assert manager.remove_oldest_lora()
|
||||
assert manager.remove_oldest_adapter()
|
||||
|
||||
assert set(manager.list_loras()) == {1}
|
||||
assert set(manager.list_adapters()) == {1}
|
||||
|
||||
|
||||
def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
||||
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files):
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
worker_lora_manager = LRUCacheWorkerLoRAManager(
|
||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
||||
worker_adapter_manager.create_lora_manager(
|
||||
llama_2_7b_model_extra_embeddings)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
worker_lora_manager.set_active_loras([
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files)
|
||||
], mapping)
|
||||
assert worker_lora_manager.list_loras() == {1, 2}
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
|
||||
worker_lora_manager.set_active_loras([
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("3", 3, sql_lora_files),
|
||||
LoRARequest("4", 4, sql_lora_files)
|
||||
], mapping)
|
||||
assert worker_lora_manager.list_loras() == {1, 2, 3, 4}
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_lora_manager.set_active_loras([
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files),
|
||||
LoRARequest("5", 5, sql_lora_files)
|
||||
], mapping)
|
||||
assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_lora_manager.set_active_loras([
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files)
|
||||
], mapping)
|
||||
assert worker_lora_manager.list_loras() == {1, 2, 4, 5}
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_lora_manager.set_active_loras([
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("6", 6, sql_lora_files),
|
||||
LoRARequest("7", 7, sql_lora_files),
|
||||
LoRARequest("8", 8, sql_lora_files)
|
||||
], mapping)
|
||||
assert worker_lora_manager.list_loras() == {1, 6, 7, 8}
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6
|
||||
assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
|
||||
|
||||
# Over capacity
|
||||
with pytest.raises(RuntimeError):
|
||||
worker_lora_manager.set_active_loras([
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("10", 10, sql_lora_files),
|
||||
LoRARequest("11", 11, sql_lora_files),
|
||||
LoRARequest("12", 12, sql_lora_files),
|
||||
@ -426,68 +427,69 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
||||
], mapping)
|
||||
|
||||
|
||||
def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
||||
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files):
|
||||
# Should remove every LoRA not specified in the request.
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
worker_lora_manager = WorkerLoRAManager(
|
||||
worker_adapter_manager = WorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
||||
worker_adapter_manager.create_lora_manager(
|
||||
llama_2_7b_model_extra_embeddings)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
worker_lora_manager.set_active_loras([
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files)
|
||||
], mapping)
|
||||
assert worker_lora_manager.list_loras() == {1, 2}
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
|
||||
worker_lora_manager.set_active_loras([
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("3", 3, sql_lora_files),
|
||||
LoRARequest("4", 4, sql_lora_files)
|
||||
], mapping)
|
||||
assert worker_lora_manager.list_loras() == {1, 3, 4}
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4
|
||||
assert worker_adapter_manager.list_adapters() == {1, 3, 4}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
|
||||
|
||||
worker_lora_manager.set_active_loras([
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files),
|
||||
LoRARequest("5", 5, sql_lora_files)
|
||||
], mapping)
|
||||
assert worker_lora_manager.list_loras() == {1, 2, 5}
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
|
||||
|
||||
worker_lora_manager.set_active_loras([
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files)
|
||||
], mapping)
|
||||
assert worker_lora_manager.list_loras() == {1}
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None
|
||||
assert worker_adapter_manager.list_adapters() == {1}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
|
||||
|
||||
worker_lora_manager.set_active_loras([
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("6", 6, sql_lora_files),
|
||||
LoRARequest("7", 7, sql_lora_files),
|
||||
LoRARequest("8", 8, sql_lora_files)
|
||||
], mapping)
|
||||
assert worker_lora_manager.list_loras() == {6, 7, 8}
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6
|
||||
assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7
|
||||
assert worker_adapter_manager.list_adapters() == {6, 7, 8}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
|
||||
|
||||
# Over capacity
|
||||
with pytest.raises(RuntimeError):
|
||||
worker_lora_manager.set_active_loras([
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("10", 10, sql_lora_files),
|
||||
LoRARequest("11", 11, sql_lora_files),
|
||||
LoRARequest("12", 12, sql_lora_files),
|
||||
@ -525,8 +527,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
|
||||
|
||||
assert isinstance(model.get_submodule("gate_up_proj"),
|
||||
MergedColumnParallelLinearWithLoRA)
|
||||
assert manager.add_lora(model_lora)
|
||||
assert manager.add_lora(model_lora1)
|
||||
assert manager.add_adapter(model_lora)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
|
||||
packed_lora = model_lora.get_lora("gate_up_proj")
|
||||
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
|
||||
|
45
tests/prompt_adapter/test_bloom.py
Normal file
45
tests/prompt_adapter/test_bloom.py
Normal file
@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
|
||||
MODEL_PATH = "bigscience/bloomz-560m"
|
||||
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
|
||||
|
||||
|
||||
def do_sample(llm, pa_name: str, pa_id: int):
|
||||
|
||||
prompts = [
|
||||
"Tweet text : @nationalgridus I have no water and the bill is \
|
||||
current and paid. Can you do something about this? Label : ",
|
||||
"Tweet text : @nationalgridus Looks good thanks! Label : "
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0.0,
|
||||
max_tokens=3,
|
||||
stop_token_ids=[3])
|
||||
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
prompt_adapter_request=PromptAdapterRequest(
|
||||
pa_name, pa_id, PA_PATH, 8) if pa_id else None)
|
||||
|
||||
# Print the outputs.
|
||||
generated_texts = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_twitter_prompt_adapter(enforce_eager: bool):
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_adapter=True,
|
||||
max_prompt_adapter_token=8)
|
||||
|
||||
expected_output = ['complaint', 'no complaint']
|
||||
|
||||
assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output
|
53
tests/prompt_adapter/test_multi_adapter_inference.py
Normal file
53
tests/prompt_adapter/test_multi_adapter_inference.py
Normal file
@ -0,0 +1,53 @@
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
|
||||
MODEL_PATH = "bigscience/bloomz-560m"
|
||||
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
|
||||
pa_path2 = 'swapnilbp/angry_tweet_ptune'
|
||||
|
||||
|
||||
def do_sample(engine):
|
||||
|
||||
prompts = [
|
||||
("Tweet text: I have complaints! Label: ",
|
||||
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
|
||||
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
|
||||
("Tweet text: I have no problems Label: ",
|
||||
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
|
||||
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
|
||||
("Tweet text: I have complaints! Label: ",
|
||||
SamplingParams(temperature=0.0, max_tokens=3), None),
|
||||
("Tweet text: I have no problems Label: ",
|
||||
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
|
||||
PromptAdapterRequest("complain", 3, pa_path, 8)),
|
||||
]
|
||||
|
||||
request_id = 0
|
||||
results = set()
|
||||
while prompts or engine.has_unfinished_requests():
|
||||
if prompts:
|
||||
prompt, sampling_params, pa_request = prompts.pop(0)
|
||||
engine.add_request(str(request_id),
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_adapter_request=pa_request)
|
||||
request_id += 1
|
||||
|
||||
request_outputs = engine.step()
|
||||
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
results.add(request_output.outputs[0].text)
|
||||
return results
|
||||
|
||||
|
||||
def test_multi_prompt_adapters():
|
||||
engine_args = EngineArgs(model=MODEL_PATH,
|
||||
max_prompt_adapters=3,
|
||||
enable_prompt_adapter=True,
|
||||
max_prompt_adapter_token=8)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
expected_output = {
|
||||
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
|
||||
}
|
||||
assert do_sample(engine) == expected_output
|
61
tests/prompt_adapter/test_pa_lora.py
Normal file
61
tests/prompt_adapter/test_pa_lora.py
Normal file
@ -0,0 +1,61 @@
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
|
||||
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
|
||||
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||
|
||||
|
||||
def do_sample(engine):
|
||||
|
||||
prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501
|
||||
|
||||
# first prompt with a prompt adapter and second without adapter
|
||||
prompts = [
|
||||
(prompt_text,
|
||||
SamplingParams(temperature=0.0, max_tokens=100,
|
||||
stop=["[/assistant]"]),
|
||||
PromptAdapterRequest("hate_speech", 1, pa_path,
|
||||
8), LoRARequest("sql_test", 1, lora_path)),
|
||||
(prompt_text,
|
||||
SamplingParams(temperature=0.0, max_tokens=100,
|
||||
stop=["[/assistant]"]), None,
|
||||
LoRARequest("sql_test", 1, lora_path)),
|
||||
]
|
||||
|
||||
request_id = 0
|
||||
results = set()
|
||||
while prompts or engine.has_unfinished_requests():
|
||||
if prompts:
|
||||
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
|
||||
engine.add_request(str(request_id),
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_adapter_request=pa_request,
|
||||
lora_request=lora_request)
|
||||
request_id += 1
|
||||
|
||||
request_outputs = engine.step()
|
||||
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
results.add(request_output.outputs[0].text)
|
||||
return results
|
||||
|
||||
|
||||
def test_lora_prompt_adapter():
|
||||
engine_args = EngineArgs(model=MODEL_PATH,
|
||||
enable_prompt_adapter=True,
|
||||
enable_lora=True,
|
||||
max_num_seqs=60,
|
||||
max_prompt_adapter_token=8)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
result = do_sample(engine)
|
||||
|
||||
expected_output = {
|
||||
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
|
||||
}
|
||||
assert result == expected_output
|
@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -92,6 +93,7 @@ class AsyncLLM:
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalDataDict] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> List[RequestOutput]:
|
||||
|
||||
if prompts is None:
|
||||
|
@ -23,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
|
||||
cache_config=engine_config.cache_config,
|
||||
load_config=engine_config.load_config,
|
||||
lora_config=engine_config.lora_config,
|
||||
prompt_adapter_config=engine_config.prompt_adapter_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
return model_runner
|
||||
|
0
vllm/adapter_commons/__init__.py
Normal file
0
vllm/adapter_commons/__init__.py
Normal file
14
vllm/adapter_commons/layers.py
Normal file
14
vllm/adapter_commons/layers.py
Normal file
@ -0,0 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterMapping:
|
||||
# Per every token in input_ids:
|
||||
index_mapping: Tuple[int, ...]
|
||||
# Per sampled token:
|
||||
prompt_mapping: Tuple[int, ...]
|
||||
|
||||
def __post_init__(self):
|
||||
self.index_mapping = tuple(self.index_mapping)
|
||||
self.prompt_mapping = tuple(self.prompt_mapping)
|
104
vllm/adapter_commons/models.py
Normal file
104
vllm/adapter_commons/models.py
Normal file
@ -0,0 +1,104 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import LRUCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AdapterModel(ABC):
|
||||
|
||||
def __init__(self, model_id=None):
|
||||
self.id = model_id
|
||||
|
||||
@abstractmethod
|
||||
def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
|
||||
# Common initialization code
|
||||
# Load weights or embeddings from local checkpoint
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class AdapterLRUCache(LRUCache[T]):
|
||||
|
||||
def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
|
||||
None]):
|
||||
super().__init__(capacity)
|
||||
self.deactivate_fn = deactivate_fn
|
||||
|
||||
def _on_remove(self, key: Hashable, value: T):
|
||||
logger.debug("Removing adapter int id: %d", key)
|
||||
self.deactivate_fn(key)
|
||||
return super()._on_remove(key, value)
|
||||
|
||||
|
||||
class AdapterModelManager(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
):
|
||||
"""Create a AdapterModelManager and adapter for a given model.
|
||||
Args:
|
||||
model: the model to be adapted.
|
||||
"""
|
||||
self.model: nn.Module = model
|
||||
self._registered_adapters: Dict[int, Any] = {}
|
||||
# Dict instead of a Set for compatibility with LRUCache.
|
||||
self._active_adapters: Dict[int, None] = {}
|
||||
self.adapter_type = 'Adapter'
|
||||
self._last_mapping = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._registered_adapters)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def adapter_slots(self):
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def capacity(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def activate_adapter(self, adapter_id: int) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_adapter(self, adapter: Any) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_adapter_mapping(self, mapping: Any) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def remove_all_adapters(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_adapter(self, adapter_id: int) -> Optional[Any]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def list_adapters(self) -> Dict[int, Any]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def pin_adapter(self, adapter_id: int) -> bool:
|
||||
...
|
25
vllm/adapter_commons/request.py
Normal file
25
vllm/adapter_commons/request.py
Normal file
@ -0,0 +1,25 @@
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterRequest:
|
||||
"""
|
||||
Base class for adapter requests.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def adapter_id(self):
|
||||
...
|
||||
|
||||
def __post_init__(self):
|
||||
if self.adapter_id < 1:
|
||||
raise ValueError(f"id must be > 0, got {self.adapter_id}")
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(
|
||||
value, self.__class__) and self.adapter_id == value.adapter_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.adapter_id)
|
90
vllm/adapter_commons/utils.py
Normal file
90
vllm/adapter_commons/utils.py
Normal file
@ -0,0 +1,90 @@
|
||||
from typing import Any, Callable, Dict, Optional, Set
|
||||
|
||||
|
||||
## model functions
|
||||
def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None],
|
||||
deactivate_func: Callable) -> bool:
|
||||
if adapter_id in active_adapters:
|
||||
deactivate_func(adapter_id)
|
||||
active_adapters.pop(adapter_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def add_adapter(adapter: Any, registered_adapters: Dict[int, Any],
|
||||
capacity: int, add_func: Callable) -> bool:
|
||||
if adapter.id not in registered_adapters:
|
||||
if len(registered_adapters) >= capacity:
|
||||
raise RuntimeError('No free adapter slots.')
|
||||
add_func(adapter)
|
||||
registered_adapters[adapter.id] = adapter
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def set_adapter_mapping(mapping: Any, last_mapping: Any,
|
||||
set_mapping_func: Callable) -> Any:
|
||||
if last_mapping != mapping:
|
||||
set_mapping_func(mapping)
|
||||
return mapping
|
||||
return last_mapping
|
||||
|
||||
|
||||
def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any],
|
||||
deactivate_func: Callable) -> bool:
|
||||
deactivate_func(adapter_id)
|
||||
return bool(registered_adapters.pop(adapter_id, None))
|
||||
|
||||
|
||||
def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]:
|
||||
return dict(registered_adapters)
|
||||
|
||||
|
||||
def get_adapter(adapter_id: int,
|
||||
registered_adapters: Dict[int, Any]) -> Optional[Any]:
|
||||
return registered_adapters.get(adapter_id, None)
|
||||
|
||||
|
||||
## worker functions
|
||||
def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any],
|
||||
apply_adapters_func,
|
||||
set_adapter_mapping_func) -> None:
|
||||
apply_adapters_func(requests)
|
||||
set_adapter_mapping_func(mapping)
|
||||
|
||||
|
||||
def add_adapter_worker(adapter_request: Any, list_adapters_func,
|
||||
load_adapter_func, add_adapter_func,
|
||||
activate_adapter_func) -> bool:
|
||||
if adapter_request.adapter_id in list_adapters_func():
|
||||
return False
|
||||
loaded_adapter = load_adapter_func(adapter_request)
|
||||
loaded = add_adapter_func(loaded_adapter)
|
||||
activate_adapter_func(loaded_adapter.id)
|
||||
return loaded
|
||||
|
||||
|
||||
def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func,
|
||||
adapter_slots: int, remove_adapter_func,
|
||||
add_adapter_func) -> None:
|
||||
models_that_exist = list_adapters_func()
|
||||
models_map = {
|
||||
adapter_request.adapter_id: adapter_request
|
||||
for adapter_request in adapter_requests if adapter_request
|
||||
}
|
||||
if len(models_map) > adapter_slots:
|
||||
raise RuntimeError(
|
||||
f"Number of requested models ({len(models_map)}) is greater "
|
||||
f"than the number of GPU model slots "
|
||||
f"({adapter_slots}).")
|
||||
new_models = set(models_map)
|
||||
models_to_add = new_models - models_that_exist
|
||||
models_to_remove = models_that_exist - new_models
|
||||
for adapter_id in models_to_remove:
|
||||
remove_adapter_func(adapter_id)
|
||||
for adapter_id in models_to_add:
|
||||
add_adapter_func(models_map[adapter_id])
|
||||
|
||||
|
||||
def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]:
|
||||
return set(adapter_manager_list_adapters_func())
|
36
vllm/adapter_commons/worker_manager.py
Normal file
36
vllm/adapter_commons/worker_manager.py
Normal file
@ -0,0 +1,36 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbstractWorkerManager(ABC):
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
self.device = device
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_enabled(self) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_active_adapters(self, requests: Set[Any],
|
||||
mapping: Optional[Any]) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_adapter(self, adapter_request: Any) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def remove_all_adapters(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def list_adapters(self) -> Set[int]:
|
||||
...
|
@ -1285,6 +1285,39 @@ class LoRAConfig:
|
||||
raise ValueError("LoRA is not supported with chunked prefill yet.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterConfig:
|
||||
max_prompt_adapters: int
|
||||
max_prompt_adapter_token: int
|
||||
max_cpu_prompt_adapters: Optional[int] = None
|
||||
prompt_adapter_dtype: Optional[torch.dtype] = None
|
||||
|
||||
def __post_init__(self):
|
||||
library_name = 'peft'
|
||||
try:
|
||||
__import__(library_name)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"'{library_name}' is not installed for prompt adapter support."
|
||||
f"Please install it using 'pip install {library_name}'."
|
||||
) from e
|
||||
|
||||
if self.max_prompt_adapters < 1:
|
||||
raise ValueError(f"max_prompt_adapters "
|
||||
f"({self.max_prompt_adapters}) must be >= 1.")
|
||||
if self.max_prompt_adapter_token == 0:
|
||||
raise ValueError("max_prompt_adapter_token must be set.")
|
||||
if self.max_cpu_prompt_adapters is None:
|
||||
self.max_cpu_prompt_adapters = self.max_prompt_adapters
|
||||
|
||||
def verify_with_model_config(self, model_config: ModelConfig):
|
||||
if self.prompt_adapter_dtype in (None, "auto"):
|
||||
self.prompt_adapter_dtype = model_config.dtype
|
||||
elif isinstance(self.prompt_adapter_dtype, str):
|
||||
self.prompt_adapter_dtype = getattr(torch,
|
||||
self.prompt_adapter_dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalConfig:
|
||||
"""Configs the input data format and how models should run for
|
||||
@ -1518,6 +1551,7 @@ class EngineConfig:
|
||||
speculative_config: Optional[SpeculativeConfig]
|
||||
decoding_config: Optional[DecodingConfig]
|
||||
observability_config: Optional[ObservabilityConfig]
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig]
|
||||
|
||||
def __post_init__(self):
|
||||
"""Verify configs are valid & consistent with each other.
|
||||
@ -1529,6 +1563,9 @@ class EngineConfig:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
self.scheduler_config)
|
||||
if self.prompt_adapter_config:
|
||||
self.prompt_adapter_config.verify_with_model_config(
|
||||
self.model_config)
|
||||
|
||||
def to_dict(self):
|
||||
"""Return the configs as a dictionary, for use in **kwargs.
|
||||
|
@ -11,6 +11,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.core.policy import Policy, PolicyFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceStatus)
|
||||
|
||||
@ -139,6 +140,8 @@ class SchedulerOutputs:
|
||||
if self.num_loras > 0:
|
||||
self._sort_by_lora_ids()
|
||||
|
||||
self.num_prompt_adapters: int = len(self.prompt_adapter_requests)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
# NOTE: We do not consider the ignored sequence groups.
|
||||
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
|
||||
@ -157,6 +160,14 @@ class SchedulerOutputs:
|
||||
if g.seq_group.lora_request is not None
|
||||
}
|
||||
|
||||
@property
|
||||
def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]:
|
||||
return {
|
||||
g.seq_group.prompt_adapter_request
|
||||
for g in self.scheduled_seq_groups
|
||||
if g.seq_group.prompt_adapter_request is not None
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerRunningOutputs:
|
||||
@ -1024,6 +1035,7 @@ class Scheduler:
|
||||
# `multi_modal_data` will be None.
|
||||
multi_modal_data=seq_group.multi_modal_data
|
||||
if scheduler_outputs.num_prefill_groups > 0 else None,
|
||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||
)
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
|
@ -7,8 +7,8 @@ from typing import List, Optional, Tuple, Union
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig,
|
||||
TokenizerPoolConfig)
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TokenizerPoolConfig)
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -66,6 +66,9 @@ class EngineArgs:
|
||||
enable_lora: bool = False
|
||||
max_loras: int = 1
|
||||
max_lora_rank: int = 16
|
||||
enable_prompt_adapter: bool = False
|
||||
max_prompt_adapters: int = 1
|
||||
max_prompt_adapter_token: int = 0
|
||||
fully_sharded_loras: bool = False
|
||||
lora_extra_vocab_size: int = 256
|
||||
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
||||
@ -449,6 +452,17 @@ class EngineArgs:
|
||||
'Enabling this will use the fully sharded layers. '
|
||||
'At high sequence length, max rank or '
|
||||
'tensor parallel size, this is likely faster.'))
|
||||
parser.add_argument('--enable-prompt-adapter',
|
||||
action='store_true',
|
||||
help='If True, enable handling of PromptAdapters.')
|
||||
parser.add_argument('--max-prompt-adapters',
|
||||
type=int,
|
||||
default=EngineArgs.max_prompt_adapters,
|
||||
help='Max number of PromptAdapters in a batch.')
|
||||
parser.add_argument('--max-prompt-adapter-token',
|
||||
type=int,
|
||||
default=EngineArgs.max_prompt_adapter_token,
|
||||
help='Max number of PromptAdapters tokens')
|
||||
parser.add_argument("--device",
|
||||
type=str,
|
||||
default=EngineArgs.device,
|
||||
@ -726,6 +740,11 @@ class EngineArgs:
|
||||
model_loader_extra_config=self.model_loader_extra_config,
|
||||
)
|
||||
|
||||
prompt_adapter_config = PromptAdapterConfig(
|
||||
max_prompt_adapters=self.max_prompt_adapters,
|
||||
max_prompt_adapter_token=self.max_prompt_adapter_token) \
|
||||
if self.enable_prompt_adapter else None
|
||||
|
||||
decoding_config = DecodingConfig(
|
||||
guided_decoding_backend=self.guided_decoding_backend)
|
||||
|
||||
@ -751,6 +770,7 @@ class EngineArgs:
|
||||
load_config=load_config,
|
||||
decoding_config=decoding_config,
|
||||
observability_config=observability_config,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
)
|
||||
|
||||
|
||||
|
@ -18,6 +18,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -264,6 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
if isinstance(inputs, str):
|
||||
inputs = {"prompt": inputs}
|
||||
@ -279,6 +281,12 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
if prompt_adapter_request:
|
||||
prompt_token_ids = [
|
||||
0
|
||||
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
|
||||
prompt_token_ids
|
||||
|
||||
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
@ -293,6 +301,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
@ -301,7 +310,10 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
arrival_time = time.time()
|
||||
|
||||
processed_inputs = await self.process_model_inputs_async(
|
||||
request_id=request_id, inputs=inputs, lora_request=lora_request)
|
||||
request_id=request_id,
|
||||
inputs=inputs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
@ -309,6 +321,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
@ -627,6 +640,7 @@ class AsyncLLMEngine:
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
if isinstance(inputs, str):
|
||||
@ -669,7 +683,7 @@ class AsyncLLMEngine:
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return stream
|
||||
|
||||
@ -680,6 +694,7 @@ class AsyncLLMEngine:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
@ -695,6 +710,8 @@ class AsyncLLMEngine:
|
||||
request_id: The unique id of the request.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: Prompt Adapter request to use
|
||||
for generation, if any.
|
||||
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMEngine
|
||||
@ -749,6 +766,7 @@ class AsyncLLMEngine:
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, RequestOutput)
|
||||
|
||||
@ -837,6 +855,7 @@ class AsyncLLMEngine:
|
||||
*,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Common logic to process requests with SamplingParams or
|
||||
PoolingParams."""
|
||||
@ -849,6 +868,7 @@ class AsyncLLMEngine:
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -8,7 +8,8 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, SchedulerConfig,
|
||||
ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
|
||||
SchedulerOutputs)
|
||||
@ -27,6 +28,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
PoolerOutput, SamplerOutput, Sequence,
|
||||
@ -93,6 +95,8 @@ class LLMEngine:
|
||||
decoding.
|
||||
executor_class: The model executor class for managing distributed
|
||||
execution.
|
||||
prompt_adapter_config (Optional): The configuration related to serving
|
||||
prompt adapters.
|
||||
log_stats: Whether to log statistics.
|
||||
usage_context: Specified entry point, used for usage info collection.
|
||||
"""
|
||||
@ -161,6 +165,7 @@ class LLMEngine:
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
decoding_config: Optional[DecodingConfig],
|
||||
observability_config: Optional[ObservabilityConfig],
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig],
|
||||
executor_class: Type[ExecutorBase],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
@ -222,6 +227,7 @@ class LLMEngine:
|
||||
self.speculative_config = speculative_config
|
||||
self.load_config = load_config
|
||||
self.decoding_config = decoding_config or DecodingConfig()
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.observability_config = observability_config or ObservabilityConfig(
|
||||
)
|
||||
self.log_stats = log_stats
|
||||
@ -250,6 +256,7 @@ class LLMEngine:
|
||||
multimodal_config=multimodal_config,
|
||||
speculative_config=speculative_config,
|
||||
load_config=load_config,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
)
|
||||
|
||||
if not self.model_config.embedding_mode:
|
||||
@ -282,6 +289,8 @@ class LLMEngine:
|
||||
# Feature flags
|
||||
"enable_lora":
|
||||
bool(lora_config),
|
||||
"enable_prompt_adapter":
|
||||
bool(prompt_adapter_config),
|
||||
"enable_prefix_caching":
|
||||
cache_config.enable_prefix_caching,
|
||||
"enforce_eager":
|
||||
@ -376,7 +385,6 @@ class LLMEngine:
|
||||
engine_config = engine_args.create_engine_config()
|
||||
distributed_executor_backend = (
|
||||
engine_config.parallel_config.distributed_executor_backend)
|
||||
|
||||
# Initialize the cluster and specify the executor class.
|
||||
if engine_config.device_config.device_type == "neuron":
|
||||
from vllm.executor.neuron_executor import NeuronExecutor
|
||||
@ -409,7 +417,6 @@ class LLMEngine:
|
||||
else:
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
executor_class = GPUExecutor
|
||||
|
||||
# Create the LLM engine.
|
||||
engine = cls(
|
||||
**engine_config.to_dict(),
|
||||
@ -470,6 +477,9 @@ class LLMEngine:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
self.scheduler_config)
|
||||
if self.prompt_adapter_config:
|
||||
self.prompt_adapter_config.verify_with_model_config(
|
||||
self.model_config)
|
||||
|
||||
def _get_eos_token_id(
|
||||
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
|
||||
@ -487,6 +497,7 @@ class LLMEngine:
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
# Create the sequences.
|
||||
@ -495,7 +506,7 @@ class LLMEngine:
|
||||
eos_token_id = self._get_eos_token_id(lora_request)
|
||||
|
||||
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
|
||||
lora_request)
|
||||
lora_request, prompt_adapter_request)
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
@ -506,7 +517,7 @@ class LLMEngine:
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
elif isinstance(params, PoolingParams):
|
||||
seq_group = self._create_sequence_group_with_pooling(
|
||||
request_id,
|
||||
@ -514,7 +525,7 @@ class LLMEngine:
|
||||
params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either SamplingParams or PoolingParams must be provided.")
|
||||
@ -535,6 +546,7 @@ class LLMEngine:
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
if isinstance(inputs, str):
|
||||
inputs = {"prompt": inputs}
|
||||
@ -549,6 +561,11 @@ class LLMEngine:
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
if prompt_adapter_request:
|
||||
prompt_token_ids = \
|
||||
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
|
||||
+ prompt_token_ids
|
||||
|
||||
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
@ -563,6 +580,7 @@ class LLMEngine:
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
"""Add a request to the engine's request pool.
|
||||
|
||||
@ -612,9 +630,11 @@ class LLMEngine:
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
processed_inputs = self.process_model_inputs(request_id=request_id,
|
||||
processed_inputs = self.process_model_inputs(
|
||||
request_id=request_id,
|
||||
inputs=inputs,
|
||||
lora_request=lora_request)
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
@ -622,6 +642,7 @@ class LLMEngine:
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
@ -633,6 +654,7 @@ class LLMEngine:
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with SamplingParams."""
|
||||
max_logprobs = self.get_model_config().max_logprobs
|
||||
@ -658,7 +680,7 @@ class LLMEngine:
|
||||
sampling_params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return seq_group
|
||||
|
||||
@ -669,16 +691,19 @@ class LLMEngine:
|
||||
pooling_params: PoolingParams,
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with PoolingParams."""
|
||||
# Defensive copy of PoolingParams, which are used by the pooler
|
||||
pooling_params = pooling_params.clone()
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seq_group = SequenceGroup(
|
||||
request_id=request_id,
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params)
|
||||
pooling_params=pooling_params,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
return seq_group
|
||||
|
||||
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
@ -1082,6 +1107,16 @@ class LLMEngine:
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_executor.pin_lora(lora_id)
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
return self.model_executor.add_prompt_adapter(prompt_adapter_request)
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
return self.model_executor.remove_prompt_adapter(prompt_adapter_id)
|
||||
|
||||
def list_prompt_adapters(self) -> List[int]:
|
||||
return self.model_executor.list_prompt_adapters()
|
||||
|
||||
def check_health(self) -> None:
|
||||
if self.tokenizer:
|
||||
self.tokenizer.check_health()
|
||||
|
@ -13,6 +13,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -255,6 +256,7 @@ class LLM:
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[RequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
@ -271,6 +273,8 @@ class LLM:
|
||||
prompts and it is paired one by one with the prompt.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of `RequestOutput` objects containing the
|
||||
@ -304,7 +308,7 @@ class LLM:
|
||||
inputs=inputs,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return LLMEngine.validate_outputs(outputs, RequestOutput)
|
||||
@ -397,6 +401,7 @@ class LLM:
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
@ -412,6 +417,8 @@ class LLM:
|
||||
use the default pooling parameters.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of `EmbeddingRequestOutput` objects containing the
|
||||
@ -445,6 +452,7 @@ class LLM:
|
||||
inputs=inputs,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
@ -504,6 +512,7 @@ class LLM:
|
||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||
Sequence[PoolingParams]],
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
if isinstance(inputs, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
@ -526,19 +535,23 @@ class LLM:
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
lora_request=lora_request[i] if isinstance(
|
||||
lora_request, Sequence) else lora_request,
|
||||
)
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest],
|
||||
LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_engine.add_request(request_id,
|
||||
self.llm_engine.add_request(
|
||||
request_id,
|
||||
inputs,
|
||||
params,
|
||||
lora_request=lora_request)
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
def _run_engine(
|
||||
self, *, use_tqdm: bool
|
||||
|
@ -116,7 +116,7 @@ async def detokenize(request: DetokenizeRequest):
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def show_available_models():
|
||||
models = await openai_serving_chat.show_available_models()
|
||||
models = await openai_serving_completion.show_available_models()
|
||||
return JSONResponse(content=models.model_dump())
|
||||
|
||||
|
||||
@ -236,7 +236,8 @@ if __name__ == "__main__":
|
||||
args.lora_modules,
|
||||
args.chat_template)
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
engine, model_config, served_model_names, args.lora_modules)
|
||||
engine, model_config, served_model_names, args.lora_modules,
|
||||
args.prompt_adapters)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
|
||||
served_model_names)
|
||||
app.root_path = args.root_path
|
||||
|
@ -9,7 +9,8 @@ import json
|
||||
import ssl
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
PromptAdapterPath)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -23,6 +24,16 @@ class LoRAParserAction(argparse.Action):
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
class PromptAdapterParserAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
adapter_list = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
adapter_list.append(PromptAdapterPath(name, path))
|
||||
setattr(namespace, self.dest, adapter_list)
|
||||
|
||||
|
||||
def make_arg_parser():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
@ -65,6 +76,14 @@ def make_arg_parser():
|
||||
action=LoRAParserAction,
|
||||
help="LoRA module configurations in the format name=path. "
|
||||
"Multiple modules can be specified.")
|
||||
parser.add_argument(
|
||||
"--prompt-adapters",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=PromptAdapterParserAction,
|
||||
help="Prompt adapter configurations in the format name=path. "
|
||||
"Multiple adapters can be specified.")
|
||||
parser.add_argument("--chat-template",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
|
@ -258,7 +258,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
prompt=prompt,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
_, lora_request = self._maybe_get_adapter(request)
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
|
@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
TokenizeResponse, UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing)
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
@ -67,11 +68,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
lora_modules: Optional[List[LoRAModulePath]]):
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]]):
|
||||
super().__init__(engine=engine,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules)
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters)
|
||||
|
||||
async def create_completion(self, request: CompletionRequest,
|
||||
raw_request: Request):
|
||||
@ -101,7 +104,12 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
generators: List[AsyncIterator[RequestOutput]] = []
|
||||
try:
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
adapter_type, adapter_request = self._maybe_get_adapter(request)
|
||||
lora_request, prompt_adapter_request = None, None
|
||||
if adapter_type == 'LoRA':
|
||||
lora_request, prompt_adapter_request = adapter_request, None
|
||||
elif adapter_type == 'PromptAdapter':
|
||||
lora_request, prompt_adapter_request = None, adapter_request
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
@ -147,6 +155,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
sampling_params,
|
||||
f"{request_id}-{i}",
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
|
@ -16,12 +16,19 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ModelPermission, TokenizeRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterPath:
|
||||
name: str
|
||||
local_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
@ -30,9 +37,14 @@ class LoRAModulePath:
|
||||
|
||||
class OpenAIServing:
|
||||
|
||||
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
lora_modules: Optional[List[LoRAModulePath]]):
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.engine = engine
|
||||
@ -49,9 +61,8 @@ class OpenAIServing:
|
||||
|
||||
self.served_model_names = served_model_names
|
||||
|
||||
if lora_modules is None:
|
||||
self.lora_requests = []
|
||||
else:
|
||||
if lora_modules is not None:
|
||||
self.lora_requests = [
|
||||
LoRARequest(
|
||||
lora_name=lora.name,
|
||||
@ -60,6 +71,20 @@ class OpenAIServing:
|
||||
) for i, lora in enumerate(lora_modules, start=1)
|
||||
]
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
if prompt_adapters is not None:
|
||||
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
||||
with open(f"./{prompt_adapter.local_path}"
|
||||
f"/adapter_config.json") as f:
|
||||
adapter_config = json.load(f)
|
||||
num_virtual_tokens = adapter_config["num_virtual_tokens"]
|
||||
self.prompt_adapter_requests.append(
|
||||
PromptAdapterRequest(
|
||||
prompt_adapter_name=prompt_adapter.name,
|
||||
prompt_adapter_id=i,
|
||||
prompt_adapter_local_path=prompt_adapter.local_path,
|
||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [
|
||||
@ -75,7 +100,14 @@ class OpenAIServing:
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests
|
||||
]
|
||||
prompt_adapter_cards = [
|
||||
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
||||
root=self.served_model_names[0],
|
||||
permission=[ModelPermission()])
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
model_cards.extend(prompt_adapter_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
def create_error_response(
|
||||
@ -109,20 +141,29 @@ class OpenAIServing:
|
||||
return None
|
||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||
return None
|
||||
if request.model in [
|
||||
prompt_adapter.prompt_adapter_name
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]:
|
||||
return None
|
||||
return self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
def _maybe_get_lora(
|
||||
def _maybe_get_adapter(
|
||||
self, request: Union[CompletionRequest, ChatCompletionRequest,
|
||||
EmbeddingRequest]
|
||||
) -> Optional[LoRARequest]:
|
||||
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
|
||||
PromptAdapterRequest]]]:
|
||||
if request.model in self.served_model_names:
|
||||
return None
|
||||
return None, None
|
||||
for lora in self.lora_requests:
|
||||
if request.model == lora.lora_name:
|
||||
return lora
|
||||
return 'LoRA', lora
|
||||
for prompt_adapter in self.prompt_adapter_requests:
|
||||
if request.model == prompt_adapter.prompt_adapter_name:
|
||||
return 'PromptAdapter', prompt_adapter
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
|
@ -7,6 +7,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
@ -48,6 +49,7 @@ class CPUExecutor(ExecutorBase):
|
||||
lora_config=self.lora_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
self.driver_worker.init_device()
|
||||
@ -90,6 +92,19 @@ class CPUExecutor(ExecutorBase):
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
return self.driver_worker.list_prompt_adapters()
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# CPUExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
|
@ -4,8 +4,10 @@ from typing import List, Optional, Set, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig)
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
|
||||
|
||||
@ -28,6 +30,7 @@ class ExecutorBase(ABC):
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@ -38,6 +41,7 @@ class ExecutorBase(ABC):
|
||||
self.device_config = device_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.speculative_config = speculative_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
|
||||
self._init_executor()
|
||||
|
||||
@ -95,6 +99,23 @@ class ExecutorBase(ABC):
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
raise NotImplementedError # type: ignore
|
||||
|
||||
@abstractmethod
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def check_health(self) -> None:
|
||||
"""Checks if the executor is healthy. If not, it should raise an
|
||||
@ -122,12 +143,14 @@ class ExecutorAsyncBase(ExecutorBase):
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig],
|
||||
) -> None:
|
||||
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
||||
|
||||
super().__init__(model_config, cache_config, parallel_config,
|
||||
scheduler_config, device_config, load_config,
|
||||
lora_config, multimodal_config, speculative_config)
|
||||
lora_config, multimodal_config, speculative_config,
|
||||
prompt_adapter_config)
|
||||
|
||||
@abstractmethod
|
||||
async def execute_model_async(
|
||||
|
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
@ -45,6 +46,7 @@ class GPUExecutor(ExecutorBase):
|
||||
lora_config=self.lora_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
speculative_config=self.speculative_config,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
@ -107,6 +109,25 @@ class GPUExecutor(ExecutorBase):
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
assert prompt_adapter_request.prompt_adapter_id > 0, \
|
||||
"prompt_adapter_id must be greater than 0."
|
||||
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
assert prompt_adapter_id > 0, \
|
||||
"prompt_adapter_id must be greater than 0."
|
||||
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
assert prompt_adapter_id > 0, \
|
||||
"prompt_adapter_id must be greater than 0."
|
||||
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
return self.driver_worker.list_prompt_adapters()
|
||||
|
||||
def check_health(self) -> None:
|
||||
# GPUExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
|
@ -8,7 +8,8 @@ from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig)
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
@ -44,6 +45,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
assert device_config.device_type == "xpu"
|
||||
@ -58,6 +60,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
|
@ -4,7 +4,8 @@ import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig)
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
@ -27,6 +28,7 @@ class XPUExecutor(GPUExecutor):
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
assert device_config.device_type == "xpu"
|
||||
@ -43,6 +45,7 @@ class XPUExecutor(GPUExecutor):
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.speculative_config = None
|
||||
|
||||
# Instantiate the worker and load the model to GPU.
|
||||
|
@ -8,6 +8,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.adapter_commons.layers import AdapterMapping
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -134,15 +135,8 @@ def _apply_lora_packed_nslice(
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAMapping:
|
||||
# Per every token in input_ids:
|
||||
index_mapping: Tuple[int, ...]
|
||||
# Per sampled token:
|
||||
prompt_mapping: Tuple[int, ...]
|
||||
|
||||
def __post_init__(self):
|
||||
self.index_mapping = tuple(self.index_mapping)
|
||||
self.prompt_mapping = tuple(self.prompt_mapping)
|
||||
class LoRAMapping(AdapterMapping):
|
||||
pass
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
|
@ -4,12 +4,17 @@ import math
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
|
||||
AdapterModelManager)
|
||||
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
|
||||
get_adapter, list_adapters,
|
||||
remove_adapter, set_adapter_mapping)
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA,
|
||||
@ -19,7 +24,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||
from vllm.utils import LRUCache, is_pin_memory_available
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -153,7 +158,7 @@ def get_lora_id():
|
||||
return _GLOBAL_LORA_ID
|
||||
|
||||
|
||||
class LoRAModel:
|
||||
class LoRAModel(AdapterModel):
|
||||
"""A LoRA fine-tuned model."""
|
||||
|
||||
def __init__(
|
||||
@ -388,7 +393,7 @@ class LoRAModel:
|
||||
)
|
||||
|
||||
|
||||
class LoRAModelManager:
|
||||
class LoRAModelManager(AdapterModelManager):
|
||||
"""A manager that manages multiple LoRA-fine-tuned models."""
|
||||
|
||||
def __init__(
|
||||
@ -440,8 +445,7 @@ class LoRAModelManager:
|
||||
# base_indices, sampler_indices, sampler_indices_padded,
|
||||
# embeddings_indices
|
||||
self.indices_len: List[Optional[int]] = [None] * 4
|
||||
|
||||
self.model = model
|
||||
super().__init__(model)
|
||||
if hasattr(self.model, "supported_lora_modules"):
|
||||
self.supported_lora_modules = copy.deepcopy(
|
||||
self.model.supported_lora_modules)
|
||||
@ -453,11 +457,11 @@ class LoRAModelManager:
|
||||
self.model.packed_modules_mapping)
|
||||
self.packed_modules: Dict[str, List[str]] = {}
|
||||
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
|
||||
self._registered_loras: Dict[int, LoRAModel] = {}
|
||||
# Dict instead of a Set for compatibility with LRUCache.
|
||||
self._active_loras: Dict[int, None] = {}
|
||||
self._last_mapping: Optional[LoRAMapping] = None
|
||||
self._create_lora_modules()
|
||||
self.model.lora_manager = self
|
||||
self.adapter_type = 'LoRa'
|
||||
|
||||
@property
|
||||
def capacity(self) -> int:
|
||||
@ -467,15 +471,16 @@ class LoRAModelManager:
|
||||
def lora_slots(self) -> int:
|
||||
return self.lora_config.max_loras
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._registered_loras)
|
||||
@property
|
||||
def adapter_slots(self) -> int:
|
||||
return self.lora_slots
|
||||
|
||||
def activate_lora(
|
||||
def activate_adapter(
|
||||
self,
|
||||
lora_id: int,
|
||||
) -> bool:
|
||||
"""Move LoRA into a GPU buffer to be used in the forward pass."""
|
||||
if lora_id in self._active_loras:
|
||||
if lora_id in self._active_adapters:
|
||||
return False
|
||||
first_free_slot = next(
|
||||
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
|
||||
@ -483,8 +488,8 @@ class LoRAModelManager:
|
||||
if first_free_slot is None:
|
||||
raise ValueError("No free lora slots")
|
||||
index, _ = first_free_slot
|
||||
self._active_loras[lora_id] = None
|
||||
lora_model = self._registered_loras[lora_id]
|
||||
self._active_adapters[lora_id] = None
|
||||
lora_model = self._registered_adapters[lora_id]
|
||||
logger.debug("Activating LoRA. int id: %d, slot index: %d",
|
||||
lora_model.id, index)
|
||||
self.lora_index_to_id[index] = lora_model.id
|
||||
@ -498,21 +503,13 @@ class LoRAModelManager:
|
||||
module.reset_lora(index)
|
||||
return True
|
||||
|
||||
def _deactivate_lora(self, lora_id: int):
|
||||
def _deactivate_adapter(self, lora_id: int):
|
||||
try:
|
||||
index = self.lora_index_to_id.index(lora_id)
|
||||
self.lora_index_to_id[index] = None
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def deactivate_lora(self, lora_id: int) -> bool:
|
||||
"""Remove a LoRA from a GPU buffer."""
|
||||
if lora_id in self._active_loras:
|
||||
self._deactivate_lora(lora_id)
|
||||
self._active_loras.pop(lora_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _set_long_lora_context(self, lora: LoRAModel):
|
||||
if self.long_lora_context is None:
|
||||
return
|
||||
@ -528,40 +525,19 @@ class LoRAModelManager:
|
||||
if offsets:
|
||||
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
|
||||
|
||||
def _add_lora(self, lora: LoRAModel):
|
||||
def _add_adapter(self, lora: LoRAModel):
|
||||
self._create_merged_loras_inplace(lora)
|
||||
self._registered_loras[lora.id] = lora
|
||||
self._registered_adapters[lora.id] = lora
|
||||
self._set_long_lora_context(lora)
|
||||
|
||||
def add_lora(self, lora: LoRAModel) -> bool:
|
||||
"""Add a LoRAModel to the manager CPU cache."""
|
||||
logger.debug(
|
||||
"Adding lora. Model id: %d, "
|
||||
"int id: %d, "
|
||||
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
|
||||
if lora.id not in self._registered_loras:
|
||||
if len(self._registered_loras) >= self.capacity:
|
||||
raise RuntimeError("No free LoRA slots.")
|
||||
self._add_lora(lora)
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove a LoRAModel from the manager CPU cache."""
|
||||
# TODO: should we check active lora?
|
||||
self.deactivate_lora(lora_id)
|
||||
if self.long_lora_context:
|
||||
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
|
||||
return bool(self._registered_loras.pop(lora_id, None))
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
def pin_adapter(self, lora_id: int) -> bool:
|
||||
"""Pin a LoRAModel in the manager cache."""
|
||||
raise NotImplementedError(
|
||||
"Pinning is not supported in LoRAModelManager."
|
||||
"Use LRUCacheLoRAModelManager for pinning") # type: ignore
|
||||
|
||||
# TODO see if this can be vectorized
|
||||
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
|
||||
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||
(base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices, long_lora_offsets_tensor,
|
||||
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
|
||||
@ -583,23 +559,11 @@ class LoRAModelManager:
|
||||
# Maintain the reference
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
|
||||
if self._last_mapping != lora_mapping:
|
||||
self._set_lora_mapping(lora_mapping)
|
||||
self._last_mapping = lora_mapping
|
||||
|
||||
def list_loras(self) -> Dict[int, LoRAModel]:
|
||||
"""List all registered LoRAModels."""
|
||||
return dict(self._registered_loras)
|
||||
|
||||
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
|
||||
return self._registered_loras.get(lora_id, None)
|
||||
|
||||
def remove_all_loras(self):
|
||||
def remove_all_adapters(self):
|
||||
"""Remove all LoRAModels from the manager."""
|
||||
self._registered_loras.clear()
|
||||
self._registered_adapters.clear()
|
||||
self.lora_index_to_id = [None] * self.lora_slots
|
||||
self._active_loras.clear()
|
||||
self._active_adapters.clear()
|
||||
|
||||
def _create_lora_modules(self):
|
||||
for module_name, module in self.model.named_modules(
|
||||
@ -743,18 +707,39 @@ class LoRAModelManager:
|
||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||
replacement_loras)
|
||||
|
||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||
return deactivate_adapter(adapter_id, self._active_adapters,
|
||||
self._deactivate_adapter)
|
||||
|
||||
class LoRALRUCache(LRUCache[LoRAModel]):
|
||||
def add_adapter(self, adapter: LoRAModel) -> bool:
|
||||
logger.debug(
|
||||
"Adding lora. Model id: %d, "
|
||||
"int id: %d, "
|
||||
"scaling factor: %s", adapter.id, adapter.id,
|
||||
adapter.scaling_factor)
|
||||
return add_adapter(adapter, self._registered_adapters, self.capacity,
|
||||
self._add_adapter)
|
||||
|
||||
def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
|
||||
self._set_adapter_mapping)
|
||||
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
return remove_adapter(adapter_id, self._registered_adapters,
|
||||
self.deactivate_adapter)
|
||||
|
||||
def list_adapters(self) -> Dict[int, Any]:
|
||||
return list_adapters(self._registered_adapters)
|
||||
|
||||
def get_adapter(self, adapter_id: int) -> Optional[Any]:
|
||||
return get_adapter(adapter_id, self._registered_adapters)
|
||||
|
||||
|
||||
class LoRALRUCache(AdapterLRUCache[LoRAModel]):
|
||||
|
||||
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
|
||||
bool]):
|
||||
super().__init__(capacity)
|
||||
self.deactivate_lora_fn = deactivate_lora_fn
|
||||
|
||||
def _on_remove(self, key: int, value: LoRAModel):
|
||||
logger.debug("Removing LoRA. int id: %d", key)
|
||||
self.deactivate_lora_fn(key)
|
||||
return super()._on_remove(key, value)
|
||||
super().__init__(capacity, deactivate_lora_fn)
|
||||
|
||||
|
||||
class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
@ -770,49 +755,49 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
):
|
||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||
vocab_size, lora_config)
|
||||
self._registered_loras: LoRALRUCache = LoRALRUCache(
|
||||
self.capacity, self.deactivate_lora)
|
||||
self._active_loras: LoRALRUCache = LoRALRUCache(
|
||||
self.lora_slots, self._deactivate_lora)
|
||||
self._registered_adapters: LoRALRUCache = LoRALRUCache(
|
||||
self.capacity, self.deactivate_adapter)
|
||||
self._active_adapters: LoRALRUCache = LoRALRUCache(
|
||||
self.lora_slots, self._deactivate_adapter)
|
||||
|
||||
def list_loras(self) -> Dict[int, LoRAModel]:
|
||||
def list_adapters(self) -> Dict[int, LoRAModel]:
|
||||
"""List all registered LoRAModels."""
|
||||
return dict(self._registered_loras.cache)
|
||||
return dict(self._registered_adapters.cache)
|
||||
|
||||
def add_lora(self, lora: LoRAModel) -> bool:
|
||||
def add_adapter(self, lora: LoRAModel) -> bool:
|
||||
"""Add a LoRAModel to the manager."""
|
||||
logger.debug(
|
||||
"Adding lora. Model id: %d, "
|
||||
"int id: %d, "
|
||||
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
|
||||
if lora.id not in self._registered_loras:
|
||||
self._add_lora(lora)
|
||||
if lora.id not in self._registered_adapters:
|
||||
self._add_adapter(lora)
|
||||
was_added = True
|
||||
else:
|
||||
# We always touch to update the LRU cache order
|
||||
self._registered_loras.touch(lora.id)
|
||||
self._registered_adapters.touch(lora.id)
|
||||
was_added = False
|
||||
return was_added
|
||||
|
||||
def activate_lora(
|
||||
def activate_adapter(
|
||||
self,
|
||||
lora_id: int,
|
||||
) -> bool:
|
||||
if lora_id not in self._active_loras and len(
|
||||
self._active_loras) >= self.lora_slots:
|
||||
self._active_loras.remove_oldest()
|
||||
result = super().activate_lora(lora_id)
|
||||
if lora_id not in self._active_adapters and len(
|
||||
self._active_adapters) >= self.lora_slots:
|
||||
self._active_adapters.remove_oldest()
|
||||
result = super().activate_adapter(lora_id)
|
||||
# We always touch to update the LRU cache order
|
||||
self._active_loras.touch(lora_id)
|
||||
self._active_adapters.touch(lora_id)
|
||||
return result
|
||||
|
||||
def remove_oldest_lora(self) -> bool:
|
||||
if len(self._registered_loras) > 0:
|
||||
self._registered_loras.remove_oldest()
|
||||
def remove_oldest_adapter(self) -> bool:
|
||||
if len(self._registered_adapters) > 0:
|
||||
self._registered_adapters.remove_oldest()
|
||||
return True
|
||||
return False
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
def pin_adapter(self, lora_id: int) -> bool:
|
||||
"""Pin a LoRAModel in the manager cache."""
|
||||
self._pin_lora_in_cpu_cache(lora_id)
|
||||
self._pin_lora_in_gpu_cache(lora_id)
|
||||
@ -820,17 +805,17 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
|
||||
def _pin_lora_in_cpu_cache(self, lora_id: int):
|
||||
try:
|
||||
self._registered_loras.pin(lora_id)
|
||||
self._registered_adapters.pin(lora_id)
|
||||
except ValueError as err:
|
||||
raise ValueError("Pinning failed. "
|
||||
f"LoRA {lora_id} is not registered.") from err
|
||||
|
||||
def _pin_lora_in_gpu_cache(self, lora_id: int):
|
||||
if lora_id not in self._active_loras:
|
||||
if lora_id not in self._active_adapters:
|
||||
# move lora to gpu if not already active
|
||||
self.activate_lora(lora_id)
|
||||
self.activate_adapter(lora_id)
|
||||
|
||||
self._active_loras.pin(lora_id)
|
||||
self._active_adapters.pin(lora_id)
|
||||
|
||||
|
||||
def create_lora_manager(
|
||||
|
@ -1,13 +1,15 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.adapter_commons.request import AdapterRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRARequest:
|
||||
class LoRARequest(AdapterRequest):
|
||||
"""
|
||||
Request for a LoRA adapter.
|
||||
|
||||
Note that this class should be be used internally. For online
|
||||
Note that this class should be used internally. For online
|
||||
serving, it is recommended to not allow users to use this class but
|
||||
instead provide another layer of abstraction to prevent users from
|
||||
accessing unauthorized LoRA adapters.
|
||||
@ -20,15 +22,16 @@ class LoRARequest:
|
||||
lora_int_id: int
|
||||
lora_local_path: str
|
||||
long_lora_max_len: Optional[int] = None
|
||||
__hash__ = AdapterRequest.__hash__
|
||||
|
||||
def __post_init__(self):
|
||||
if self.lora_int_id < 1:
|
||||
raise ValueError(
|
||||
f"lora_int_id must be > 0, got {self.lora_int_id}")
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(
|
||||
value, LoRARequest) and self.lora_int_id == value.lora_int_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
@property
|
||||
def adapter_id(self):
|
||||
return self.lora_int_id
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.lora_name
|
||||
|
||||
@property
|
||||
def local_path(self):
|
||||
return self.lora_local_path
|
||||
|
@ -1,12 +1,15 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.adapter_commons.utils import (add_adapter_worker,
|
||||
apply_adapters_worker,
|
||||
list_adapters_worker,
|
||||
set_active_adapters_worker)
|
||||
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager, create_lora_manager)
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -14,79 +17,13 @@ from vllm.lora.request import LoRARequest
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AbstractWorkerLoRAManager(ABC):
|
||||
"""Abstract class for managing LoRA models on the worker side."""
|
||||
|
||||
def __init__(self,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
device: torch.device,
|
||||
max_position_embeddings: Optional[int] = None):
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
self.lora_config = lora_config
|
||||
|
||||
# If False, do not cache. If None, cache is empty.
|
||||
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
|
||||
|
||||
@contextmanager
|
||||
def dummy_lora_cache(self):
|
||||
"""Use this context manager to reuse the dummy lora model
|
||||
to avoid creating it repeatedly."""
|
||||
self._cached_dummy_lora = None
|
||||
yield
|
||||
self._cached_dummy_lora = False
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_enabled(self) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def remove_all_loras(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def list_loras(self) -> Set[int]:
|
||||
...
|
||||
|
||||
|
||||
class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
class WorkerLoRAManager(AbstractWorkerManager):
|
||||
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
||||
|
||||
Every request, the requested LoRAs will be loaded (unless they are already
|
||||
loaded), and every other LoRA will be unloaded."""
|
||||
|
||||
_lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager
|
||||
_manager_cls: Type[LoRAModelManager] = LoRAModelManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -103,16 +40,23 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
self._lora_model_cls = lora_model_cls
|
||||
self.embedding_modules = embedding_modules
|
||||
self.embedding_padding_modules = embedding_padding_modules
|
||||
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.lora_config = lora_config
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
super().__init__(device)
|
||||
# Lazily initialized by create_lora_manager.
|
||||
self._lora_manager: LoRAModelManager
|
||||
super().__init__(
|
||||
max_num_seqs,
|
||||
max_num_batched_tokens,
|
||||
vocab_size,
|
||||
lora_config,
|
||||
device,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
)
|
||||
self._adapter_manager: LoRAModelManager
|
||||
|
||||
@contextmanager
|
||||
def dummy_lora_cache(self):
|
||||
"""Use this context manager to reuse the dummy lora model
|
||||
to avoid creating it repeatedly."""
|
||||
self._cached_dummy_lora = None
|
||||
yield
|
||||
self._cached_dummy_lora = False
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
@ -128,41 +72,14 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
lora_config=self.lora_config,
|
||||
lora_manager_cls=self._lora_manager_cls,
|
||||
lora_manager_cls=self._manager_cls,
|
||||
)
|
||||
self._lora_manager = lora_manager
|
||||
self._adapter_manager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
self._apply_loras(lora_requests)
|
||||
self._lora_manager.set_lora_mapping(lora_mapping)
|
||||
|
||||
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
|
||||
loras_that_exist = self.list_loras()
|
||||
loras_map = {
|
||||
lora_request.lora_int_id: lora_request
|
||||
for lora_request in lora_requests if lora_request
|
||||
}
|
||||
if len(loras_map) > self._lora_manager.lora_slots:
|
||||
raise RuntimeError(
|
||||
f"Number of requested LoRAs ({len(loras_map)}) is greater "
|
||||
"than the number of GPU LoRA slots "
|
||||
f"({self._lora_manager.lora_slots}).")
|
||||
|
||||
new_loras = set(loras_map)
|
||||
loras_to_add = new_loras - loras_that_exist
|
||||
loras_to_remove = loras_that_exist - new_loras
|
||||
|
||||
for lora_id in loras_to_remove:
|
||||
self.remove_lora(lora_id)
|
||||
|
||||
for lora_id in loras_to_add:
|
||||
self.add_lora(loras_map[lora_id])
|
||||
|
||||
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
|
||||
def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
|
||||
try:
|
||||
model = self._lora_manager.model
|
||||
model = self._adapter_manager.model
|
||||
supported_lora_modules = model.supported_lora_modules
|
||||
packed_modules_mapping = model.packed_modules_mapping
|
||||
expected_lora_modules: List[str] = []
|
||||
@ -198,37 +115,45 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
return lora
|
||||
|
||||
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
||||
if lora_request.lora_int_id in self.list_loras():
|
||||
if lora_request.lora_int_id in self.list_adapters():
|
||||
return False
|
||||
if isinstance(self._cached_dummy_lora, LoRAModel):
|
||||
dummy_lora = self._cached_dummy_lora.clone(
|
||||
lora_request.lora_int_id)
|
||||
else:
|
||||
dummy_lora = self._lora_manager.create_dummy_lora(
|
||||
dummy_lora = self._adapter_manager.create_dummy_lora(
|
||||
lora_request.lora_int_id, rank, 1, self.embedding_modules)
|
||||
if self._cached_dummy_lora is None:
|
||||
self._cached_dummy_lora = dummy_lora
|
||||
return self._lora_manager.add_lora(dummy_lora)
|
||||
return self._adapter_manager.add_adapter(dummy_lora)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if lora_request.lora_int_id in self.list_loras():
|
||||
return False
|
||||
lora = self._load_lora(lora_request)
|
||||
loaded = self._lora_manager.add_lora(lora)
|
||||
self._lora_manager.activate_lora(lora.id)
|
||||
return loaded
|
||||
def pin_adapter(self, adapter_id: int) -> bool:
|
||||
return self._adapter_manager.pin_adapter(adapter_id)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self._lora_manager.remove_lora(lora_id)
|
||||
def set_active_adapters(self, requests: Set[Any],
|
||||
mapping: Optional[Any]) -> None:
|
||||
set_active_adapters_worker(requests, mapping, self._apply_adapters,
|
||||
self._adapter_manager.set_adapter_mapping)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self._lora_manager.pin_lora(lora_id)
|
||||
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
|
||||
apply_adapters_worker(adapter_requests, self.list_adapters,
|
||||
self._adapter_manager.adapter_slots,
|
||||
self.remove_adapter, self.add_adapter)
|
||||
|
||||
def remove_all_loras(self):
|
||||
self._lora_manager.remove_all_loras()
|
||||
def add_adapter(self, adapter_request: Any) -> bool:
|
||||
return add_adapter_worker(adapter_request, self.list_adapters,
|
||||
self._load_adapter,
|
||||
self._adapter_manager.add_adapter,
|
||||
self._adapter_manager.activate_adapter)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return set(self._lora_manager.list_loras())
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
return self._adapter_manager.remove_adapter(adapter_id)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
self._adapter_manager.remove_all_adapters()
|
||||
|
||||
def list_adapters(self) -> Set[int]:
|
||||
return list_adapters_worker(self._adapter_manager.list_adapters)
|
||||
|
||||
|
||||
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
@ -238,8 +163,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
(unless they are already loaded) and least recently used LoRAs will
|
||||
be unloaded if the cache is above capacity."""
|
||||
|
||||
_lora_manager_cls: Type[
|
||||
LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
|
||||
_manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
|
||||
|
||||
def create_lora_manager(
|
||||
self,
|
||||
@ -247,40 +171,41 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
lora_manager_cls=self._lora_manager_cls,
|
||||
lora_manager_cls=self._manager_cls,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
vocab_size=self.vocab_size,
|
||||
lora_config=self.lora_config,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
)
|
||||
self._lora_manager = lora_manager
|
||||
self._adapter_manager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
|
||||
def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
|
||||
loras_map = {
|
||||
lora_request.lora_int_id: lora_request
|
||||
for lora_request in lora_requests if lora_request
|
||||
}
|
||||
if len(loras_map) > self._lora_manager.lora_slots:
|
||||
if len(loras_map) > self._adapter_manager.lora_slots:
|
||||
raise RuntimeError(
|
||||
f"Number of requested LoRAs ({len(loras_map)}) is greater "
|
||||
"than the number of GPU LoRA slots "
|
||||
f"({self._lora_manager.lora_slots}).")
|
||||
f"({self._adapter_manager.lora_slots}).")
|
||||
for lora in loras_map.values():
|
||||
self.add_lora(lora)
|
||||
self.add_adapter(lora)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if lora_request.lora_int_id not in self.list_loras():
|
||||
def add_adapter(self, lora_request: LoRARequest) -> bool:
|
||||
if lora_request.lora_int_id not in self.list_adapters():
|
||||
# Remove before we load the new lora to save memory
|
||||
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
|
||||
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
|
||||
self._lora_manager.remove_oldest_lora()
|
||||
lora = self._load_lora(lora_request)
|
||||
loaded = self._lora_manager.add_lora(lora)
|
||||
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
|
||||
assert isinstance(self._adapter_manager,
|
||||
LRUCacheLoRAModelManager)
|
||||
self._adapter_manager.remove_oldest_adapter()
|
||||
lora = self._load_adapter(lora_request)
|
||||
loaded = self._adapter_manager.add_adapter(lora)
|
||||
else:
|
||||
# If the lora is already loaded, just touch it to
|
||||
# update its position in the caches
|
||||
loaded = self._lora_manager.get_lora(
|
||||
loaded = self._adapter_manager.get_adapter(
|
||||
lora_request.lora_int_id) is not None
|
||||
self._lora_manager.activate_lora(lora_request.lora_int_id)
|
||||
self._adapter_manager.activate_adapter(lora_request.lora_int_id)
|
||||
return loaded
|
||||
|
0
vllm/prompt_adapter/__init__.py
Normal file
0
vllm/prompt_adapter/__init__.py
Normal file
80
vllm/prompt_adapter/layers.py
Normal file
80
vllm/prompt_adapter/layers.py
Normal file
@ -0,0 +1,80 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.adapter_commons.layers import AdapterMapping
|
||||
from vllm.config import PromptAdapterConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterMapping(AdapterMapping):
|
||||
pass
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithPromptAdapter(nn.Module):
|
||||
|
||||
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.emb_layer = self.base_layer
|
||||
if 'LoRA' in base_layer.__class__.__name__:
|
||||
self.emb_layer = self.base_layer.base_layer
|
||||
|
||||
def create_prompt_adapter_weights(
|
||||
self, prompt_adapter_config: PromptAdapterConfig):
|
||||
self.embeddings_tensors = torch.zeros(
|
||||
(
|
||||
prompt_adapter_config.max_prompt_adapters,
|
||||
prompt_adapter_config.max_prompt_adapter_token,
|
||||
self.emb_layer.embedding_dim,
|
||||
),
|
||||
dtype=self.emb_layer.weight.dtype,
|
||||
device=self.emb_layer.weight.device,
|
||||
)
|
||||
self.adapter_lengths = torch.zeros(
|
||||
prompt_adapter_config.max_prompt_adapters,
|
||||
dtype=torch.long,
|
||||
device=self.emb_layer.weight.device)
|
||||
|
||||
self.indices_gpu: torch.Tensor
|
||||
self.embedding_indices_gpu: torch.Tensor
|
||||
|
||||
def reset_prompt_adapter(self, index: int):
|
||||
self.embeddings_tensors[index] = 0
|
||||
|
||||
def set_prompt_adapter(
|
||||
self,
|
||||
index: int,
|
||||
adapter_model: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_prompt_adapter(index)
|
||||
if adapter_model is not None:
|
||||
length = adapter_model.shape[0]
|
||||
self.embeddings_tensors[index, :length] = adapter_model
|
||||
self.adapter_lengths[index] = length
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
prompt_indices: torch.Tensor,
|
||||
prompt_embedding_indices: torch.Tensor,
|
||||
):
|
||||
self.indices_gpu = prompt_indices.to(
|
||||
device=self.emb_layer.weight.device)
|
||||
self.embedding_indices_gpu = prompt_embedding_indices.to(
|
||||
device=self.emb_layer.weight.device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.base_layer(x)
|
||||
if self.embedding_indices_gpu.ndim > 1:
|
||||
valid_mask = self.indices_gpu != -1
|
||||
gathered_embeddings = self.embeddings_tensors[
|
||||
self.embedding_indices_gpu[:, 0],
|
||||
self.embedding_indices_gpu[:, 1]]
|
||||
|
||||
# Update hidden states
|
||||
hidden_states[valid_mask] = gathered_embeddings
|
||||
return hidden_states
|
355
vllm/prompt_adapter/models.py
Normal file
355
vllm/prompt_adapter/models.py
Normal file
@ -0,0 +1,355 @@
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
|
||||
AdapterModelManager)
|
||||
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
|
||||
get_adapter, list_adapters,
|
||||
remove_adapter, set_adapter_mapping)
|
||||
from vllm.config import PromptAdapterConfig
|
||||
from vllm.prompt_adapter.layers import (
|
||||
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
|
||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GLOBAL_PROMPT_ADAPTER_ID = 0
|
||||
|
||||
|
||||
def get_prompt_adapter_id():
|
||||
global _GLOBAL_PROMPT_ADAPTER_ID
|
||||
_GLOBAL_PROMPT_ADAPTER_ID += 1
|
||||
return _GLOBAL_PROMPT_ADAPTER_ID
|
||||
|
||||
|
||||
def convert_to_embedding_indices(indices):
|
||||
embedding_indices = []
|
||||
count = 0
|
||||
|
||||
for value in indices:
|
||||
if value == -1:
|
||||
count = 0
|
||||
else:
|
||||
embedding_indices.append([value, count])
|
||||
count += 1
|
||||
|
||||
return torch.tensor(embedding_indices)
|
||||
|
||||
|
||||
def convert_mapping(
|
||||
mapping: PromptAdapterMapping,
|
||||
prompt_adapter_index_to_id: List[Optional[int]],
|
||||
) -> torch.Tensor:
|
||||
"""Converts PromptAdapterMapping to index tensors.
|
||||
|
||||
Args:
|
||||
mapping: PromptAdapterMapping mapping rows in a
|
||||
batch to PromptAdapter ids.
|
||||
prompt_adapter_index_to_id: List mapping PromptAdapter
|
||||
ids to PromptAdapter indices.
|
||||
|
||||
Returns:
|
||||
pa_indices: Tensor of shape [batch_size] mapping batch rows to
|
||||
PromptAdapter indices.
|
||||
"""
|
||||
id_to_index = {
|
||||
id_: idx
|
||||
for idx, id_ in enumerate(prompt_adapter_index_to_id)
|
||||
if id_ is not None
|
||||
}
|
||||
pa_indices = ([
|
||||
id_to_index.get(id_, -1) if id_ > 0 else -1
|
||||
for id_ in mapping.index_mapping
|
||||
])
|
||||
|
||||
pa_embedding_mapping = convert_to_embedding_indices(pa_indices)
|
||||
pa_indices = torch.tensor(pa_indices)
|
||||
return pa_indices, pa_embedding_mapping
|
||||
|
||||
|
||||
class PromptAdapterModel(AdapterModel):
|
||||
|
||||
def __init__(self,
|
||||
prompt_adapter_id=None,
|
||||
num_virtual_tokens=None,
|
||||
prompt_embedding=None) -> None:
|
||||
self.id = prompt_adapter_id
|
||||
self.prompt_embedding = prompt_embedding
|
||||
self.num_virtual_tokens = num_virtual_tokens
|
||||
|
||||
@classmethod
|
||||
def from_local_checkpoint(
|
||||
cls,
|
||||
adapter_model_path: str,
|
||||
prompt_adapter_id: int,
|
||||
num_virtual_tokens: int,
|
||||
config: PromptAdapterConfig,
|
||||
device: str = "cuda",
|
||||
) -> "PromptAdapterModel":
|
||||
from peft.utils import load_peft_weights
|
||||
|
||||
if num_virtual_tokens > config.max_prompt_adapter_token:
|
||||
raise ValueError(
|
||||
f'num_virtual_tokens ({num_virtual_tokens}) should be <= '
|
||||
f'max_prompt_adapter_token({config.max_prompt_adapter_token})')
|
||||
|
||||
adapters_weights = load_peft_weights(adapter_model_path, device)
|
||||
prompt_embedding = adapters_weights["prompt_embeddings"].to(
|
||||
config.prompt_adapter_dtype)
|
||||
|
||||
return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding)
|
||||
|
||||
|
||||
class PromptAdapterModelManager(AdapterModelManager):
|
||||
"""A manager that manages multiple Prompt Adapter models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
):
|
||||
"""Create a PromptAdapterModel and adapter for a given model.
|
||||
|
||||
Args:
|
||||
model: the model to be adapted.
|
||||
max_num_seqs: the maximum number of sequences model can run in a
|
||||
single batch.
|
||||
max_num_batched_tokens: the maximum number of tokens model can run
|
||||
in a single batch.
|
||||
prompt_adapter_config: the PromptAdapter config,
|
||||
"""
|
||||
self.model: nn.Module = model
|
||||
# Dict instead of a Set for compatibility with LRUCache.
|
||||
self.prompt_adapter_index_to_id: List[
|
||||
Optional[int]] = [None] * self.prompt_adapter_slots
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.model.prompt_adapter_manager = self
|
||||
self.adapter_type = 'PromptAdapter'
|
||||
|
||||
self.base_indices = torch.tensor([-1])
|
||||
self.base_embedding_indices = torch.tensor([])
|
||||
|
||||
self.modules: Dict[str, nn.Module] = {}
|
||||
self._create_prompt_adapter_modules()
|
||||
self._last_mapping: Optional[PromptAdapterMapping] = None
|
||||
|
||||
@property
|
||||
def prompt_adapter_slots(self) -> int:
|
||||
return self.prompt_adapter_config.max_prompt_adapters
|
||||
|
||||
@property
|
||||
def adapter_slots(self) -> int:
|
||||
return self.prompt_adapter_slots
|
||||
|
||||
@property
|
||||
def capacity(self) -> int:
|
||||
return self.prompt_adapter_config.max_cpu_prompt_adapters
|
||||
|
||||
def activate_adapter(
|
||||
self,
|
||||
prompt_adapter_id: int,
|
||||
) -> bool:
|
||||
"""Move PromptAdapter into a GPU buffer
|
||||
to be used in the forward pass."""
|
||||
if prompt_adapter_id in self._active_adapters:
|
||||
return False
|
||||
first_free_slot = next(
|
||||
((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate(
|
||||
self.prompt_adapter_index_to_id) if prompt_adapter_id is None),
|
||||
None)
|
||||
if first_free_slot is None:
|
||||
raise ValueError("No free prompt_adapter slots")
|
||||
index, _ = first_free_slot
|
||||
self._active_adapters[prompt_adapter_id] = None
|
||||
prompt_adapter_model = (self._registered_adapters[prompt_adapter_id])
|
||||
logger.debug("Activating prompt_adapter. int id: %d, slot index: %d",
|
||||
prompt_adapter_model.id, index)
|
||||
self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id
|
||||
for _, v in self.modules.items():
|
||||
v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding)
|
||||
return True
|
||||
|
||||
def _deactivate_adapter(self, prompt_adapter_id: int):
|
||||
try:
|
||||
index = self.prompt_adapter_index_to_id.index(prompt_adapter_id)
|
||||
self.prompt_adapter_index_to_id[index] = None
|
||||
for _, v in self.modules.items():
|
||||
v.reset_prompt_adapter(index)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _add_adapter(self, prompt_adapter: PromptAdapterModel):
|
||||
self._registered_adapters[prompt_adapter.id] = prompt_adapter
|
||||
|
||||
def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
|
||||
base_indices, base_embedding_indices = convert_mapping(
|
||||
mapping, self.prompt_adapter_index_to_id)
|
||||
for k, v in self.modules.items():
|
||||
v.set_mapping(base_indices, base_embedding_indices)
|
||||
|
||||
def _create_prompt_adapter_modules(self):
|
||||
for module_name, module in self.model.named_modules(
|
||||
remove_duplicate=False):
|
||||
if "VocabParallel" in module.__class__.__name__:
|
||||
new_module = VocabParallelEmbeddingWithPromptAdapter(module)
|
||||
new_module.create_prompt_adapter_weights(
|
||||
self.prompt_adapter_config)
|
||||
replaced_module = self.replace_submodule(
|
||||
self.model, module_name, new_module)
|
||||
self.register_module(module.__class__.__name__,
|
||||
replaced_module)
|
||||
replaced_module.set_mapping(self.base_indices,
|
||||
self.base_embedding_indices)
|
||||
break
|
||||
|
||||
def replace_submodule(self, model: nn.Module, module_name: str,
|
||||
new_module: nn.Module) -> nn.Module:
|
||||
"""Replace a submodule in a model with a new module."""
|
||||
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
||||
target_name = module_name.split(".")[-1]
|
||||
setattr(parent, target_name, new_module)
|
||||
return new_module
|
||||
|
||||
def register_module(self, module_name: str, module: nn.Module):
|
||||
self.modules[module_name] = module
|
||||
|
||||
def pin_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
"""Pin a PromptAdapterModel in the manager cache."""
|
||||
raise NotImplementedError(
|
||||
"Pinning is not supported in PromptAdapterModelManager."
|
||||
"Use LRUCachePromptAdapterModelManager for pinning"
|
||||
) # type: ignore
|
||||
|
||||
def remove_all_adapters(self):
|
||||
"""Remove all PromptAdapterModel from the manager."""
|
||||
self._registered_adapters.clear()
|
||||
self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots
|
||||
self._active_adapters.clear()
|
||||
|
||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||
return deactivate_adapter(adapter_id, self._active_adapters,
|
||||
self._deactivate_adapter)
|
||||
|
||||
def add_adapter(self, adapter: PromptAdapterModel) -> bool:
|
||||
return add_adapter(adapter, self._registered_adapters, self.capacity,
|
||||
self._add_adapter)
|
||||
|
||||
def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
|
||||
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
|
||||
self._set_adapter_mapping)
|
||||
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
return remove_adapter(adapter_id, self._registered_adapters,
|
||||
self.deactivate_adapter)
|
||||
|
||||
def list_adapters(self) -> Dict[int, Any]:
|
||||
return list_adapters(self._registered_adapters)
|
||||
|
||||
def get_adapter(self, adapter_id: int) -> Optional[Any]:
|
||||
return get_adapter(adapter_id, self._registered_adapters)
|
||||
|
||||
|
||||
class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]):
|
||||
|
||||
def __init__(self, capacity: int,
|
||||
deactivate_prompt_adapter_fn: Callable[[int], bool]):
|
||||
super().__init__(capacity, deactivate_prompt_adapter_fn)
|
||||
|
||||
|
||||
class LRUCachePromptAdapterModelManager(PromptAdapterModelManager):
|
||||
"""A model manager that manages multiple prompt_adapters with LRU cache."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
):
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||
prompt_adapter_config)
|
||||
self._registered_adapters = PromptAdapterLRUCache(
|
||||
self.capacity, self.deactivate_adapter)
|
||||
self._active_adapters = PromptAdapterLRUCache(
|
||||
self.prompt_adapter_slots, self._deactivate_adapter)
|
||||
|
||||
def list_adapters(self) -> Dict[int, PromptAdapterModel]:
|
||||
"""List all registered PromptAdapterModel."""
|
||||
return dict(self._registered_adapters.cache)
|
||||
|
||||
def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool:
|
||||
"""Add a PromptAdapterModel to the manager."""
|
||||
if prompt_adapter.id not in self._registered_adapters:
|
||||
self._add_adapter(prompt_adapter)
|
||||
was_added = True
|
||||
else:
|
||||
# We always touch to update the LRU cache order
|
||||
self._registered_adapters.touch(prompt_adapter.id)
|
||||
was_added = False
|
||||
return was_added
|
||||
|
||||
def activate_adapter(
|
||||
self,
|
||||
prompt_adapter_id: int,
|
||||
) -> bool:
|
||||
if prompt_adapter_id not in self._active_adapters and len(
|
||||
self._active_adapters) >= self.prompt_adapter_slots:
|
||||
self._active_adapters.remove_oldest()
|
||||
result = super().activate_adapter(prompt_adapter_id)
|
||||
# We always touch to update the LRU cache order
|
||||
self._active_adapters.touch(prompt_adapter_id)
|
||||
return result
|
||||
|
||||
def remove_oldest_adapter(self) -> bool:
|
||||
if len(self._registered_adapters) > 0:
|
||||
self._registered_adapters.remove_oldest()
|
||||
return True
|
||||
return False
|
||||
|
||||
def pin_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
"""Pin a PromptAdapterModel in the manager cache."""
|
||||
self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id)
|
||||
self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id)
|
||||
return True
|
||||
|
||||
def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int):
|
||||
try:
|
||||
self._registered_adapters.pin(prompt_adapter_id)
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
"Pinning failed. "
|
||||
f"Prompt Adapter {prompt_adapter_id} is not registered."
|
||||
) from err
|
||||
|
||||
def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int):
|
||||
if prompt_adapter_id not in self._active_adapters:
|
||||
# move adapter to gpu if not already active
|
||||
self.activate_adapter(prompt_adapter_id)
|
||||
self._active_adapters.pin(prompt_adapter_id)
|
||||
|
||||
|
||||
def create_prompt_adapter_manager(
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
prompt_adapter_manager_cls: Type[
|
||||
PromptAdapterModelManager] = PromptAdapterModelManager,
|
||||
**kwargs) -> PromptAdapterModelManager:
|
||||
"""Create a PromptAdapterModel for a given model."""
|
||||
prompt_adapter_manager = prompt_adapter_manager_cls(
|
||||
model=model,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
**kwargs)
|
||||
return prompt_adapter_manager
|
30
vllm/prompt_adapter/request.py
Normal file
30
vllm/prompt_adapter/request.py
Normal file
@ -0,0 +1,30 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.adapter_commons.request import AdapterRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterRequest(AdapterRequest):
|
||||
"""
|
||||
Request for a Prompt adapter.
|
||||
"""
|
||||
|
||||
prompt_adapter_name: str
|
||||
prompt_adapter_id: int
|
||||
prompt_adapter_local_path: str
|
||||
prompt_adapter_num_virtual_tokens: int
|
||||
|
||||
def __hash__(self):
|
||||
return super().__hash__()
|
||||
|
||||
@property
|
||||
def adapter_id(self):
|
||||
return self.prompt_adapter_id
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.prompt_adapter_name
|
||||
|
||||
@property
|
||||
def local_path(self):
|
||||
return self.prompt_adapter_local_path
|
176
vllm/prompt_adapter/worker_manager.py
Normal file
176
vllm/prompt_adapter/worker_manager.py
Normal file
@ -0,0 +1,176 @@
|
||||
import logging
|
||||
from typing import Any, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.adapter_commons.utils import (add_adapter_worker,
|
||||
apply_adapters_worker,
|
||||
list_adapters_worker,
|
||||
set_active_adapters_worker)
|
||||
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
|
||||
from vllm.config import PromptAdapterConfig
|
||||
from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager,
|
||||
PromptAdapterModel,
|
||||
PromptAdapterModelManager,
|
||||
create_prompt_adapter_manager)
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerPromptAdapterManager(AbstractWorkerManager):
|
||||
"""WorkerPromptAdapterManager that manages
|
||||
prompt_adapter models on the worker side.
|
||||
|
||||
Every request, the requested prompt_adapters will be
|
||||
loaded (unless they are already loaded),
|
||||
and every other prompt_adapter will be unloaded."""
|
||||
|
||||
_manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel
|
||||
):
|
||||
self._adapter_manager: PromptAdapterModelManager
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self._prompt_adapter_model_cls = prompt_adapter_model_cls
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
super().__init__(device)
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_prompt_adapter_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
prompt_adapter_manager = create_prompt_adapter_manager(
|
||||
model,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
prompt_adapter_manager_cls=self._manager_cls,
|
||||
)
|
||||
self._adapter_manager = prompt_adapter_manager
|
||||
return prompt_adapter_manager.model
|
||||
|
||||
def _load_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest
|
||||
) -> PromptAdapterModel:
|
||||
try:
|
||||
prompt_adapter = (
|
||||
self._prompt_adapter_model_cls.from_local_checkpoint(
|
||||
prompt_adapter_request.prompt_adapter_local_path,
|
||||
prompt_adapter_id=prompt_adapter_request.prompt_adapter_id,
|
||||
num_virtual_tokens=prompt_adapter_request.
|
||||
prompt_adapter_num_virtual_tokens,
|
||||
config=self.prompt_adapter_config,
|
||||
device=str(self.device),
|
||||
))
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Loading prompt_adapter "
|
||||
f"{prompt_adapter_request.prompt_adapter_local_path}"
|
||||
f" failed") from e
|
||||
return prompt_adapter
|
||||
|
||||
def add_dummy_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
return True
|
||||
|
||||
def pin_adapter(self, adapter_id: int) -> bool:
|
||||
return self._adapter_manager.pin_adapter(adapter_id)
|
||||
|
||||
def set_active_adapters(self, requests: Set[Any],
|
||||
mapping: Optional[Any]) -> None:
|
||||
set_active_adapters_worker(requests, mapping, self._apply_adapters,
|
||||
self._adapter_manager.set_adapter_mapping)
|
||||
|
||||
def add_adapter(self, adapter_request: Any) -> bool:
|
||||
return add_adapter_worker(adapter_request, self.list_adapters,
|
||||
self._load_adapter,
|
||||
self._adapter_manager.add_adapter,
|
||||
self._adapter_manager.activate_adapter)
|
||||
|
||||
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
|
||||
apply_adapters_worker(adapter_requests, self.list_adapters,
|
||||
self._adapter_manager.adapter_slots,
|
||||
self.remove_adapter, self.add_adapter)
|
||||
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
return self._adapter_manager.remove_adapter(adapter_id)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
self._adapter_manager.remove_all_adapters()
|
||||
|
||||
def list_adapters(self) -> Set[int]:
|
||||
return list_adapters_worker(self._adapter_manager.list_adapters)
|
||||
|
||||
|
||||
class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager):
|
||||
"""WorkerPromptAdapterManager that manages
|
||||
prompt_adapter models on the worker side.
|
||||
|
||||
Uses an LRU Cache. Every request, the requested
|
||||
prompt_adapters will be loaded (unless they are already loaded)
|
||||
and least recently used prompt_adapters will
|
||||
be unloaded if the cache is above capacity."""
|
||||
|
||||
_prompt_adapter_manager_cls: Type[
|
||||
LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager
|
||||
|
||||
def create_prompt_adapter_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
prompt_adapter_manager = create_prompt_adapter_manager(
|
||||
model,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
prompt_adapter_manager_cls=self._prompt_adapter_manager_cls)
|
||||
self._adapter_manager: LRUCachePromptAdapterModelManager = (
|
||||
prompt_adapter_manager)
|
||||
return prompt_adapter_manager.model
|
||||
|
||||
def _apply_adapters(
|
||||
self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None:
|
||||
prompt_adapters_map = {
|
||||
prompt_adapter_request.prompt_adapter_id: prompt_adapter_request
|
||||
for prompt_adapter_request in prompt_adapter_requests
|
||||
if prompt_adapter_request
|
||||
}
|
||||
if len(prompt_adapters_map
|
||||
) > self._adapter_manager.prompt_adapter_slots:
|
||||
raise RuntimeError(
|
||||
f"Number of requested prompt_adapters "
|
||||
f"({len(prompt_adapters_map)}) is greater "
|
||||
"than the number of GPU prompt_adapter slots "
|
||||
f"({self._adapter_manager.prompt_adapter_slots}).")
|
||||
for prompt_adapter in prompt_adapters_map.values():
|
||||
self.add_adapter(prompt_adapter)
|
||||
|
||||
def add_adapter(self,
|
||||
prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
if prompt_adapter_request.prompt_adapter_id not in self.list_adapters(
|
||||
):
|
||||
# Remove before we load the new prompt_adapter to save memory
|
||||
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
|
||||
self._adapter_manager.remove_oldest_adapter()
|
||||
prompt_adapter = self._load_adapter(prompt_adapter_request)
|
||||
loaded = self._adapter_manager.add_adapter(prompt_adapter)
|
||||
else:
|
||||
# If the prompt_adapter is already loaded, just touch it to
|
||||
# update its position in the caches
|
||||
loaded = self._adapter_manager.get_adapter(
|
||||
prompt_adapter_request.prompt_adapter_id) is not None
|
||||
self._adapter_manager.activate_adapter(
|
||||
prompt_adapter_request.prompt_adapter_id)
|
||||
return loaded
|
@ -10,6 +10,7 @@ import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -238,6 +239,8 @@ class Sequence:
|
||||
block_size: The block size of the sequence. Should be the same as the
|
||||
block size used by the block manager and cache engine.
|
||||
lora_request: LoRA request.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -247,12 +250,14 @@ class Sequence:
|
||||
block_size: int,
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.inputs = inputs
|
||||
self.block_size = block_size
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
|
||||
self.data = SequenceData(self.prompt_token_ids)
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
@ -287,6 +292,11 @@ class Sequence:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_id(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_id \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
def get_output_text_to_return(self, buffer_length: int):
|
||||
# We return the full output text if the sequence is finished.
|
||||
truncate = buffer_length and not self.is_finished()
|
||||
@ -414,6 +424,7 @@ class SequenceGroup:
|
||||
encoder_seq: Optional, the single encoder sequence. Should be None
|
||||
unless you are working with an encoder/decoder model.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -427,6 +438,7 @@ class SequenceGroup:
|
||||
pooling_params: Optional[PoolingParams] = None,
|
||||
encoder_seq: Optional[Sequence] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
@ -441,6 +453,7 @@ class SequenceGroup:
|
||||
self.state = SequenceGroupState()
|
||||
self.embeddings = embeddings
|
||||
self.pooling_params = pooling_params
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.encoder_seq = encoder_seq
|
||||
self.trace_headers = trace_headers
|
||||
|
||||
@ -466,6 +479,16 @@ class SequenceGroup:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_id(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_id \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_num_virtual_tokens(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
def get_last_latency(self, now: float) -> Optional[float]:
|
||||
"""Sets the last token time for Request level timings."""
|
||||
# If still in prefill phase, raise Error.
|
||||
@ -624,6 +647,7 @@ class SequenceGroupMetadata:
|
||||
(SequenceGroup.encoder_seq). Should be None
|
||||
unless you are working with an encoder/decoder
|
||||
model.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -642,6 +666,7 @@ class SequenceGroupMetadata:
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
||||
encoder_seq_data: Optional[SequenceData] = None,
|
||||
cross_block_table: Optional[List[int]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.is_prompt = is_prompt
|
||||
@ -650,6 +675,7 @@ class SequenceGroupMetadata:
|
||||
self.block_tables = block_tables
|
||||
self.pooling_params = pooling_params
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.computed_block_nums = computed_block_nums
|
||||
self.multi_modal_data = multi_modal_data
|
||||
self.state = SequenceGroupState() if state is None else state
|
||||
@ -674,6 +700,16 @@ class SequenceGroupMetadata:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_id(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_id \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_num_virtual_tokens(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
@property
|
||||
def token_chunk_size(self) -> int:
|
||||
"""Return the number of tokens to be processed (chunk size)."""
|
||||
|
@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
@ -48,6 +48,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
if return_hidden_states:
|
||||
@ -66,6 +67,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
multimodal_config=multimodal_config,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
@ -136,6 +138,13 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
virtual_engine = model_input.virtual_engine
|
||||
outputs: List[SamplerOutput] = []
|
||||
for step in range(num_steps):
|
||||
|
@ -8,7 +8,7 @@ from torch import nn
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@ -81,6 +81,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
@ -94,6 +95,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.load_config = load_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
|
@ -7,7 +7,7 @@ import torch.distributed
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
@ -133,6 +133,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
@ -145,6 +146,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.is_driver_worker:
|
||||
@ -167,6 +169,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
lora_config=self.lora_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
is_driver_worker=is_driver_worker)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -40,6 +40,7 @@ class EmbeddingModelRunner(
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
):
|
||||
super().__init__(model_config,
|
||||
@ -51,6 +52,7 @@ class EmbeddingModelRunner(
|
||||
lora_config=lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
multimodal_config=multimodal_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -71,6 +73,13 @@ class EmbeddingModelRunner(
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
|
@ -25,7 +25,7 @@ except ImportError:
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
@ -40,6 +40,10 @@ from vllm.model_executor.models.interfaces import (supports_lora,
|
||||
supports_vision)
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
|
||||
MultiModalInputs)
|
||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.prompt_adapter.worker_manager import (
|
||||
LRUCacheWorkerPromptAdapterManager)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
@ -85,6 +89,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
lora_mapping: Optional["LoRAMapping"] = None
|
||||
lora_requests: Optional[Set[LoRARequest]] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
|
||||
prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
|
||||
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
|
||||
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
@ -97,6 +103,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
"prompt_adapter_mapping": self.prompt_adapter_mapping,
|
||||
"prompt_adapter_requests": self.prompt_adapter_requests,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||
"finished_requests_ids": self.finished_requests_ids,
|
||||
@ -133,6 +141,8 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
"prompt_adapter_mapping": self.prompt_adapter_mapping,
|
||||
"prompt_adapter_requests": self.prompt_adapter_requests,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
|
||||
"finished_requests_ids": self.finished_requests_ids,
|
||||
@ -172,6 +182,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
@ -183,6 +194,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
@ -232,6 +244,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.model: nn.Module # Set after load_model
|
||||
# Set after load_model.
|
||||
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
|
||||
|
||||
self.flashinfer_decode_workspace_buffer = None
|
||||
self.flashinfer_decode_wrapper = None
|
||||
@ -240,16 +253,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
|
||||
def load_model(self) -> None:
|
||||
with CudaMemoryProfiler() as m:
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
device_config=self.device_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config,
|
||||
)
|
||||
cache_config=self.cache_config)
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
@ -274,6 +285,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
)
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens, self.device,
|
||||
self.prompt_adapter_config)
|
||||
self.model = (
|
||||
self.prompt_adapter_manager.create_prompt_adapter_manager(
|
||||
self.model))
|
||||
|
||||
if self.kv_cache_dtype == "fp8" and is_hip():
|
||||
# Currently only ROCm accepts kv-cache scaling factors
|
||||
# via quantization_param_path and this will be deprecated
|
||||
@ -354,6 +374,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
lora_index_mapping: List[int] = []
|
||||
lora_prompt_mapping: List[int] = []
|
||||
lora_requests: Set[LoRARequest] = set()
|
||||
prompt_adapter_index_mapping: List[int] = []
|
||||
prompt_adapter_prompt_mapping: List[int] = []
|
||||
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
|
||||
|
||||
seq_lens: List[int] = []
|
||||
prefill_seq_lens: List[int] = []
|
||||
@ -504,6 +527,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
input_tokens.extend(tokens)
|
||||
input_positions.extend(list(range(context_len, seq_len)))
|
||||
lora_id = seq_group_metadata.lora_int_id
|
||||
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
|
||||
|
||||
if is_prompt:
|
||||
assert len(seq_ids) == 1
|
||||
@ -534,6 +558,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
mm_kwargs = self.multi_modal_input_mapper(mm_data)
|
||||
multi_modal_inputs_list.append(mm_kwargs)
|
||||
|
||||
if prompt_adapter_id > 0 and is_prompt:
|
||||
prompt_adapter_requests.add(
|
||||
seq_group_metadata.prompt_adapter_request)
|
||||
|
||||
num_tokens = seq_group_metadata.\
|
||||
prompt_adapter_num_virtual_tokens
|
||||
pm = [prompt_adapter_id
|
||||
] * num_tokens + [0] * (query_len - num_tokens)
|
||||
prompt_adapter_index_mapping += pm
|
||||
prompt_adapter_prompt_mapping.extend(
|
||||
[prompt_adapter_id] *
|
||||
(query_len if seq_group_metadata.sampling_params
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs
|
||||
else 1))
|
||||
|
||||
is_profile_run = _is_block_tables_empty(
|
||||
seq_group_metadata.block_tables)
|
||||
if is_profile_run:
|
||||
@ -618,12 +657,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
seq_lens.append(1)
|
||||
block_tables.append([])
|
||||
lora_index_mapping.append(0)
|
||||
|
||||
prompt_adapter_index_mapping.append(0)
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
last_paged_kv_indptr = paged_kv_indptr[-1]
|
||||
paged_kv_indptr.append(last_paged_kv_indptr)
|
||||
paged_kv_last_page_len.append(0)
|
||||
|
||||
batch_size = graph_batch_size
|
||||
num_decode_tokens = batch_size
|
||||
|
||||
@ -759,6 +797,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
else:
|
||||
lora_mapping = None
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
prompt_adapter_mapping = PromptAdapterMapping(
|
||||
prompt_adapter_index_mapping,
|
||||
prompt_adapter_prompt_mapping,
|
||||
)
|
||||
else:
|
||||
prompt_adapter_mapping = None
|
||||
|
||||
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
|
||||
device=self.device)
|
||||
request_ids_to_seq_ids = {
|
||||
@ -776,7 +822,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
lora_requests=lora_requests,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
request_ids_to_seq_ids=request_ids_to_seq_ids,
|
||||
finished_requests_ids=finished_requests_ids)
|
||||
finished_requests_ids=finished_requests_ids,
|
||||
prompt_adapter_mapping=prompt_adapter_mapping,
|
||||
prompt_adapter_requests=prompt_adapter_requests,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
@ -878,33 +927,67 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
def remove_all_loras(self):
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
self.lora_manager.remove_all_loras()
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
self.lora_manager.set_active_loras(lora_requests, lora_mapping)
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.add_lora(lora_request)
|
||||
return self.lora_manager.add_adapter(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.remove_lora(lora_id)
|
||||
return self.lora_manager.remove_adapter(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.pin_lora(lora_id)
|
||||
return self.lora_manager.pin_adapter(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.list_loras()
|
||||
return self.lora_manager.list_adapters()
|
||||
|
||||
def remove_all_prompt_adapters(self):
|
||||
if not self.prompt_adapter_manager:
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
self.prompt_adapter_manager.remove_all_adapters()
|
||||
|
||||
def set_active_prompt_adapters(
|
||||
self, prompt_adapter_requests: Set[PromptAdapterRequest],
|
||||
prompt_adapter_mapping: PromptAdapterMapping) -> None:
|
||||
if not self.prompt_adapter_manager:
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
self.prompt_adapter_manager.set_active_adapters(
|
||||
prompt_adapter_requests, prompt_adapter_mapping)
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
if not self.prompt_adapter_manager:
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
if not self.prompt_adapter_manager:
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
if not self.prompt_adapter_manager:
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
if not self.prompt_adapter_manager:
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
return self.prompt_adapter_manager.list_adapters()
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
||||
@ -1063,6 +1146,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
)
|
||||
self.set_active_loras(set(), lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
prompt_adapter_mapping = PromptAdapterMapping(
|
||||
[-1] * batch_size,
|
||||
[-1] * batch_size,
|
||||
)
|
||||
self.set_active_prompt_adapters(
|
||||
set(), prompt_adapter_mapping)
|
||||
|
||||
graph_runner = CUDAGraphRunner(
|
||||
self.model, self.attn_backend.get_name())
|
||||
|
||||
@ -1189,6 +1280,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
assert model_input.attn_metadata is not None
|
||||
assert model_input.input_tokens is not None
|
||||
|
@ -8,7 +8,8 @@ import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig)
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
@ -16,6 +17,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||
@ -45,6 +47,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
|
||||
) -> None:
|
||||
@ -59,6 +62,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if parallel_config and is_driver_worker:
|
||||
assert rank % parallel_config.tensor_parallel_size == 0, \
|
||||
@ -92,6 +96,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
multimodal_config=multimodal_config,
|
||||
**speculative_args,
|
||||
)
|
||||
@ -296,6 +301,19 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
def add_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
return self.model_runner.add_prompt_adapter(prompt_adapter_request)
|
||||
|
||||
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(prompt_adapter_id)
|
||||
|
||||
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
return self.model_runner.pin_prompt_adapter(prompt_adapter_id)
|
||||
|
||||
def list_prompt_adapters(self) -> Set[int]:
|
||||
return self.model_runner.list_prompt_adapters()
|
||||
|
||||
@property
|
||||
def max_model_len(self) -> int:
|
||||
return self.model_config.max_model_len
|
||||
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
@ -88,6 +88,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
lora_config: Optional[LoRAConfig],
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
@ -98,6 +99,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.cache_config = cache_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
|
@ -10,7 +10,8 @@ import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig)
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
@ -47,6 +48,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
assert device_config.device_type == "xpu"
|
||||
@ -63,6 +65,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.lora_config = lora_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.is_driver_worker:
|
||||
assert self.rank == 0, "The driver worker must have rank 0."
|
||||
|
Loading…
x
Reference in New Issue
Block a user