[Core] Pipeline Parallel Support (#4412)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
Murali Andoorveedu 2024-07-02 10:58:08 -07:00 committed by GitHub
parent 15aba081f3
commit c5832d2ae9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
82 changed files with 1096 additions and 400 deletions

View File

@ -74,6 +74,16 @@ steps:
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
- label: Pipeline Parallelism Test
working_dir: "/vllm-workspace/tests"
num_gpus: 4
commands:
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- label: Engine Test - label: Engine Test
mirror_hardwares: [amd] mirror_hardwares: [amd]
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py

View File

@ -5,6 +5,7 @@ import pytest
import torch import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from ..utils import wait_for_gpu_memory_to_clear from ..utils import wait_for_gpu_memory_to_clear
@ -23,8 +24,11 @@ class MockEngine:
self.add_request_calls = 0 self.add_request_calls = 0
self.abort_request_calls = 0 self.abort_request_calls = 0
self.request_id = None self.request_id = None
# Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False)
async def step_async(self): async def step_async(self, virtual_engine):
# PP size is 1, ignore virtual engine
self.step_calls += 1 self.step_calls += 1
return [RequestOutput( return [RequestOutput(
request_id=self.request_id)] if self.request_id else [] request_id=self.request_id)] if self.request_id else []
@ -32,6 +36,9 @@ class MockEngine:
async def process_model_inputs_async(self, *args, **kwargs): async def process_model_inputs_async(self, *args, **kwargs):
pass pass
async def stop_remote_worker_execution_loop_async(self):
pass
def generate(self, request_id): def generate(self, request_id):
self.request_id = request_id self.request_id = request_id
@ -41,6 +48,7 @@ class MockEngine:
def add_request(self, **kwargs): def add_request(self, **kwargs):
del kwargs # Unused del kwargs # Unused
self.add_request_calls += 1 self.add_request_calls += 1
print(f'Request calls: {self.add_request_calls}')
async def add_request_async(self, **kwargs): async def add_request_async(self, **kwargs):
self.add_request_calls += 1 self.add_request_calls += 1
@ -53,6 +61,9 @@ class MockEngine:
def has_unfinished_requests(self): def has_unfinished_requests(self):
return self.request_id is not None return self.request_id is not None
def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
return self.request_id is not None
class MockAsyncLLMEngine(AsyncLLMEngine): class MockAsyncLLMEngine(AsyncLLMEngine):
@ -76,6 +87,7 @@ async def test_new_requests_event():
engine.engine.generate("2") engine.engine.generate("2")
await asyncio.sleep(0) await asyncio.sleep(0)
await asyncio.sleep(0) await asyncio.sleep(0)
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2 assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls >= 2 assert engine.engine.step_calls >= 2
await asyncio.sleep(0.001) await asyncio.sleep(0.001)

View File

@ -4,7 +4,7 @@ import pytest
# and debugging. # and debugging.
import ray import ray
from ..utils import RemoteOpenAIServer from ..utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m" MODEL_NAME = "facebook/opt-125m"
@ -12,7 +12,7 @@ MODEL_NAME = "facebook/opt-125m"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def ray_ctx():
ray.init() ray.init(runtime_env={"working_dir": VLLM_PATH})
yield yield
ray.shutdown() ray.shutdown()

View File

@ -56,8 +56,8 @@ def test_chunked_prefill_recompute(
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
ARTIFICIAL_PREEMPTION_MAX_CNT) < ARTIFICIAL_PREEMPTION_MAX_CNT)
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i] hf_output_ids, hf_output_str = hf_outputs[i]
@ -91,10 +91,10 @@ def test_preemption(
disable_log_stats=False, disable_log_stats=False,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
ARTIFICIAL_PREEMPTION_MAX_CNT) < ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = ( total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
@ -147,10 +147,10 @@ def test_swap(
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts, vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens) beam_width, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
ARTIFICIAL_PREEMPTION_MAX_CNT) < ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = ( total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i] hf_output_ids, _ = hf_outputs[i]
@ -214,8 +214,8 @@ def test_swap_infeasible(
example_prompts, example_prompts,
sampling_params=sampling_params, sampling_params=sampling_params,
) )
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
ARTIFICIAL_PREEMPTION_MAX_CNT) < ARTIFICIAL_PREEMPTION_MAX_CNT)
# Verify the request is ignored and not hang. # Verify the request is ignored and not hang.
assert req_outputs[0].outputs[0].finish_reason == "length" assert req_outputs[0].outputs[0].finish_reason == "length"
@ -252,8 +252,8 @@ def test_preemption_infeasible(
sampling_params=sampling_params, sampling_params=sampling_params,
) )
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
ARTIFICIAL_PREEMPTION_MAX_CNT) < ARTIFICIAL_PREEMPTION_MAX_CNT)
# Verify the request is ignored and not hang. # Verify the request is ignored and not hang.
for req_output in req_outputs: for req_output in req_outputs:

View File

@ -32,7 +32,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
(r + 1) for r in range(tp_size) (r + 1) for r in range(tp_size)
] ]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank] t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_reduce(t) t = tensor_model_parallel_all_reduce(t)
assert torch.allclose(t, expected) assert torch.allclose(t, expected)
@ -60,7 +60,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
for r in range(tp_size) for r in range(tp_size)
] ]
expected = torch.cat(all_tensors, dim=all_gather_dimension) expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank] t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_gather(t, all_gather_dimension) t = tensor_model_parallel_all_gather(t, all_gather_dimension)
assert torch.allclose(t, expected) assert torch.allclose(t, expected)
@ -91,7 +91,7 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
"f": torch.tensor([], dtype=torch.float32, device="cuda"), "f": torch.tensor([], dtype=torch.float32, device="cuda"),
} }
if rank == 0: if (rank % tp_size) == 0:
broadcast_tensor_dict(test_dict, src=0) broadcast_tensor_dict(test_dict, src=0)
else: else:
recv_dict = broadcast_tensor_dict(src=0) recv_dict = broadcast_tensor_dict(src=0)
@ -184,3 +184,17 @@ def test_multi_process_tensor_parallel(tp_size, test_target):
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]) "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
def test_multi_process_pipeline_parallel(pp_size, test_target): def test_multi_process_pipeline_parallel(pp_size, test_target):
multi_process_parallel(1, pp_size, test_target) multi_process_parallel(1, pp_size, test_target)
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize("test_target", [
send_recv_test_worker, send_recv_tensor_dict_test_worker,
all_reduce_test_worker, all_gather_test_worker,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel_pipeline_parallel(
tp_size, pp_size, test_target):
multi_process_parallel(tp_size, pp_size, test_target)

View File

@ -0,0 +1,149 @@
import os
import openai # use the official client for correctness check
import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
from ..utils import VLLM_PATH, RemoteOpenAIServer
# downloading lora to test lora requests
# any model with a chat template should work here
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0)))
CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0)))
TP_SIZE = int(os.getenv("TP_SIZE", 1))
PP_SIZE = int(os.getenv("PP_SIZE", 1))
pytestmark = pytest.mark.asyncio
@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
@pytest.fixture(scope="module")
def server(ray_ctx):
args = [
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--pipeline-parallel-size",
str(PP_SIZE),
"--tensor-parallel-size",
str(TP_SIZE),
"--distributed-executor-backend",
"ray",
]
if CHUNKED_PREFILL:
args += [
"--enable-chunked-prefill",
]
if EAGER_MODE:
args += [
"--enforce-eager",
]
return RemoteOpenAIServer(args, num_gpus=PP_SIZE * TP_SIZE)
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
served_model = models[0]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_single_completion(server, client: openai.AsyncOpenAI,
model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME],
)
async def test_batch_completions(server, client: openai.AsyncOpenAI,
model_name: str):
# test simple list
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]

View File

