[torch.compile] Hide KV cache behind torch.compile boundary (#11677)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
3de2b1eafb
commit
cf5f000d21
@ -142,12 +142,18 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
|
|||||||
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
|
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
|
||||||
|
|
||||||
# Construct KV cache
|
# Construct KV cache
|
||||||
kv_cache = make_kv_cache(test_pt.num_blocks,
|
if test_pt.attn_type in (AttentionType.DECODER,
|
||||||
test_pt.num_heads,
|
AttentionType.ENCODER_DECODER):
|
||||||
test_pt.head_size,
|
kv_cache = make_kv_cache(test_pt.num_blocks,
|
||||||
test_pt.block_size,
|
test_pt.num_heads,
|
||||||
device=CUDA_DEVICE,
|
test_pt.head_size,
|
||||||
backend=test_pt.backend_name)
|
test_pt.block_size,
|
||||||
|
device=CUDA_DEVICE,
|
||||||
|
backend=test_pt.backend_name)
|
||||||
|
else:
|
||||||
|
kv_cache = torch.tensor([])
|
||||||
|
|
||||||
|
attn.kv_cache = [kv_cache]
|
||||||
return TestResources(scale, attn, kv_cache)
|
return TestResources(scale, attn, kv_cache)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,9 +7,11 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from vllm_test_utils import monitor
|
from vllm_test_utils import monitor
|
||||||
|
|
||||||
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
|
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
|
||||||
StoreBoolean, deprecate_kwargs, get_open_port,
|
StoreBoolean, bind_kv_cache, deprecate_kwargs,
|
||||||
memory_profiling, merge_async_iterators, supports_kw)
|
get_open_port, memory_profiling, merge_async_iterators,
|
||||||
|
supports_kw)
|
||||||
|
|
||||||
from .utils import error_on_warning, fork_new_process_for_each_test
|
from .utils import error_on_warning, fork_new_process_for_each_test
|
||||||
|
|
||||||
@ -325,6 +327,85 @@ def test_memory_profiling():
|
|||||||
lib.cudaFree(handle2)
|
lib.cudaFree(handle2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bind_kv_cache():
|
||||||
|
from vllm.attention import Attention
|
||||||
|
|
||||||
|
ctx = {
|
||||||
|
'layers.0.self_attn': Attention(32, 128, 0.1),
|
||||||
|
'layers.1.self_attn': Attention(32, 128, 0.1),
|
||||||
|
'layers.2.self_attn': Attention(32, 128, 0.1),
|
||||||
|
'layers.3.self_attn': Attention(32, 128, 0.1),
|
||||||
|
}
|
||||||
|
kv_cache = [
|
||||||
|
torch.zeros((1, )),
|
||||||
|
torch.zeros((1, )),
|
||||||
|
torch.zeros((1, )),
|
||||||
|
torch.zeros((1, )),
|
||||||
|
]
|
||||||
|
bind_kv_cache(ctx, [kv_cache])
|
||||||
|
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
|
||||||
|
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
|
||||||
|
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
|
||||||
|
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
|
||||||
|
|
||||||
|
def test_bind_kv_cache_non_attention():
|
||||||
|
from vllm.attention import Attention
|
||||||
|
|
||||||
|
# example from Jamba PP=2
|
||||||
|
ctx = {
|
||||||
|
'model.layers.20.attn': Attention(32, 128, 0.1),
|
||||||
|
'model.layers.28.attn': Attention(32, 128, 0.1),
|
||||||
|
}
|
||||||
|
kv_cache = [
|
||||||
|
torch.zeros((1, )),
|
||||||
|
torch.zeros((1, )),
|
||||||
|
]
|
||||||
|
bind_kv_cache(ctx, [kv_cache])
|
||||||
|
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0]
|
||||||
|
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_bind_kv_cache_encoder_decoder():
|
||||||
|
from vllm.attention import Attention, AttentionType
|
||||||
|
|
||||||
|
# example from bart
|
||||||
|
ctx = {
|
||||||
|
'encoder.layers.0.self_attn.attn':
|
||||||
|
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER),
|
||||||
|
'decoder.layers.0.encoder_attn.attn':
|
||||||
|
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER),
|
||||||
|
'decoder.layers.0.self_attn.attn':
|
||||||
|
Attention(32, 128, 0.1, attn_type=AttentionType.DECODER),
|
||||||
|
}
|
||||||
|
|
||||||
|
kv_cache = [
|
||||||
|
torch.zeros((1, )),
|
||||||
|
]
|
||||||
|
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache
|
||||||
|
|
||||||
|
bind_kv_cache(ctx, [kv_cache])
|
||||||
|
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
|
||||||
|
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
|
||||||
|
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_bind_kv_cache_pp():
|
||||||
|
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
|
||||||
|
with set_current_vllm_config(cfg):
|
||||||
|
from vllm.attention import Attention
|
||||||
|
|
||||||
|
ctx = {
|
||||||
|
'layers.0.self_attn': Attention(32, 128, 0.1),
|
||||||
|
}
|
||||||
|
kv_cache = [
|
||||||
|
[torch.zeros((1, ))],
|
||||||
|
[torch.zeros((1, ))]
|
||||||
|
]
|
||||||
|
bind_kv_cache(ctx, kv_cache)
|
||||||
|
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0]
|
||||||
|
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]
|
||||||
|
|
||||||
|
|
||||||
def test_placeholder_module_error_handling():
|
def test_placeholder_module_error_handling():
|
||||||
placeholder = PlaceholderModule("placeholder_1234")
|
placeholder = PlaceholderModule("placeholder_1234")
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import uuid
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.utils import fork_new_process_for_each_test
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -36,6 +37,7 @@ def make_request() -> EngineCoreRequest:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@fork_new_process_for_each_test
|
||||||
def test_engine_core(monkeypatch):
|
def test_engine_core(monkeypatch):
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
@ -138,6 +140,7 @@ def test_engine_core(monkeypatch):
|
|||||||
assert len(engine_core.scheduler.running) == 0
|
assert len(engine_core.scheduler.running) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@fork_new_process_for_each_test
|
||||||
def test_engine_core_advanced_sampling(monkeypatch):
|
def test_engine_core_advanced_sampling(monkeypatch):
|
||||||
"""
|
"""
|
||||||
A basic end-to-end test to verify that the engine functions correctly
|
A basic end-to-end test to verify that the engine functions correctly
|
||||||
|
@ -6,6 +6,7 @@ from typing import Dict, List
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.utils import fork_new_process_for_each_test
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -75,6 +76,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
@fork_new_process_for_each_test
|
||||||
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
|
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
|
||||||
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
||||||
|
|
||||||
@ -143,6 +145,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
|||||||
client.abort_requests([request.request_id])
|
client.abort_requests([request.request_id])
|
||||||
|
|
||||||
|
|
||||||
|
@fork_new_process_for_each_test
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_engine_core_client_asyncio(monkeypatch):
|
async def test_engine_core_client_asyncio(monkeypatch):
|
||||||
|
|
||||||
|
@ -121,6 +121,13 @@ class Attention(nn.Module):
|
|||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
self.layer_name = prefix
|
self.layer_name = prefix
|
||||||
self.attn_type = attn_type
|
self.attn_type = attn_type
|
||||||
|
# use a placeholder kv cache tensor during init, which will be replaced
|
||||||
|
# by bind_kv_cache
|
||||||
|
# this variable will not be accessed if use_direct_call is True
|
||||||
|
self.kv_cache = [
|
||||||
|
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||||
|
).parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -148,11 +155,11 @@ class Attention(nn.Module):
|
|||||||
if value is not None:
|
if value is not None:
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
torch.ops.vllm.unified_attention_with_output(
|
torch.ops.vllm.unified_attention_with_output(
|
||||||
query, key, value, output, kv_cache, self.layer_name)
|
query, key, value, output, self.layer_name)
|
||||||
return output.view(-1, hidden_size)
|
return output.view(-1, hidden_size)
|
||||||
else:
|
else:
|
||||||
return torch.ops.vllm.unified_attention(query, key, value,
|
return torch.ops.vllm.unified_attention(query, key, value,
|
||||||
kv_cache, self.layer_name)
|
self.layer_name)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||||
@ -230,12 +237,12 @@ def unified_attention(
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.dynamic_forward_context
|
attn_metadata = forward_context.attn_metadata
|
||||||
self = forward_context.static_forward_context[layer_name]
|
self = forward_context.attn_layers[layer_name]
|
||||||
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
||||||
self._k_scale, self._v_scale)
|
self._k_scale, self._v_scale)
|
||||||
|
|
||||||
@ -244,7 +251,6 @@ def unified_attention_fake(
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty_like(query).contiguous()
|
return torch.empty_like(query).contiguous()
|
||||||
@ -253,7 +259,7 @@ def unified_attention_fake(
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="unified_attention",
|
op_name="unified_attention",
|
||||||
op_func=unified_attention,
|
op_func=unified_attention,
|
||||||
mutates_args=["kv_cache"],
|
mutates_args=[],
|
||||||
fake_impl=unified_attention_fake,
|
fake_impl=unified_attention_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
@ -264,12 +270,12 @@ def unified_attention_with_output(
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.dynamic_forward_context
|
attn_metadata = forward_context.attn_metadata
|
||||||
self = forward_context.static_forward_context[layer_name]
|
self = forward_context.attn_layers[layer_name]
|
||||||
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
self.impl.forward(query,
|
self.impl.forward(query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
@ -285,7 +291,6 @@ def unified_attention_with_output_fake(
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
return
|
return
|
||||||
@ -294,7 +299,7 @@ def unified_attention_with_output_fake(
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="unified_attention_with_output",
|
op_name="unified_attention_with_output",
|
||||||
op_func=unified_attention_with_output,
|
op_func=unified_attention_with_output,
|
||||||
mutates_args=["kv_cache", "output"],
|
mutates_args=["output"],
|
||||||
fake_impl=unified_attention_with_output_fake,
|
fake_impl=unified_attention_with_output_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
@ -2780,7 +2780,6 @@ class CompilationConfig(BaseModel):
|
|||||||
compilation_time: float = PrivateAttr
|
compilation_time: float = PrivateAttr
|
||||||
|
|
||||||
# Per-model forward context
|
# Per-model forward context
|
||||||
# Mainly used to store attention cls
|
|
||||||
# Map from layer name to the attention cls
|
# Map from layer name to the attention cls
|
||||||
static_forward_context: Dict[str, Any] = PrivateAttr
|
static_forward_context: Dict[str, Any] = PrivateAttr
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import time
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -10,6 +10,9 @@ import vllm.envs as envs
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
||||||
@ -21,9 +24,12 @@ batchsize_forward_time: defaultdict = defaultdict(list)
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardContext:
|
class ForwardContext:
|
||||||
static_forward_context: Dict[str, Any]
|
# copy from vllm_config.compilation_config.static_forward_context
|
||||||
|
attn_layers: Dict[str, Any]
|
||||||
# TODO: extend to support per-layer dynamic forward context
|
# TODO: extend to support per-layer dynamic forward context
|
||||||
dynamic_forward_context: Any
|
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
||||||
|
# TODO: remove after making all virtual_engines share the same kv cache
|
||||||
|
virtual_engine: int # set dynamically for each forward pass
|
||||||
|
|
||||||
|
|
||||||
_forward_context: Optional[ForwardContext] = None
|
_forward_context: Optional[ForwardContext] = None
|
||||||
@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext:
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_forward_context(context: Any, vllm_config: VllmConfig):
|
def set_forward_context(attn_metadata: Any,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
virtual_engine: int = 0):
|
||||||
"""A context manager that stores the current forward context,
|
"""A context manager that stores the current forward context,
|
||||||
can be attention metadata, etc.
|
can be attention metadata, etc.
|
||||||
Here we can inject common logic for every model forward pass.
|
Here we can inject common logic for every model forward pass.
|
||||||
"""
|
"""
|
||||||
global forward_start_time
|
global forward_start_time
|
||||||
need_to_track_batchsize = track_batchsize and context is not None
|
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
||||||
if need_to_track_batchsize:
|
if need_to_track_batchsize:
|
||||||
forward_start_time = time.perf_counter()
|
forward_start_time = time.perf_counter()
|
||||||
global _forward_context
|
global _forward_context
|
||||||
prev_context = _forward_context
|
prev_context = _forward_context
|
||||||
_forward_context = ForwardContext(
|
_forward_context = ForwardContext(
|
||||||
static_forward_context=vllm_config.compilation_config.
|
attn_layers=vllm_config.compilation_config.static_forward_context,
|
||||||
static_forward_context,
|
virtual_engine=virtual_engine,
|
||||||
dynamic_forward_context=context)
|
attn_metadata=attn_metadata)
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
global batchsize_counter
|
|
||||||
global last_logging_time, batchsize_logging_interval
|
global last_logging_time, batchsize_logging_interval
|
||||||
if need_to_track_batchsize:
|
if need_to_track_batchsize:
|
||||||
if hasattr(context, "num_prefill_tokens"):
|
if hasattr(attn_metadata, "num_prefill_tokens"):
|
||||||
# for v0 attention backends
|
# for v0 attention backends
|
||||||
batchsize = context.num_prefill_tokens + \
|
batchsize = attn_metadata.num_prefill_tokens + \
|
||||||
context.num_decode_tokens
|
attn_metadata.num_decode_tokens
|
||||||
else:
|
else:
|
||||||
# for v1 attention backends
|
# for v1 attention backends
|
||||||
batchsize = context.num_input_tokens
|
batchsize = attn_metadata.num_input_tokens
|
||||||
# we use synchronous scheduling right now,
|
# we use synchronous scheduling right now,
|
||||||
# adding a sync point here should not affect
|
# adding a sync point here should not affect
|
||||||
# scheduling of the next batch
|
# scheduling of the next batch
|
||||||
|
@ -2138,3 +2138,38 @@ def get_mp_context():
|
|||||||
_check_multiproc_method()
|
_check_multiproc_method()
|
||||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||||
return multiprocessing.get_context(mp_method)
|
return multiprocessing.get_context(mp_method)
|
||||||
|
|
||||||
|
|
||||||
|
def bind_kv_cache(
|
||||||
|
ctx: Dict[str, Any],
|
||||||
|
kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index]
|
||||||
|
) -> None:
|
||||||
|
# Bind the kv_cache tensor to Attention modules, similar to
|
||||||
|
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
|
||||||
|
# Special things handled here:
|
||||||
|
# 1. Some models have non-attention layers, e.g., Jamba
|
||||||
|
# 2. Pipeline parallelism, each rank only has a subset of layers
|
||||||
|
# 3. Encoder attention has no kv cache
|
||||||
|
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
|
||||||
|
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
|
||||||
|
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
|
||||||
|
# tensor
|
||||||
|
from vllm.attention import AttentionType
|
||||||
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
layer_need_kv_cache = [
|
||||||
|
layer_name for layer_name in ctx
|
||||||
|
if ctx[layer_name].attn_type in (AttentionType.DECODER,
|
||||||
|
AttentionType.ENCODER_DECODER)
|
||||||
|
]
|
||||||
|
layer_index_sorted = sorted(
|
||||||
|
set(
|
||||||
|
extract_layer_index(layer_name)
|
||||||
|
for layer_name in layer_need_kv_cache))
|
||||||
|
for layer_name in layer_need_kv_cache:
|
||||||
|
kv_cache_idx = layer_index_sorted.index(
|
||||||
|
extract_layer_index(layer_name))
|
||||||
|
forward_ctx = ctx[layer_name]
|
||||||
|
assert len(forward_ctx.kv_cache) == len(kv_cache)
|
||||||
|
for ve, ve_kv_cache in enumerate(kv_cache):
|
||||||
|
assert forward_ctx.kv_cache[ve].numel() == 0
|
||||||
|
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
|
||||||
|
@ -16,7 +16,8 @@ from vllm.model_executor.model_loader import get_model
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
LayerBlockType, cdiv, is_pin_memory_available)
|
LayerBlockType, bind_kv_cache, cdiv,
|
||||||
|
is_pin_memory_available)
|
||||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||||
FlashAttentionMetadata)
|
FlashAttentionMetadata)
|
||||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||||
@ -860,3 +861,6 @@ class GPUModelRunner:
|
|||||||
torch.zeros(kv_cache_shape,
|
torch.zeros(kv_cache_shape,
|
||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
device=self.device))
|
device=self.device))
|
||||||
|
bind_kv_cache(
|
||||||
|
self.vllm_config.compilation_config.static_forward_context,
|
||||||
|
[self.kv_caches])
|
||||||
|
@ -305,7 +305,8 @@ class CPUEncoderDecoderModelRunner(
|
|||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
}
|
}
|
||||||
|
|
||||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||||
|
model_input.virtual_engine):
|
||||||
hidden_states = model_executable(**execute_model_kwargs)
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
# Compute the logits.
|
# Compute the logits.
|
||||||
|
@ -526,7 +526,8 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
|||||||
execute_model_kwargs.update(
|
execute_model_kwargs.update(
|
||||||
{"previous_hidden_states": previous_hidden_states})
|
{"previous_hidden_states": previous_hidden_states})
|
||||||
|
|
||||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||||
|
model_input.virtual_engine):
|
||||||
hidden_states = model_executable(
|
hidden_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
|
@ -69,7 +69,8 @@ class CPUPoolingModelRunner(
|
|||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
}
|
}
|
||||||
|
|
||||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||||
|
model_input.virtual_engine):
|
||||||
hidden_states = model_executable(**execute_model_kwargs)
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
# Only perform pooling in the driver worker.
|
# Only perform pooling in the driver worker.
|
||||||
|
@ -13,7 +13,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache
|
||||||
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
||||||
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
|
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
|
||||||
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
|
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
|
||||||
@ -293,6 +293,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
self.cache_engine[ve].cpu_cache
|
self.cache_engine[ve].cpu_cache
|
||||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
|
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||||
|
self.cpu_cache)
|
||||||
self.model_runner.block_size = self.cache_engine[0].block_size
|
self.model_runner.block_size = self.cache_engine[0].block_size
|
||||||
|
|
||||||
assert all(
|
assert all(
|
||||||
|
@ -175,7 +175,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
} if self.has_inner_state else {}
|
} if self.has_inner_state else {}
|
||||||
|
|
||||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||||
|
model_input.virtual_engine):
|
||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
|
@ -1527,7 +1527,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
self._update_inputs_to_capture_for_enc_dec_model(
|
self._update_inputs_to_capture_for_enc_dec_model(
|
||||||
capture_inputs)
|
capture_inputs)
|
||||||
|
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config,
|
||||||
|
virtual_engine):
|
||||||
graph_runner.capture(**capture_inputs)
|
graph_runner.capture(**capture_inputs)
|
||||||
self.graph_memory_pool = graph_runner.graph.pool()
|
self.graph_memory_pool = graph_runner.graph.pool()
|
||||||
self.graph_runners[virtual_engine][batch_size] = (
|
self.graph_runners[virtual_engine][batch_size] = (
|
||||||
@ -1695,7 +1696,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
|
|
||||||
if not bypass_model_exec:
|
if not bypass_model_exec:
|
||||||
with set_forward_context(model_input.attn_metadata,
|
with set_forward_context(model_input.attn_metadata,
|
||||||
self.vllm_config):
|
self.vllm_config, virtual_engine):
|
||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
|
@ -105,7 +105,8 @@ class PoolingModelRunner(
|
|||||||
if model_input.token_types is not None:
|
if model_input.token_types is not None:
|
||||||
cross_enc_kwargs["token_type_ids"] = model_input.token_types
|
cross_enc_kwargs["token_type_ids"] = model_input.token_types
|
||||||
|
|
||||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||||
|
virtual_engine):
|
||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
|
@ -21,7 +21,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||||
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||||||
from vllm.utils import GiB_bytes, memory_profiling
|
from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||||
@ -285,6 +285,8 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
self.cache_engine[ve].gpu_cache
|
self.cache_engine[ve].gpu_cache
|
||||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
|
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||||
|
self.gpu_cache)
|
||||||
|
|
||||||
def _warm_up_model(self) -> None:
|
def _warm_up_model(self) -> None:
|
||||||
if not self.model_config.enforce_eager:
|
if not self.model_config.enforce_eager:
|
||||||
|
@ -43,6 +43,7 @@ class WorkerBase(ABC):
|
|||||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||||
self.observability_config = vllm_config.observability_config
|
self.observability_config = vllm_config.observability_config
|
||||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
|
self.compilation_config = vllm_config.compilation_config
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
self.current_platform = current_platform
|
self.current_platform = current_platform
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user