[torch.compile] Hide KV cache behind torch.compile boundary (#11677)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-01-10 13:14:42 +08:00 committed by GitHub
parent 3de2b1eafb
commit cf5f000d21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 198 additions and 44 deletions

View File

@ -142,12 +142,18 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE)) torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
# Construct KV cache # Construct KV cache
kv_cache = make_kv_cache(test_pt.num_blocks, if test_pt.attn_type in (AttentionType.DECODER,
test_pt.num_heads, AttentionType.ENCODER_DECODER):
test_pt.head_size, kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.block_size, test_pt.num_heads,
device=CUDA_DEVICE, test_pt.head_size,
backend=test_pt.backend_name) test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
else:
kv_cache = torch.tensor([])
attn.kv_cache = [kv_cache]
return TestResources(scale, attn, kv_cache) return TestResources(scale, attn, kv_cache)

View File

@ -7,9 +7,11 @@ import pytest
import torch import torch
from vllm_test_utils import monitor from vllm_test_utils import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule, from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
StoreBoolean, deprecate_kwargs, get_open_port, StoreBoolean, bind_kv_cache, deprecate_kwargs,
memory_profiling, merge_async_iterators, supports_kw) get_open_port, memory_profiling, merge_async_iterators,
supports_kw)
from .utils import error_on_warning, fork_new_process_for_each_test from .utils import error_on_warning, fork_new_process_for_each_test
@ -325,6 +327,85 @@ def test_memory_profiling():
lib.cudaFree(handle2) lib.cudaFree(handle2)
def test_bind_kv_cache():
from vllm.attention import Attention
ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
'layers.1.self_attn': Attention(32, 128, 0.1),
'layers.2.self_attn': Attention(32, 128, 0.1),
'layers.3.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
]
bind_kv_cache(ctx, [kv_cache])
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
def test_bind_kv_cache_non_attention():
from vllm.attention import Attention
# example from Jamba PP=2
ctx = {
'model.layers.20.attn': Attention(32, 128, 0.1),
'model.layers.28.attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
]
bind_kv_cache(ctx, [kv_cache])
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0]
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]
def test_bind_kv_cache_encoder_decoder():
from vllm.attention import Attention, AttentionType
# example from bart
ctx = {
'encoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER),
'decoder.layers.0.encoder_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER),
'decoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.DECODER),
}
kv_cache = [
torch.zeros((1, )),
]
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache
bind_kv_cache(ctx, [kv_cache])
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]
def test_bind_kv_cache_pp():
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
with set_current_vllm_config(cfg):
from vllm.attention import Attention
ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
[torch.zeros((1, ))],
[torch.zeros((1, ))]
]
bind_kv_cache(ctx, kv_cache)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0]
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]
def test_placeholder_module_error_handling(): def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234") placeholder = PlaceholderModule("placeholder_1234")

View File

@ -4,6 +4,7 @@ import uuid
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.utils import fork_new_process_for_each_test
from vllm import SamplingParams from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -36,6 +37,7 @@ def make_request() -> EngineCoreRequest:
) )
@fork_new_process_for_each_test
def test_engine_core(monkeypatch): def test_engine_core(monkeypatch):
with monkeypatch.context() as m: with monkeypatch.context() as m:
@ -138,6 +140,7 @@ def test_engine_core(monkeypatch):
assert len(engine_core.scheduler.running) == 0 assert len(engine_core.scheduler.running) == 0
@fork_new_process_for_each_test
def test_engine_core_advanced_sampling(monkeypatch): def test_engine_core_advanced_sampling(monkeypatch):
""" """
A basic end-to-end test to verify that the engine functions correctly A basic end-to-end test to verify that the engine functions correctly

View File

@ -6,6 +6,7 @@ from typing import Dict, List
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.utils import fork_new_process_for_each_test
from vllm import SamplingParams from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -75,6 +76,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
break break
@fork_new_process_for_each_test
@pytest.mark.parametrize("multiprocessing_mode", [True, False]) @pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
@ -143,6 +145,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client.abort_requests([request.request_id]) client.abort_requests([request.request_id])
@fork_new_process_for_each_test
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_engine_core_client_asyncio(monkeypatch): async def test_engine_core_client_asyncio(monkeypatch):

View File

