[Attention] MLA support for V1 (#13789)
Signed-off-by: Yang Chen <yangche@fb.com>
This commit is contained in:
parent
f1579b229d
commit
58d1b2aa77
@ -89,6 +89,7 @@ class Attention(nn.Module):
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
self.use_mla = use_mla
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
@ -158,6 +159,10 @@ class Attention(nn.Module):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
# For some alternate attention backends like MLA the attention output
|
||||
# shape does not match the query shape, so we optionally let the model
|
||||
# definition specify the output tensor shape.
|
||||
output_shape: Optional[torch.Size] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
@ -173,17 +178,25 @@ class Attention(nn.Module):
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(key, value)
|
||||
if self.use_output:
|
||||
output = torch.empty_like(query)
|
||||
hidden_size = query.size(-1)
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.empty(output_shape,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# We skip reshaping query, key and value tensors for the MLA
|
||||
# backend since these tensors have different semantics and are
|
||||
# processed differently.
|
||||
if not self.use_mla:
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
@ -420,9 +420,15 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
# In the MLA backend, kv_cache includes both k_c and
|
||||
# pe (i.e. decoupled position embeddings). In particular,
|
||||
# the concat_and_cache_mla op requires
|
||||
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
|
||||
# i.e.
|
||||
# kv_lora_rank + qk_rope_head_dim == head_size
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
@ -458,7 +464,10 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe)
|
||||
return self.mla_attn(hidden_states_or_q_c,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=hidden_states.shape)
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
@ -162,8 +162,13 @@ class CudaPlatformBase(Platform):
|
||||
kv_cache_dtype, block_size, use_v1,
|
||||
use_mla) -> str:
|
||||
if use_v1:
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
if use_mla:
|
||||
logger.info("Using Triton MLA backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.triton_mla.TritonMLABackend"
|
||||
else:
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.flash_attn."
|
||||
"FlashAttentionBackend")
|
||||
if use_mla:
|
||||
if selected_backend == _Backend.FLASHMLA:
|
||||
from vllm.attention.backends.flashmla import (
|
||||
|
@ -35,6 +35,7 @@ class _Backend(enum.Enum):
|
||||
OPENVINO = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
TRITON_MLA = enum.auto()
|
||||
TRITON_MLA_VLLM_V1 = enum.auto()
|
||||
FLASHMLA = enum.auto()
|
||||
HPU_ATTN = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -14,6 +14,11 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler_output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
@ -40,6 +45,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return FlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
|
||||
return FlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
@ -85,6 +94,62 @@ class FlashAttentionMetadata:
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
|
||||
class FlashAttentionMetadataBuilder:
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner"):
|
||||
self.runner = runner
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput"):
|
||||
pass
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int):
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.runner.device, non_blocking=True)
|
||||
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
|
||||
non_blocking=True)
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
# TODO: Optimize.
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
@ -371,4 +436,4 @@ def cascade_attention(
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
|
||||
suffix_lse)
|
||||
suffix_lse)
|
||||
|
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
1022
vllm/v1/attention/backends/mla/common.py
Normal file
1022
vllm/v1/attention/backends/mla/common.py
Normal file
File diff suppressed because it is too large
Load Diff
110
vllm/v1/attention/backends/triton_mla.py
Normal file
110
vllm/v1/attention/backends/triton_mla.py
Normal file
@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
num_kv_splits = 4 # TODO: heuristic
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
B,
|
||||
self.num_heads,
|
||||
num_kv_splits,
|
||||
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||
# just mirror that
|
||||
self.kv_lora_rank + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
# Add a head dim of 1
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||
|
||||
# Run MQA
|
||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||
attn_metadata.block_table, attn_metadata.seq_lens,
|
||||
attn_logits, num_kv_splits, self.scale, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
@ -80,7 +80,14 @@ class InputBatch:
|
||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, ),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.num_computed_tokens_cpu = \
|
||||
self.num_computed_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Block table.
|
||||
self.block_table = BlockTable(
|
||||
@ -356,6 +363,61 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
old_id_i2 = self._req_ids[i2]
|
||||
self._req_ids[i1], self._req_ids[i2] =\
|
||||
self._req_ids[i2], self._req_ids[i1] # noqa
|
||||
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
||||
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
||||
assert old_id_i1 is not None and old_id_i2 is not None
|
||||
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
||||
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
||||
self.num_tokens[i1], self.num_tokens[i2] =\
|
||||
self.num_tokens[i2], self.num_tokens[i1]
|
||||
self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
||||
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
||||
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
||||
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
||||
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
||||
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
||||
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
||||
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
||||
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
||||
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
||||
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
||||
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
||||
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
||||
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
|
||||
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
||||
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
|
||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
||||
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
||||
|
||||
g1 = self.generators.get(i1)
|
||||
g2 = self.generators.get(i2)
|
||||
if g1 is not None:
|
||||
self.generators[i2] = g1
|
||||
if g2 is not None:
|
||||
self.generators[i1] = g2
|
||||
|
||||
t1 = self.min_tokens.get(i1)
|
||||
t2 = self.min_tokens.get(i2)
|
||||
if t1 is not None:
|
||||
self.min_tokens[i2] = t1
|
||||
if t2 is not None:
|
||||
self.min_tokens[i1] = t2
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
self.logit_bias[i1], self.logit_bias[i2] =\
|
||||
self.logit_bias[i2], self.logit_bias[i1]
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
def condense(self, empty_req_indices: List[int]) -> None:
|
||||
num_reqs = self.num_reqs
|
||||
if num_reqs == 0:
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import gc
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -9,7 +10,7 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
||||
@ -24,8 +25,7 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, cdiv, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
@ -92,6 +92,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
if self.attn_backend is None:
|
||||
error_msg = (
|
||||
f"Error with get_att_backend: {self.head_size=}, "
|
||||
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{self.model_config.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 GPUModelRunner.")
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self))
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
@ -433,6 +454,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# Some attention backends (namely MLA) may want to separate requests
|
||||
# based on if the attention computation will be compute-bound or
|
||||
# memory-bound. This gives them a hook to do that.
|
||||
self.attn_metadata_builder.reorder_batch(self.input_batch,
|
||||
scheduler_output)
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit(num_reqs)
|
||||
@ -515,7 +542,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
max_seq_len = self.seq_lens_np[:num_reqs].max()
|
||||
|
||||
# Copy the tensors to the GPU.
|
||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||
@ -530,49 +556,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.device, non_blocking=True)
|
||||
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
|
||||
non_blocking=True)
|
||||
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
||||
self.device, non_blocking=True).long()
|
||||
|
||||
# Prepare for cascade attention if needed.
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.num_common_prefix_blocks,
|
||||
)
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
# TODO: Optimize.
|
||||
cu_prefix_query_lens = torch.tensor(
|
||||
[0, total_num_scheduled_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=(
|
||||
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
)
|
||||
|
||||
use_spec_decode = len(
|
||||
@ -586,7 +580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# from these partial requests, we do so for simplicity.
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
logits_indices = attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
@ -667,7 +661,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# common_prefix_len should be a multiple of the block size.
|
||||
common_prefix_len = (common_prefix_len // self.block_size *
|
||||
self.block_size)
|
||||
use_cascade = FlashAttentionBackend.use_cascade_attention(
|
||||
use_cascade = self.attn_backend.use_cascade_attention(
|
||||
common_prefix_len=common_prefix_len,
|
||||
query_lens=num_scheduled_tokens,
|
||||
num_query_heads=self.num_query_heads,
|
||||
@ -1379,7 +1373,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert tensor_config.size % layer_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_config.size // layer_spec.page_size_bytes
|
||||
if isinstance(layer_spec, FullAttentionSpec):
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
|
||||
layer_spec.head_size)
|
||||
dtype = layer_spec.dtype
|
||||
|
Loading…
x
Reference in New Issue
Block a user