[torch.compile] Hide KV cache behind torch.compile boundary (#11677)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
3de2b1eafb
commit
cf5f000d21
@ -142,12 +142,18 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
||||
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
|
||||
|
||||
# Construct KV cache
|
||||
kv_cache = make_kv_cache(test_pt.num_blocks,
|
||||
test_pt.num_heads,
|
||||
test_pt.head_size,
|
||||
test_pt.block_size,
|
||||
device=CUDA_DEVICE,
|
||||
backend=test_pt.backend_name)
|
||||
if test_pt.attn_type in (AttentionType.DECODER,
|
||||
AttentionType.ENCODER_DECODER):
|
||||
kv_cache = make_kv_cache(test_pt.num_blocks,
|
||||
test_pt.num_heads,
|
||||
test_pt.head_size,
|
||||
test_pt.block_size,
|
||||
device=CUDA_DEVICE,
|
||||
backend=test_pt.backend_name)
|
||||
else:
|
||||
kv_cache = torch.tensor([])
|
||||
|
||||
attn.kv_cache = [kv_cache]
|
||||
return TestResources(scale, attn, kv_cache)
|
||||
|
||||
|
||||
|
@ -7,9 +7,11 @@ import pytest
|
||||
import torch
|
||||
from vllm_test_utils import monitor
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
|
||||
StoreBoolean, deprecate_kwargs, get_open_port,
|
||||
memory_profiling, merge_async_iterators, supports_kw)
|
||||
StoreBoolean, bind_kv_cache, deprecate_kwargs,
|
||||
get_open_port, memory_profiling, merge_async_iterators,
|
||||
supports_kw)
|
||||
|
||||
from .utils import error_on_warning, fork_new_process_for_each_test
|
||||
|
||||
@ -325,6 +327,85 @@ def test_memory_profiling():
|
||||
lib.cudaFree(handle2)
|
||||
|
||||
|
||||
def test_bind_kv_cache():
|
||||
from vllm.attention import Attention
|
||||
|
||||
ctx = {
|
||||
'layers.0.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.1.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.2.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.3.self_attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = [
|
||||
torch.zeros((1, )),
|
||||
torch.zeros((1, )),
|
||||
torch.zeros((1, )),
|
||||
torch.zeros((1, )),
|
||||
]
|
||||
bind_kv_cache(ctx, [kv_cache])
|
||||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
|
||||
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
|
||||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
|
||||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
|
||||
|
||||
def test_bind_kv_cache_non_attention():
|
||||
from vllm.attention import Attention
|
||||
|
||||
# example from Jamba PP=2
|
||||
ctx = {
|
||||
'model.layers.20.attn': Attention(32, 128, 0.1),
|
||||
'model.layers.28.attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = [
|
||||
torch.zeros((1, )),
|
||||
torch.zeros((1, )),
|
||||
]
|
||||
bind_kv_cache(ctx, [kv_cache])
|
||||
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0]
|
||||
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]
|
||||
|
||||
|
||||
def test_bind_kv_cache_encoder_decoder():
|
||||
from vllm.attention import Attention, AttentionType
|
||||
|
||||
# example from bart
|
||||
ctx = {
|
||||
'encoder.layers.0.self_attn.attn':
|
||||
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER),
|
||||
'decoder.layers.0.encoder_attn.attn':
|
||||
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER),
|
||||
'decoder.layers.0.self_attn.attn':
|
||||
Attention(32, 128, 0.1, attn_type=AttentionType.DECODER),
|
||||
}
|
||||
|
||||
kv_cache = [
|
||||
torch.zeros((1, )),
|
||||
]
|
||||
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache
|
||||
|
||||
bind_kv_cache(ctx, [kv_cache])
|
||||
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
|
||||
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
|
||||
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]
|
||||
|
||||
|
||||
def test_bind_kv_cache_pp():
|
||||
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
|
||||
with set_current_vllm_config(cfg):
|
||||
from vllm.attention import Attention
|
||||
|
||||
ctx = {
|
||||
'layers.0.self_attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = [
|
||||
[torch.zeros((1, ))],
|
||||
[torch.zeros((1, ))]
|
||||
]
|
||||
bind_kv_cache(ctx, kv_cache)
|
||||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0]
|
||||
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]
|
||||
|
||||
|
||||
def test_placeholder_module_error_handling():
|
||||
placeholder = PlaceholderModule("placeholder_1234")
|
||||
|
||||
|
@ -4,6 +4,7 @@ import uuid
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.utils import fork_new_process_for_each_test
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
@ -36,6 +37,7 @@ def make_request() -> EngineCoreRequest:
|
||||
)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_engine_core(monkeypatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
@ -138,6 +140,7 @@ def test_engine_core(monkeypatch):
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_engine_core_advanced_sampling(monkeypatch):
|
||||
"""
|
||||
A basic end-to-end test to verify that the engine functions correctly
|
||||
|
@ -6,6 +6,7 @@ from typing import Dict, List
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.utils import fork_new_process_for_each_test
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
@ -75,6 +76,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
|
||||
break
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
|
||||
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
||||
|
||||
@ -143,6 +145,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
||||
client.abort_requests([request.request_id])
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_core_client_asyncio(monkeypatch):
|
||||
|
||||
|
@ -121,6 +121,13 @@ class Attention(nn.Module):
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
self.attn_type = attn_type
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -148,11 +155,11 @@ class Attention(nn.Module):
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, kv_cache, self.layer_name)
|
||||
query, key, value, output, self.layer_name)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(query, key, value,
|
||||
kv_cache, self.layer_name)
|
||||
self.layer_name)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||
@ -230,12 +237,12 @@ def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.dynamic_forward_context
|
||||
self = forward_context.static_forward_context[layer_name]
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self = forward_context.attn_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
||||
self._k_scale, self._v_scale)
|
||||
|
||||
@ -244,7 +251,6 @@ def unified_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
@ -253,7 +259,7 @@ def unified_attention_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention",
|
||||
op_func=unified_attention,
|
||||
mutates_args=["kv_cache"],
|
||||
mutates_args=[],
|
||||
fake_impl=unified_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
@ -264,12 +270,12 @@ def unified_attention_with_output(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.dynamic_forward_context
|
||||
self = forward_context.static_forward_context[layer_name]
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self = forward_context.attn_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(query,
|
||||
key,
|
||||
value,
|
||||
@ -285,7 +291,6 @@ def unified_attention_with_output_fake(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
@ -294,7 +299,7 @@ def unified_attention_with_output_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention_with_output",
|
||||
op_func=unified_attention_with_output,
|
||||
mutates_args=["kv_cache", "output"],
|
||||
mutates_args=["output"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
@ -2780,7 +2780,6 @@ class CompilationConfig(BaseModel):
|
||||
compilation_time: float = PrivateAttr
|
||||
|
||||
# Per-model forward context
|
||||
# Mainly used to store attention cls
|
||||
# Map from layer name to the attention cls
|
||||
static_forward_context: Dict[str, Any] = PrivateAttr
|
||||
|
||||
|
@ -2,7 +2,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -10,6 +10,9 @@ import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
||||
@ -21,9 +24,12 @@ batchsize_forward_time: defaultdict = defaultdict(list)
|
||||
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
static_forward_context: Dict[str, Any]
|
||||
# copy from vllm_config.compilation_config.static_forward_context
|
||||
attn_layers: Dict[str, Any]
|
||||
# TODO: extend to support per-layer dynamic forward context
|
||||
dynamic_forward_context: Any
|
||||
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
||||
# TODO: remove after making all virtual_engines share the same kv cache
|
||||
virtual_engine: int # set dynamically for each forward pass
|
||||
|
||||
|
||||
_forward_context: Optional[ForwardContext] = None
|
||||
@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(context: Any, vllm_config: VllmConfig):
|
||||
def set_forward_context(attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
"""
|
||||
global forward_start_time
|
||||
need_to_track_batchsize = track_batchsize and context is not None
|
||||
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
||||
if need_to_track_batchsize:
|
||||
forward_start_time = time.perf_counter()
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
_forward_context = ForwardContext(
|
||||
static_forward_context=vllm_config.compilation_config.
|
||||
static_forward_context,
|
||||
dynamic_forward_context=context)
|
||||
attn_layers=vllm_config.compilation_config.static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
global batchsize_counter
|
||||
global last_logging_time, batchsize_logging_interval
|
||||
if need_to_track_batchsize:
|
||||
if hasattr(context, "num_prefill_tokens"):
|
||||
if hasattr(attn_metadata, "num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = context.num_prefill_tokens + \
|
||||
context.num_decode_tokens
|
||||
batchsize = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends
|
||||
batchsize = context.num_input_tokens
|
||||
batchsize = attn_metadata.num_input_tokens
|
||||
# we use synchronous scheduling right now,
|
||||
# adding a sync point here should not affect
|
||||
# scheduling of the next batch
|
||||
|
@ -2138,3 +2138,38 @@ def get_mp_context():
|
||||
_check_multiproc_method()
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
return multiprocessing.get_context(mp_method)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
ctx: Dict[str, Any],
|
||||
kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index]
|
||||
) -> None:
|
||||
# Bind the kv_cache tensor to Attention modules, similar to
|
||||
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
|
||||
# Special things handled here:
|
||||
# 1. Some models have non-attention layers, e.g., Jamba
|
||||
# 2. Pipeline parallelism, each rank only has a subset of layers
|
||||
# 3. Encoder attention has no kv cache
|
||||
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
|
||||
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
|
||||
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
|
||||
# tensor
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
layer_need_kv_cache = [
|
||||
layer_name for layer_name in ctx
|
||||
if ctx[layer_name].attn_type in (AttentionType.DECODER,
|
||||
AttentionType.ENCODER_DECODER)
|
||||
]
|
||||
layer_index_sorted = sorted(
|
||||
set(
|
||||
extract_layer_index(layer_name)
|
||||
for layer_name in layer_need_kv_cache))
|
||||
for layer_name in layer_need_kv_cache:
|
||||
kv_cache_idx = layer_index_sorted.index(
|
||||
extract_layer_index(layer_name))
|
||||
forward_ctx = ctx[layer_name]
|
||||
assert len(forward_ctx.kv_cache) == len(kv_cache)
|
||||
for ve, ve_kv_cache in enumerate(kv_cache):
|
||||
assert forward_ctx.kv_cache[ve].numel() == 0
|
||||
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
|
||||
|
@ -16,7 +16,8 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, cdiv, is_pin_memory_available)
|
||||
LayerBlockType, bind_kv_cache, cdiv,
|
||||
is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||
@ -860,3 +861,6 @@ class GPUModelRunner:
|
||||
torch.zeros(kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device))
|
||||
bind_kv_cache(
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
[self.kv_caches])
|
||||
|
@ -305,7 +305,8 @@ class CPUEncoderDecoderModelRunner(
|
||||
intermediate_tensors,
|
||||
}
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Compute the logits.
|
||||
|
@ -526,7 +526,8 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
execute_model_kwargs.update(
|
||||
{"previous_hidden_states": previous_hidden_states})
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
|
@ -69,7 +69,8 @@ class CPUPoolingModelRunner(
|
||||
intermediate_tensors,
|
||||
}
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Only perform pooling in the driver worker.
|
||||
|
@ -13,7 +13,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache
|
||||
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
||||
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
|
||||
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
|
||||
@ -293,6 +293,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
self.cache_engine[ve].cpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
self.cpu_cache)
|
||||
self.model_runner.block_size = self.cache_engine[0].block_size
|
||||
|
||||
assert all(
|
||||
|
@ -175,7 +175,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
} if self.has_inner_state else {}
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
|
@ -1527,7 +1527,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self._update_inputs_to_capture_for_enc_dec_model(
|
||||
capture_inputs)
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
with set_forward_context(attn_metadata, self.vllm_config,
|
||||
virtual_engine):
|
||||
graph_runner.capture(**capture_inputs)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[virtual_engine][batch_size] = (
|
||||
@ -1695,7 +1696,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
|
||||
if not bypass_model_exec:
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config):
|
||||
self.vllm_config, virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
|
@ -105,7 +105,8 @@ class PoolingModelRunner(
|
||||
if model_input.token_types is not None:
|
||||
cross_enc_kwargs["token_type_ids"] = model_input.token_types
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
|
@ -21,7 +21,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||||
from vllm.utils import GiB_bytes, memory_profiling
|
||||
from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
@ -285,6 +285,8 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
self.cache_engine[ve].gpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
self.gpu_cache)
|
||||
|
||||
def _warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
|
@ -43,6 +43,7 @@ class WorkerBase(ABC):
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
from vllm.platforms import current_platform
|
||||
self.current_platform = current_platform
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user