[torch.compile] Hide KV cache behind torch.compile boundary (#11677)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-01-10 13:14:42 +08:00 committed by GitHub
parent 3de2b1eafb
commit cf5f000d21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 198 additions and 44 deletions

View File

@ -142,12 +142,18 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
# Construct KV cache
kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
if test_pt.attn_type in (AttentionType.DECODER,
AttentionType.ENCODER_DECODER):
kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
else:
kv_cache = torch.tensor([])
attn.kv_cache = [kv_cache]
return TestResources(scale, attn, kv_cache)

View File

@ -7,9 +7,11 @@ import pytest
import torch
from vllm_test_utils import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
StoreBoolean, deprecate_kwargs, get_open_port,
memory_profiling, merge_async_iterators, supports_kw)
StoreBoolean, bind_kv_cache, deprecate_kwargs,
get_open_port, memory_profiling, merge_async_iterators,
supports_kw)
from .utils import error_on_warning, fork_new_process_for_each_test
@ -325,6 +327,85 @@ def test_memory_profiling():
lib.cudaFree(handle2)
def test_bind_kv_cache():
from vllm.attention import Attention
ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
'layers.1.self_attn': Attention(32, 128, 0.1),
'layers.2.self_attn': Attention(32, 128, 0.1),
'layers.3.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
]
bind_kv_cache(ctx, [kv_cache])
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
def test_bind_kv_cache_non_attention():
from vllm.attention import Attention
# example from Jamba PP=2
ctx = {
'model.layers.20.attn': Attention(32, 128, 0.1),
'model.layers.28.attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
]
bind_kv_cache(ctx, [kv_cache])
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0]
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]
def test_bind_kv_cache_encoder_decoder():
from vllm.attention import Attention, AttentionType
# example from bart
ctx = {
'encoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER),
'decoder.layers.0.encoder_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER),
'decoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.DECODER),
}
kv_cache = [
torch.zeros((1, )),
]
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache
bind_kv_cache(ctx, [kv_cache])
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]
def test_bind_kv_cache_pp():
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
with set_current_vllm_config(cfg):
from vllm.attention import Attention
ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
[torch.zeros((1, ))],
[torch.zeros((1, ))]
]
bind_kv_cache(ctx, kv_cache)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0]
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]
def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234")

View File

