[V1][core] Implement pipeline parallel on Ray (#12996)

This commit is contained in:
Rui Qiao 2025-02-13 00:02:46 -08:00 committed by GitHub
parent 0ccd8769fb
commit 9605c1256e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 110 additions and 45 deletions

View File

@ -40,10 +40,23 @@ class PPTestOptions(NamedTuple):
@dataclass
class PPTestSettings:
parallel_setups: List[ParallelSetup]
# NOTE: the length of distributed_backends and
# vllm_major_versions should be the same, and they
# are first zipped together to iterate over all
# test settings.
distributed_backends: List[str]
# vllm major version: "0" for V0, "1" for V1
vllm_major_versions: List[str]
task: TaskOption
test_options: PPTestOptions
def __post_init__(self):
if len(self.distributed_backends) != len(self.vllm_major_versions):
raise ValueError(
f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})")
@staticmethod
def detailed(
*,
@ -79,7 +92,9 @@ class PPTestSettings:
eager_mode=True,
chunked_prefill=False),
],
distributed_backends=["mp", "ray"],
# only ray is supported for V1
distributed_backends=["mp", "ray", "ray"],
vllm_major_versions=["0", "0", "1"],
task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code,
@ -108,6 +123,7 @@ class PPTestSettings:
chunked_prefill=False),
],
distributed_backends=["mp"],
vllm_major_versions=["0"],
task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code,
@ -120,8 +136,9 @@ class PPTestSettings:
opts = self.test_options
for parallel_setup in self.parallel_setups:
for distributed_backend in self.distributed_backends:
yield (model_name, parallel_setup, distributed_backend,
for backend, vllm_major_version in zip(self.distributed_backends,
self.vllm_major_versions):
yield (model_name, parallel_setup, backend, vllm_major_version,
self.task, opts)
@ -244,6 +261,7 @@ def _compare_tp(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available: int,
@ -296,10 +314,13 @@ def _compare_tp(
if hf_overrides:
common_args.extend(["--hf-overrides", hf_overrides])
if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
and chunked_prefill):
# Test Ray ADAG for a subset of the tests
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
if distributed_backend == "ray" and (vllm_major_version == "1"
or specific_case):
# For V1, test Ray ADAG for all the tests
# For V0, test Ray ADAG for a subset of the tests
pp_env = {
"VLLM_USE_V1": vllm_major_version,
"VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
@ -348,8 +369,8 @@ def _compare_tp(
@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
("model_name", "parallel_setup", "distributed_backend",
"vllm_major_version", "task", "test_options"),
[
params for model_name, settings in TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_name)
@ -361,6 +382,7 @@ def test_tp_language_generation(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
@ -368,6 +390,7 @@ def test_tp_language_generation(
_compare_tp(model_name,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
@ -375,8 +398,8 @@ def test_tp_language_generation(
@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
("model_name", "parallel_setup", "distributed_backend",
"vllm_major_version", "task", "test_options"),
[
params for model_name, settings in EMBEDDING_MODELS.items()
for params in settings.iter_params(model_name)
@ -388,6 +411,7 @@ def test_tp_language_embedding(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
@ -395,6 +419,7 @@ def test_tp_language_embedding(
_compare_tp(model_name,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,
@ -402,8 +427,8 @@ def test_tp_language_embedding(
@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task",
"test_options"),
("model_name", "parallel_setup", "distributed_backend",
"vllm_major_version", "task", "test_options"),
[
params for model_name, settings in MULTIMODAL_MODELS.items()
for params in settings.iter_params(model_name)
@ -415,6 +440,7 @@ def test_tp_multimodal_generation(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
task: TaskOption,
test_options: PPTestOptions,
num_gpus_available,
@ -422,6 +448,7 @@ def test_tp_multimodal_generation(
_compare_tp(model_name,
parallel_setup,
distributed_backend,
vllm_major_version,
task,
test_options,
num_gpus_available,

View File

@ -35,7 +35,7 @@ try:
class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -118,7 +118,14 @@ try:
) -> "ModelRunnerOutput":
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
output = self.worker.model_runner.execute_model(scheduler_output)
if isinstance(scheduler_output, tuple):
scheduler_output, intermediate_tensors = scheduler_output
else:
scheduler_output, intermediate_tensors = scheduler_output, None
output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
return output
def override_env_vars(self, vars: Dict[str, str]):

View File

@ -488,7 +488,8 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec,
available_memory: int) -> KVCacheConfig:
available_memory: int,
num_layers: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
@ -497,6 +498,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
available_memory: Memory available for KV cache in bytes.
num_layers: The number of layers in the model.
Returns:
The generated KVCacheConfig
@ -506,7 +508,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
assert len(page_sizes) == 1
page_size = page_sizes.pop()
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
num_blocks = int(available_memory // page_size // num_layers)
num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None:
@ -536,25 +538,36 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return kv_cache_config
def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
available_memory: int) -> KVCacheConfig:
def get_kv_cache_configs(vllm_config: VllmConfig,
kv_cache_specs: List[KVCacheSpec],
available_memory: int) -> List[KVCacheConfig]:
"""
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_specs: The kv cache specs of the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for most models.
# Allocate the same amount of memory for each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)
else:
raise NotImplementedError
# Use the max number of layers to conservatively determine
# the number of blocks.
num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs)
kv_cache_configs = []
for kv_cache_spec in kv_cache_specs:
check_enough_kv_cache_memory(vllm_config, kv_cache_spec,
available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
kv_cache_configs.append(
_get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory,
num_layers))
else:
raise NotImplementedError
return kv_cache_configs

View File

@ -16,7 +16,7 @@ from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
@ -73,20 +73,25 @@ class EngineCore:
start = time.time()
# Get all kv cache needed by the model
kv_cache_spec = self.model_executor.get_kv_cache_spec()
kv_cache_specs = self.model_executor.get_kv_cache_specs()
# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
availble_gpu_memory = self.model_executor.determine_available_memory()
available_gpu_memory = self.model_executor.determine_available_memory()
# Get the kv cache tensor size
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
availble_gpu_memory)
num_gpu_blocks = kv_cache_config.num_blocks
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
available_gpu_memory)
num_gpu_blocks_set = set(config.num_blocks
for config in kv_cache_configs)
assert len(num_gpu_blocks_set) == 1, (
f"num_gpu_blocks need to be the same across workers, "
f"but they are different: {num_gpu_blocks_set}")
num_gpu_blocks = num_gpu_blocks_set.pop()
num_cpu_blocks = 0
# Initialize kv cache and warmup the execution
self.model_executor.initialize(kv_cache_config)
self.model_executor.initialize(kv_cache_configs)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Type
from typing import List, Type
from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
@ -48,12 +48,12 @@ class Executor(ExecutorBase):
f"{distributed_executor_backend}")
return executor_class
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model")
def determine_available_memory(self) -> int: # in bytes
@ -63,11 +63,9 @@ class Executor(ExecutorBase):
# operators can be applied to all workers.
return min(output)
def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_specs(self) -> List[KVCacheSpec]:
output = self.collective_rpc("get_kv_cache_spec")
for x in output:
assert x == output[0]
return output[0]
return output
def execute_model(
self,

View File

@ -12,7 +12,7 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
@ -21,6 +21,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
@ -773,6 +774,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> ModelRunnerOutput:
batch_changed = self._update_states(scheduler_output)
@ -831,8 +833,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions=positions,
kv_caches=self.kv_caches,
attn_metadata=None,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if not get_pp_group().is_last_rank:
return hidden_states
hidden_states = hidden_states[:num_scheduled_tokens]
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
@ -1007,12 +1012,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.model_config.dtype,
device=self.device)
with set_forward_context(None, self.vllm_config):
hidden_states = model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=None,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
@ -1142,6 +1154,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens,
dummy_kv_caches)
if not get_pp_group().is_last_rank:
return hidden_states
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
# TODO(woosuk): Consider the memory usage of the sampler.

View File

@ -2,7 +2,7 @@
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional
import torch
import torch.distributed
@ -194,8 +194,9 @@ class Worker:
def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec()
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")