@ -32,7 +32,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
output_processor = MultiStepOutputProcessor( output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer, detokenizer=detokenizer,
scheduler=scheduler, scheduler=[scheduler],
seq_counter=seq_counter, seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(), get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker, stop_checker=stop_checker,
@ -86,7 +86,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
output_processor = MultiStepOutputProcessor( output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer, detokenizer=detokenizer,
scheduler=scheduler, scheduler=[scheduler],
seq_counter=seq_counter, seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(), get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker, stop_checker=stop_checker,
@ -148,7 +148,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
output_processor = MultiStepOutputProcessor( output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer, detokenizer=detokenizer,
scheduler=scheduler, scheduler=[scheduler],
seq_counter=seq_counter, seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker, stop_checker=stop_checker,
@ -215,7 +215,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
output_processor = MultiStepOutputProcessor( output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer, detokenizer=detokenizer,
scheduler=scheduler, scheduler=[scheduler],
seq_counter=seq_counter, seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker, stop_checker=stop_checker,

View File

@ -14,7 +14,7 @@ import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from openai import BadRequestError from openai import BadRequestError
from ...utils import RemoteOpenAIServer from ...utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@ -77,7 +77,7 @@ def zephyr_lora_files():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def ray_ctx():
ray.init() ray.init(runtime_env={"working_dir": VLLM_PATH})
yield yield
ray.shutdown() ray.shutdown()

View File

@ -16,7 +16,7 @@ from openai import BadRequestError
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer from ...utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@ -79,7 +79,7 @@ def zephyr_lora_files():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def ray_ctx():
ray.init() ray.init(runtime_env={"working_dir": VLLM_PATH})
yield yield
ray.shutdown() ray.shutdown()

View File

@ -5,14 +5,14 @@ import openai
import pytest import pytest
import ray import ray
from ...utils import RemoteOpenAIServer from ...utils import VLLM_PATH, RemoteOpenAIServer
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def ray_ctx():
ray.init() ray.init(runtime_env={"working_dir": VLLM_PATH})
yield yield
ray.shutdown() ray.shutdown()

View File

@ -6,7 +6,7 @@ import ray
# downloading lora to test lora requests # downloading lora to test lora requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from ...utils import RemoteOpenAIServer from ...utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@ -22,7 +22,7 @@ def zephyr_lora_files():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def ray_ctx():
ray.init() ray.init(runtime_env={"working_dir": VLLM_PATH})
yield yield
ray.shutdown() ray.shutdown()

View File

@ -24,13 +24,13 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def ray_ctx():
ray.init() ray.init(runtime_env={"working_dir": VLLM_PATH})
yield yield
ray.shutdown() ray.shutdown()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server(ray_ctx):
return RemoteOpenAIServer([ return RemoteOpenAIServer([
"--model", "--model",
MODEL_NAME, MODEL_NAME,

View File

@ -54,9 +54,9 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
return new_execute_model return new_execute_model
def zero_kv_cache(cache_engine: CacheEngine): def zero_kv_cache(cache_engine: List[CacheEngine]):
assert cache_engine.gpu_cache assert cache_engine[0].gpu_cache
for key_blocks, value_blocks in cache_engine.gpu_cache: for key_blocks, value_blocks in cache_engine[0].gpu_cache:
key_blocks.zero_() key_blocks.zero_()
value_blocks.zero_() value_blocks.zero_()

View File

@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
tensorize_vllm_model) tensorize_vllm_model)
from ..conftest import VllmRunner, cleanup from ..conftest import VllmRunner, cleanup
from ..utils import RemoteOpenAIServer from ..utils import VLLM_PATH, RemoteOpenAIServer
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring
@ -220,6 +220,8 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
json.dumps(model_loader_extra_config), json.dumps(model_loader_extra_config),
] ]
ray.init(runtime_env={"working_dir": VLLM_PATH})
server = RemoteOpenAIServer(openai_args) server = RemoteOpenAIServer(openai_args)
print("Server ready.") print("Server ready.")

View File

@ -49,7 +49,6 @@ class RemoteOpenAIServer:
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
@ray.remote(num_gpus=1)
class _RemoteRunner: class _RemoteRunner:
def __init__(self, cli_args: List[str], *, wait_url: str, def __init__(self, cli_args: List[str], *, wait_url: str,
@ -92,7 +91,11 @@ class RemoteOpenAIServer:
if hasattr(self, "proc"): if hasattr(self, "proc"):
self.proc.terminate() self.proc.terminate()
def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None: def __init__(self,
cli_args: List[str],
*,
auto_port: bool = True,
num_gpus: int = 1) -> None:
if auto_port: if auto_port:
if "-p" in cli_args or "--port" in cli_args: if "-p" in cli_args or "--port" in cli_args:
raise ValueError("You have manually specified the port" raise ValueError("You have manually specified the port"
@ -105,7 +108,8 @@ class RemoteOpenAIServer:
self.host = str(args.host or 'localhost') self.host = str(args.host or 'localhost')
self.port = int(args.port) self.port = int(args.port)
self._runner = self._RemoteRunner.remote( # type: ignore self._runner = ray.remote(num_gpus=num_gpus)(
self._RemoteRunner).remote(
cli_args, cli_args,
wait_url=self.url_for("health"), wait_url=self.url_for("health"),
wait_timeout=self.MAX_SERVER_START_WAIT_S) wait_timeout=self.MAX_SERVER_START_WAIT_S)

View File

@ -39,8 +39,8 @@ def test_swap() -> None:
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks) num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
# Randomly initialize the cache. # Randomly initialize the cache.
gpu_cache = worker.cache_engine.gpu_cache gpu_cache = worker.cache_engine[0].gpu_cache
cpu_cache = worker.cache_engine.cpu_cache cpu_cache = worker.cache_engine[0].cpu_cache
num_layers = len(gpu_cache) num_layers = len(gpu_cache)
for i in range(num_layers): for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i] gpu_key_cache, gpu_value_cache = gpu_cache[i]

View File

@ -27,6 +27,17 @@ logger = init_logger(__name__)
_GB = 1 << 30 _GB = 1 << 30
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_PP_SUPPORTED_MODELS = [
"AquilaModel",
"AquilaForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM",
"LLaMAForCausalLM",
"MistralForCausalLM",
"Phi3ForCausalLM",
"GPT2LMHeadModel",
]
class ModelConfig: class ModelConfig:
"""Configuration for the model. """Configuration for the model.
@ -258,6 +269,13 @@ class ModelConfig:
total_num_hidden_layers = getattr(self.hf_text_config, total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0) "num_hidden_layers", 0)
pipeline_parallel_size = parallel_config.pipeline_parallel_size pipeline_parallel_size = parallel_config.pipeline_parallel_size
architectures = getattr(self.hf_config, "architectures", [])
if not all(arch in _PP_SUPPORTED_MODELS
for arch in architectures) and pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.")
if total_num_hidden_layers % pipeline_parallel_size != 0: if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError( raise ValueError(
f"Total number of hidden layers ({total_num_hidden_layers}) " f"Total number of hidden layers ({total_num_hidden_layers}) "
@ -665,9 +683,10 @@ class ParallelConfig:
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1: if (self.pipeline_parallel_size > 1
raise NotImplementedError( and self.distributed_executor_backend == "mp"):
"Pipeline parallelism is not supported yet.") raise NotImplementedError("Pipeline parallelism is not supported "
"yet with multiprocessing.")
if self.distributed_executor_backend not in ("ray", "mp", None): if self.distributed_executor_backend not in ("ray", "mp", None):
raise ValueError( raise ValueError(
"Unrecognized distributed executor backend. Supported values " "Unrecognized distributed executor backend. Supported values "

View File

@ -471,6 +471,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block. # NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM. # Thus, it is always safe from OOM.
if parent_seq.seq_id not in self.block_tables:
# Parent sequence has either been freed or never existed.
return
src_block_table = self.block_tables[parent_seq.seq_id] src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy() self.block_tables[child_seq.seq_id] = src_block_table.copy()
# When using a sliding window, blocks will be eventually reused. # When using a sliding window, blocks will be eventually reused.

View File

@ -317,6 +317,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
computed_seq_block_ids) # type: ignore computed_seq_block_ids) # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
if parent_seq.seq_id not in self.block_tables:
# Parent sequence has either been freed or never existed.
return
src_block_table = self.block_tables[parent_seq.seq_id] src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork() self.block_tables[child_seq.seq_id] = src_block_table.fork()

View File

@ -256,6 +256,7 @@ class Scheduler:
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
@ -273,11 +274,19 @@ class Scheduler:
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version) version)
num_gpu_blocks = cache_config.num_gpu_blocks
if num_gpu_blocks:
num_gpu_blocks //= pipeline_parallel_size
num_cpu_blocks = cache_config.num_cpu_blocks
if num_cpu_blocks:
num_cpu_blocks //= pipeline_parallel_size
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManagerImpl( self.block_manager = BlockSpaceManagerImpl(
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
sliding_window=self.cache_config.sliding_window, sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching) enable_caching=self.cache_config.enable_prefix_caching)

View File

@ -416,7 +416,7 @@ class GroupCoordinator:
assert dst < self.world_size, f"Invalid dst rank ({dst})" assert dst < self.world_size, f"Invalid dst rank ({dst})"
assert dst != self.rank, ( assert dst != self.rank_in_group, (
"Invalid destination rank. Destination rank is the same " "Invalid destination rank. Destination rank is the same "
"as the current rank.") "as the current rank.")
@ -446,7 +446,7 @@ class GroupCoordinator:
assert src < self.world_size, f"Invalid src rank ({src})" assert src < self.world_size, f"Invalid src rank ({src})"
assert src != self.rank, ( assert src != self.rank_in_group, (
"Invalid source rank. Source rank is the same as the current rank." "Invalid source rank. Source rank is the same as the current rank."
) )
@ -454,7 +454,7 @@ class GroupCoordinator:
# Receive object size # Receive object size
rank_size = torch.distributed.recv(size_tensor, rank_size = torch.distributed.recv(size_tensor,
src=src, src=self.ranks[src],
group=self.cpu_group) group=self.cpu_group)
# Tensor to receive serialized objects into. # Tensor to receive serialized objects into.
@ -464,7 +464,7 @@ class GroupCoordinator:
device="cpu") device="cpu")
rank_object = torch.distributed.recv(object_tensor, rank_object = torch.distributed.recv(object_tensor,
src=src, src=self.ranks[src],
group=self.cpu_group) group=self.cpu_group)
assert rank_object == rank_size, ( assert rank_object == rank_size, (
@ -491,10 +491,9 @@ class GroupCoordinator:
group = self.device_group group = self.device_group
metadata_group = self.cpu_group metadata_group = self.cpu_group
assert src < self.world_size, f"Invalid src rank ({src})" assert src < self.world_size, f"Invalid src rank ({src})"
src = self.ranks[src]
rank = self.rank rank_in_group = self.rank_in_group
if rank == src: if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = [] metadata_list: List[Tuple[Any, Any]] = []
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
@ -512,13 +511,13 @@ class GroupCoordinator:
if tensor.is_cpu: if tensor.is_cpu:
# use metadata_group for CPU tensors # use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor, handle = torch.distributed.broadcast(tensor,
src=src, src=self.ranks[src],
group=metadata_group, group=metadata_group,
async_op=True) async_op=True)
else: else:
# use group for GPU tensors # use group for GPU tensors
handle = torch.distributed.broadcast(tensor, handle = torch.distributed.broadcast(tensor,
src=src, src=self.ranks[src],
group=group, group=group,
async_op=True) async_op=True)
async_handles.append(handle) async_handles.append(handle)
@ -542,13 +541,14 @@ class GroupCoordinator:
# use metadata_group for CPU tensors # use metadata_group for CPU tensors
handle = torch.distributed.broadcast( handle = torch.distributed.broadcast(
tensor, tensor,
src=src, src=self.ranks[src],
group=metadata_group, group=metadata_group,
async_op=True) async_op=True)
else: else:
# use group for GPU tensors # use group for GPU tensors
handle = torch.distributed.broadcast(tensor, handle = torch.distributed.broadcast(
src=src, tensor,
src=self.ranks[src],
group=group, group=group,
async_op=True) async_op=True)
async_handles.append(handle) async_handles.append(handle)
@ -575,7 +575,7 @@ class GroupCoordinator:
metadata_group = self.cpu_group metadata_group = self.cpu_group
if dst is None: if dst is None:
dst = self.next_rank dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})" assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = [] metadata_list: List[Tuple[Any, Any]] = []
@ -593,10 +593,14 @@ class GroupCoordinator:
continue continue
if tensor.is_cpu: if tensor.is_cpu:
# use metadata_group for CPU tensors # use metadata_group for CPU tensors
torch.distributed.send(tensor, dst=dst, group=metadata_group) torch.distributed.send(tensor,
dst=self.ranks[dst],
group=metadata_group)
else: else:
# use group for GPU tensors # use group for GPU tensors
torch.distributed.send(tensor, dst=dst, group=group) torch.distributed.send(tensor,
dst=self.ranks[dst],
group=group)
return None return None
def recv_tensor_dict( def recv_tensor_dict(
@ -614,7 +618,7 @@ class GroupCoordinator:
metadata_group = self.cpu_group metadata_group = self.cpu_group
if src is None: if src is None:
src = self.prev_rank src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})" assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src) recv_metadata_list = self.recv_object(src=src)
@ -631,11 +635,13 @@ class GroupCoordinator:
if tensor.is_cpu: if tensor.is_cpu:
# use metadata_group for CPU tensors # use metadata_group for CPU tensors
torch.distributed.recv(tensor, torch.distributed.recv(tensor,
src=src, src=self.ranks[src],
group=metadata_group) group=metadata_group)
else: else:
# use group for GPU tensors # use group for GPU tensors
torch.distributed.recv(tensor, src=src, group=group) torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
_update_nested_dict(tensor_dict, key, tensor) _update_nested_dict(tensor_dict, key, tensor)
else: else:
_update_nested_dict(tensor_dict, key, value) _update_nested_dict(tensor_dict, key, value)
@ -654,7 +660,7 @@ class GroupCoordinator:
"""Sends a tensor to the destination rank in a non-blocking way""" """Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank.""" """NOTE: `dst` is the local rank of the destination rank."""
if dst is None: if dst is None:
dst = self.next_rank dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled: if pynccl_comm is not None and not pynccl_comm.disabled:
@ -669,7 +675,7 @@ class GroupCoordinator:
"""Receives a tensor from the src rank.""" """Receives a tensor from the src rank."""
"""NOTE: `src` is the local rank of the destination rank.""" """NOTE: `src` is the local rank of the destination rank."""
if src is None: if src is None:
src = self.prev_rank src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device) tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm

