[V1][core] Implement pipeline parallel on Ray (#12996)
This commit is contained in:
parent
0ccd8769fb
commit
9605c1256e
@ -40,10 +40,23 @@ class PPTestOptions(NamedTuple):
|
||||
@dataclass
|
||||
class PPTestSettings:
|
||||
parallel_setups: List[ParallelSetup]
|
||||
# NOTE: the length of distributed_backends and
|
||||
# vllm_major_versions should be the same, and they
|
||||
# are first zipped together to iterate over all
|
||||
# test settings.
|
||||
distributed_backends: List[str]
|
||||
# vllm major version: "0" for V0, "1" for V1
|
||||
vllm_major_versions: List[str]
|
||||
task: TaskOption
|
||||
test_options: PPTestOptions
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.distributed_backends) != len(self.vllm_major_versions):
|
||||
raise ValueError(
|
||||
f"Length mismatch: distributed_backends "
|
||||
f"({len(self.distributed_backends)}) != "
|
||||
f"vllm_major_versions ({len(self.vllm_major_versions)})")
|
||||
|
||||
@staticmethod
|
||||
def detailed(
|
||||
*,
|
||||
@ -79,7 +92,9 @@ class PPTestSettings:
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
],
|
||||
distributed_backends=["mp", "ray"],
|
||||
# only ray is supported for V1
|
||||
distributed_backends=["mp", "ray", "ray"],
|
||||
vllm_major_versions=["0", "0", "1"],
|
||||
task=task,
|
||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -108,6 +123,7 @@ class PPTestSettings:
|
||||
chunked_prefill=False),
|
||||
],
|
||||
distributed_backends=["mp"],
|
||||
vllm_major_versions=["0"],
|
||||
task=task,
|
||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -120,8 +136,9 @@ class PPTestSettings:
|
||||
opts = self.test_options
|
||||
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for distributed_backend in self.distributed_backends:
|
||||
yield (model_name, parallel_setup, distributed_backend,
|
||||
for backend, vllm_major_version in zip(self.distributed_backends,
|
||||
self.vllm_major_versions):
|
||||
yield (model_name, parallel_setup, backend, vllm_major_version,
|
||||
self.task, opts)
|
||||
|
||||
|
||||
@ -244,6 +261,7 @@ def _compare_tp(
|
||||
model_name: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
task: TaskOption,
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available: int,
|
||||
@ -296,10 +314,13 @@ def _compare_tp(
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", hf_overrides])
|
||||
|
||||
if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
|
||||
and chunked_prefill):
|
||||
# Test Ray ADAG for a subset of the tests
|
||||
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
|
||||
if distributed_backend == "ray" and (vllm_major_version == "1"
|
||||
or specific_case):
|
||||
# For V1, test Ray ADAG for all the tests
|
||||
# For V0, test Ray ADAG for a subset of the tests
|
||||
pp_env = {
|
||||
"VLLM_USE_V1": vllm_major_version,
|
||||
"VLLM_USE_RAY_COMPILED_DAG": "1",
|
||||
"VLLM_USE_RAY_SPMD_WORKER": "1",
|
||||
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
|
||||
@ -348,8 +369,8 @@ def _compare_tp(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||
"test_options"),
|
||||
("model_name", "parallel_setup", "distributed_backend",
|
||||
"vllm_major_version", "task", "test_options"),
|
||||
[
|
||||
params for model_name, settings in TEXT_GENERATION_MODELS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
@ -361,6 +382,7 @@ def test_tp_language_generation(
|
||||
model_name: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
task: TaskOption,
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
@ -368,6 +390,7 @@ def test_tp_language_generation(
|
||||
_compare_tp(model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
task,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
@ -375,8 +398,8 @@ def test_tp_language_generation(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||
"test_options"),
|
||||
("model_name", "parallel_setup", "distributed_backend",
|
||||
"vllm_major_version", "task", "test_options"),
|
||||
[
|
||||
params for model_name, settings in EMBEDDING_MODELS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
@ -388,6 +411,7 @@ def test_tp_language_embedding(
|
||||
model_name: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
task: TaskOption,
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
@ -395,6 +419,7 @@ def test_tp_language_embedding(
|
||||
_compare_tp(model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
task,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
@ -402,8 +427,8 @@ def test_tp_language_embedding(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||
"test_options"),
|
||||
("model_name", "parallel_setup", "distributed_backend",
|
||||
"vllm_major_version", "task", "test_options"),
|
||||
[
|
||||
params for model_name, settings in MULTIMODAL_MODELS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
@ -415,6 +440,7 @@ def test_tp_multimodal_generation(
|
||||
model_name: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
task: TaskOption,
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
@ -422,6 +448,7 @@ def test_tp_multimodal_generation(
|
||||
_compare_tp(model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
task,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
|
@ -35,7 +35,7 @@ try:
|
||||
|
||||
class RayWorkerWrapper(WorkerWrapperBase):
|
||||
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
||||
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
||||
lazily initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -118,7 +118,14 @@ try:
|
||||
) -> "ModelRunnerOutput":
|
||||
self.setup_device_if_necessary()
|
||||
assert self.worker is not None, "Worker is not initialized"
|
||||
output = self.worker.model_runner.execute_model(scheduler_output)
|
||||
if isinstance(scheduler_output, tuple):
|
||||
scheduler_output, intermediate_tensors = scheduler_output
|
||||
else:
|
||||
scheduler_output, intermediate_tensors = scheduler_output, None
|
||||
output = self.worker.model_runner.execute_model(
|
||||
scheduler_output, intermediate_tensors)
|
||||
if isinstance(output, IntermediateTensors):
|
||||
output = scheduler_output, output
|
||||
return output
|
||||
|
||||
def override_env_vars(self, vars: Dict[str, str]):
|
||||
|
@ -488,7 +488,8 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
|
||||
|
||||
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
available_memory: int,
|
||||
num_layers: int) -> KVCacheConfig:
|
||||
"""
|
||||
Generates the KV cache configuration for a model with one type of KV cache.
|
||||
Divide the available memory equally among all layers.
|
||||
@ -497,6 +498,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
num_layers: The number of layers in the model.
|
||||
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
@ -506,7 +508,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
assert len(page_sizes) == 1
|
||||
page_size = page_sizes.pop()
|
||||
|
||||
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
|
||||
num_blocks = int(available_memory // page_size // num_layers)
|
||||
num_blocks = max(num_blocks, 0)
|
||||
|
||||
if vllm_config.cache_config.num_gpu_blocks_override is not None:
|
||||
@ -536,25 +538,36 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
def get_kv_cache_configs(vllm_config: VllmConfig,
|
||||
kv_cache_specs: List[KVCacheSpec],
|
||||
available_memory: int) -> List[KVCacheConfig]:
|
||||
"""
|
||||
Generates the KV cache configuration for a model
|
||||
TODO: support hybrid models with more than one type of KV cache.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of the model
|
||||
kv_cache_specs: The kv cache specs of the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
The generated KVCacheConfigs
|
||||
"""
|
||||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
|
||||
if is_kv_cache_type_uniform(kv_cache_spec):
|
||||
# KV cache of all layers are the same, which is true for most models.
|
||||
# Allocate the same amount of memory for each layer.
|
||||
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# Use the max number of layers to conservatively determine
|
||||
# the number of blocks.
|
||||
num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs)
|
||||
kv_cache_configs = []
|
||||
for kv_cache_spec in kv_cache_specs:
|
||||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
if is_kv_cache_type_uniform(kv_cache_spec):
|
||||
# KV cache of all layers are the same, which is true for
|
||||
# most models. Allocate the same amount of memory for
|
||||
# each layer.
|
||||
kv_cache_configs.append(
|
||||
_get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
|
||||
available_memory,
|
||||
num_layers))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return kv_cache_configs
|
||||
|
@ -16,7 +16,7 @@ from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
@ -73,20 +73,25 @@ class EngineCore:
|
||||
start = time.time()
|
||||
|
||||
# Get all kv cache needed by the model
|
||||
kv_cache_spec = self.model_executor.get_kv_cache_spec()
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
|
||||
# Profiles the peak memory usage of the model to determine how much
|
||||
# memory can be allocated for kv cache.
|
||||
availble_gpu_memory = self.model_executor.determine_available_memory()
|
||||
available_gpu_memory = self.model_executor.determine_available_memory()
|
||||
|
||||
# Get the kv cache tensor size
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
availble_gpu_memory)
|
||||
num_gpu_blocks = kv_cache_config.num_blocks
|
||||
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
|
||||
available_gpu_memory)
|
||||
num_gpu_blocks_set = set(config.num_blocks
|
||||
for config in kv_cache_configs)
|
||||
assert len(num_gpu_blocks_set) == 1, (
|
||||
f"num_gpu_blocks need to be the same across workers, "
|
||||
f"but they are different: {num_gpu_blocks_set}")
|
||||
num_gpu_blocks = num_gpu_blocks_set.pop()
|
||||
num_cpu_blocks = 0
|
||||
|
||||
# Initialize kv cache and warmup the execution
|
||||
self.model_executor.initialize(kv_cache_config)
|
||||
self.model_executor.initialize(kv_cache_configs)
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(("init engine (profile, create kv cache, "
|
||||
|
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Type
|
||||
from typing import List, Type
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
@ -48,12 +48,12 @@ class Executor(ExecutorBase):
|
||||
f"{distributed_executor_backend}")
|
||||
return executor_class
|
||||
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
|
||||
"""
|
||||
Initialize the KV caches and begin the model execution loop of the
|
||||
underlying workers.
|
||||
"""
|
||||
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
|
||||
self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
|
||||
def determine_available_memory(self) -> int: # in bytes
|
||||
@ -63,11 +63,9 @@ class Executor(ExecutorBase):
|
||||
# operators can be applied to all workers.
|
||||
return min(output)
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
def get_kv_cache_specs(self) -> List[KVCacheSpec]:
|
||||
output = self.collective_rpc("get_kv_cache_spec")
|
||||
for x in output:
|
||||
assert x == output[0]
|
||||
return output[0]
|
||||
return output
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
@ -21,6 +21,7 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, cdiv, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||
@ -773,6 +774,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> ModelRunnerOutput:
|
||||
batch_changed = self._update_states(scheduler_output)
|
||||
|
||||
@ -831,8 +833,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions=positions,
|
||||
kv_caches=self.kv_caches,
|
||||
attn_metadata=None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
@ -1007,12 +1012,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions = self.mrope_positions[:, :num_tokens]
|
||||
else:
|
||||
positions = self.positions[:num_tokens]
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
@ -1142,6 +1154,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.max_num_tokens,
|
||||
dummy_kv_caches)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||
|
@ -2,7 +2,7 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -194,8 +194,9 @@ class Worker:
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
kv_cache_config = kv_cache_configs[self.rank]
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
context = allocator.use_memory_pool(tag="kv_cache")
|
||||
|
Loading…
x
Reference in New Issue
Block a user