@ -121,6 +121,13 @@ class Attention(nn.Module):
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix self.layer_name = prefix
self.attn_type = attn_type self.attn_type = attn_type
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([]) for _ in range(get_current_vllm_config(
).parallel_config.pipeline_parallel_size)
]
def forward( def forward(
self, self,
@ -148,11 +155,11 @@ class Attention(nn.Module):
if value is not None: if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
query, key, value, output, kv_cache, self.layer_name) query, key, value, output, self.layer_name)
return output.view(-1, hidden_size) return output.view(-1, hidden_size)
else: else:
return torch.ops.vllm.unified_attention(query, key, value, return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, self.layer_name) self.layer_name)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore s = f"head_size={self.impl.head_size}" # type: ignore
@ -230,12 +237,12 @@ def unified_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str, layer_name: str,
) -> torch.Tensor: ) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context attn_metadata = forward_context.attn_metadata
self = forward_context.static_forward_context[layer_name] self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(query, key, value, kv_cache, attn_metadata, return self.impl.forward(query, key, value, kv_cache, attn_metadata,
self._k_scale, self._v_scale) self._k_scale, self._v_scale)
@ -244,7 +251,6 @@ def unified_attention_fake(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str, layer_name: str,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(query).contiguous() return torch.empty_like(query).contiguous()
@ -253,7 +259,7 @@ def unified_attention_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="unified_attention", op_name="unified_attention",
op_func=unified_attention, op_func=unified_attention,
mutates_args=["kv_cache"], mutates_args=[],
fake_impl=unified_attention_fake, fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
@ -264,12 +270,12 @@ def unified_attention_with_output(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str, layer_name: str,
) -> None: ) -> None:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context attn_metadata = forward_context.attn_metadata
self = forward_context.static_forward_context[layer_name] self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(query, self.impl.forward(query,
key, key,
value, value,
@ -285,7 +291,6 @@ def unified_attention_with_output_fake(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str, layer_name: str,
) -> None: ) -> None:
return return
@ -294,7 +299,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op( direct_register_custom_op(
op_name="unified_attention_with_output", op_name="unified_attention_with_output",
op_func=unified_attention_with_output, op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"], mutates_args=["output"],
fake_impl=unified_attention_with_output_fake, fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )

View File

@ -2780,7 +2780,6 @@ class CompilationConfig(BaseModel):
compilation_time: float = PrivateAttr compilation_time: float = PrivateAttr
# Per-model forward context # Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls # Map from layer name to the attention cls
static_forward_context: Dict[str, Any] = PrivateAttr static_forward_context: Dict[str, Any] = PrivateAttr

View File

@ -2,7 +2,7 @@ import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import torch import torch
@ -10,6 +10,9 @@ import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
@ -21,9 +24,12 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@dataclass @dataclass
class ForwardContext: class ForwardContext:
static_forward_context: Dict[str, Any] # copy from vllm_config.compilation_config.static_forward_context
attn_layers: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context # TODO: extend to support per-layer dynamic forward context
dynamic_forward_context: Any attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
_forward_context: Optional[ForwardContext] = None _forward_context: Optional[ForwardContext] = None
@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext:
@contextmanager @contextmanager
def set_forward_context(context: Any, vllm_config: VllmConfig): def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
Here we can inject common logic for every model forward pass. Here we can inject common logic for every model forward pass.
""" """
global forward_start_time global forward_start_time
need_to_track_batchsize = track_batchsize and context is not None need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize: if need_to_track_batchsize:
forward_start_time = time.perf_counter() forward_start_time = time.perf_counter()
global _forward_context global _forward_context
prev_context = _forward_context prev_context = _forward_context
_forward_context = ForwardContext( _forward_context = ForwardContext(
static_forward_context=vllm_config.compilation_config. attn_layers=vllm_config.compilation_config.static_forward_context,
static_forward_context, virtual_engine=virtual_engine,
dynamic_forward_context=context) attn_metadata=attn_metadata)
try: try:
yield yield
finally: finally:
global batchsize_counter
global last_logging_time, batchsize_logging_interval global last_logging_time, batchsize_logging_interval
if need_to_track_batchsize: if need_to_track_batchsize:
if hasattr(context, "num_prefill_tokens"): if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends # for v0 attention backends
batchsize = context.num_prefill_tokens + \ batchsize = attn_metadata.num_prefill_tokens + \
context.num_decode_tokens attn_metadata.num_decode_tokens
else: else:
# for v1 attention backends # for v1 attention backends
batchsize = context.num_input_tokens batchsize = attn_metadata.num_input_tokens
# we use synchronous scheduling right now, # we use synchronous scheduling right now,
# adding a sync point here should not affect # adding a sync point here should not affect
# scheduling of the next batch # scheduling of the next batch

View File

@ -2138,3 +2138,38 @@ def get_mp_context():
_check_multiproc_method() _check_multiproc_method()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method) return multiprocessing.get_context(mp_method)
def bind_kv_cache(
ctx: Dict[str, Any],
kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index]
) -> None:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
# Special things handled here:
# 1. Some models have non-attention layers, e.g., Jamba
# 2. Pipeline parallelism, each rank only has a subset of layers
# 3. Encoder attention has no kv cache
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor
from vllm.attention import AttentionType
from vllm.model_executor.models.utils import extract_layer_index
layer_need_kv_cache = [
layer_name for layer_name in ctx
if ctx[layer_name].attn_type in (AttentionType.DECODER,
AttentionType.ENCODER_DECODER)
]
layer_index_sorted = sorted(
set(
extract_layer_index(layer_name)
for layer_name in layer_need_kv_cache))
for layer_name in layer_need_kv_cache:
kv_cache_idx = layer_index_sorted.index(
extract_layer_index(layer_name))
forward_ctx = ctx[layer_name]
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
assert forward_ctx.kv_cache[ve].numel() == 0
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]