View File

@ -2,7 +2,7 @@
# Adapted from # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Sequence from typing import Sequence, Tuple
import torch import torch
@ -46,3 +46,12 @@ def split_tensor_along_last_dim(
return tuple(chunk.contiguous() for chunk in tensor_list) return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list return tensor_list
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
pp_size: int) -> Tuple[int, int]:
layers_per_partition = divide(num_hidden_layers, pp_size)
start_layer = pp_rank * layers_per_partition
end_layer = start_layer + layers_per_partition
return (start_layer, end_layer)

View File

@ -211,7 +211,8 @@ class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods.""" """Extension of LLMEngine to add async methods."""
async def step_async( async def step_async(
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible. The workers are ran asynchronously if possible.
@ -221,7 +222,8 @@ class _AsyncLLMEngine(LLMEngine):
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
# Execute the model. # Execute the model.
@ -230,6 +232,7 @@ class _AsyncLLMEngine(LLMEngine):
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
virtual_engine=virtual_engine,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots, num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size, running_queue_size=scheduler_outputs.running_queue_size,
) )
@ -248,16 +251,12 @@ class _AsyncLLMEngine(LLMEngine):
# Tracing # Tracing
self.do_tracing(scheduler_outputs) self.do_tracing(scheduler_outputs)
if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
await self.model_executor.stop_remote_worker_execution_loop_async()
return request_outputs return request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def process_model_inputs_async( async def process_model_inputs_async(
self, self,
request_id: str, request_id: str,
@ -491,7 +490,8 @@ class AsyncLLMEngine:
# order of the arguments. # order of the arguments.
cache_config = kwargs["cache_config"] cache_config = kwargs["cache_config"]
parallel_config = kwargs["parallel_config"] parallel_config = kwargs["parallel_config"]
if parallel_config.tensor_parallel_size == 1: if (parallel_config.tensor_parallel_size == 1
and parallel_config.pipeline_parallel_size == 1):
num_gpus = cache_config.gpu_memory_utilization num_gpus = cache_config.gpu_memory_utilization
else: else:
num_gpus = 1 num_gpus = 1
@ -499,7 +499,7 @@ class AsyncLLMEngine:
self._engine_class).remote self._engine_class).remote
return engine_class(*args, **kwargs) return engine_class(*args, **kwargs)
async def engine_step(self) -> bool: async def engine_step(self, virtual_engine: int) -> bool:
"""Kick the engine to process the waiting requests. """Kick the engine to process the waiting requests.
Returns True if there are in-progress requests.""" Returns True if there are in-progress requests."""
@ -530,7 +530,7 @@ class AsyncLLMEngine:
if self.engine_use_ray: if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore request_outputs = await self.engine.step.remote() # type: ignore
else: else:
request_outputs = await self.engine.step_async() request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams. # Put the outputs into the corresponding streams.
for request_output in request_outputs: for request_output in request_outputs:
@ -546,18 +546,65 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids) self.engine.abort_request(request_ids)
async def run_engine_loop(self): async def run_engine_loop(self):
has_requests_in_progress = False if self.engine_use_ray:
pipeline_parallel_size = 1 # type: ignore
else:
pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size
while True: while True:
if not has_requests_in_progress: if not any(has_requests_in_progress):
logger.debug("Waiting for new requests...") logger.debug("Waiting for new requests...")
# Stop the execute model loop in parallel workers until there
# are more requests to process. This avoids waiting
# indefinitely in torch.distributed ops which may otherwise
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
if self.engine_use_ray:
await (self.engine.stop_remote_worker_execution_loop.
remote() # type: ignore
)
else:
await self.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests() await self._request_tracker.wait_for_new_requests()
logger.debug("Got new requests!") logger.debug("Got new requests!")
requests_in_progress = [
asyncio.create_task(self.engine_step(ve))
for ve in range(pipeline_parallel_size)
]
has_requests_in_progress = [True] * pipeline_parallel_size
# Abort if iteration takes too long due to unrecoverable errors # Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts). # (eg. NCCL timeouts).
try: try:
async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S): async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
has_requests_in_progress = await self.engine_step() done, _ = await asyncio.wait(
requests_in_progress,
return_when=asyncio.FIRST_COMPLETED)
for _ in range(pipeline_parallel_size):
await asyncio.sleep(0)
for task in done:
result = task.result()
virtual_engine = requests_in_progress.index(task)
if self.engine_use_ray:
has_unfinished_requests = (
await (self.engine.
has_unfinished_requests_for_virtual_engine.
remote( # type: ignore
virtual_engine)))
else:
has_unfinished_requests = (
self.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine))
if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = (
asyncio.create_task(
self.engine_step(virtual_engine)))
has_requests_in_progress[virtual_engine] = True
else:
has_requests_in_progress[virtual_engine] = False
except asyncio.TimeoutError as exc: except asyncio.TimeoutError as exc:
logger.error( logger.error(
"Engine iteration timed out. This should never happen!") "Engine iteration timed out. This should never happen!")

View File

@ -173,6 +173,7 @@ class LLMEngine:
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, quantization=%s, " "disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, " "enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, " "quantization_param_path=%s, device_config=%s, "
@ -195,6 +196,7 @@ class LLMEngine:
load_config.download_dir, load_config.download_dir,
load_config.load_format, load_config.load_format,
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce, parallel_config.disable_custom_all_reduce,
model_config.quantization, model_config.quantization,
model_config.enforce_eager, model_config.enforce_eager,
@ -296,7 +298,11 @@ class LLMEngine:
# Create the scheduler. # Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of # NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor. # GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) self.scheduler = [
Scheduler(scheduler_config, cache_config, lora_config,
parallel_config.pipeline_parallel_size)
for _ in range(parallel_config.pipeline_parallel_size)
]
# Metric Logging. # Metric Logging.
if self.log_stats: if self.log_stats:
@ -513,8 +519,16 @@ class LLMEngine:
raise ValueError( raise ValueError(
"Either SamplingParams or PoolingParams must be provided.") "Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler. # Add the sequence group to the scheduler with least unfinished seqs.
self.scheduler.add_seq_group(seq_group) costs = [
scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler
]
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
def process_model_inputs( def process_model_inputs(
self, self,
@ -684,7 +698,8 @@ class LLMEngine:
>>> # abort the request >>> # abort the request
>>> engine.abort_request(request_id) >>> engine.abort_request(request_id)
""" """
self.scheduler.abort_seq_group(request_id) for scheduler in self.scheduler:
scheduler.abort_seq_group(request_id)
def get_model_config(self) -> ModelConfig: def get_model_config(self) -> ModelConfig:
"""Gets the model configuration.""" """Gets the model configuration."""
@ -696,11 +711,20 @@ class LLMEngine:
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests.""" """Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups() return sum(scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler)
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests.""" """Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs() return any(scheduler.has_unfinished_seqs()
for scheduler in self.scheduler)
def has_unfinished_requests_for_virtual_engine(
self, virtual_engine: int) -> bool:
"""
Returns True if there are unfinished requests for the virtual engine.
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()
def _process_sequence_group_outputs( def _process_sequence_group_outputs(
self, self,
@ -749,7 +773,8 @@ class LLMEngine:
self.output_processor.process_outputs(seq_group, outputs) self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
self.scheduler.free_finished_seq_groups() for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# Create the outputs. # Create the outputs.
request_outputs: List[Union[RequestOutput, request_outputs: List[Union[RequestOutput,
@ -815,7 +840,12 @@ class LLMEngine:
>>> if not (engine.has_unfinished_requests() or example_inputs): >>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break >>> break
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if self.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise.")
seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
execute_model_req = ExecuteModelRequest( execute_model_req = ExecuteModelRequest(
@ -886,23 +916,28 @@ class LLMEngine:
# System State # System State
# Scheduler State # Scheduler State
num_running_sys = len(self.scheduler.running) num_running_sys = sum(
num_swapped_sys = len(self.scheduler.swapped) len(scheduler.running) for scheduler in self.scheduler)
num_waiting_sys = len(self.scheduler.waiting) num_swapped_sys = sum(
len(scheduler.swapped) for scheduler in self.scheduler)
num_waiting_sys = sum(
len(scheduler.waiting) for scheduler in self.scheduler)
# KV Cache Usage in % # KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks num_total_gpu = self.cache_config.num_gpu_blocks
gpu_cache_usage_sys = 0. gpu_cache_usage_sys = 0.
if num_total_gpu is not None: if num_total_gpu is not None:
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks( num_free_gpu = sum(
) scheduler.block_manager.get_num_free_gpu_blocks()
for scheduler in self.scheduler)
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage_sys = 0. cpu_cache_usage_sys = 0.
if num_total_cpu is not None and num_total_cpu > 0: if num_total_cpu is not None and num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( num_free_cpu = sum(
) scheduler.block_manager.get_num_free_cpu_blocks()
for scheduler in self.scheduler)
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
# Iteration stats # Iteration stats

View File

@ -27,7 +27,7 @@ class SequenceGroupOutputProcessor(ABC):
def create_output_processor( def create_output_processor(
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker", stop_checker: "StopChecker",

View File

@ -34,7 +34,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def __init__( def __init__(
self, self,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker, stop_checker: StopChecker,
@ -141,4 +141,5 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
break break
if seq.is_finished(): if seq.is_finished():
self.scheduler.free_seq(seq) for scheduler in self.scheduler:
scheduler.free_seq(seq)

View File

@ -33,7 +33,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
self, self,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: Scheduler, scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
stop_checker: StopChecker, stop_checker: StopChecker,
): ):
@ -95,7 +95,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# not be used in the future iterations. # not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id) seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent) for scheduler in self.scheduler:
scheduler.free_seq(parent)
continue continue
# Fork the parent sequence if there are multiple child samples. # Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]: for child_sample in child_samples[:-1]:
@ -133,7 +134,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
if seq is not parent: if seq is not parent:
seq_group.add(seq) seq_group.add(seq)
if not seq.is_finished(): if not seq.is_finished():
self.scheduler.fork_seq(parent, seq) for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block # Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output. # manager. Keep them in the sequence group as candidate output.
@ -141,7 +143,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# old sequences. # old sequences.
for seq, parent in child_seqs: for seq, parent in child_seqs:
if seq is parent and seq.is_finished(): if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq) for scheduler in self.scheduler:
scheduler.free_seq(seq)
return return
# Beam search case # Beam search case
@ -226,13 +229,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
if seq is not parent: if seq is not parent:
seq_group.add(seq) seq_group.add(seq)
if not seq.is_finished(): if not seq.is_finished():
self.scheduler.fork_seq(parent, seq) for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block # Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output. # manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs: for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished(): if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq) for scheduler in self.scheduler:
scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and # Remove the unselected parent sequences from the sequence group and
# free their memory in block manager. # free their memory in block manager.
@ -241,7 +246,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Remove the parent sequence if it is not selected for next # Remove the parent sequence if it is not selected for next
# iteration # iteration
seq_group.remove(seq.seq_id) seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq) for scheduler in self.scheduler:
scheduler.free_seq(seq)
def _check_beam_search_early_stopping( def _check_beam_search_early_stopping(
self, self,

View File

@ -69,7 +69,7 @@ class DistributedGPUExecutor(GPUExecutor):
if self.parallel_worker_tasks is None: if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers( self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop", "start_worker_execution_loop",
async_run_remote_workers_only=True, async_run_tensor_parallel_workers_only=True,
**self.extra_execute_model_run_workers_kwargs) **self.extra_execute_model_run_workers_kwargs)
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.
@ -138,17 +138,17 @@ class DistributedGPUExecutor(GPUExecutor):
self, self,
method: str, method: str,
*args, *args,
async_run_remote_workers_only: bool = False, async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers. """Runs the given method on all workers.
Args: Args:
async_run_remote_workers_only: If True the method will be run only async_run_tensor_parallel_workers_only: If True the method will be
in the remote workers, not the driver worker. It will also be run only in the remote TP workers, not the driver worker.
run asynchronously and return a list of futures rather than It will also be run asynchronously and return a list of futures
blocking on the results. rather than blocking on the results.
""" """
raise NotImplementedError raise NotImplementedError

View File

@ -1,3 +1,4 @@
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
@ -110,6 +111,30 @@ class ExecutorBase(ABC):
class ExecutorAsyncBase(ExecutorBase): class ExecutorAsyncBase(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
# This locks each pipeline parallel stage so multiple virtual engines
# can't execute on the same stage at the same time
self.pp_locks = [
asyncio.Lock()
for _ in range(parallel_config.pipeline_parallel_size)
]
super().__init__(model_config, cache_config, parallel_config,
scheduler_config, device_config, load_config,
lora_config, vision_language_config,
speculative_config)
@abstractmethod @abstractmethod
async def execute_model_async( async def execute_model_async(
self, self,

View File

@ -45,7 +45,8 @@ class GPUExecutor(ExecutorBase):
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
speculative_config=self.speculative_config, speculative_config=self.speculative_config,
is_driver_worker=rank == 0, is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
) )
def _create_worker(self, def _create_worker(self,

View File

@ -91,17 +91,17 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
self, self,
method: str, method: str,
*args, *args,
async_run_remote_workers_only: bool = False, async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers. """Runs the given method on all workers.
Args: Args:
async_run_remote_workers_only: If True the method will be run only async_run_tensor_parallel_workers_only: If True the method will be
in the remote workers, not the driver worker. It will also be run only in the remote TP workers, not the driver worker.
run asynchronously and return a list of futures rather than It will also be run asynchronously and return a list of futures
blocking on the results. rather than blocking on the results.
""" """
if max_concurrent_workers: if max_concurrent_workers:
@ -114,7 +114,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
for worker in self.workers for worker in self.workers
] ]
if async_run_remote_workers_only: if async_run_tensor_parallel_workers_only:
# Just return futures # Just return futures
return worker_outputs return worker_outputs

View File

@ -62,7 +62,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1: if (self.parallel_config.tensor_parallel_size == 1
and self.parallel_config.pipeline_parallel_size == 1):
# For single GPU case, we use a ray worker with constrained memory. # For single GPU case, we use a ray worker with constrained memory.
num_gpus = self.cache_config.gpu_memory_utilization num_gpus = self.cache_config.gpu_memory_utilization
else: else:
@ -189,6 +190,26 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers=self.parallel_config. max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers) max_parallel_loading_workers)
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
for tp_rank in range(self.parallel_config.tensor_parallel_size):
rank = (pp_rank *
self.parallel_config.tensor_parallel_size) + tp_rank
if rank == 0:
pass
elif rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(self.workers[rank - 1])
else:
self.non_driver_workers.append(self.workers[rank - 1])
def _driver_execute_model( def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest] self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
@ -204,7 +225,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
self, self,
method: str, method: str,
*args, *args,
async_run_remote_workers_only: bool = False, async_run_tensor_parallel_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None, all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False, use_dummy_driver: bool = False,
@ -215,10 +236,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
"""Runs the given method on all workers. Can be used in the following """Runs the given method on all workers. Can be used in the following
ways: ways:
- async_run_remote_workers_only: If True the method will be run only Args:
in the remote workers, not the driver worker. It will also be - async_run_tensor_parallel_workers_only: If True the method will be
run asynchronously and return a list of futures rather than blocking run only in the remote TP workers, not the driver worker.
on the results. It will also be run asynchronously and return a list of futures
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs - args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified - all_args/all_kwargs: args/kwargs for each worker are specified
individually individually
@ -228,7 +250,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
raise NotImplementedError( raise NotImplementedError(
"max_concurrent_workers is not supported yet.") "max_concurrent_workers is not supported yet.")
count = len(self.workers) count = len(self.workers) if not \
async_run_tensor_parallel_workers_only \
else len(self.non_driver_workers)
all_worker_args = repeat(args, count) if all_args is None \ all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None) else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
@ -242,14 +266,17 @@ class RayGPUExecutor(DistributedGPUExecutor):
ray_worker_outputs = [] ray_worker_outputs = []
else: else:
# Start the ray workers first. # Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [ ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, worker.execute_method.remote(method, *worker_args,
**worker_kwargs) **worker_kwargs)
for (worker, worker_args, worker_kwargs for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs) ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
] ]
if async_run_remote_workers_only: if async_run_tensor_parallel_workers_only:
# Just return futures # Just return futures
return ray_worker_outputs return ray_worker_outputs
@ -319,12 +346,32 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
return await self.driver_exec_method("execute_model",
execute_model_req) async def _run_task_with_lock(task, lock, *args, **kwargs):
async with lock:
return await task(*args, **kwargs)
tasks = []
tasks.append(
asyncio.create_task(
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
"execute_model", execute_model_req)))
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1):
tasks.append(
asyncio.create_task(
_run_task_with_lock(driver_worker.execute_method.remote,
self.pp_locks[pp_rank],
"execute_model", execute_model_req)))
results = await asyncio.gather(*tasks)
# Only the last PP stage has the final results.
return results[-1]
async def _start_worker_execution_loop(self): async def _start_worker_execution_loop(self):
coros = [ coros = [
worker.execute_method.remote("start_worker_execution_loop") worker.execute_method.remote("start_worker_execution_loop")
for worker in self.workers for worker in self.non_driver_workers
] ]
return await asyncio.gather(*coros) return await asyncio.gather(*coros)

