[Core] Pipeline Parallel Support (#4412)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
parent
15aba081f3
commit
c5832d2ae9
@ -74,6 +74,16 @@ steps:
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
|
||||
- label: Pipeline Parallelism Test
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
commands:
|
||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
|
||||
|
||||
- label: Engine Test
|
||||
mirror_hardwares: [amd]
|
||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
|
||||
|
||||
from ..utils import wait_for_gpu_memory_to_clear
|
||||
@ -23,8 +24,11 @@ class MockEngine:
|
||||
self.add_request_calls = 0
|
||||
self.abort_request_calls = 0
|
||||
self.request_id = None
|
||||
# Ugly, remove dependency when possible
|
||||
self.parallel_config = ParallelConfig(1, 1, False)
|
||||
|
||||
async def step_async(self):
|
||||
async def step_async(self, virtual_engine):
|
||||
# PP size is 1, ignore virtual engine
|
||||
self.step_calls += 1
|
||||
return [RequestOutput(
|
||||
request_id=self.request_id)] if self.request_id else []
|
||||
@ -32,6 +36,9 @@ class MockEngine:
|
||||
async def process_model_inputs_async(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self):
|
||||
pass
|
||||
|
||||
def generate(self, request_id):
|
||||
self.request_id = request_id
|
||||
|
||||
@ -41,6 +48,7 @@ class MockEngine:
|
||||
def add_request(self, **kwargs):
|
||||
del kwargs # Unused
|
||||
self.add_request_calls += 1
|
||||
print(f'Request calls: {self.add_request_calls}')
|
||||
|
||||
async def add_request_async(self, **kwargs):
|
||||
self.add_request_calls += 1
|
||||
@ -53,6 +61,9 @@ class MockEngine:
|
||||
def has_unfinished_requests(self):
|
||||
return self.request_id is not None
|
||||
|
||||
def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
|
||||
return self.request_id is not None
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||
|
||||
@ -76,6 +87,7 @@ async def test_new_requests_event():
|
||||
engine.engine.generate("2")
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.add_request_calls == 2
|
||||
assert engine.engine.step_calls >= 2
|
||||
await asyncio.sleep(0.001)
|
||||
|
@ -4,7 +4,7 @@ import pytest
|
||||
# and debugging.
|
||||
import ray
|
||||
|
||||
from ..utils import RemoteOpenAIServer
|
||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
@ -12,7 +12,7 @@ MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ray_ctx():
|
||||
ray.init()
|
||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
||||
yield
|
||||
ray.shutdown()
|
||||
|
||||
|
@ -56,8 +56,8 @@ def test_chunked_prefill_recompute(
|
||||
max_num_seqs=max_num_seqs,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
|
||||
< ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
@ -91,10 +91,10 @@ def test_preemption(
|
||||
disable_log_stats=False,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
|
||||
< ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
total_preemption = (
|
||||
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
|
||||
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
@ -147,10 +147,10 @@ def test_swap(
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
|
||||
beam_width, max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
|
||||
< ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
total_preemption = (
|
||||
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
|
||||
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, _ = hf_outputs[i]
|
||||
@ -214,8 +214,8 @@ def test_swap_infeasible(
|
||||
example_prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
|
||||
< ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
|
||||
# Verify the request is ignored and not hang.
|
||||
assert req_outputs[0].outputs[0].finish_reason == "length"
|
||||
@ -252,8 +252,8 @@ def test_preemption_infeasible(
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
|
||||
< ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
|
||||
# Verify the request is ignored and not hang.
|
||||
for req_output in req_outputs:
|
||||
|
@ -32,7 +32,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
(r + 1) for r in range(tp_size)
|
||||
]
|
||||
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||
t = all_tensors[rank]
|
||||
t = all_tensors[rank % tp_size]
|
||||
t = tensor_model_parallel_all_reduce(t)
|
||||
assert torch.allclose(t, expected)
|
||||
|
||||
@ -60,7 +60,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
for r in range(tp_size)
|
||||
]
|
||||
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
||||
t = all_tensors[rank]
|
||||
t = all_tensors[rank % tp_size]
|
||||
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
|
||||
assert torch.allclose(t, expected)
|
||||
|
||||
@ -91,7 +91,7 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||
}
|
||||
|
||||
if rank == 0:
|
||||
if (rank % tp_size) == 0:
|
||||
broadcast_tensor_dict(test_dict, src=0)
|
||||
else:
|
||||
recv_dict = broadcast_tensor_dict(src=0)
|
||||
@ -184,3 +184,17 @@ def test_multi_process_tensor_parallel(tp_size, test_target):
|
||||
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
|
||||
def test_multi_process_pipeline_parallel(pp_size, test_target):
|
||||
multi_process_parallel(1, pp_size, test_target)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("pp_size", [2])
|
||||
@pytest.mark.parametrize("test_target", [
|
||||
send_recv_test_worker, send_recv_tensor_dict_test_worker,
|
||||
all_reduce_test_worker, all_gather_test_worker,
|
||||
broadcast_tensor_dict_test_worker
|
||||
])
|
||||
def test_multi_process_tensor_parallel_pipeline_parallel(
|
||||
tp_size, pp_size, test_target):
|
||||
multi_process_parallel(tp_size, pp_size, test_target)
|
||||
|
149
tests/distributed/test_pipeline_parallel.py
Normal file
149
tests/distributed/test_pipeline_parallel.py
Normal file
@ -0,0 +1,149 @@
|
||||
import os
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
# using Ray for overall ease of process management, parallel requests,
|
||||
# and debugging.
|
||||
import ray
|
||||
|
||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
# downloading lora to test lora requests
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
|
||||
EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0)))
|
||||
CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0)))
|
||||
TP_SIZE = int(os.getenv("TP_SIZE", 1))
|
||||
PP_SIZE = int(os.getenv("PP_SIZE", 1))
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ray_ctx():
|
||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
||||
yield
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(ray_ctx):
|
||||
args = [
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--pipeline-parallel-size",
|
||||
str(PP_SIZE),
|
||||
"--tensor-parallel-size",
|
||||
str(TP_SIZE),
|
||||
"--distributed-executor-backend",
|
||||
"ray",
|
||||
]
|
||||
if CHUNKED_PREFILL:
|
||||
args += [
|
||||
"--enable-chunked-prefill",
|
||||
]
|
||||
if EAGER_MODE:
|
||||
args += [
|
||||
"--enforce-eager",
|
||||
]
|
||||
return RemoteOpenAIServer(args, num_gpus=PP_SIZE * TP_SIZE)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
|
||||
|
||||
async def test_check_models(server, client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
models = models.data
|
||||
served_model = models[0]
|
||||
assert served_model.id == MODEL_NAME
|
||||
assert all(model.root == MODEL_NAME for model in models)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_single_completion(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
assert completion.choices[0].text is not None and len(
|
||||
completion.choices[0].text) >= 5
|
||||
assert completion.choices[0].finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert completion.choices[0].text is not None and len(
|
||||
completion.choices[0].text) >= 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_batch_completions(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
# test simple list
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(batch.choices) == 2
|
||||
assert batch.choices[0].text == batch.choices[1].text
|
||||
|
||||
# test n = 2
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
n=2,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
|
||||
# for official client.
|
||||
use_beam_search=True),
|
||||
)
|
||||
assert len(batch.choices) == 4
|
||||
assert batch.choices[0].text != batch.choices[
|
||||
1].text, "beam search should be different"
|
||||
assert batch.choices[0].text == batch.choices[
|
||||
2].text, "two copies of the same prompt should be the same"
|
||||
assert batch.choices[1].text == batch.choices[
|
||||
3].text, "two copies of the same prompt should be the same"
|
||||
|
||||
# test streaming
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
texts = [""] * 2
|
||||
async for chunk in batch:
|
||||
assert len(chunk.choices) == 1
|
||||
choice = chunk.choices[0]
|
||||
texts[choice.index] += choice.text
|
||||
assert texts[0] == texts[1]
|
@ -32,7 +32,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
scheduler=[scheduler],
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
|
||||
stop_checker=stop_checker,
|
||||
@ -86,7 +86,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
scheduler=[scheduler],
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
|
||||
stop_checker=stop_checker,
|
||||
@ -148,7 +148,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
scheduler=[scheduler],
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
|
||||
stop_checker=stop_checker,
|
||||
@ -215,7 +215,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
scheduler=[scheduler],
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
|
||||
stop_checker=stop_checker,
|
||||
|
@ -14,7 +14,7 @@ import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from openai import BadRequestError
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
@ -77,7 +77,7 @@ def zephyr_lora_files():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ray_ctx():
|
||||
ray.init()
|
||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
||||
yield
|
||||
ray.shutdown()
|
||||
|
||||
|
@ -16,7 +16,7 @@ from openai import BadRequestError
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
@ -79,7 +79,7 @@ def zephyr_lora_files():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ray_ctx():
|
||||
ray.init()
|
||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
||||
yield
|
||||
ray.shutdown()
|
||||
|
||||
|
@ -5,14 +5,14 @@ import openai
|
||||
import pytest
|
||||
import ray
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ray_ctx():
|
||||
ray.init()
|
||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
||||
yield
|
||||
ray.shutdown()
|
||||
|
||||
|
@ -6,7 +6,7 @@ import ray
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
@ -22,7 +22,7 @@ def zephyr_lora_files():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ray_ctx():
|
||||
ray.init()
|
||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
||||
yield
|
||||
ray.shutdown()
|
||||
|
||||
|
@ -24,13 +24,13 @@ TEST_IMAGE_URLS = [
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ray_ctx():
|
||||
ray.init()
|
||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
||||
yield
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
def server(ray_ctx):
|
||||
return RemoteOpenAIServer([
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
|
@ -54,9 +54,9 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
|
||||
return new_execute_model
|
||||
|
||||
|
||||
def zero_kv_cache(cache_engine: CacheEngine):
|
||||
assert cache_engine.gpu_cache
|
||||
for key_blocks, value_blocks in cache_engine.gpu_cache:
|
||||
def zero_kv_cache(cache_engine: List[CacheEngine]):
|
||||
assert cache_engine[0].gpu_cache
|
||||
for key_blocks, value_blocks in cache_engine[0].gpu_cache:
|
||||
key_blocks.zero_()
|
||||
value_blocks.zero_()
|
||||
|
||||
|
@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
||||
tensorize_vllm_model)
|
||||
|
||||
from ..conftest import VllmRunner, cleanup
|
||||
from ..utils import RemoteOpenAIServer
|
||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
|
||||
@ -220,6 +220,8 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
||||
json.dumps(model_loader_extra_config),
|
||||
]
|
||||
|
||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
||||
|
||||
server = RemoteOpenAIServer(openai_args)
|
||||
print("Server ready.")
|
||||
|
||||
|
@ -49,7 +49,6 @@ class RemoteOpenAIServer:
|
||||
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
||||
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class _RemoteRunner:
|
||||
|
||||
def __init__(self, cli_args: List[str], *, wait_url: str,
|
||||
@ -92,7 +91,11 @@ class RemoteOpenAIServer:
|
||||
if hasattr(self, "proc"):
|
||||
self.proc.terminate()
|
||||
|
||||
def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None:
|
||||
def __init__(self,
|
||||
cli_args: List[str],
|
||||
*,
|
||||
auto_port: bool = True,
|
||||
num_gpus: int = 1) -> None:
|
||||
if auto_port:
|
||||
if "-p" in cli_args or "--port" in cli_args:
|
||||
raise ValueError("You have manually specified the port"
|
||||
@ -105,7 +108,8 @@ class RemoteOpenAIServer:
|
||||
self.host = str(args.host or 'localhost')
|
||||
self.port = int(args.port)
|
||||
|
||||
self._runner = self._RemoteRunner.remote( # type: ignore
|
||||
self._runner = ray.remote(num_gpus=num_gpus)(
|
||||
self._RemoteRunner).remote(
|
||||
cli_args,
|
||||
wait_url=self.url_for("health"),
|
||||
wait_timeout=self.MAX_SERVER_START_WAIT_S)
|
||||
|
@ -39,8 +39,8 @@ def test_swap() -> None:
|
||||
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
|
||||
|
||||
# Randomly initialize the cache.
|
||||
gpu_cache = worker.cache_engine.gpu_cache
|
||||
cpu_cache = worker.cache_engine.cpu_cache
|
||||
gpu_cache = worker.cache_engine[0].gpu_cache
|
||||
cpu_cache = worker.cache_engine[0].cpu_cache
|
||||
num_layers = len(gpu_cache)
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
|
@ -27,6 +27,17 @@ logger = init_logger(__name__)
|
||||
_GB = 1 << 30
|
||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||
|
||||
_PP_SUPPORTED_MODELS = [
|
||||
"AquilaModel",
|
||||
"AquilaForCausalLM",
|
||||
"InternLMForCausalLM",
|
||||
"LlamaForCausalLM",
|
||||
"LLaMAForCausalLM",
|
||||
"MistralForCausalLM",
|
||||
"Phi3ForCausalLM",
|
||||
"GPT2LMHeadModel",
|
||||
]
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""Configuration for the model.
|
||||
@ -258,6 +269,13 @@ class ModelConfig:
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
if not all(arch in _PP_SUPPORTED_MODELS
|
||||
for arch in architectures) and pipeline_parallel_size > 1:
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is only supported for the following "
|
||||
f" architectures: {_PP_SUPPORTED_MODELS}.")
|
||||
|
||||
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
||||
raise ValueError(
|
||||
f"Total number of hidden layers ({total_num_hidden_layers}) "
|
||||
@ -665,9 +683,10 @@ class ParallelConfig:
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is not supported yet.")
|
||||
if (self.pipeline_parallel_size > 1
|
||||
and self.distributed_executor_backend == "mp"):
|
||||
raise NotImplementedError("Pipeline parallelism is not supported "
|
||||
"yet with multiprocessing.")
|
||||
if self.distributed_executor_backend not in ("ray", "mp", None):
|
||||
raise ValueError(
|
||||
"Unrecognized distributed executor backend. Supported values "
|
||||
|
@ -471,6 +471,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
# NOTE: fork does not allocate a new physical block.
|
||||
# Thus, it is always safe from OOM.
|
||||
if parent_seq.seq_id not in self.block_tables:
|
||||
# Parent sequence has either been freed or never existed.
|
||||
return
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.copy()
|
||||
# When using a sliding window, blocks will be eventually reused.
|
||||
|
@ -317,6 +317,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
computed_seq_block_ids) # type: ignore
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
if parent_seq.seq_id not in self.block_tables:
|
||||
# Parent sequence has either been freed or never existed.
|
||||
return
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.fork()
|
||||
|
||||
|
@ -256,6 +256,7 @@ class Scheduler:
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
pipeline_parallel_size: int = 1,
|
||||
) -> None:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
@ -273,11 +274,19 @@ class Scheduler:
|
||||
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
|
||||
version)
|
||||
|
||||
num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
if num_gpu_blocks:
|
||||
num_gpu_blocks //= pipeline_parallel_size
|
||||
|
||||
num_cpu_blocks = cache_config.num_cpu_blocks
|
||||
if num_cpu_blocks:
|
||||
num_cpu_blocks //= pipeline_parallel_size
|
||||
|
||||
# Create the block space manager.
|
||||
self.block_manager = BlockSpaceManagerImpl(
|
||||
block_size=self.cache_config.block_size,
|
||||
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
sliding_window=self.cache_config.sliding_window,
|
||||
enable_caching=self.cache_config.enable_prefix_caching)
|
||||
|
||||
|
@ -416,7 +416,7 @@ class GroupCoordinator:
|
||||
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
assert dst != self.rank, (
|
||||
assert dst != self.rank_in_group, (
|
||||
"Invalid destination rank. Destination rank is the same "
|
||||
"as the current rank.")
|
||||
|
||||
@ -446,7 +446,7 @@ class GroupCoordinator:
|
||||
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
assert src != self.rank, (
|
||||
assert src != self.rank_in_group, (
|
||||
"Invalid source rank. Source rank is the same as the current rank."
|
||||
)
|
||||
|
||||
@ -454,7 +454,7 @@ class GroupCoordinator:
|
||||
|
||||
# Receive object size
|
||||
rank_size = torch.distributed.recv(size_tensor,
|
||||
src=src,
|
||||
src=self.ranks[src],
|
||||
group=self.cpu_group)
|
||||
|
||||
# Tensor to receive serialized objects into.
|
||||
@ -464,7 +464,7 @@ class GroupCoordinator:
|
||||
device="cpu")
|
||||
|
||||
rank_object = torch.distributed.recv(object_tensor,
|
||||
src=src,
|
||||
src=self.ranks[src],
|
||||
group=self.cpu_group)
|
||||
|
||||
assert rank_object == rank_size, (
|
||||
@ -491,10 +491,9 @@ class GroupCoordinator:
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
src = self.ranks[src]
|
||||
|
||||
rank = self.rank
|
||||
if rank == src:
|
||||
rank_in_group = self.rank_in_group
|
||||
if rank_in_group == src:
|
||||
metadata_list: List[Tuple[Any, Any]] = []
|
||||
assert isinstance(
|
||||
tensor_dict,
|
||||
@ -512,13 +511,13 @@ class GroupCoordinator:
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group,
|
||||
async_op=True)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
src=self.ranks[src],
|
||||
group=group,
|
||||
async_op=True)
|
||||
async_handles.append(handle)
|
||||
@ -542,13 +541,14 @@ class GroupCoordinator:
|
||||
# use metadata_group for CPU tensors
|
||||
handle = torch.distributed.broadcast(
|
||||
tensor,
|
||||
src=src,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group,
|
||||
async_op=True)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
handle = torch.distributed.broadcast(
|
||||
tensor,
|
||||
src=self.ranks[src],
|
||||
group=group,
|
||||
async_op=True)
|
||||
async_handles.append(handle)
|
||||
@ -575,7 +575,7 @@ class GroupCoordinator:
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if dst is None:
|
||||
dst = self.next_rank
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
metadata_list: List[Tuple[Any, Any]] = []
|
||||
@ -593,10 +593,14 @@ class GroupCoordinator:
|
||||
continue
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.send(tensor, dst=dst, group=metadata_group)
|
||||
torch.distributed.send(tensor,
|
||||
dst=self.ranks[dst],
|
||||
group=metadata_group)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
torch.distributed.send(tensor, dst=dst, group=group)
|
||||
torch.distributed.send(tensor,
|
||||
dst=self.ranks[dst],
|
||||
group=group)
|
||||
return None
|
||||
|
||||
def recv_tensor_dict(
|
||||
@ -614,7 +618,7 @@ class GroupCoordinator:
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if src is None:
|
||||
src = self.prev_rank
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
@ -631,11 +635,13 @@ class GroupCoordinator:
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.recv(tensor,
|
||||
src=src,
|
||||
src=self.ranks[src],
|
||||
group=metadata_group)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
torch.distributed.recv(tensor, src=src, group=group)
|
||||
torch.distributed.recv(tensor,
|
||||
src=self.ranks[src],
|
||||
group=group)
|
||||
_update_nested_dict(tensor_dict, key, tensor)
|
||||
else:
|
||||
_update_nested_dict(tensor_dict, key, value)
|
||||
@ -654,7 +660,7 @@ class GroupCoordinator:
|
||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = self.next_rank
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
@ -669,7 +675,7 @@ class GroupCoordinator:
|
||||
"""Receives a tensor from the src rank."""
|
||||
"""NOTE: `src` is the local rank of the destination rank."""
|
||||
if src is None:
|
||||
src = self.prev_rank
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
pynccl_comm = self.pynccl_comm
|
||||
|
@ -2,7 +2,7 @@
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -46,3 +46,12 @@ def split_tensor_along_last_dim(
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
|
||||
|
||||
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
||||
pp_size: int) -> Tuple[int, int]:
|
||||
layers_per_partition = divide(num_hidden_layers, pp_size)
|
||||
start_layer = pp_rank * layers_per_partition
|
||||
end_layer = start_layer + layers_per_partition
|
||||
|
||||
return (start_layer, end_layer)
|
||||
|
@ -211,7 +211,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
"""Extension of LLMEngine to add async methods."""
|
||||
|
||||
async def step_async(
|
||||
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
self, virtual_engine: int
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
The workers are ran asynchronously if possible.
|
||||
|
||||
@ -221,7 +222,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||
virtual_engine].schedule()
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
# Execute the model.
|
||||
@ -230,6 +232,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||
running_queue_size=scheduler_outputs.running_queue_size,
|
||||
)
|
||||
@ -248,16 +251,12 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
# Tracing
|
||||
self.do_tracing(scheduler_outputs)
|
||||
|
||||
if not request_outputs:
|
||||
# Stop the execute model loop in parallel workers until there are
|
||||
# more requests to process. This avoids waiting indefinitely in
|
||||
# torch.distributed ops which may otherwise timeout, and unblocks
|
||||
# the RPC thread in the workers so that they can process any other
|
||||
# queued control plane messages, such as add/remove lora adapters.
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
return request_outputs
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
async def process_model_inputs_async(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -491,7 +490,8 @@ class AsyncLLMEngine:
|
||||
# order of the arguments.
|
||||
cache_config = kwargs["cache_config"]
|
||||
parallel_config = kwargs["parallel_config"]
|
||||
if parallel_config.tensor_parallel_size == 1:
|
||||
if (parallel_config.tensor_parallel_size == 1
|
||||
and parallel_config.pipeline_parallel_size == 1):
|
||||
num_gpus = cache_config.gpu_memory_utilization
|
||||
else:
|
||||
num_gpus = 1
|
||||
@ -499,7 +499,7 @@ class AsyncLLMEngine:
|
||||
self._engine_class).remote
|
||||
return engine_class(*args, **kwargs)
|
||||
|
||||
async def engine_step(self) -> bool:
|
||||
async def engine_step(self, virtual_engine: int) -> bool:
|
||||
"""Kick the engine to process the waiting requests.
|
||||
|
||||
Returns True if there are in-progress requests."""
|
||||
@ -530,7 +530,7 @@ class AsyncLLMEngine:
|
||||
if self.engine_use_ray:
|
||||
request_outputs = await self.engine.step.remote() # type: ignore
|
||||
else:
|
||||
request_outputs = await self.engine.step_async()
|
||||
request_outputs = await self.engine.step_async(virtual_engine)
|
||||
|
||||
# Put the outputs into the corresponding streams.
|
||||
for request_output in request_outputs:
|
||||
@ -546,18 +546,65 @@ class AsyncLLMEngine:
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
has_requests_in_progress = False
|
||||
if self.engine_use_ray:
|
||||
pipeline_parallel_size = 1 # type: ignore
|
||||
else:
|
||||
pipeline_parallel_size = \
|
||||
self.engine.parallel_config.pipeline_parallel_size
|
||||
has_requests_in_progress = [False] * pipeline_parallel_size
|
||||
while True:
|
||||
if not has_requests_in_progress:
|
||||
if not any(has_requests_in_progress):
|
||||
logger.debug("Waiting for new requests...")
|
||||
# Stop the execute model loop in parallel workers until there
|
||||
# are more requests to process. This avoids waiting
|
||||
# indefinitely in torch.distributed ops which may otherwise
|
||||
# timeout, and unblocks the RPC thread in the workers so that
|
||||
# they can process any other queued control plane messages,
|
||||
# such as add/remove lora adapters.
|
||||
if self.engine_use_ray:
|
||||
await (self.engine.stop_remote_worker_execution_loop.
|
||||
remote() # type: ignore
|
||||
)
|
||||
else:
|
||||
await self.engine.stop_remote_worker_execution_loop_async()
|
||||
await self._request_tracker.wait_for_new_requests()
|
||||
logger.debug("Got new requests!")
|
||||
requests_in_progress = [
|
||||
asyncio.create_task(self.engine_step(ve))
|
||||
for ve in range(pipeline_parallel_size)
|
||||
]
|
||||
has_requests_in_progress = [True] * pipeline_parallel_size
|
||||
|
||||
# Abort if iteration takes too long due to unrecoverable errors
|
||||
# (eg. NCCL timeouts).
|
||||
try:
|
||||
async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
|
||||
has_requests_in_progress = await self.engine_step()
|
||||
done, _ = await asyncio.wait(
|
||||
requests_in_progress,
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
for _ in range(pipeline_parallel_size):
|
||||
await asyncio.sleep(0)
|
||||
for task in done:
|
||||
result = task.result()
|
||||
virtual_engine = requests_in_progress.index(task)
|
||||
if self.engine_use_ray:
|
||||
has_unfinished_requests = (
|
||||
await (self.engine.
|
||||
has_unfinished_requests_for_virtual_engine.
|
||||
remote( # type: ignore
|
||||
virtual_engine)))
|
||||
else:
|
||||
has_unfinished_requests = (
|
||||
self.engine.
|
||||
has_unfinished_requests_for_virtual_engine(
|
||||
virtual_engine))
|
||||
if result or has_unfinished_requests:
|
||||
requests_in_progress[virtual_engine] = (
|
||||
asyncio.create_task(
|
||||
self.engine_step(virtual_engine)))
|
||||
has_requests_in_progress[virtual_engine] = True
|
||||
else:
|
||||
has_requests_in_progress[virtual_engine] = False
|
||||
except asyncio.TimeoutError as exc:
|
||||
logger.error(
|
||||
"Engine iteration timed out. This should never happen!")
|
||||
|
@ -173,6 +173,7 @@ class LLMEngine:
|
||||
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
|
||||
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
||||
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
||||
"pipeline_parallel_size=%d, "
|
||||
"disable_custom_all_reduce=%s, quantization=%s, "
|
||||
"enforce_eager=%s, kv_cache_dtype=%s, "
|
||||
"quantization_param_path=%s, device_config=%s, "
|
||||
@ -195,6 +196,7 @@ class LLMEngine:
|
||||
load_config.download_dir,
|
||||
load_config.load_format,
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.disable_custom_all_reduce,
|
||||
model_config.quantization,
|
||||
model_config.enforce_eager,
|
||||
@ -296,7 +298,11 @@ class LLMEngine:
|
||||
# Create the scheduler.
|
||||
# NOTE: the cache_config here have been updated with the numbers of
|
||||
# GPU and CPU blocks, which are profiled in the distributed executor.
|
||||
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||
self.scheduler = [
|
||||
Scheduler(scheduler_config, cache_config, lora_config,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
for _ in range(parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Metric Logging.
|
||||
if self.log_stats:
|
||||
@ -513,8 +519,16 @@ class LLMEngine:
|
||||
raise ValueError(
|
||||
"Either SamplingParams or PoolingParams must be provided.")
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
# Add the sequence group to the scheduler with least unfinished seqs.
|
||||
costs = [
|
||||
scheduler.get_num_unfinished_seq_groups()
|
||||
for scheduler in self.scheduler
|
||||
]
|
||||
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
|
||||
min_cost_scheduler.add_seq_group(seq_group)
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
self.model_executor.stop_remote_worker_execution_loop()
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
@ -684,7 +698,8 @@ class LLMEngine:
|
||||
>>> # abort the request
|
||||
>>> engine.abort_request(request_id)
|
||||
"""
|
||||
self.scheduler.abort_seq_group(request_id)
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.abort_seq_group(request_id)
|
||||
|
||||
def get_model_config(self) -> ModelConfig:
|
||||
"""Gets the model configuration."""
|
||||
@ -696,11 +711,20 @@ class LLMEngine:
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Gets the number of unfinished requests."""
|
||||
return self.scheduler.get_num_unfinished_seq_groups()
|
||||
return sum(scheduler.get_num_unfinished_seq_groups()
|
||||
for scheduler in self.scheduler)
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
"""Returns True if there are unfinished requests."""
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
return any(scheduler.has_unfinished_seqs()
|
||||
for scheduler in self.scheduler)
|
||||
|
||||
def has_unfinished_requests_for_virtual_engine(
|
||||
self, virtual_engine: int) -> bool:
|
||||
"""
|
||||
Returns True if there are unfinished requests for the virtual engine.
|
||||
"""
|
||||
return self.scheduler[virtual_engine].has_unfinished_seqs()
|
||||
|
||||
def _process_sequence_group_outputs(
|
||||
self,
|
||||
@ -749,7 +773,8 @@ class LLMEngine:
|
||||
self.output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_finished_seq_groups()
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[Union[RequestOutput,
|
||||
@ -815,7 +840,12 @@ class LLMEngine:
|
||||
>>> if not (engine.has_unfinished_requests() or example_inputs):
|
||||
>>> break
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
if self.parallel_config.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is only supported through AsyncLLMEngine "
|
||||
"as performance will be severely degraded otherwise.")
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||
0].schedule()
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
@ -886,23 +916,28 @@ class LLMEngine:
|
||||
|
||||
# System State
|
||||
# Scheduler State
|
||||
num_running_sys = len(self.scheduler.running)
|
||||
num_swapped_sys = len(self.scheduler.swapped)
|
||||
num_waiting_sys = len(self.scheduler.waiting)
|
||||
num_running_sys = sum(
|
||||
len(scheduler.running) for scheduler in self.scheduler)
|
||||
num_swapped_sys = sum(
|
||||
len(scheduler.swapped) for scheduler in self.scheduler)
|
||||
num_waiting_sys = sum(
|
||||
len(scheduler.waiting) for scheduler in self.scheduler)
|
||||
|
||||
# KV Cache Usage in %
|
||||
num_total_gpu = self.cache_config.num_gpu_blocks
|
||||
gpu_cache_usage_sys = 0.
|
||||
if num_total_gpu is not None:
|
||||
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
|
||||
)
|
||||
num_free_gpu = sum(
|
||||
scheduler.block_manager.get_num_free_gpu_blocks()
|
||||
for scheduler in self.scheduler)
|
||||
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
|
||||
|
||||
num_total_cpu = self.cache_config.num_cpu_blocks
|
||||
cpu_cache_usage_sys = 0.
|
||||
if num_total_cpu is not None and num_total_cpu > 0:
|
||||
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
|
||||
)
|
||||
num_free_cpu = sum(
|
||||
scheduler.block_manager.get_num_free_cpu_blocks()
|
||||
for scheduler in self.scheduler)
|
||||
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
|
||||
|
||||
# Iteration stats
|
||||
|
@ -27,7 +27,7 @@ class SequenceGroupOutputProcessor(ABC):
|
||||
def create_output_processor(
|
||||
scheduler_config: SchedulerConfig,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: Scheduler,
|
||||
scheduler: List[Scheduler],
|
||||
seq_counter: Counter,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
|
||||
stop_checker: "StopChecker",
|
||||
|
@ -34,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: Scheduler,
|
||||
scheduler: List[Scheduler],
|
||||
seq_counter: Counter,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
|
||||
stop_checker: StopChecker,
|
||||
@ -141,4 +141,5 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
break
|
||||
|
||||
if seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(seq)
|
||||
|
@ -33,7 +33,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
self,
|
||||
scheduler_config: SchedulerConfig,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: Scheduler,
|
||||
scheduler: List[Scheduler],
|
||||
seq_counter: Counter,
|
||||
stop_checker: StopChecker,
|
||||
):
|
||||
@ -95,7 +95,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
# not be used in the future iterations.
|
||||
parent.status = SequenceStatus.FINISHED_ABORTED
|
||||
seq_group.remove(parent.seq_id)
|
||||
self.scheduler.free_seq(parent)
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(parent)
|
||||
continue
|
||||
# Fork the parent sequence if there are multiple child samples.
|
||||
for child_sample in child_samples[:-1]:
|
||||
@ -133,7 +134,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
@ -141,7 +143,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
# old sequences.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(seq)
|
||||
return
|
||||
|
||||
# Beam search case
|
||||
@ -226,13 +229,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(seq)
|
||||
|
||||
# Remove the unselected parent sequences from the sequence group and
|
||||
# free their memory in block manager.
|
||||
@ -241,7 +246,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
# Remove the parent sequence if it is not selected for next
|
||||
# iteration
|
||||
seq_group.remove(seq.seq_id)
|
||||
self.scheduler.free_seq(seq)
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(seq)
|
||||
|
||||
def _check_beam_search_early_stopping(
|
||||
self,
|
||||
|
@ -69,7 +69,7 @@ class DistributedGPUExecutor(GPUExecutor):
|
||||
if self.parallel_worker_tasks is None:
|
||||
self.parallel_worker_tasks = self._run_workers(
|
||||
"start_worker_execution_loop",
|
||||
async_run_remote_workers_only=True,
|
||||
async_run_tensor_parallel_workers_only=True,
|
||||
**self.extra_execute_model_run_workers_kwargs)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
@ -138,17 +138,17 @@ class DistributedGPUExecutor(GPUExecutor):
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_remote_workers_only: bool = False,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers.
|
||||
|
||||
Args:
|
||||
async_run_remote_workers_only: If True the method will be run only
|
||||
in the remote workers, not the driver worker. It will also be
|
||||
run asynchronously and return a list of futures rather than
|
||||
blocking on the results.
|
||||
async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
@ -110,6 +111,30 @@ class ExecutorBase(ABC):
|
||||
|
||||
class ExecutorAsyncBase(ExecutorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
# This locks each pipeline parallel stage so multiple virtual engines
|
||||
# can't execute on the same stage at the same time
|
||||
self.pp_locks = [
|
||||
asyncio.Lock()
|
||||
for _ in range(parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
super().__init__(model_config, cache_config, parallel_config,
|
||||
scheduler_config, device_config, load_config,
|
||||
lora_config, vision_language_config,
|
||||
speculative_config)
|
||||
|
||||
@abstractmethod
|
||||
async def execute_model_async(
|
||||
self,
|
||||
|
@ -45,7 +45,8 @@ class GPUExecutor(ExecutorBase):
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
speculative_config=self.speculative_config,
|
||||
is_driver_worker=rank == 0,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
|
||||
def _create_worker(self,
|
||||
|
@ -91,17 +91,17 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_remote_workers_only: bool = False,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers.
|
||||
|
||||
Args:
|
||||
async_run_remote_workers_only: If True the method will be run only
|
||||
in the remote workers, not the driver worker. It will also be
|
||||
run asynchronously and return a list of futures rather than
|
||||
blocking on the results.
|
||||
async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
"""
|
||||
|
||||
if max_concurrent_workers:
|
||||
@ -114,7 +114,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
if async_run_remote_workers_only:
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
# Just return futures
|
||||
return worker_outputs
|
||||
|
||||
|
@ -62,7 +62,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if self.parallel_config.tensor_parallel_size == 1:
|
||||
if (self.parallel_config.tensor_parallel_size == 1
|
||||
and self.parallel_config.pipeline_parallel_size == 1):
|
||||
# For single GPU case, we use a ray worker with constrained memory.
|
||||
num_gpus = self.cache_config.gpu_memory_utilization
|
||||
else:
|
||||
@ -189,6 +190,26 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
||||
# global rank 0. These are the workers that will broadcast to the
|
||||
# rest of the workers.
|
||||
self.tp_driver_workers: List[RayWorkerWrapper] = []
|
||||
# This is the list of workers that are not drivers and not the first
|
||||
# worker in a TP group. These are the workers that will be
|
||||
# broadcasted to.
|
||||
self.non_driver_workers: List[RayWorkerWrapper] = []
|
||||
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
for tp_rank in range(self.parallel_config.tensor_parallel_size):
|
||||
rank = (pp_rank *
|
||||
self.parallel_config.tensor_parallel_size) + tp_rank
|
||||
if rank == 0:
|
||||
pass
|
||||
elif rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
self.tp_driver_workers.append(self.workers[rank - 1])
|
||||
else:
|
||||
self.non_driver_workers.append(self.workers[rank - 1])
|
||||
|
||||
def _driver_execute_model(
|
||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
@ -204,7 +225,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
async_run_remote_workers_only: bool = False,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
all_args: Optional[List[Tuple[Any, ...]]] = None,
|
||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||
use_dummy_driver: bool = False,
|
||||
@ -215,10 +236,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
"""Runs the given method on all workers. Can be used in the following
|
||||
ways:
|
||||
|
||||
- async_run_remote_workers_only: If True the method will be run only
|
||||
in the remote workers, not the driver worker. It will also be
|
||||
run asynchronously and return a list of futures rather than blocking
|
||||
on the results.
|
||||
Args:
|
||||
- async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
- args/kwargs: All workers share the same args/kwargs
|
||||
- all_args/all_kwargs: args/kwargs for each worker are specified
|
||||
individually
|
||||
@ -228,7 +250,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
count = len(self.workers)
|
||||
count = len(self.workers) if not \
|
||||
async_run_tensor_parallel_workers_only \
|
||||
else len(self.non_driver_workers)
|
||||
all_worker_args = repeat(args, count) if all_args is None \
|
||||
else islice(all_args, 1, None)
|
||||
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
|
||||
@ -242,14 +266,17 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
ray_worker_outputs = []
|
||||
else:
|
||||
# Start the ray workers first.
|
||||
ray_workers = self.workers
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
ray_workers = self.non_driver_workers
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *worker_args,
|
||||
**worker_kwargs)
|
||||
for (worker, worker_args, worker_kwargs
|
||||
) in zip(self.workers, all_worker_args, all_worker_kwargs)
|
||||
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
|
||||
]
|
||||
|
||||
if async_run_remote_workers_only:
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
# Just return futures
|
||||
return ray_worker_outputs
|
||||
|
||||
@ -319,12 +346,32 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
return await self.driver_exec_method("execute_model",
|
||||
execute_model_req)
|
||||
|
||||
async def _run_task_with_lock(task, lock, *args, **kwargs):
|
||||
async with lock:
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
tasks = []
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
|
||||
"execute_model", execute_model_req)))
|
||||
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
|
||||
start=1):
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
_run_task_with_lock(driver_worker.execute_method.remote,
|
||||
self.pp_locks[pp_rank],
|
||||
"execute_model", execute_model_req)))
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Only the last PP stage has the final results.
|
||||
return results[-1]
|
||||
|
||||
async def _start_worker_execution_loop(self):
|
||||
coros = [
|
||||
worker.execute_method.remote("start_worker_execution_loop")
|
||||
for worker in self.workers
|
||||
for worker in self.non_driver_workers
|
||||
]
|
||||
return await asyncio.gather(*coros)
|
||||
|
@ -29,7 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.transformers_utils.configs.arctic import ArcticConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -426,6 +426,7 @@ class ArcticForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
@ -338,6 +338,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
@ -286,6 +286,7 @@ class BloomForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -25,7 +25,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
@ -365,6 +365,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -46,7 +46,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
@torch.compile
|
||||
@ -353,6 +353,7 @@ class CohereForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -23,7 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
|
||||
|
||||
@ -381,6 +381,7 @@ class DbrxForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class DeepseekMLP(nn.Module):
|
||||
@ -387,6 +387,7 @@ class DeepseekForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
@ -475,6 +475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -44,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.transformers_utils.configs import RWConfig
|
||||
|
||||
FalconConfig = Union[HF_FalconConfig, RWConfig]
|
||||
@ -410,6 +410,7 @@ class FalconForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
|
@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
@ -339,6 +339,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -37,7 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
@ -338,6 +338,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -17,7 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -25,7 +25,9 @@ from transformers import GPT2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -38,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class GPT2Attention(nn.Module):
|
||||
@ -181,9 +183,17 @@ class GPT2Model(nn.Module):
|
||||
self.embed_dim = config.hidden_size
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList([
|
||||
self.start_layer, self.end_layer = get_pp_indices(
|
||||
config.num_hidden_layers,
|
||||
get_pp_group().rank_in_group,
|
||||
get_pp_group().world_size)
|
||||
self.h = nn.ModuleList(
|
||||
[nn.Identity() for _ in range(self.start_layer)] + [
|
||||
GPT2Block(config, cache_config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
for _ in range(self.start_layer, self.end_layer)
|
||||
] + [
|
||||
nn.Identity()
|
||||
for _ in range(self.end_layer, config.num_hidden_layers)
|
||||
])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
@ -193,14 +203,24 @@ class GPT2Model(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(len(self.h)):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
|
||||
hidden_states = layer(hidden_states,
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -228,9 +248,10 @@ class GPT2LMHeadModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
attn_metadata, intermediate_tensors)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
@ -247,6 +268,16 @@ class GPT2LMHeadModel(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
@ -260,6 +291,7 @@ class GPT2LMHeadModel(nn.Module):
|
||||
continue
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
try:
|
||||
param = params_dict[name]
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
@ -273,3 +305,5 @@ class GPT2LMHeadModel(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
except KeyError:
|
||||
continue
|
||||
|
@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
@ -273,6 +273,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -38,7 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class GPTJAttention(nn.Module):
|
||||
@ -239,6 +239,7 @@ class GPTJForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -38,7 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
@ -251,6 +251,7 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -22,7 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class InternLM2MLP(nn.Module):
|
||||
@ -263,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: IntermediateTensors,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.transformers_utils.configs import JAISConfig
|
||||
|
||||
|
||||
@ -289,6 +289,7 @@ class JAISLMHeadModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -29,7 +29,8 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (get_pp_group, get_pp_indices,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -46,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, kv_cache_scales_loader)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.utils import is_hip, print_warning_once
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
@ -261,11 +262,19 @@ class LlamaModel(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
self.start_layer, self.end_layer = get_pp_indices(
|
||||
config.num_hidden_layers,
|
||||
get_pp_group().rank_in_group,
|
||||
get_pp_group().world_size)
|
||||
self.layers = nn.ModuleList(
|
||||
[nn.Identity() for _ in range(self.start_layer)] + [
|
||||
LlamaDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
for _ in range(self.start_layer, self.end_layer)
|
||||
] + [
|
||||
nn.Identity()
|
||||
for _ in range(self.end_layer, config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -278,22 +287,36 @@ class LlamaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
kv_caches[i - self.start_layer],
|
||||
attn_metadata,
|
||||
residual,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
@ -372,10 +395,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
return hidden_states
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
return model_output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
@ -391,6 +415,20 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@ -416,9 +454,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
try:
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
except KeyError:
|
||||
pass
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
@ -437,10 +478,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
try:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
@ -452,6 +496,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
quantization_param_path, tp_rank, tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type):
|
||||
if not isinstance(self.model.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.model.layers[layer_idx].self_attn
|
||||
|
||||
if is_hip():
|
||||
|
@ -18,7 +18,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .interfaces import SupportsVision
|
||||
@ -202,6 +202,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
"""Run forward pass for LLaVA-1.5.
|
||||
@ -247,6 +248,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
None,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
@ -22,7 +22,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
get_clip_patch_grid_length)
|
||||
@ -376,6 +376,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
"""Run forward pass for LlaVA-NeXT.
|
||||
@ -430,6 +431,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
None,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
@ -50,7 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
@ -462,6 +462,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -51,7 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
@ -536,6 +536,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -47,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
@ -354,6 +354,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -22,7 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
|
||||
|
||||
@ -273,6 +273,7 @@ class MPTForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class OlmoAttention(nn.Module):
|
||||
@ -301,6 +301,7 @@ class OlmoForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
|
@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||
@ -304,6 +304,7 @@ class OPTForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -26,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class OrionMLP(nn.Module):
|
||||
@ -269,6 +269,7 @@ class OrionForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -57,7 +57,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
@ -278,6 +278,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -21,7 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
def load_column_parallel_weight(param: torch.nn.Parameter,
|
||||
@ -412,6 +412,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
positions: Optional[torch.LongTensor],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
output_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
|
@ -35,7 +35,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .interfaces import SupportsVision
|
||||
@ -381,9 +381,13 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
|
||||
return None
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata, **kwargs: object):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
@ -398,6 +402,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
@ -27,7 +27,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
|
||||
@ -245,6 +245,7 @@ class QWenLMHeadModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
@ -331,6 +331,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -50,7 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
@ -397,6 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -41,7 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class StablelmMLP(nn.Module):
|
||||
@ -250,6 +250,7 @@ class StablelmForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
|
||||
class Starcoder2Attention(nn.Module):
|
||||
@ -262,6 +262,7 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
@ -320,6 +320,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata)
|
||||
|
@ -770,6 +770,34 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
|
||||
return self.embeddings == other.embeddings
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntermediateTensors:
|
||||
"""For all pipeline stages except the last, we need to return the hidden
|
||||
states and residuals to be sent to the next stage. This data structure
|
||||
contains the hidden states and residuals for a request.
|
||||
"""
|
||||
|
||||
tensors: Dict[str, torch.Tensor]
|
||||
|
||||
def __getitem__(self, key: Union[str, slice]):
|
||||
if isinstance(key, str):
|
||||
return self.tensors[key]
|
||||
elif isinstance(key, slice):
|
||||
return self.__class__({k: v[key] for k, v in self.tensors.items()})
|
||||
|
||||
def __setitem__(self, key: str, value):
|
||||
self.tensors[key] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tensors)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
return isinstance(other, self.__class__) and self
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"IntermediateTensors(tensors={self.tensors})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
"""For each sequence group, we generate a list of SequenceOutput object,
|
||||
@ -896,6 +924,8 @@ class ExecuteModelRequest:
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
|
||||
# Blocks to copy. Source to dest block.
|
||||
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
|
||||
# Virtual engine ID for pipeline parallel.
|
||||
virtual_engine: int = 0
|
||||
# The number of slots for lookahead decoding.
|
||||
num_lookahead_slots: int = 0
|
||||
# The number of requests in the running queue.
|
||||
@ -914,6 +944,7 @@ class ExecuteModelRequest:
|
||||
blocks_to_swap_in=self.blocks_to_swap_in.copy(),
|
||||
blocks_to_swap_out=self.blocks_to_swap_out.copy(),
|
||||
blocks_to_copy=self.blocks_to_copy.copy(),
|
||||
virtual_engine=self.virtual_engine,
|
||||
num_lookahead_slots=self.num_lookahead_slots,
|
||||
running_queue_size=self.running_queue_size,
|
||||
previous_hidden_states=self.previous_hidden_states,
|
||||
|
@ -6,7 +6,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
||||
ModelRunner)
|
||||
|
||||
@ -76,7 +77,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""A temporary solution that caches the seq_group_metadata_list
|
||||
for multi-step execution.
|
||||
TODO: In-place update model_input and remove this function.
|
||||
@ -115,6 +116,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
self,
|
||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
# Since we do not broadcast data inside execute_model anymore,
|
||||
@ -130,6 +132,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
virtual_engine = model_input.virtual_engine
|
||||
outputs: List[SamplerOutput] = []
|
||||
for step in range(num_steps):
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
@ -139,7 +142,8 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[graph_batch_size]
|
||||
model_executable = (
|
||||
self.graph_runners[virtual_engine][graph_batch_size])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
@ -149,6 +153,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**multi_modal_kwargs,
|
||||
)
|
||||
|
||||
|
@ -38,7 +38,11 @@ class CacheEngine:
|
||||
|
||||
self.block_size = cache_config.block_size
|
||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
if self.num_gpu_blocks:
|
||||
self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
|
||||
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
||||
if self.num_cpu_blocks:
|
||||
self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
self.dtype = model_config.dtype
|
||||
|
@ -13,7 +13,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase,
|
||||
@ -315,6 +316,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
) -> CPUModelInput:
|
||||
multi_modal_kwargs = None
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
@ -351,6 +353,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
self,
|
||||
model_input: CPUModelInput,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
|
@ -167,8 +167,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
is_driver_worker=is_driver_worker)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: CPUCacheEngine
|
||||
self.cpu_cache: List[torch.Tensor]
|
||||
self.cache_engine: List[CPUCacheEngine]
|
||||
self.cpu_cache: List[List[torch.Tensor]]
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.init_distributed_environment()
|
||||
@ -242,17 +242,24 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
"initializing the engine.")
|
||||
|
||||
def _init_cache_engine(self) -> None:
|
||||
self.cache_engine = CPUCacheEngine(self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.device_config)
|
||||
self.cpu_cache = self.cache_engine.cpu_cache
|
||||
self.model_runner.block_size = self.cache_engine.block_size
|
||||
self.cache_engine = [
|
||||
CPUCacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config, self.device_config)
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.cpu_cache = [
|
||||
self.cache_engine[ve].cpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.model_runner.block_size = self.cache_engine[0].block_size
|
||||
|
||||
assert self.cpu_cache is not None
|
||||
assert all(
|
||||
self.cpu_cache[ve] is not None
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size))
|
||||
|
||||
# Populate the cache to warmup the memory
|
||||
for layer_cache in self.cpu_cache:
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size):
|
||||
for layer_cache in self.cpu_cache[ve]:
|
||||
layer_cache.fill_(0)
|
||||
|
||||
@property
|
||||
@ -260,7 +267,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
return self.parallel_config.tensor_parallel_size > 1
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[torch.Tensor]]:
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
return self.cpu_cache
|
||||
|
||||
def execute_worker(
|
||||
@ -269,12 +276,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
) -> None:
|
||||
if (worker_input.blocks_to_copy is not None
|
||||
and worker_input.blocks_to_copy.numel() > 0):
|
||||
self.cache_engine.copy(worker_input.blocks_to_copy)
|
||||
self.cache_engine[worker_input.virtual_engine].copy(
|
||||
worker_input.blocks_to_copy)
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
assert execute_model_req is not None
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
|
||||
blocks_to_copy = execute_model_req.blocks_to_copy
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
@ -285,6 +294,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
return WorkerInput(
|
||||
num_seq_groups=num_seq_groups,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
)
|
||||
|
||||
def init_distributed_environment(self) -> None:
|
||||
|
@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -57,6 +58,7 @@ class EmbeddingModelRunner(
|
||||
self,
|
||||
model_input: ModelInputForGPUWithPoolingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[PoolerOutput]]:
|
||||
if num_steps > 1:
|
||||
@ -73,10 +75,12 @@ class EmbeddingModelRunner(
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
virtual_engine = model_input.virtual_engine
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[graph_batch_size]
|
||||
model_executable = self.graph_runners[virtual_engine][
|
||||
graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
@ -115,6 +119,7 @@ class EmbeddingModelRunner(
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
virtual_engine: int = 0,
|
||||
) -> ModelInputForGPUWithPoolingMetadata:
|
||||
assert seq_group_metadata_list is not None
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
|
@ -8,6 +8,7 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
@ -25,6 +26,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
@ -37,7 +39,8 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models.interfaces import supports_lora
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
|
||||
is_pin_memory_available, make_tensor_with_pad)
|
||||
from vllm.worker.model_runner_base import (
|
||||
@ -81,6 +84,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
lora_requests: Optional[Set[LoRARequest]] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||
virtual_engine: int = 0
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
@ -89,6 +93,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
return tensor_dict
|
||||
@ -122,6 +127,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
@ -179,7 +185,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
|
||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||
|
||||
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
|
||||
{} for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.graph_memory_pool: Optional[Tuple[
|
||||
int, int]] = None # Set during graph capture.
|
||||
# When using CUDA graph, the input block tables must be padded to
|
||||
@ -787,9 +796,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
max_num_seqs = min(
|
||||
max_num_seqs,
|
||||
int(max_num_batched_tokens / vlm_config.image_feature_size))
|
||||
batch_size = 0
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
batch_size += seq_len
|
||||
|
||||
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
|
||||
.dummy_data_for_profiling(model_config, seq_len)
|
||||
@ -811,7 +822,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [None] * num_layers
|
||||
model_input = self.prepare_model_input(seqs)
|
||||
self.execute_model(model_input, kv_caches)
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=batch_size,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
torch.cuda.synchronize()
|
||||
return
|
||||
|
||||
@ -847,7 +864,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
return self.lora_manager.list_loras()
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
|
||||
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
||||
"""Cuda graph capture a model.
|
||||
|
||||
Note that CUDA graph's performance gain is negligible if number
|
||||
@ -880,10 +897,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
slot_mapping.fill_(_PAD_SLOT_ID)
|
||||
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
intermediate_inputs = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_inputs = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=max_batch_size,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
# Prepare buffer for outputs. These will be reused for all batch sizes.
|
||||
# It will be filled after the first graph capture.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
|
||||
None
|
||||
] * self.parallel_config.pipeline_parallel_size
|
||||
|
||||
graph_batch_size = _get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
@ -912,13 +937,17 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
with graph_capture() as graph_capture_context:
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for virtual_engine in range(
|
||||
self.parallel_config.pipeline_parallel_size):
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
indptr_buffer = indptr_buffer[:batch_size + 1]
|
||||
last_page_len_buffer = last_page_len_buffer[:batch_size]
|
||||
last_page_len_buffer = last_page_len_buffer[:
|
||||
batch_size]
|
||||
|
||||
num_qo_heads = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
num_qo_heads = (
|
||||
self.model_config.get_num_attention_heads(
|
||||
self.parallel_config))
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(
|
||||
self.parallel_config)
|
||||
if num_qo_heads // num_kv_heads >= 4:
|
||||
@ -927,8 +956,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
use_tensor_cores = False
|
||||
decode_wrapper = \
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
||||
decode_workspace_buffer, indptr_buffer, indices_buffer,
|
||||
last_page_len_buffer, "NHD", use_tensor_cores)
|
||||
decode_workspace_buffer, indptr_buffer,
|
||||
indices_buffer, last_page_len_buffer, "NHD",
|
||||
use_tensor_cores)
|
||||
kv_cache_dtype = get_kv_cache_torch_dtype(
|
||||
self.kv_cache_dtype, self.model_config.dtype)
|
||||
|
||||
@ -990,8 +1020,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
)
|
||||
self.set_active_loras(set(), lora_mapping)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model,
|
||||
self.attn_backend.get_name())
|
||||
graph_runner = CUDAGraphRunner(
|
||||
self.model, self.attn_backend.get_name())
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
graph_runner.flashinfer_indptr_buffer = indptr_buffer
|
||||
@ -1006,15 +1036,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
graph_runner.capture(
|
||||
input_tokens[:batch_size],
|
||||
input_positions[:batch_size],
|
||||
hidden_states[:batch_size]
|
||||
if hidden_states is not None else None,
|
||||
kv_caches,
|
||||
hidden_or_intermediate_states[
|
||||
virtual_engine] # type: ignore
|
||||
[:batch_size]
|
||||
if hidden_or_intermediate_states[virtual_engine]
|
||||
is not None else None,
|
||||
intermediate_inputs[:batch_size]
|
||||
if intermediate_inputs is not None else None,
|
||||
kv_caches[virtual_engine],
|
||||
attn_metadata,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
stream=graph_capture_context.stream,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[batch_size] = graph_runner
|
||||
self.graph_runners[virtual_engine][batch_size] = (
|
||||
graph_runner)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
elapsed_time = end_time - start_time
|
||||
@ -1047,6 +1083,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""Prepare the model input based on a given sequence group, including
|
||||
metadata for the sampling step.
|
||||
@ -1072,15 +1109,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
if seq_group_metadata_list else None)
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
is_prompt=is_prompt)
|
||||
is_prompt=is_prompt,
|
||||
virtual_engine=virtual_engine)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
||||
|
||||
@ -1124,27 +1163,34 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
# TODO(andoorve): We can remove this once all
|
||||
# virtual engines share the same kv cache.
|
||||
virtual_engine = model_input.virtual_engine
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[graph_batch_size]
|
||||
model_executable = self.graph_runners[virtual_engine][
|
||||
graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
hidden_states = model_executable(
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**multi_modal_kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
@ -1159,9 +1205,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
assert model_input.sampling_metadata is not None
|
||||
indices = model_input.sampling_metadata.selected_token_indices
|
||||
if model_input.is_prompt:
|
||||
hidden_states = hidden_states.index_select(0, indices)
|
||||
hidden_states = hidden_or_intermediate_states.index_select(
|
||||
0, indices)
|
||||
elif decode_meta.use_cuda_graph:
|
||||
hidden_states = hidden_states[:len(indices)]
|
||||
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
||||
else:
|
||||
hidden_states = hidden_or_intermediate_states
|
||||
|
||||
output.hidden_states = hidden_states
|
||||
|
||||
@ -1195,13 +1244,15 @@ class CUDAGraphRunner:
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
|
||||
torch.Tensor]],
|
||||
intermediate_inputs: Optional[IntermediateTensors],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
memory_pool: Optional[Tuple[int, int]],
|
||||
stream: torch.cuda.Stream,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
assert self._graph is None
|
||||
# Run the model a few times without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
@ -1213,6 +1264,7 @@ class CUDAGraphRunner:
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
@ -1220,18 +1272,27 @@ class CUDAGraphRunner:
|
||||
# Capture the graph.
|
||||
self._graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
||||
output_hidden_states = self.model(
|
||||
output_hidden_or_intermediate_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
if hidden_states is not None:
|
||||
hidden_states.copy_(output_hidden_states)
|
||||
if hidden_or_intermediate_states is not None:
|
||||
if get_pp_group().is_last_rank:
|
||||
hidden_or_intermediate_states.copy_(
|
||||
output_hidden_or_intermediate_states)
|
||||
else:
|
||||
hidden_states = output_hidden_states
|
||||
del output_hidden_states
|
||||
for key in hidden_or_intermediate_states.tensors:
|
||||
hidden_or_intermediate_states[key].copy_(
|
||||
output_hidden_or_intermediate_states[key])
|
||||
else:
|
||||
hidden_or_intermediate_states = (
|
||||
output_hidden_or_intermediate_states)
|
||||
|
||||
del output_hidden_or_intermediate_states
|
||||
# make sure `output_hidden_states` is deleted
|
||||
# in the graph's memory pool
|
||||
gc.collect()
|
||||
@ -1255,8 +1316,15 @@ class CUDAGraphRunner:
|
||||
attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
self.output_buffers = {"hidden_states": hidden_states}
|
||||
return hidden_states
|
||||
if intermediate_inputs is not None:
|
||||
self.input_buffers.update(intermediate_inputs.tensors)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.output_buffers = {
|
||||
"hidden_states": hidden_or_intermediate_states
|
||||
}
|
||||
else:
|
||||
self.output_buffers = hidden_or_intermediate_states
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1264,6 +1332,7 @@ class CUDAGraphRunner:
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# KV caches are fixed tensors, so we don't need to copy them.
|
||||
@ -1280,12 +1349,19 @@ class CUDAGraphRunner:
|
||||
non_blocking=True)
|
||||
self.input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
if intermediate_tensors is not None:
|
||||
for key in intermediate_tensors.tensors:
|
||||
self.input_buffers[key].copy_(intermediate_tensors[key],
|
||||
non_blocking=True)
|
||||
# Run the graph.
|
||||
self.graph.replay()
|
||||
|
||||
# Return the output tensor.
|
||||
if get_pp_group().is_last_rank:
|
||||
return self.output_buffers["hidden_states"]
|
||||
|
||||
return self.output_buffers
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
|
@ -5,7 +5,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
@ -137,6 +138,7 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
) -> T:
|
||||
"""
|
||||
Prepare the inputs to ModelRunnerBase.execute_model from an execution
|
||||
@ -150,6 +152,7 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
self,
|
||||
model_input: T,
|
||||
kv_caches: Optional[List[torch.Tensor]],
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""
|
||||
|
@ -9,7 +9,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader.neuron import get_neuron_model
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
|
||||
|
||||
@ -175,6 +176,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
) -> ModelInputForNeuron:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
@ -207,6 +209,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
self,
|
||||
model_input: ModelInputForNeuron,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
|
@ -80,7 +80,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
return False
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[torch.Tensor]]:
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
|
@ -59,9 +59,9 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.is_driver_worker:
|
||||
assert self.rank == 0, "The driver worker must have rank 0."
|
||||
|
||||
if parallel_config and is_driver_worker:
|
||||
assert rank % parallel_config.tensor_parallel_size == 0, \
|
||||
"Driver worker should be rank 0 of tensor parallel group."
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
@ -99,9 +99,9 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: CacheEngine
|
||||
self.cache_engine: List[CacheEngine]
|
||||
# Initialize gpu_cache as embedding models don't initialize kv_caches
|
||||
self.gpu_cache: Optional[List[torch.tensor]] = None
|
||||
self.gpu_cache: Optional[List[List[torch.tensor]]] = None
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "cuda":
|
||||
@ -217,10 +217,15 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
|
||||
def _init_cache_engine(self):
|
||||
assert self.cache_config.num_gpu_blocks is not None
|
||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config,
|
||||
self.device_config)
|
||||
self.gpu_cache = self.cache_engine.gpu_cache
|
||||
self.cache_engine = [
|
||||
CacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config, self.device_config)
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.gpu_cache = [
|
||||
self.cache_engine[ve].gpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
def _warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
@ -234,12 +239,13 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
return self.parallel_config.tensor_parallel_size > 1
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[torch.Tensor]]:
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
return self.gpu_cache
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||||
# they contain parameters to launch cudamemcpyasync.
|
||||
@ -261,20 +267,24 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
# Issue cache operations.
|
||||
if (worker_input.blocks_to_swap_in is not None
|
||||
and worker_input.blocks_to_swap_in.numel() > 0):
|
||||
self.cache_engine.swap_in(worker_input.blocks_to_swap_in)
|
||||
self.cache_engine[virtual_engine].swap_in(
|
||||
worker_input.blocks_to_swap_in)
|
||||
if (worker_input.blocks_to_swap_out is not None
|
||||
and worker_input.blocks_to_swap_out.numel() > 0):
|
||||
self.cache_engine.swap_out(worker_input.blocks_to_swap_out)
|
||||
self.cache_engine[virtual_engine].swap_out(
|
||||
worker_input.blocks_to_swap_out)
|
||||
if (worker_input.blocks_to_copy is not None
|
||||
and worker_input.blocks_to_copy.numel() > 0):
|
||||
self.cache_engine.copy(worker_input.blocks_to_copy)
|
||||
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
@ -6,10 +6,11 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SamplerOutput)
|
||||
from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
|
||||
update_environment_variables)
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
|
||||
@ -124,6 +125,7 @@ class WorkerInput:
|
||||
blocks_to_swap_in: Optional[torch.Tensor] = None
|
||||
blocks_to_swap_out: Optional[torch.Tensor] = None
|
||||
blocks_to_copy: Optional[torch.Tensor] = None
|
||||
virtual_engine: int = 0
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
@ -139,6 +141,7 @@ class WorkerInput:
|
||||
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
|
||||
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
|
||||
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
|
||||
virtual_engine=tensor_dict["virtual_engine"],
|
||||
)
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
@ -151,6 +154,7 @@ class WorkerInput:
|
||||
"blocks_to_swap_in": self.blocks_to_swap_in,
|
||||
"blocks_to_swap_out": self.blocks_to_swap_out,
|
||||
"blocks_to_copy": self.blocks_to_copy,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
}
|
||||
|
||||
return tensor_dict
|
||||
@ -181,11 +185,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def kv_cache(self) -> Optional[List[torch.Tensor]]:
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
"""
|
||||
Get the kv cache to pass to the worker's model runner. Used by the
|
||||
default `execute_model`. If the worker's model runner does not follow
|
||||
the ModelRunnerBase interface, then inherit from WorkerBase instead.
|
||||
Gets the list of kv caches to pass to the worker's model runner. Each
|
||||
element in the list is a kv cache corresponding to a particular virtual
|
||||
engine (PP stream). Used by the default `execute_model`. If the worker's
|
||||
model runner does not follow the ModelRunnerBase interface, then inherit
|
||||
from WorkerBase instead.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -227,7 +233,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list))
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine))
|
||||
num_steps = execute_model_req.num_steps
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
@ -255,9 +262,24 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
if worker_input.num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
return self.model_runner.execute_model(model_input, self.kv_cache,
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict())
|
||||
|
||||
output = self.model_runner.execute_model(
|
||||
model_input, self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None, intermediate_tensors,
|
||||
num_steps)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
get_pp_group().send_tensor_dict(output.tensors)
|
||||
return [None]
|
||||
|
||||
# Worker only supports single-step execution. Wrap the output in a
|
||||
# list to conform to interface.
|
||||
return output
|
||||
|
||||
|
||||
class WorkerWrapperBase:
|
||||
"""
|
||||
|
@ -12,7 +12,8 @@ from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
|
||||
from vllm.worker.model_runner_base import (
|
||||
@ -190,6 +191,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
) -> ModelInputForXPU:
|
||||
multi_modal_input = None
|
||||
if self.is_driver_worker:
|
||||
@ -334,6 +336,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
self,
|
||||
model_input: ModelInputForXPU,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
|
@ -85,8 +85,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: CacheEngine
|
||||
self.gpu_cache: List[torch.Tensor]
|
||||
self.cache_engine: List[CacheEngine]
|
||||
self.gpu_cache: Optional[List[List[torch.Tensor]]]
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "xpu" and is_xpu():
|
||||
|
Loading…
x
Reference in New Issue
Block a user