View File

@ -16,7 +16,8 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available) LayerBlockType, bind_kv_cache, cdiv,
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata) FlashAttentionMetadata)
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
@ -860,3 +861,6 @@ class GPUModelRunner:
torch.zeros(kv_cache_shape, torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device)) device=self.device))
bind_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
[self.kv_caches])

View File

@ -305,7 +305,8 @@ class CPUEncoderDecoderModelRunner(
intermediate_tensors, intermediate_tensors,
} }
with set_forward_context(model_input.attn_metadata, self.vllm_config): with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits. # Compute the logits.

View File

@ -526,7 +526,8 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
execute_model_kwargs.update( execute_model_kwargs.update(
{"previous_hidden_states": previous_hidden_states}) {"previous_hidden_states": previous_hidden_states})
with set_forward_context(model_input.attn_metadata, self.vllm_config): with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable( hidden_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,

View File

@ -69,7 +69,8 @@ class CPUPoolingModelRunner(
intermediate_tensors, intermediate_tensors,
} }
with set_forward_context(model_input.attn_metadata, self.vllm_config): with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker. # Only perform pooling in the driver worker.

View File

@ -13,7 +13,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
@ -293,6 +293,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.cache_engine[ve].cpu_cache self.cache_engine[ve].cpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size) for ve in range(self.parallel_config.pipeline_parallel_size)
] ]
bind_kv_cache(self.compilation_config.static_forward_context,
self.cpu_cache)
self.model_runner.block_size = self.cache_engine[0].block_size self.model_runner.block_size = self.cache_engine[0].block_size
assert all( assert all(

View File

@ -175,7 +175,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
} if self.has_inner_state else {} } if self.has_inner_state else {}
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
with set_forward_context(model_input.attn_metadata, self.vllm_config): with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,

View File

@ -1527,7 +1527,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self._update_inputs_to_capture_for_enc_dec_model( self._update_inputs_to_capture_for_enc_dec_model(
capture_inputs) capture_inputs)
with set_forward_context(attn_metadata, self.vllm_config): with set_forward_context(attn_metadata, self.vllm_config,
virtual_engine):
graph_runner.capture(**capture_inputs) graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool() self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = ( self.graph_runners[virtual_engine][batch_size] = (
@ -1695,7 +1696,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if not bypass_model_exec: if not bypass_model_exec:
with set_forward_context(model_input.attn_metadata, with set_forward_context(model_input.attn_metadata,
self.vllm_config): self.vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,

View File

@ -105,7 +105,8 @@ class PoolingModelRunner(
if model_input.token_types is not None: if model_input.token_types is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_types cross_enc_kwargs["token_type_ids"] = model_input.token_types
with set_forward_context(model_input.attn_metadata, self.vllm_config): with set_forward_context(model_input.attn_metadata, self.vllm_config,
virtual_engine):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,

View File

@ -21,7 +21,7 @@ from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta) SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import GiB_bytes, memory_profiling from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
@ -285,6 +285,8 @@ class Worker(LocalOrDistributedWorkerBase):
self.cache_engine[ve].gpu_cache self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size) for ve in range(self.parallel_config.pipeline_parallel_size)
] ]
bind_kv_cache(self.compilation_config.static_forward_context,
self.gpu_cache)
def _warm_up_model(self) -> None: def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:

View File

@ -43,6 +43,7 @@ class WorkerBase(ABC):
self.prompt_adapter_config = vllm_config.prompt_adapter_config self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config self.kv_transfer_config = vllm_config.kv_transfer_config
self.compilation_config = vllm_config.compilation_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
self.current_platform = current_platform self.current_platform = current_platform