View File

@ -29,7 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs.arctic import ArcticConfig from vllm.transformers_utils.configs.arctic import ArcticConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@ -426,6 +426,7 @@ class ArcticForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
@ -338,6 +338,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@ -286,6 +286,7 @@ class BloomForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -25,7 +25,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
@ -365,6 +365,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -46,7 +46,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
@torch.compile @torch.compile
@ -353,6 +353,7 @@ class CohereForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -23,7 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
@ -381,6 +381,7 @@ class DbrxForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class DeepseekMLP(nn.Module): class DeepseekMLP(nn.Module):
@ -387,6 +387,7 @@ class DeepseekForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -48,7 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
@ -475,6 +475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -44,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
FalconConfig = Union[HF_FalconConfig, RWConfig] FalconConfig = Union[HF_FalconConfig, RWConfig]
@ -410,6 +410,7 @@ class FalconForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,

View File

@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
@ -339,6 +339,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -37,7 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
@ -338,6 +338,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights.""" """Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -25,7 +25,9 @@ from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size)
from vllm.distributed.utils import get_pp_indices
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
@ -38,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class GPT2Attention(nn.Module): class GPT2Attention(nn.Module):
@ -181,9 +183,17 @@ class GPT2Model(nn.Module):
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([ self.start_layer, self.end_layer = get_pp_indices(
config.num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
self.h = nn.ModuleList(
[nn.Identity() for _ in range(self.start_layer)] + [
GPT2Block(config, cache_config, quant_config) GPT2Block(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(self.start_layer, self.end_layer)
] + [
nn.Identity()
for _ in range(self.end_layer, config.num_hidden_layers)
]) ])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@ -193,14 +203,24 @@ class GPT2Model(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds hidden_states = inputs_embeds + position_embeds
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.h)): for i in range(self.start_layer, self.end_layer):
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
@ -228,9 +248,10 @@ class GPT2LMHeadModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
@ -247,6 +268,16 @@ class GPT2LMHeadModel(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
@ -260,6 +291,7 @@ class GPT2LMHeadModel(nn.Module):
continue continue
if not name.startswith("transformer."): if not name.startswith("transformer."):
name = "transformer." + name name = "transformer." + name
try:
param = params_dict[name] param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights. # Because of this, we need to transpose the weights.
@ -273,3 +305,5 @@ class GPT2LMHeadModel(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
except KeyError:
continue

View File

@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
@ -273,6 +273,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -38,7 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class GPTJAttention(nn.Module): class GPTJAttention(nn.Module):
@ -239,6 +239,7 @@ class GPTJForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -38,7 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
@ -251,6 +251,7 @@ class GPTNeoXForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches, hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -22,7 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class InternLM2MLP(nn.Module): class InternLM2MLP(nn.Module):
@ -263,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import JAISConfig from vllm.transformers_utils.configs import JAISConfig
@ -289,6 +289,7 @@ class JAISLMHeadModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -29,7 +29,8 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_pp_indices,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -46,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader) default_weight_loader, kv_cache_scales_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import is_hip, print_warning_once from vllm.utils import is_hip, print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
@ -261,11 +262,19 @@ class LlamaModel(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer = get_pp_indices(
config.num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
self.layers = nn.ModuleList(
[nn.Identity() for _ in range(self.start_layer)] + [
LlamaDecoderLayer(config=config, LlamaDecoderLayer(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
for idx in range(config.num_hidden_layers) for _ in range(self.start_layer, self.end_layer)
] + [
nn.Identity()
for _ in range(self.end_layer, config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -278,22 +287,36 @@ class LlamaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
residual = None residual = None
for i in range(len(self.layers)): else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
@ -372,10 +395,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors] = None,
hidden_states = self.model(input_ids, positions, kv_caches, ) -> Union[torch.Tensor, IntermediateTensors]:
attn_metadata) model_output = self.model(input_ids, positions, kv_caches,
return hidden_states attn_metadata, intermediate_tensors)
return model_output
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
@ -391,6 +415,20 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
@ -416,9 +454,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
try:
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
except KeyError:
pass
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
@ -437,10 +478,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
continue continue
else: else:
name = remapped_kv_scale_name name = remapped_kv_scale_name
try:
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
except KeyError:
pass
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
@ -452,6 +496,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
quantization_param_path, tp_rank, tp_size, quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers, self.config.num_hidden_layers,
self.config.__class__.model_type): self.config.__class__.model_type):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip(): if is_hip():

View File

@ -18,7 +18,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision from .interfaces import SupportsVision
@ -202,6 +202,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> SamplerOutput:
"""Run forward pass for LLaVA-1.5. """Run forward pass for LLaVA-1.5.
@ -247,6 +248,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states

View File

@ -22,7 +22,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_patch_grid_length) get_clip_patch_grid_length)
@ -376,6 +376,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> SamplerOutput:
"""Run forward pass for LlaVA-NeXT. """Run forward pass for LlaVA-NeXT.
@ -430,6 +431,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states

View File

@ -50,7 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
@ -462,6 +462,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -51,7 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
@ -536,6 +536,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -47,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class MixtralMLP(nn.Module): class MixtralMLP(nn.Module):
@ -354,6 +354,7 @@ class MixtralForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -22,7 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
@ -273,6 +273,7 @@ class MPTForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class OlmoAttention(nn.Module): class OlmoAttention(nn.Module):
@ -301,6 +301,7 @@ class OlmoForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,

View File

@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTLearnedPositionalEmbedding(nn.Embedding):
@ -304,6 +304,7 @@ class OPTForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -26,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class OrionMLP(nn.Module): class OrionMLP(nn.Module):
@ -269,6 +269,7 @@ class OrionForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -57,7 +57,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
@ -278,6 +278,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -21,7 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
def load_column_parallel_weight(param: torch.nn.Parameter, def load_column_parallel_weight(param: torch.nn.Parameter,
@ -412,6 +412,7 @@ class Phi3SmallForCausalLM(nn.Module):
positions: Optional[torch.LongTensor], positions: Optional[torch.LongTensor],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
output_hidden_states = self.model( output_hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,

View File

@ -35,7 +35,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision from .interfaces import SupportsVision
@ -381,9 +381,13 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return None return None
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, **kwargs: object): attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object):
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
@ -398,6 +402,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states

View File

@ -27,7 +27,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
@ -245,6 +245,7 @@ class QWenLMHeadModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
@ -331,6 +331,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -50,7 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):
@ -397,6 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -41,7 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):
@ -250,6 +250,7 @@ class StablelmForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class Starcoder2Attention(nn.Module): class Starcoder2Attention(nn.Module):
@ -262,6 +262,7 @@ class Starcoder2ForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
@ -320,6 +320,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)

View File

@ -770,6 +770,34 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
return self.embeddings == other.embeddings return self.embeddings == other.embeddings
@dataclass
class IntermediateTensors:
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
"""
tensors: Dict[str, torch.Tensor]
def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
return self.tensors[key]
elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()})
def __setitem__(self, key: str, value):
self.tensors[key] = value
def __len__(self):
return len(self.tensors)
def __eq__(self, other: object):
return isinstance(other, self.__class__) and self
def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"
@dataclass @dataclass
class SamplerOutput: class SamplerOutput:
"""For each sequence group, we generate a list of SequenceOutput object, """For each sequence group, we generate a list of SequenceOutput object,
@ -896,6 +924,8 @@ class ExecuteModelRequest:
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
# Blocks to copy. Source to dest block. # Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
# Virtual engine ID for pipeline parallel.
virtual_engine: int = 0
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
# The number of requests in the running queue. # The number of requests in the running queue.
@ -914,6 +944,7 @@ class ExecuteModelRequest:
blocks_to_swap_in=self.blocks_to_swap_in.copy(), blocks_to_swap_in=self.blocks_to_swap_in.copy(),
blocks_to_swap_out=self.blocks_to_swap_out.copy(), blocks_to_swap_out=self.blocks_to_swap_out.copy(),
blocks_to_copy=self.blocks_to_copy.copy(), blocks_to_copy=self.blocks_to_copy.copy(),
virtual_engine=self.virtual_engine,
num_lookahead_slots=self.num_lookahead_slots, num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size, running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states, previous_hidden_states=self.previous_hidden_states,

View File

@ -6,7 +6,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner) ModelRunner)
@ -76,7 +77,7 @@ class TP1DraftModelRunner(ModelRunner):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInputForGPUWithSamplingMetadata: virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list """A temporary solution that caches the seq_group_metadata_list
for multi-step execution. for multi-step execution.
TODO: In-place update model_input and remove this function. TODO: In-place update model_input and remove this function.
@ -115,6 +116,7 @@ class TP1DraftModelRunner(ModelRunner):
self, self,
model_input: ModelInputForGPUWithSamplingMetadata, model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
# Since we do not broadcast data inside execute_model anymore, # Since we do not broadcast data inside execute_model anymore,
@ -130,6 +132,7 @@ class TP1DraftModelRunner(ModelRunner):
self.set_active_loras(model_input.lora_requests, self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping) model_input.lora_mapping)
virtual_engine = model_input.virtual_engine
outputs: List[SamplerOutput] = [] outputs: List[SamplerOutput] = []
for step in range(num_steps): for step in range(num_steps):
# Currently cuda graph is only supported by the decode phase. # Currently cuda graph is only supported by the decode phase.
@ -139,7 +142,8 @@ class TP1DraftModelRunner(ModelRunner):
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = (
self.graph_runners[virtual_engine][graph_batch_size])
else: else:
model_executable = self.model model_executable = self.model
@ -149,6 +153,7 @@ class TP1DraftModelRunner(ModelRunner):
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs, **multi_modal_kwargs,
) )

