[V1][core] Implement pipeline parallel on Ray (#12996)

This commit is contained in:
Rui Qiao 2025-02-13 00:02:46 -08:00 committed by GitHub
parent 0ccd8769fb
commit 9605c1256e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 110 additions and 45 deletions

View File

@ -40,10 +40,23 @@ class PPTestOptions(NamedTuple):
@dataclass @dataclass
class PPTestSettings: class PPTestSettings:
parallel_setups: List[ParallelSetup] parallel_setups: List[ParallelSetup]
# NOTE: the length of distributed_backends and
# vllm_major_versions should be the same, and they
# are first zipped together to iterate over all
# test settings.
distributed_backends: List[str] distributed_backends: List[str]
# vllm major version: "0" for V0, "1" for V1
vllm_major_versions: List[str]
task: TaskOption task: TaskOption
test_options: PPTestOptions test_options: PPTestOptions
def __post_init__(self):
if len(self.distributed_backends) != len(self.vllm_major_versions):
raise ValueError(
f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})")
@staticmethod @staticmethod
def detailed( def detailed(
*, *,
@ -79,7 +92,9 @@ class PPTestSettings:
eager_mode=True, eager_mode=True,
chunked_prefill=False), chunked_prefill=False),
], ],
distributed_backends=["mp", "ray"], # only ray is supported for V1
distributed_backends=["mp", "ray", "ray"],
vllm_major_versions=["0", "0", "1"],
task=task, task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only, test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
@ -108,6 +123,7 @@ class PPTestSettings:
chunked_prefill=False), chunked_prefill=False),
], ],
distributed_backends=["mp"], distributed_backends=["mp"],
vllm_major_versions=["0"],
task=task, task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only, test_options=PPTestOptions(multi_node_only=multi_node_only,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
@ -120,8 +136,9 @@ class PPTestSettings:
opts = self.test_options opts = self.test_options
for parallel_setup in self.parallel_setups: for parallel_setup in self.parallel_setups:
for distributed_backend in self.distributed_backends: for backend, vllm_major_version in zip(self.distributed_backends,
yield (model_name, parallel_setup, distributed_backend, self.vllm_major_versions):
yield (model_name, parallel_setup, backend, vllm_major_version,
self.task, opts) self.task, opts)
@ -244,6 +261,7 @@ def _compare_tp(
model_name: str, model_name: str,
parallel_setup: ParallelSetup, parallel_setup: ParallelSetup,
distributed_backend: str, distributed_backend: str,
vllm_major_version: str,
task: TaskOption, task: TaskOption,
test_options: PPTestOptions, test_options: PPTestOptions,
num_gpus_available: int, num_gpus_available: int,
@ -296,10 +314,13 @@ def _compare_tp(
if hf_overrides: if hf_overrides:
common_args.extend(["--hf-overrides", hf_overrides]) common_args.extend(["--hf-overrides", hf_overrides])
if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2 specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
and chunked_prefill): if distributed_backend == "ray" and (vllm_major_version == "1"
# Test Ray ADAG for a subset of the tests or specific_case):
# For V1, test Ray ADAG for all the tests
# For V0, test Ray ADAG for a subset of the tests
pp_env = { pp_env = {
"VLLM_USE_V1": vllm_major_version,
"VLLM_USE_RAY_COMPILED_DAG": "1", "VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1", "VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1", "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
@ -348,8 +369,8 @@ def _compare_tp(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task", ("model_name", "parallel_setup", "distributed_backend",
"test_options"), "vllm_major_version", "task", "test_options"),
[ [
params for model_name, settings in TEXT_GENERATION_MODELS.items() params for model_name, settings in TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_name) for params in settings.iter_params(model_name)
@ -361,6 +382,7 @@ def test_tp_language_generation(
model_name: str, model_name: str,
parallel_setup: ParallelSetup, parallel_setup: ParallelSetup,
distributed_backend: str, distributed_backend: str,
vllm_major_version: str,
task: TaskOption, task: TaskOption,
test_options: PPTestOptions, test_options: PPTestOptions,
num_gpus_available, num_gpus_available,
@ -368,6 +390,7 @@ def test_tp_language_generation(
_compare_tp(model_name, _compare_tp(model_name,
parallel_setup, parallel_setup,
distributed_backend, distributed_backend,
vllm_major_version,
task, task,
test_options, test_options,
num_gpus_available, num_gpus_available,
@ -375,8 +398,8 @@ def test_tp_language_generation(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task", ("model_name", "parallel_setup", "distributed_backend",
"test_options"), "vllm_major_version", "task", "test_options"),
[ [
params for model_name, settings in EMBEDDING_MODELS.items() params for model_name, settings in EMBEDDING_MODELS.items()
for params in settings.iter_params(model_name) for params in settings.iter_params(model_name)
@ -388,6 +411,7 @@ def test_tp_language_embedding(
model_name: str, model_name: str,
parallel_setup: ParallelSetup, parallel_setup: ParallelSetup,
distributed_backend: str, distributed_backend: str,
vllm_major_version: str,
task: TaskOption, task: TaskOption,
test_options: PPTestOptions, test_options: PPTestOptions,
num_gpus_available, num_gpus_available,
@ -395,6 +419,7 @@ def test_tp_language_embedding(
_compare_tp(model_name, _compare_tp(model_name,
parallel_setup, parallel_setup,
distributed_backend, distributed_backend,
vllm_major_version,
task, task,
test_options, test_options,
num_gpus_available, num_gpus_available,
@ -402,8 +427,8 @@ def test_tp_language_embedding(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend", "task", ("model_name", "parallel_setup", "distributed_backend",
"test_options"), "vllm_major_version", "task", "test_options"),
[ [
params for model_name, settings in MULTIMODAL_MODELS.items() params for model_name, settings in MULTIMODAL_MODELS.items()
for params in settings.iter_params(model_name) for params in settings.iter_params(model_name)
@ -415,6 +440,7 @@ def test_tp_multimodal_generation(
model_name: str, model_name: str,
parallel_setup: ParallelSetup, parallel_setup: ParallelSetup,
distributed_backend: str, distributed_backend: str,
vllm_major_version: str,
task: TaskOption, task: TaskOption,
test_options: PPTestOptions, test_options: PPTestOptions,
num_gpus_available, num_gpus_available,
@ -422,6 +448,7 @@ def test_tp_multimodal_generation(
_compare_tp(model_name, _compare_tp(model_name,
parallel_setup, parallel_setup,
distributed_backend, distributed_backend,
vllm_major_version,
task, task,
test_options, test_options,
num_gpus_available, num_gpus_available,

View File

@ -35,7 +35,7 @@ try:
class RayWorkerWrapper(WorkerWrapperBase): class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be """Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -118,7 +118,14 @@ try:
) -> "ModelRunnerOutput": ) -> "ModelRunnerOutput":
self.setup_device_if_necessary() self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized" assert self.worker is not None, "Worker is not initialized"
output = self.worker.model_runner.execute_model(scheduler_output) if isinstance(scheduler_output, tuple):
scheduler_output, intermediate_tensors = scheduler_output
else:
scheduler_output, intermediate_tensors = scheduler_output, None
output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
return output return output
def override_env_vars(self, vars: Dict[str, str]): def override_env_vars(self, vars: Dict[str, str]):

View File

@ -488,7 +488,8 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
available_memory: int) -> KVCacheConfig: available_memory: int,
num_layers: int) -> KVCacheConfig:
""" """
Generates the KV cache configuration for a model with one type of KV cache. Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers. Divide the available memory equally among all layers.
@ -497,6 +498,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
vllm_config: The global VllmConfig vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model kv_cache_spec: The kv cache spec of the model
available_memory: Memory available for KV cache in bytes. available_memory: Memory available for KV cache in bytes.
num_layers: The number of layers in the model.
Returns: Returns:
The generated KVCacheConfig The generated KVCacheConfig
@ -506,7 +508,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
assert len(page_sizes) == 1 assert len(page_sizes) == 1
page_size = page_sizes.pop() page_size = page_sizes.pop()
num_blocks = int(available_memory // page_size // len(kv_cache_spec)) num_blocks = int(available_memory // page_size // num_layers)
num_blocks = max(num_blocks, 0) num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None: if vllm_config.cache_config.num_gpu_blocks_override is not None:
@ -536,25 +538,36 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return kv_cache_config return kv_cache_config
def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, def get_kv_cache_configs(vllm_config: VllmConfig,
available_memory: int) -> KVCacheConfig: kv_cache_specs: List[KVCacheSpec],
available_memory: int) -> List[KVCacheConfig]:
""" """
Generates the KV cache configuration for a model Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache. TODO: support hybrid models with more than one type of KV cache.
Args: Args:
vllm_config: The global VllmConfig vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model kv_cache_specs: The kv cache specs of the model
available_memory: Memory available for KV cache in bytes. available_memory: Memory available for KV cache in bytes.
Returns: Returns:
The generated KVCacheConfig The generated KVCacheConfigs
""" """
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) # Use the max number of layers to conservatively determine
if is_kv_cache_type_uniform(kv_cache_spec): # the number of blocks.
# KV cache of all layers are the same, which is true for most models. num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs)
# Allocate the same amount of memory for each layer. kv_cache_configs = []
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, for kv_cache_spec in kv_cache_specs:
check_enough_kv_cache_memory(vllm_config, kv_cache_spec,
available_memory) available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
kv_cache_configs.append(
_get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory,
num_layers))
else: else:
raise NotImplementedError raise NotImplementedError
return kv_cache_configs

View File

@ -16,7 +16,7 @@ from vllm.logger import init_logger
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.scheduler import Scheduler from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType) EngineCoreRequestType)
@ -73,20 +73,25 @@ class EngineCore:
start = time.time() start = time.time()
# Get all kv cache needed by the model # Get all kv cache needed by the model
kv_cache_spec = self.model_executor.get_kv_cache_spec() kv_cache_specs = self.model_executor.get_kv_cache_specs()
# Profiles the peak memory usage of the model to determine how much # Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache. # memory can be allocated for kv cache.
availble_gpu_memory = self.model_executor.determine_available_memory() available_gpu_memory = self.model_executor.determine_available_memory()
# Get the kv cache tensor size # Get the kv cache tensor size
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
availble_gpu_memory) available_gpu_memory)
num_gpu_blocks = kv_cache_config.num_blocks num_gpu_blocks_set = set(config.num_blocks
for config in kv_cache_configs)
assert len(num_gpu_blocks_set) == 1, (
f"num_gpu_blocks need to be the same across workers, "
f"but they are different: {num_gpu_blocks_set}")
num_gpu_blocks = num_gpu_blocks_set.pop()
num_cpu_blocks = 0 num_cpu_blocks = 0
# Initialize kv cache and warmup the execution # Initialize kv cache and warmup the execution
self.model_executor.initialize(kv_cache_config) self.model_executor.initialize(kv_cache_configs)
elapsed = time.time() - start elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, " logger.info(("init engine (profile, create kv cache, "

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Type from typing import List, Type
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
@ -48,12 +48,12 @@ class Executor(ExecutorBase):
f"{distributed_executor_backend}") f"{distributed_executor_backend}")
return executor_class return executor_class
def initialize(self, kv_cache_config: KVCacheConfig) -> None: def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
""" """
Initialize the KV caches and begin the model execution loop of the Initialize the KV caches and begin the model execution loop of the
underlying workers. underlying workers.
""" """
self.collective_rpc("initialize_cache", args=(kv_cache_config, )) self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model") self.collective_rpc("compile_or_warm_up_model")
def determine_available_memory(self) -> int: # in bytes def determine_available_memory(self) -> int: # in bytes
@ -63,11 +63,9 @@ class Executor(ExecutorBase):
# operators can be applied to all workers. # operators can be applied to all workers.
return min(output) return min(output)
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_specs(self) -> List[KVCacheSpec]:
output = self.collective_rpc("get_kv_cache_spec") output = self.collective_rpc("get_kv_cache_spec")
for x in output: return output
assert x == output[0]
return output[0]
def execute_model( def execute_model(
self, self,

View File

@ -12,7 +12,7 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
@ -21,6 +21,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available) LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
@ -773,6 +774,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def execute_model( def execute_model(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
batch_changed = self._update_states(scheduler_output) batch_changed = self._update_states(scheduler_output)
@ -831,8 +833,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions=positions, positions=positions,
kv_caches=self.kv_caches, kv_caches=self.kv_caches,
attn_metadata=None, attn_metadata=None,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
if not get_pp_group().is_last_rank:
return hidden_states
hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[:num_scheduled_tokens]
sample_hidden_states = hidden_states[logits_indices] sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
@ -1007,12 +1012,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions = self.mrope_positions[:, :num_tokens] positions = self.mrope_positions[:, :num_tokens]
else: else:
positions = self.positions[:num_tokens] positions = self.positions[:num_tokens]
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.model_config.dtype,
device=self.device)
with set_forward_context(None, self.vllm_config): with set_forward_context(None, self.vllm_config):
hidden_states = model( hidden_states = model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=None, attn_metadata=None,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
return hidden_states return hidden_states
@ -1142,6 +1154,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Trigger compilation for general shape. # Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens, hidden_states = self._dummy_run(self.max_num_tokens,
dummy_kv_caches) dummy_kv_caches)
if not get_pp_group().is_last_rank:
return hidden_states
hidden_states = hidden_states[logit_indices] hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None) logits = self.model.compute_logits(hidden_states, None)
# TODO(woosuk): Consider the memory usage of the sampler. # TODO(woosuk): Consider the memory usage of the sampler.

View File

@ -2,7 +2,7 @@
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import os import os
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, List, Optional
import torch import torch
import torch.distributed import torch.distributed
@ -194,8 +194,9 @@ class Worker:
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec() return self.model_runner.get_kv_cache_spec()
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config.""" """Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
if self.vllm_config.model_config.enable_sleep_mode: if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache") context = allocator.use_memory_pool(tag="kv_cache")