@ -4,6 +4,7 @@ import uuid
import pytest
from transformers import AutoTokenizer
from tests.utils import fork_new_process_for_each_test
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
@ -36,6 +37,7 @@ def make_request() -> EngineCoreRequest:
)
@fork_new_process_for_each_test
def test_engine_core(monkeypatch):
with monkeypatch.context() as m:
@ -138,6 +140,7 @@ def test_engine_core(monkeypatch):
assert len(engine_core.scheduler.running) == 0
@fork_new_process_for_each_test
def test_engine_core_advanced_sampling(monkeypatch):
"""
A basic end-to-end test to verify that the engine functions correctly

View File

@ -6,6 +6,7 @@ from typing import Dict, List
import pytest
from transformers import AutoTokenizer
from tests.utils import fork_new_process_for_each_test
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
@ -75,6 +76,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
break
@fork_new_process_for_each_test
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
@ -143,6 +145,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client.abort_requests([request.request_id])
@fork_new_process_for_each_test
@pytest.mark.asyncio
async def test_engine_core_client_asyncio(monkeypatch):

View File

@ -121,6 +121,13 @@ class Attention(nn.Module):
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.attn_type = attn_type
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([]) for _ in range(get_current_vllm_config(
).parallel_config.pipeline_parallel_size)
]
def forward(
self,
@ -148,11 +155,11 @@ class Attention(nn.Module):
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, kv_cache, self.layer_name)
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, self.layer_name)
self.layer_name)
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
@ -230,12 +237,12 @@ def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
self._k_scale, self._v_scale)
@ -244,7 +251,6 @@ def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
@ -253,7 +259,7 @@ def unified_attention_fake(
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
mutates_args=["kv_cache"],
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
@ -264,12 +270,12 @@ def unified_attention_with_output(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(query,
key,
value,
@ -285,7 +291,6 @@ def unified_attention_with_output_fake(
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> None:
return
@ -294,7 +299,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
mutates_args=["output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -2780,7 +2780,6 @@ class CompilationConfig(BaseModel):
compilation_time: float = PrivateAttr
# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr

View File

@ -2,7 +2,7 @@ import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional
import torch
@ -10,6 +10,9 @@ import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
logger = init_logger(__name__)
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
@ -21,9 +24,12 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@dataclass
class ForwardContext:
static_forward_context: Dict[str, Any]
# copy from vllm_config.compilation_config.static_forward_context
attn_layers: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context: Any
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
_forward_context: Optional[ForwardContext] = None
@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext:
@contextmanager
def set_forward_context(context: Any, vllm_config: VllmConfig):
def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global forward_start_time
need_to_track_batchsize = track_batchsize and context is not None
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
static_forward_context=vllm_config.compilation_config.
static_forward_context,
dynamic_forward_context=context)
attn_layers=vllm_config.compilation_config.static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata)
try:
yield
finally:
global batchsize_counter
global last_logging_time, batchsize_logging_interval
if need_to_track_batchsize:
if hasattr(context, "num_prefill_tokens"):
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = context.num_prefill_tokens + \
context.num_decode_tokens
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = context.num_input_tokens
batchsize = attn_metadata.num_input_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch

View File

@ -2138,3 +2138,38 @@ def get_mp_context():
_check_multiproc_method()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)
def bind_kv_cache(
ctx: Dict[str, Any],
kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index]
) -> None:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
# Special things handled here:
# 1. Some models have non-attention layers, e.g., Jamba
# 2. Pipeline parallelism, each rank only has a subset of layers
# 3. Encoder attention has no kv cache
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor
from vllm.attention import AttentionType
from vllm.model_executor.models.utils import extract_layer_index
layer_need_kv_cache = [
layer_name for layer_name in ctx
if ctx[layer_name].attn_type in (AttentionType.DECODER,
AttentionType.ENCODER_DECODER)
]
layer_index_sorted = sorted(
set(
extract_layer_index(layer_name)
for layer_name in layer_need_kv_cache))
for layer_name in layer_need_kv_cache:
kv_cache_idx = layer_index_sorted.index(
extract_layer_index(layer_name))
forward_ctx = ctx[layer_name]
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
assert forward_ctx.kv_cache[ve].numel() == 0
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]

View File

@ -16,7 +16,8 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
LayerBlockType, bind_kv_cache, cdiv,
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
@ -860,3 +861,6 @@ class GPUModelRunner:
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
bind_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
[self.kv_caches])

View File

@ -305,7 +305,8 @@ class CPUEncoderDecoderModelRunner(
intermediate_tensors,
}
with set_forward_context(model_input.attn_metadata, self.vllm_config):
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.

View File

@ -526,7 +526,8 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
execute_model_kwargs.update(
{"previous_hidden_states": previous_hidden_states})
with set_forward_context(model_input.attn_metadata, self.vllm_config):
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,

View File

@ -69,7 +69,8 @@ class CPUPoolingModelRunner(
intermediate_tensors,
}
with set_forward_context(model_input.attn_metadata, self.vllm_config):
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker.

View File

@ -13,7 +13,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
@ -293,6 +293,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.cache_engine[ve].cpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.cpu_cache)
self.model_runner.block_size = self.cache_engine[0].block_size
assert all(

View File

@ -175,7 +175,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
} if self.has_inner_state else {}
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
with set_forward_context(model_input.attn_metadata, self.vllm_config):
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,

View File

@ -1527,7 +1527,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self._update_inputs_to_capture_for_enc_dec_model(
capture_inputs)
with set_forward_context(attn_metadata, self.vllm_config):
with set_forward_context(attn_metadata, self.vllm_config,
virtual_engine):
graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = (
@ -1695,7 +1696,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if not bypass_model_exec:
with set_forward_context(model_input.attn_metadata,
self.vllm_config):
self.vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,

View File

@ -105,7 +105,8 @@ class PoolingModelRunner(
if model_input.token_types is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_types
with set_forward_context(model_input.attn_metadata, self.vllm_config):
with set_forward_context(model_input.attn_metadata, self.vllm_config,
virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,

View File

@ -21,7 +21,7 @@ from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import GiB_bytes, memory_profiling
from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
@ -285,6 +285,8 @@ class Worker(LocalOrDistributedWorkerBase):
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.gpu_cache)
def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager:

View File

@ -43,6 +43,7 @@ class WorkerBase(ABC):
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config
self.compilation_config = vllm_config.compilation_config
from vllm.platforms import current_platform
self.current_platform = current_platform