View File

@ -38,7 +38,11 @@ class CacheEngine:
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_gpu_blocks = cache_config.num_gpu_blocks
if self.num_gpu_blocks:
self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
self.num_cpu_blocks = cache_config.num_cpu_blocks self.num_cpu_blocks = cache_config.num_cpu_blocks
if self.num_cpu_blocks:
self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
if cache_config.cache_dtype == "auto": if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype self.dtype = model_config.dtype

View File

@ -13,7 +13,8 @@ from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerBase, ModelRunnerInputBase,
@ -315,6 +316,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> CPUModelInput: ) -> CPUModelInput:
multi_modal_kwargs = None multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
@ -351,6 +353,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self, self,
model_input: CPUModelInput, model_input: CPUModelInput,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
if num_steps > 1: if num_steps > 1:

View File

@ -167,8 +167,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: CPUCacheEngine self.cache_engine: List[CPUCacheEngine]
self.cpu_cache: List[torch.Tensor] self.cpu_cache: List[List[torch.Tensor]]
def init_device(self) -> None: def init_device(self) -> None:
self.init_distributed_environment() self.init_distributed_environment()
@ -242,17 +242,24 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"initializing the engine.") "initializing the engine.")
def _init_cache_engine(self) -> None: def _init_cache_engine(self) -> None:
self.cache_engine = CPUCacheEngine(self.cache_config, self.cache_engine = [
self.model_config, CPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.parallel_config, self.device_config)
self.device_config) for _ in range(self.parallel_config.pipeline_parallel_size)
self.cpu_cache = self.cache_engine.cpu_cache ]
self.model_runner.block_size = self.cache_engine.block_size self.cpu_cache = [
self.cache_engine[ve].cpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
self.model_runner.block_size = self.cache_engine[0].block_size
assert self.cpu_cache is not None assert all(
self.cpu_cache[ve] is not None
for ve in range(self.parallel_config.pipeline_parallel_size))
# Populate the cache to warmup the memory # Populate the cache to warmup the memory
for layer_cache in self.cpu_cache: for ve in range(self.parallel_config.pipeline_parallel_size):
for layer_cache in self.cpu_cache[ve]:
layer_cache.fill_(0) layer_cache.fill_(0)
@property @property
@ -260,7 +267,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return self.parallel_config.tensor_parallel_size > 1 return self.parallel_config.tensor_parallel_size > 1
@property @property
def kv_cache(self) -> Optional[List[torch.Tensor]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.cpu_cache return self.cpu_cache
def execute_worker( def execute_worker(
@ -269,12 +276,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
) -> None: ) -> None:
if (worker_input.blocks_to_copy is not None if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0): and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy) self.cache_engine[worker_input.virtual_engine].copy(
worker_input.blocks_to_copy)
@torch.inference_mode() @torch.inference_mode()
def prepare_worker_input( def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput: self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None assert execute_model_req is not None
virtual_engine = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = execute_model_req.blocks_to_copy blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
@ -285,6 +294,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return WorkerInput( return WorkerInput(
num_seq_groups=num_seq_groups, num_seq_groups=num_seq_groups,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
) )
def init_distributed_environment(self) -> None: def init_distributed_environment(self) -> None:

View File

@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
logger = init_logger(__name__) logger = init_logger(__name__)
@ -57,6 +58,7 @@ class EmbeddingModelRunner(
self, self,
model_input: ModelInputForGPUWithPoolingMetadata, model_input: ModelInputForGPUWithPoolingMetadata,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[PoolerOutput]]: ) -> Optional[List[PoolerOutput]]:
if num_steps > 1: if num_steps > 1:
@ -73,10 +75,12 @@ class EmbeddingModelRunner(
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata decode_meta = model_input.attn_metadata.decode_metadata
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
else: else:
model_executable = self.model model_executable = self.model
@ -115,6 +119,7 @@ class EmbeddingModelRunner(
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
) -> ModelInputForGPUWithPoolingMetadata: ) -> ModelInputForGPUWithPoolingMetadata:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors( model_input = self._prepare_model_input_tensors(

View File

@ -8,6 +8,7 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
import numpy as np import numpy as np
import torch import torch
import torch.distributed
import torch.nn as nn import torch.nn as nn
try: try:
@ -25,6 +26,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
@ -37,7 +39,8 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import supports_lora from vllm.model_executor.models.interfaces import supports_lora
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad) is_pin_memory_available, make_tensor_with_pad)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
@ -81,6 +84,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
lora_requests: Optional[Set[LoRARequest]] = None lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
virtual_engine: int = 0
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
@ -89,6 +93,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict return tensor_dict
@ -122,6 +127,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict, _add_sampling_metadata_broadcastable_dict(tensor_dict,
@ -179,7 +185,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.sliding_window = model_config.get_sliding_window() self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.graph_memory_pool: Optional[Tuple[ self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture. int, int]] = None # Set during graph capture.
# When using CUDA graph, the input block tables must be padded to # When using CUDA graph, the input block tables must be padded to
@ -787,9 +796,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_num_seqs = min( max_num_seqs = min(
max_num_seqs, max_num_seqs,
int(max_num_batched_tokens / vlm_config.image_feature_size)) int(max_num_batched_tokens / vlm_config.image_feature_size))
batch_size = 0
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len) .dummy_data_for_profiling(model_config, seq_len)
@ -811,7 +822,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers kv_caches = [None] * num_layers
model_input = self.prepare_model_input(seqs) model_input = self.prepare_model_input(seqs)
self.execute_model(model_input, kv_caches) intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize() torch.cuda.synchronize()
return return
@ -847,7 +864,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
return self.lora_manager.list_loras() return self.lora_manager.list_loras()
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[torch.Tensor]) -> None: def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"""Cuda graph capture a model. """Cuda graph capture a model.
Note that CUDA graph's performance gain is negligible if number Note that CUDA graph's performance gain is negligible if number
@ -880,10 +897,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
slot_mapping.fill_(_PAD_SLOT_ID) slot_mapping.fill_(_PAD_SLOT_ID)
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda()
intermediate_inputs = None
if not get_pp_group().is_first_rank:
intermediate_inputs = self.model.make_empty_intermediate_tensors(
batch_size=max_batch_size,
dtype=self.model_config.dtype,
device=self.device)
# Prepare buffer for outputs. These will be reused for all batch sizes. # Prepare buffer for outputs. These will be reused for all batch sizes.
# It will be filled after the first graph capture. # It will be filled after the first graph capture.
hidden_states: Optional[torch.Tensor] = None hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
None
] * self.parallel_config.pipeline_parallel_size
graph_batch_size = _get_graph_batch_size( graph_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs) self.scheduler_config.max_num_seqs)
@ -912,13 +937,17 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
with graph_capture() as graph_capture_context: with graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
for virtual_engine in range(
self.parallel_config.pipeline_parallel_size):
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
if self.attn_backend.get_name() == "flashinfer": if self.attn_backend.get_name() == "flashinfer":
indptr_buffer = indptr_buffer[:batch_size + 1] indptr_buffer = indptr_buffer[:batch_size + 1]
last_page_len_buffer = last_page_len_buffer[:batch_size] last_page_len_buffer = last_page_len_buffer[:
batch_size]
num_qo_heads = self.model_config.get_num_attention_heads( num_qo_heads = (
self.parallel_config) self.model_config.get_num_attention_heads(
self.parallel_config))
num_kv_heads = self.model_config.get_num_kv_heads( num_kv_heads = self.model_config.get_num_kv_heads(
self.parallel_config) self.parallel_config)
if num_qo_heads // num_kv_heads >= 4: if num_qo_heads // num_kv_heads >= 4:
@ -927,8 +956,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
use_tensor_cores = False use_tensor_cores = False
decode_wrapper = \ decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper( CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
decode_workspace_buffer, indptr_buffer, indices_buffer, decode_workspace_buffer, indptr_buffer,
last_page_len_buffer, "NHD", use_tensor_cores) indices_buffer, last_page_len_buffer, "NHD",
use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype( kv_cache_dtype = get_kv_cache_torch_dtype(
self.kv_cache_dtype, self.model_config.dtype) self.kv_cache_dtype, self.model_config.dtype)
@ -990,8 +1020,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
) )
self.set_active_loras(set(), lora_mapping) self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model, graph_runner = CUDAGraphRunner(
self.attn_backend.get_name()) self.model, self.attn_backend.get_name())
if self.attn_backend.get_name() == "flashinfer": if self.attn_backend.get_name() == "flashinfer":
graph_runner.flashinfer_indptr_buffer = indptr_buffer graph_runner.flashinfer_indptr_buffer = indptr_buffer
@ -1006,15 +1036,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
graph_runner.capture( graph_runner.capture(
input_tokens[:batch_size], input_tokens[:batch_size],
input_positions[:batch_size], input_positions[:batch_size],
hidden_states[:batch_size] hidden_or_intermediate_states[
if hidden_states is not None else None, virtual_engine] # type: ignore
kv_caches, [:batch_size]
if hidden_or_intermediate_states[virtual_engine]
is not None else None,
intermediate_inputs[:batch_size]
if intermediate_inputs is not None else None,
kv_caches[virtual_engine],
attn_metadata, attn_metadata,
memory_pool=self.graph_memory_pool, memory_pool=self.graph_memory_pool,
stream=graph_capture_context.stream, stream=graph_capture_context.stream,
) )
self.graph_memory_pool = graph_runner.graph.pool() self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner self.graph_runners[virtual_engine][batch_size] = (
graph_runner)
end_time = time.perf_counter() end_time = time.perf_counter()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
@ -1047,6 +1083,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> ModelInputForGPUWithSamplingMetadata: ) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including """Prepare the model input based on a given sequence group, including
metadata for the sampling step. metadata for the sampling step.
@ -1072,15 +1109,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if seq_group_metadata_list else None) if seq_group_metadata_list else None)
return dataclasses.replace(model_input, return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
is_prompt=is_prompt) is_prompt=is_prompt,
virtual_engine=virtual_engine)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
model_input: ModelInputForGPUWithSamplingMetadata, model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1: if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner") raise ValueError("num_steps > 1 is not supported in ModelRunner")
@ -1124,27 +1163,34 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata decode_meta = model_input.attn_metadata.decode_metadata
# TODO(andoorve): We can remove this once all
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
else: else:
model_executable = self.model model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
hidden_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs, **multi_modal_kwargs,
) )
# Compute the logits. # Compute the logits in the last pipeline stage.
logits = self.model.compute_logits(hidden_states, if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata) model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker: if not self.is_driver_worker:
return [] return []
@ -1159,9 +1205,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
assert model_input.sampling_metadata is not None assert model_input.sampling_metadata is not None
indices = model_input.sampling_metadata.selected_token_indices indices = model_input.sampling_metadata.selected_token_indices
if model_input.is_prompt: if model_input.is_prompt:
hidden_states = hidden_states.index_select(0, indices) hidden_states = hidden_or_intermediate_states.index_select(
0, indices)
elif decode_meta.use_cuda_graph: elif decode_meta.use_cuda_graph:
hidden_states = hidden_states[:len(indices)] hidden_states = hidden_or_intermediate_states[:len(indices)]
else:
hidden_states = hidden_or_intermediate_states
output.hidden_states = hidden_states output.hidden_states = hidden_states
@ -1195,13 +1244,15 @@ class CUDAGraphRunner:
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: Optional[torch.Tensor], hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
torch.Tensor]],
intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]], memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream, stream: torch.cuda.Stream,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
assert self._graph is None assert self._graph is None
# Run the model a few times without capturing the graph. # Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the # This is to make sure that the captured graph does not include the
@ -1213,6 +1264,7 @@ class CUDAGraphRunner:
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
intermediate_inputs,
**kwargs, **kwargs,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -1220,18 +1272,27 @@ class CUDAGraphRunner:
# Capture the graph. # Capture the graph.
self._graph = torch.cuda.CUDAGraph() self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
output_hidden_states = self.model( output_hidden_or_intermediate_states = self.model(
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
intermediate_inputs,
**kwargs, **kwargs,
) )
if hidden_states is not None: if hidden_or_intermediate_states is not None:
hidden_states.copy_(output_hidden_states) if get_pp_group().is_last_rank:
hidden_or_intermediate_states.copy_(
output_hidden_or_intermediate_states)
else: else:
hidden_states = output_hidden_states for key in hidden_or_intermediate_states.tensors:
del output_hidden_states hidden_or_intermediate_states[key].copy_(
output_hidden_or_intermediate_states[key])
else:
hidden_or_intermediate_states = (
output_hidden_or_intermediate_states)
del output_hidden_or_intermediate_states
# make sure `output_hidden_states` is deleted # make sure `output_hidden_states` is deleted
# in the graph's memory pool # in the graph's memory pool
gc.collect() gc.collect()
@ -1255,8 +1316,15 @@ class CUDAGraphRunner:
attn_metadata.decode_metadata.seq_lens_tensor, attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
} }
self.output_buffers = {"hidden_states": hidden_states} if intermediate_inputs is not None:
return hidden_states self.input_buffers.update(intermediate_inputs.tensors)
if get_pp_group().is_last_rank:
self.output_buffers = {
"hidden_states": hidden_or_intermediate_states
}
else:
self.output_buffers = hidden_or_intermediate_states
return hidden_or_intermediate_states
def forward( def forward(
self, self,
@ -1264,6 +1332,7 @@ class CUDAGraphRunner:
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# KV caches are fixed tensors, so we don't need to copy them. # KV caches are fixed tensors, so we don't need to copy them.
@ -1280,12 +1349,19 @@ class CUDAGraphRunner:
non_blocking=True) non_blocking=True)
self.input_buffers["block_tables"].copy_( self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True) attn_metadata.decode_metadata.block_tables, non_blocking=True)
if intermediate_tensors is not None:
for key in intermediate_tensors.tensors:
self.input_buffers[key].copy_(intermediate_tensors[key],
non_blocking=True)
# Run the graph. # Run the graph.
self.graph.replay() self.graph.replay()
# Return the output tensor. # Return the output tensor.
if get_pp_group().is_last_rank:
return self.output_buffers["hidden_states"] return self.output_buffers["hidden_states"]
return self.output_buffers
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)

View File

@ -5,7 +5,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import torch import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
@ -137,6 +138,7 @@ class ModelRunnerBase(ABC, Generic[T]):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> T: ) -> T:
""" """
Prepare the inputs to ModelRunnerBase.execute_model from an execution Prepare the inputs to ModelRunnerBase.execute_model from an execution
@ -150,6 +152,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self, self,
model_input: T, model_input: T,
kv_caches: Optional[List[torch.Tensor]], kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors],
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
""" """

View File

@ -9,7 +9,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
@ -175,6 +176,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> ModelInputForNeuron: ) -> ModelInputForNeuron:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
@ -207,6 +209,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self, self,
model_input: ModelInputForNeuron, model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None, kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
if num_steps > 1: if num_steps > 1:

View File

@ -80,7 +80,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return False return False
@property @property
def kv_cache(self) -> Optional[List[torch.Tensor]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None return None
@torch.inference_mode() @torch.inference_mode()

View File

@ -59,9 +59,9 @@ class Worker(LocalOrDistributedWorkerBase):
self.lora_config = lora_config self.lora_config = lora_config
self.load_config = load_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if parallel_config and is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
if self.model_config.trust_remote_code: if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
@ -99,9 +99,9 @@ class Worker(LocalOrDistributedWorkerBase):
) )
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: CacheEngine self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches # Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[torch.tensor]] = None self.gpu_cache: Optional[List[List[torch.tensor]]] = None
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
@ -217,10 +217,15 @@ class Worker(LocalOrDistributedWorkerBase):
def _init_cache_engine(self): def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.cache_engine = [
self.parallel_config, CacheEngine(self.cache_config, self.model_config,
self.device_config) self.parallel_config, self.device_config)
self.gpu_cache = self.cache_engine.gpu_cache for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.gpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
def _warm_up_model(self) -> None: def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
@ -234,12 +239,13 @@ class Worker(LocalOrDistributedWorkerBase):
return self.parallel_config.tensor_parallel_size > 1 return self.parallel_config.tensor_parallel_size > 1
@property @property
def kv_cache(self) -> Optional[List[torch.Tensor]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.gpu_cache return self.gpu_cache
@torch.inference_mode() @torch.inference_mode()
def prepare_worker_input( def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput: self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list) num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync. # they contain parameters to launch cudamemcpyasync.
@ -261,20 +267,24 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
) )
@torch.inference_mode() @torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None: def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
# Issue cache operations. # Issue cache operations.
if (worker_input.blocks_to_swap_in is not None if (worker_input.blocks_to_swap_in is not None
and worker_input.blocks_to_swap_in.numel() > 0): and worker_input.blocks_to_swap_in.numel() > 0):
self.cache_engine.swap_in(worker_input.blocks_to_swap_in) self.cache_engine[virtual_engine].swap_in(
worker_input.blocks_to_swap_in)
if (worker_input.blocks_to_swap_out is not None if (worker_input.blocks_to_swap_out is not None
and worker_input.blocks_to_swap_out.numel() > 0): and worker_input.blocks_to_swap_out.numel() > 0):
self.cache_engine.swap_out(worker_input.blocks_to_swap_out) self.cache_engine[virtual_engine].swap_out(
worker_input.blocks_to_swap_out)
if (worker_input.blocks_to_copy is not None if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0): and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy) self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request) return self.model_runner.add_lora(lora_request)

View File

@ -6,10 +6,11 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import torch import torch
from vllm.distributed import broadcast_tensor_dict from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.utils import (enable_trace_function_call_for_thread, is_hip, from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
update_environment_variables) update_environment_variables)
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
@ -124,6 +125,7 @@ class WorkerInput:
blocks_to_swap_in: Optional[torch.Tensor] = None blocks_to_swap_in: Optional[torch.Tensor] = None
blocks_to_swap_out: Optional[torch.Tensor] = None blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None
virtual_engine: int = 0
@classmethod @classmethod
def from_broadcasted_tensor_dict( def from_broadcasted_tensor_dict(
@ -139,6 +141,7 @@ class WorkerInput:
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
) )
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
@ -151,6 +154,7 @@ class WorkerInput:
"blocks_to_swap_in": self.blocks_to_swap_in, "blocks_to_swap_in": self.blocks_to_swap_in,
"blocks_to_swap_out": self.blocks_to_swap_out, "blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy, "blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
} }
return tensor_dict return tensor_dict
@ -181,11 +185,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
@property @property
@abstractmethod @abstractmethod
def kv_cache(self) -> Optional[List[torch.Tensor]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
""" """
Get the kv cache to pass to the worker's model runner. Used by the Gets the list of kv caches to pass to the worker's model runner. Each
default `execute_model`. If the worker's model runner does not follow element in the list is a kv cache corresponding to a particular virtual
the ModelRunnerBase interface, then inherit from WorkerBase instead. engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
""" """
raise NotImplementedError raise NotImplementedError
@ -227,7 +233,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = ( model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list)) execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine))
num_steps = execute_model_req.num_steps num_steps = execute_model_req.num_steps
if self.do_metadata_broadcast: if self.do_metadata_broadcast:
@ -255,9 +262,24 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if worker_input.num_seq_groups == 0: if worker_input.num_seq_groups == 0:
return [] return []
return self.model_runner.execute_model(model_input, self.kv_cache, intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict())
output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None, intermediate_tensors,
num_steps) num_steps)
if not get_pp_group().is_last_rank:
get_pp_group().send_tensor_dict(output.tensors)
return [None]
# Worker only supports single-step execution. Wrap the output in a
# list to conform to interface.
return output
class WorkerWrapperBase: class WorkerWrapperBase:
""" """

View File

@ -12,7 +12,8 @@ from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
@ -190,6 +191,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> ModelInputForXPU: ) -> ModelInputForXPU:
multi_modal_input = None multi_modal_input = None
if self.is_driver_worker: if self.is_driver_worker:
@ -334,6 +336,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self, self,
model_input: ModelInputForXPU, model_input: ModelInputForXPU,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
if num_steps > 1: if num_steps > 1:

View File

@ -85,8 +85,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
) )
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: CacheEngine self.cache_engine: List[CacheEngine]
self.gpu_cache: List[torch.Tensor] self.gpu_cache: Optional[List[List[torch.Tensor]]]
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "xpu" and is_xpu(): if self.device_config.device.type == "xpu" and is_xpu():