[BugFix] MLA + V1, illegal memory access and accuracy issues (#14253)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
71eaf8969b
commit
f6bb18fd9a
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import inspect
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -9,7 +10,8 @@ import torch
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
|
||||||
|
InputBatch)
|
||||||
|
|
||||||
VOCAB_SIZE = 1024
|
VOCAB_SIZE = 1024
|
||||||
NUM_OUTPUT_TOKENS = 20
|
NUM_OUTPUT_TOKENS = 20
|
||||||
@ -20,6 +22,34 @@ CUDA_DEVICES = [
|
|||||||
MAX_NUM_PROMPT_TOKENS = 64
|
MAX_NUM_PROMPT_TOKENS = 64
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_objs(obj1, obj2):
|
||||||
|
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
|
||||||
|
attr_names = set([
|
||||||
|
a[0] for a in attrs
|
||||||
|
if not (a[0].startswith('__') and a[0].endswith('__'))
|
||||||
|
])
|
||||||
|
for attr_name in attr_names:
|
||||||
|
a = getattr(obj1, attr_name)
|
||||||
|
b = getattr(obj2, attr_name)
|
||||||
|
|
||||||
|
is_same = False
|
||||||
|
if isinstance(a, torch.Tensor):
|
||||||
|
if (a.numel() == 0 or b.numel() == 0):
|
||||||
|
is_same = (a.numel() == 0 and b.numel() == 0)
|
||||||
|
elif torch.allclose(a, b):
|
||||||
|
is_same = True
|
||||||
|
elif isinstance(a, np.ndarray):
|
||||||
|
if np.allclose(a, b):
|
||||||
|
is_same = True
|
||||||
|
elif isinstance(a, (BlockTable, SamplingMetadata)):
|
||||||
|
_compare_objs(a, b)
|
||||||
|
is_same = True # if we make it here must be same
|
||||||
|
elif a == b:
|
||||||
|
is_same = True
|
||||||
|
assert is_same, f"Attribute {attr_name} is different"\
|
||||||
|
f" in {obj1} and {obj2}: {a} != {b}"
|
||||||
|
|
||||||
|
|
||||||
def _remove_requests(
|
def _remove_requests(
|
||||||
input_batch: InputBatch, batch_size: int,
|
input_batch: InputBatch, batch_size: int,
|
||||||
reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]:
|
reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]:
|
||||||
@ -254,3 +284,61 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
|||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
expected_sampling_metadata.allowed_token_ids_mask,
|
expected_sampling_metadata.allowed_token_ids_mask,
|
||||||
sampling_metadata.allowed_token_ids_mask)
|
sampling_metadata.allowed_token_ids_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.parametrize("batch_size", [32])
|
||||||
|
@pytest.mark.parametrize("swap_list", [((0, 1), )])
|
||||||
|
def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||||
|
swap_list: list):
|
||||||
|
"""
|
||||||
|
Tests the logic for managing sampling metadata in the InputBatch.
|
||||||
|
|
||||||
|
This test involves adding a set of requests to the InputBatch,
|
||||||
|
followed by removing a subset of them. Afterward, the batch is compacted,
|
||||||
|
and the `make_sampling_metadata` method is invoked on the batch. The
|
||||||
|
output of `make_sampling_metadata` is then compared against the expected
|
||||||
|
results to ensure correctness.
|
||||||
|
"""
|
||||||
|
input_batch: InputBatch = InputBatch(
|
||||||
|
max_num_reqs=batch_size,
|
||||||
|
max_model_len=1024,
|
||||||
|
max_num_blocks_per_req=10,
|
||||||
|
device=torch.device(device),
|
||||||
|
pin_memory=is_pin_memory_available(),
|
||||||
|
vocab_size=1024,
|
||||||
|
)
|
||||||
|
ref_input_batch: InputBatch = InputBatch(
|
||||||
|
max_num_reqs=batch_size,
|
||||||
|
max_model_len=1024,
|
||||||
|
max_num_blocks_per_req=10,
|
||||||
|
device=torch.device(device),
|
||||||
|
pin_memory=is_pin_memory_available(),
|
||||||
|
vocab_size=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
reqs: list[CachedRequestState] = []
|
||||||
|
req_id_reqs = {}
|
||||||
|
req_id_output_token_ids = {}
|
||||||
|
# Add requests
|
||||||
|
for req_index in range(batch_size):
|
||||||
|
req: CachedRequestState = _construct_cached_request_state(req_index)
|
||||||
|
input_batch.add_request(req, req_index)
|
||||||
|
reqs.append(req)
|
||||||
|
req_id_reqs[req.req_id] = req
|
||||||
|
req_id_output_token_ids[req.req_id] = req.output_token_ids
|
||||||
|
|
||||||
|
reordered_reqs = reqs.copy()
|
||||||
|
for swap_pair in swap_list:
|
||||||
|
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \
|
||||||
|
reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]]
|
||||||
|
input_batch.swap_states(swap_pair[0], swap_pair[1])
|
||||||
|
|
||||||
|
for req_index in range(batch_size):
|
||||||
|
req = reordered_reqs[req_index]
|
||||||
|
ref_input_batch.add_request(req, req_index)
|
||||||
|
|
||||||
|
input_batch.refresh_sampling_metadata()
|
||||||
|
ref_input_batch.refresh_sampling_metadata()
|
||||||
|
|
||||||
|
_compare_objs(input_batch, ref_input_batch)
|
||||||
|
@ -100,8 +100,8 @@ class FlashAttentionMetadataBuilder:
|
|||||||
self.runner = runner
|
self.runner = runner
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "InputBatch",
|
def reorder_batch(self, input_batch: "InputBatch",
|
||||||
scheduler_output: "SchedulerOutput"):
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
pass
|
return False
|
||||||
|
|
||||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||||
common_prefix_len: int):
|
common_prefix_len: int):
|
||||||
|
@ -275,17 +275,47 @@ class MLACommonBackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MLACommonMetadata:
|
class MLACommonPrefillMetadata:
|
||||||
|
""" Prefill Specific Metadata """
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkedContextMetadata:
|
||||||
|
# New for MLA (compared to FlashAttention)
|
||||||
|
# For handling chunked prefill
|
||||||
|
cu_seq_lens: torch.Tensor
|
||||||
|
starts: torch.Tensor
|
||||||
|
seq_tot: list[int]
|
||||||
|
max_seq_lens: list[int]
|
||||||
|
workspace: torch.Tensor
|
||||||
|
|
||||||
|
# Input positions for rotrary embeddings since for MLA the rotary
|
||||||
|
# position embeddings are applied inside the attention backend
|
||||||
|
input_positions: torch.Tensor
|
||||||
|
block_table: torch.Tensor
|
||||||
|
query_start_loc: torch.Tensor
|
||||||
|
max_query_len: int
|
||||||
|
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLACommonDecodeMetadata:
|
||||||
|
# Input positions for rotrary embeddings since for MLA the rotary
|
||||||
|
# position embeddings are applied inside the attention backend
|
||||||
|
input_positions: torch.Tensor
|
||||||
|
block_table: torch.Tensor
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
D = TypeVar("D", bound=MLACommonDecodeMetadata)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLACommonMetadata(Generic[D]):
|
||||||
"""Metadata for MLACommon.
|
"""Metadata for MLACommon.
|
||||||
|
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
understand this class
|
understand this class
|
||||||
"""
|
"""
|
||||||
# New for MLA (compared to FlashAttention)
|
|
||||||
# Input positions for rotrary embeddings since for MLA the rotary
|
|
||||||
# position embeddings are applied inside the attention backend
|
|
||||||
input_positions: torch.Tensor
|
|
||||||
|
|
||||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||||
# |---------- N-1 iteration --------|
|
# |---------- N-1 iteration --------|
|
||||||
# |---------------- N iteration ---------------------|
|
# |---------------- N iteration ---------------------|
|
||||||
@ -295,30 +325,23 @@ class MLACommonMetadata:
|
|||||||
# |-- query_len ---|
|
# |-- query_len ---|
|
||||||
|
|
||||||
num_actual_tokens: int # Number of tokens excluding padding.
|
num_actual_tokens: int # Number of tokens excluding padding.
|
||||||
max_query_len: int
|
|
||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
max_seq_len: int
|
|
||||||
seq_lens: torch.Tensor
|
|
||||||
block_table: torch.Tensor
|
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
|
# New for MLA (compared to FlashAttention)
|
||||||
|
# For handling prefill decode split
|
||||||
|
num_decodes: int
|
||||||
|
num_decode_tokens: int
|
||||||
|
num_prefills: int
|
||||||
|
|
||||||
# For logging.
|
# For logging.
|
||||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||||
|
|
||||||
# The dimension of the attention heads
|
# The dimension of the attention heads
|
||||||
head_dim: Optional[int] = None
|
head_dim: Optional[int] = None
|
||||||
|
|
||||||
# New for MLA (compared to FlashAttention)
|
decode: Optional[D] = None
|
||||||
# For chunked prefill
|
prefill: Optional[MLACommonPrefillMetadata] = None
|
||||||
num_decodes: Optional[int] = None
|
|
||||||
num_decode_tokens: Optional[int] = None
|
|
||||||
num_prefills: Optional[int] = None
|
|
||||||
has_context: bool = False
|
|
||||||
context_chunk_cu_seq_lens: Optional[torch.Tensor] = None
|
|
||||||
context_chunk_starts: Optional[torch.Tensor] = None
|
|
||||||
context_chunk_seq_tot: Optional[list[int]] = None
|
|
||||||
context_chunk_max_seq_lens: Optional[list[int]] = None
|
|
||||||
chunked_prefill_workspace: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
|
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
|
||||||
@ -329,10 +352,10 @@ class MLACommonMetadata:
|
|||||||
f"received {self.head_dim}.")
|
f"received {self.head_dim}.")
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=MLACommonMetadata)
|
M = TypeVar("M", bound=MLACommonMetadata)
|
||||||
|
|
||||||
|
|
||||||
class MLACommonMetadataBuilder(Generic[T]):
|
class MLACommonMetadataBuilder(Generic[M]):
|
||||||
"""
|
"""
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
understand this class
|
understand this class
|
||||||
@ -340,8 +363,9 @@ class MLACommonMetadataBuilder(Generic[T]):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
runner: "GPUModelRunner",
|
runner: "GPUModelRunner",
|
||||||
cls: Optional[type[T]] = None):
|
metadata_cls: Optional[type[M]] = None):
|
||||||
self.cls = cls if cls is not None else MLACommonMetadata
|
self.metadata_cls = metadata_cls \
|
||||||
|
if metadata_cls is not None else MLACommonMetadata
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
scheduler_config = runner.scheduler_config
|
scheduler_config = runner.scheduler_config
|
||||||
model_config = runner.model_config
|
model_config = runner.model_config
|
||||||
@ -375,7 +399,7 @@ class MLACommonMetadataBuilder(Generic[T]):
|
|||||||
self.page_size = self.runner.block_size
|
self.page_size = self.runner.block_size
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "InputBatch",
|
def reorder_batch(self, input_batch: "InputBatch",
|
||||||
scheduler_output: "SchedulerOutput"):
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
# We now want to reorder the batch so that the "decode" requests are and
|
# We now want to reorder the batch so that the "decode" requests are and
|
||||||
# the front and the "prefill" requests are at the using the least amount
|
# the front and the "prefill" requests are at the using the least amount
|
||||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||||
@ -413,6 +437,7 @@ class MLACommonMetadataBuilder(Generic[T]):
|
|||||||
num_decodes = len(decodes)
|
num_decodes = len(decodes)
|
||||||
num_prefills = len(prefills)
|
num_prefills = len(prefills)
|
||||||
first_prefill = 0
|
first_prefill = 0
|
||||||
|
modified_batch = False
|
||||||
|
|
||||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||||
# If the decode is at the "back" of the batch, i, we can swap it
|
# If the decode is at the "back" of the batch, i, we can swap it
|
||||||
@ -421,6 +446,7 @@ class MLACommonMetadataBuilder(Generic[T]):
|
|||||||
input_batch.swap_states(prefills[first_prefill],
|
input_batch.swap_states(prefills[first_prefill],
|
||||||
decodes[num_decodes - i])
|
decodes[num_decodes - i])
|
||||||
first_prefill += 1
|
first_prefill += 1
|
||||||
|
modified_batch = True
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -432,10 +458,21 @@ class MLACommonMetadataBuilder(Generic[T]):
|
|||||||
self._num_decode_tokens = num_decode_tokens
|
self._num_decode_tokens = num_decode_tokens
|
||||||
self._num_prefill_tokens = num_prefill_tokens
|
self._num_prefill_tokens = num_prefill_tokens
|
||||||
|
|
||||||
|
return modified_batch
|
||||||
|
|
||||||
|
def _build_decode(self, input_positions: torch.Tensor,
|
||||||
|
block_table: torch.Tensor, seq_lens: torch.Tensor):
|
||||||
|
return MLACommonDecodeMetadata(
|
||||||
|
input_positions=input_positions,
|
||||||
|
block_table=block_table,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
)
|
||||||
|
|
||||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||||
common_prefix_len: int) -> T:
|
common_prefix_len: int) -> M:
|
||||||
|
assert self._num_decodes + self._num_prefills == num_reqs
|
||||||
|
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
|
||||||
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
||||||
device, non_blocking=True)
|
device, non_blocking=True)
|
||||||
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device,
|
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device,
|
||||||
@ -447,85 +484,103 @@ class MLACommonMetadataBuilder(Generic[T]):
|
|||||||
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
||||||
device, non_blocking=True).long()
|
device, non_blocking=True).long()
|
||||||
|
|
||||||
context_chunk_cu_seq_lens = None
|
prefill_metadata = None
|
||||||
context_chunk_starts = None
|
if self._num_prefills > 0:
|
||||||
context_chunk_seq_tot = None
|
reqs_start = self._num_decodes # prefill_start
|
||||||
context_chunk_max_seq_lens = None
|
tokens_start = self._num_decode_tokens
|
||||||
|
|
||||||
num_computed_tokens_cpu_tensor = \
|
context_lens_cpu = self.runner.input_batch.\
|
||||||
self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]
|
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
||||||
context_lens_tensor = \
|
context_lens = context_lens_cpu.to(device, non_blocking=True)
|
||||||
num_computed_tokens_cpu_tensor.to(device, non_blocking=True)
|
|
||||||
|
|
||||||
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
chunked_context_metadata = None
|
||||||
and context_lens_tensor[self._num_decodes:].max() > 0:
|
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
||||||
# NOTE: it is recommend you read the `Chunked Prefill` section in
|
and context_lens.max() > 0:
|
||||||
# the comment at the top of the file before trying to understand
|
# NOTE: it is recommend you read the `Chunked Prefill` section
|
||||||
# the following code
|
# in the comment at the top of the file before trying to
|
||||||
|
# understand the following code
|
||||||
|
|
||||||
self.has_context = True
|
num_prefills_with_context = (context_lens > 0).sum().item()
|
||||||
|
|
||||||
num_prefills_with_context = \
|
# currently we allocate an equal amount of workspace for each
|
||||||
(context_lens_tensor[self._num_decodes:] > 0).sum().item()
|
# prefill in the batch, we could probably use a more advanced
|
||||||
|
# algorithm here and allocate more workspace to prefills with
|
||||||
|
# longer context lengths
|
||||||
|
max_context_chunk = \
|
||||||
|
self.chunked_prefill_workspace_size \
|
||||||
|
// num_prefills_with_context
|
||||||
|
|
||||||
# currently we allocate an equal amount of workspace for each
|
# align max_context_chunk to page_size by rounding down,
|
||||||
# prefill in the batch, we could probably use a more advanced
|
# currently the `gather_cache` kernel cannot handle
|
||||||
# algorithm here and allocate more workspace to prefills with
|
# `context_chunk_starts` that are not aligned to page_size
|
||||||
# longer context lengths
|
max_context_chunk = round_down(max_context_chunk,
|
||||||
max_context_chunk = \
|
self.page_size)
|
||||||
self.chunked_prefill_workspace_size // num_prefills_with_context
|
|
||||||
|
|
||||||
# align max_context_chunk to page_size by rounding down,
|
assert max_context_chunk > 0
|
||||||
# currently the `gather_cache` kernel cannot handle
|
num_chunks = cdiv(context_lens.max(), max_context_chunk)
|
||||||
# `context_chunk_starts` that are not aligned to page_size
|
|
||||||
max_context_chunk = round_down(max_context_chunk, self.page_size)
|
|
||||||
assert max_context_chunk > 0
|
|
||||||
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
|
|
||||||
|
|
||||||
# if `max_context_chunk = 256`, `num_chunks = 3`, and
|
# if `max_context_chunk = 256`, `num_chunks = 3`, and
|
||||||
# `num_prefills_with_context = 4`, create a tensor that looks like
|
# `num_prefills_with_context = 4`, create a tensor that looks
|
||||||
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
|
# like
|
||||||
context_chunk_starts = \
|
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
|
||||||
torch.arange(num_chunks, device=device, dtype=torch.int32) \
|
chunk_starts = \
|
||||||
.unsqueeze(1).expand(-1, self._num_prefills) \
|
torch.arange(num_chunks, device=device, dtype=torch.int32) \
|
||||||
* max_context_chunk
|
.unsqueeze(1).expand(-1, self._num_prefills) \
|
||||||
chunk_ends = torch.min(context_lens_tensor[self._num_decodes:] \
|
* max_context_chunk
|
||||||
.unsqueeze(0), context_chunk_starts + max_context_chunk)
|
chunk_ends = torch.min(context_lens.unsqueeze(0),
|
||||||
chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0)
|
chunk_starts + max_context_chunk)
|
||||||
_context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
|
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||||
torch.int32)
|
_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
|
||||||
zero = torch.zeros(num_chunks, dtype=torch.int32, device=device) \
|
torch.int32)
|
||||||
.unsqueeze(-1)
|
zero = torch.zeros(num_chunks,
|
||||||
context_chunk_cu_seq_lens = \
|
dtype=torch.int32,
|
||||||
torch.cat([zero, _context_chunk_cu_seq_lens], dim=1)
|
device=device).unsqueeze(-1)
|
||||||
context_chunk_max_seq_lens = \
|
|
||||||
chunk_seq_lens.max(dim=1).values.tolist()
|
|
||||||
context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist()
|
|
||||||
assert max(context_chunk_seq_tot) <= \
|
|
||||||
self.chunked_prefill_workspace_size
|
|
||||||
|
|
||||||
return self.cls(
|
chunked_context_metadata = \
|
||||||
input_positions=input_positions,
|
MLACommonPrefillMetadata.ChunkedContextMetadata(
|
||||||
|
cu_seq_lens=torch.cat(
|
||||||
|
[zero, _chunk_cu_seq_lens], dim=1),
|
||||||
|
starts=chunk_starts,
|
||||||
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
|
workspace=self.chunked_prefill_workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert max(chunked_context_metadata.max_seq_lens) <= \
|
||||||
|
self.chunked_prefill_workspace_size
|
||||||
|
|
||||||
|
prefill_metadata = MLACommonPrefillMetadata(
|
||||||
|
input_positions=input_positions[tokens_start:],
|
||||||
|
block_table=block_table[reqs_start:, ...],
|
||||||
|
query_start_loc=query_start_loc[reqs_start:] -
|
||||||
|
query_start_loc[reqs_start],
|
||||||
|
max_query_len=seq_lens[reqs_start:].max().item(),
|
||||||
|
chunked_context=chunked_context_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
decode_metadata = None
|
||||||
|
if self._num_decodes > 0:
|
||||||
|
decode_metadata = self._build_decode(
|
||||||
|
input_positions=input_positions[:self._num_decode_tokens],
|
||||||
|
block_table=block_table[:self._num_decodes, ...],
|
||||||
|
seq_lens=seq_lens[:self._num_decodes],
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.metadata_cls(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
max_query_len=max_query_len,
|
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
block_table=block_table,
|
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
head_dim=self.runner.model_config.get_head_size(),
|
head_dim=self.runner.model_config.get_head_size(),
|
||||||
# MLACommonMetadata Chunk prefill specific
|
# MLACommonMetadata Chunk prefill specific
|
||||||
num_decodes=self._num_decodes,
|
num_decodes=self._num_decodes,
|
||||||
num_decode_tokens=self._num_decode_tokens,
|
num_decode_tokens=self._num_decode_tokens,
|
||||||
num_prefills=self._num_prefills,
|
num_prefills=self._num_prefills,
|
||||||
context_chunk_cu_seq_lens=context_chunk_cu_seq_lens,
|
prefill=prefill_metadata,
|
||||||
context_chunk_starts=context_chunk_starts,
|
decode=decode_metadata,
|
||||||
context_chunk_seq_tot=context_chunk_seq_tot,
|
|
||||||
context_chunk_max_seq_lens=context_chunk_max_seq_lens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||||
"""
|
"""
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
understand this class
|
understand this class
|
||||||
@ -798,28 +853,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
):
|
):
|
||||||
assert attn_metadata.num_prefills is not None
|
assert attn_metadata.prefill is not None
|
||||||
assert attn_metadata.context_chunk_seq_tot is not None
|
prefill_metadata = attn_metadata.prefill
|
||||||
assert attn_metadata.context_chunk_cu_seq_lens is not None
|
assert prefill_metadata.chunked_context is not None
|
||||||
assert attn_metadata.context_chunk_starts is not None
|
|
||||||
assert attn_metadata.context_chunk_max_seq_lens is not None
|
|
||||||
|
|
||||||
output = None
|
output = None
|
||||||
iters = len(attn_metadata.context_chunk_seq_tot)
|
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||||
|
workspace = prefill_metadata.chunked_context.workspace
|
||||||
assert attn_metadata.chunked_prefill_workspace is not None
|
|
||||||
workspace = attn_metadata.chunked_prefill_workspace
|
|
||||||
|
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
toks = attn_metadata.context_chunk_seq_tot[i]
|
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||||
|
|
||||||
ops.gather_cache(
|
ops.gather_cache(
|
||||||
src_cache=kv_c_and_k_pe_cache,
|
src_cache=kv_c_and_k_pe_cache,
|
||||||
dst=workspace,
|
dst=workspace,
|
||||||
block_table=attn_metadata.block_table,
|
block_table=prefill_metadata.block_table,
|
||||||
cu_seq_lens=attn_metadata.context_chunk_cu_seq_lens[i],
|
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||||
batch_size=attn_metadata.num_prefills,
|
batch_size=attn_metadata.num_prefills,
|
||||||
seq_starts=attn_metadata.context_chunk_starts[i],
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_c_normed = workspace[:toks]\
|
kv_c_normed = workspace[:toks]\
|
||||||
@ -845,10 +896,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v_padded,
|
v=v_padded,
|
||||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||||
cu_seqlens_k=attn_metadata.context_chunk_cu_seq_lens[i],
|
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||||
max_seqlen_q=attn_metadata.max_query_len,
|
max_seqlen_q=prefill_metadata.max_query_len,
|
||||||
max_seqlen_k=attn_metadata.context_chunk_max_seq_lens[i],
|
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=False, # Context is unmasked
|
causal=False, # Context is unmasked
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
@ -881,7 +932,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
has_context = attn_metadata.has_context
|
assert attn_metadata.prefill is not None
|
||||||
|
|
||||||
|
has_context = attn_metadata.prefill.chunked_context is not None
|
||||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||||
k_nope, v = kv_nope\
|
k_nope, v = kv_nope\
|
||||||
@ -898,10 +951,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v_padded,
|
v=v_padded,
|
||||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||||
cu_seqlens_k=attn_metadata.query_start_loc,
|
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
|
||||||
max_seqlen_q=attn_metadata.max_query_len,
|
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||||
max_seqlen_k=attn_metadata.max_seq_len,
|
max_seqlen_k=attn_metadata.prefill.max_query_len,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
return_softmax_lse=has_context,
|
return_softmax_lse=has_context,
|
||||||
@ -934,7 +987,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
q_nope: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
q_pe: torch.Tensor,
|
q_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: T,
|
attn_metadata: M,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -945,7 +998,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
k_c_normed: torch.Tensor, # key in unified attn
|
k_c_normed: torch.Tensor, # key in unified attn
|
||||||
k_pe: torch.Tensor, # value in unified attn
|
k_pe: torch.Tensor, # value in unified attn
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: T,
|
attn_metadata: M,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
@ -966,7 +1019,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
|
|
||||||
# Restore head dim (for rotary embedding)
|
# Restore head dim (for rotary embedding)
|
||||||
k_pe = k_pe.unsqueeze(1)
|
k_pe = k_pe.unsqueeze(1)
|
||||||
assert hasattr(attn_metadata, "input_positions")
|
|
||||||
|
|
||||||
assert attn_metadata.num_decodes is not None and \
|
assert attn_metadata.num_decodes is not None and \
|
||||||
attn_metadata.num_prefills is not None and \
|
attn_metadata.num_prefills is not None and \
|
||||||
@ -978,28 +1030,27 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
|
|
||||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||||
decode_k_pe = k_pe[:num_decode_tokens]
|
decode_k_pe = k_pe[:num_decode_tokens]
|
||||||
decode_input_positions = \
|
|
||||||
attn_metadata.input_positions[:num_decode_tokens]
|
|
||||||
|
|
||||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
||||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||||
prefill_input_positions = \
|
|
||||||
attn_metadata.input_positions[num_decode_tokens:]
|
|
||||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
|
assert attn_metadata.decode is not None
|
||||||
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||||
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\
|
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\
|
||||||
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
||||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||||
decode_input_positions, decode_q_pe, decode_k_pe)
|
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
|
||||||
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
|
assert attn_metadata.prefill is not None
|
||||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||||
.view(-1, self.num_heads, self.qk_head_dim)
|
.view(-1, self.num_heads, self.qk_head_dim)
|
||||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||||
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
attn_metadata.prefill.input_positions, prefill_q_pe,
|
||||||
|
prefill_k_pe)
|
||||||
|
|
||||||
# write the latent and rope to kv cache
|
# write the latent and rope to kv cache
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
|
@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
|||||||
is_flashmla_supported)
|
is_flashmla_supported)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||||
|
MLACommonDecodeMetadata,
|
||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata,
|
MLACommonMetadata,
|
||||||
MLACommonMetadataBuilder)
|
MLACommonMetadataBuilder)
|
||||||
@ -38,34 +39,41 @@ class FlashMLABackend(MLACommonBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashMLAMetadata(MLACommonMetadata):
|
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||||
decode_tile_scheduler_metadata: Optional[tuple[torch.Tensor,
|
tile_scheduler_metadata: tuple[torch.Tensor, torch.Tensor]
|
||||||
torch.Tensor]] = None
|
num_splits: torch.Tensor
|
||||||
decode_num_splits: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||||
|
|
||||||
def __init__(self, runner):
|
def __init__(self, runner):
|
||||||
super().__init__(runner, cls=FlashMLAMetadata)
|
super().__init__(runner)
|
||||||
|
|
||||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||||
self.runner.parallel_config)
|
self.runner.parallel_config)
|
||||||
|
|
||||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
def _build_decode(self, input_positions: torch.Tensor,
|
||||||
common_prefix_len: int):
|
block_table: torch.Tensor,
|
||||||
m = super().build(num_reqs, num_actual_tokens, max_query_len,
|
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||||
common_prefix_len)
|
tile_scheduler_metadata, num_splits = \
|
||||||
|
get_mla_metadata(
|
||||||
|
seq_lens,
|
||||||
|
self.num_q_heads,
|
||||||
|
1, # MQA for the decode path
|
||||||
|
)
|
||||||
|
|
||||||
if m.num_decode_tokens is not None and m.num_decode_tokens > 0:
|
return FlashMLADecodeMetadata(
|
||||||
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
|
input_positions=input_positions,
|
||||||
get_mla_metadata(
|
block_table=block_table,
|
||||||
m.seq_lens[:m.num_decode_tokens],
|
seq_lens=seq_lens,
|
||||||
self.num_q_heads,
|
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||||
1, # MQA for the decode path
|
num_splits=num_splits,
|
||||||
)
|
)
|
||||||
|
|
||||||
return m
|
|
||||||
|
|
||||||
|
|
||||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||||
@ -115,6 +123,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
attn_metadata: FlashMLAMetadata,
|
attn_metadata: FlashMLAMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
|
assert attn_metadata.decode is not None
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
raise NotImplementedError("FP8 FlashMLA not yet supported")
|
raise NotImplementedError("FP8 FlashMLA not yet supported")
|
||||||
|
|
||||||
@ -124,14 +134,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
o, _ = flash_mla_with_kvcache(
|
o, _ = flash_mla_with_kvcache(
|
||||||
q=q,
|
q=q,
|
||||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||||
block_table=attn_metadata.block_table[:attn_metadata.num_decodes,
|
block_table=attn_metadata.decode.block_table,
|
||||||
...],
|
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||||
cache_seqlens=attn_metadata.seq_lens[:attn_metadata.
|
|
||||||
num_decode_tokens],
|
|
||||||
head_dim_v=self.kv_lora_rank,
|
head_dim_v=self.kv_lora_rank,
|
||||||
tile_scheduler_metadata=attn_metadata.
|
tile_scheduler_metadata=attn_metadata.decode.
|
||||||
decode_tile_scheduler_metadata,
|
tile_scheduler_metadata,
|
||||||
num_splits=attn_metadata.decode_num_splits,
|
num_splits=attn_metadata.decode.num_splits,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
@ -69,6 +69,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
|
assert attn_metadata.decode is not None
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||||
|
|
||||||
@ -104,7 +106,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
|
|
||||||
# Run MQA
|
# Run MQA
|
||||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||||
attn_metadata.block_table, attn_metadata.seq_lens,
|
attn_metadata.decode.block_table,
|
||||||
attn_logits, num_kv_splits, self.scale, PAGE_SIZE)
|
attn_metadata.decode.seq_lens, attn_logits,
|
||||||
|
num_kv_splits, self.scale, PAGE_SIZE)
|
||||||
|
|
||||||
return self._v_up_proj_and_o_proj(o)
|
return self._v_up_proj_and_o_proj(o)
|
||||||
|
@ -383,8 +383,6 @@ class InputBatch:
|
|||||||
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
||||||
self.num_tokens[i1], self.num_tokens[i2] =\
|
self.num_tokens[i1], self.num_tokens[i2] =\
|
||||||
self.num_tokens[i2], self.num_tokens[i1]
|
self.num_tokens[i2], self.num_tokens[i1]
|
||||||
self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
|
||||||
self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
|
||||||
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
||||||
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
||||||
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
||||||
@ -406,24 +404,47 @@ class InputBatch:
|
|||||||
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
||||||
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
||||||
|
|
||||||
|
# NOTE: the following is unsafe
|
||||||
|
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||||
|
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
||||||
|
# instead, we need to temporiarily copy the data for one of the indices
|
||||||
|
# TODO(lucas): optimize this by only copying valid indices
|
||||||
|
tmp = self.token_ids_cpu[i1, ...].copy()
|
||||||
|
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||||
|
self.token_ids_cpu[i2, ...] = tmp
|
||||||
|
|
||||||
g1 = self.generators.get(i1)
|
g1 = self.generators.get(i1)
|
||||||
g2 = self.generators.get(i2)
|
g2 = self.generators.get(i2)
|
||||||
if g1 is not None:
|
if g1 is not None:
|
||||||
self.generators[i2] = g1
|
self.generators[i2] = g1
|
||||||
|
else:
|
||||||
|
self.generators.pop(i2, None)
|
||||||
if g2 is not None:
|
if g2 is not None:
|
||||||
self.generators[i1] = g2
|
self.generators[i1] = g2
|
||||||
|
else:
|
||||||
|
self.generators.pop(i1, None)
|
||||||
|
|
||||||
t1 = self.min_tokens.get(i1)
|
t1 = self.min_tokens.get(i1)
|
||||||
t2 = self.min_tokens.get(i2)
|
t2 = self.min_tokens.get(i2)
|
||||||
if t1 is not None:
|
if t1 is not None:
|
||||||
self.min_tokens[i2] = t1
|
self.min_tokens[i2] = t1
|
||||||
|
else:
|
||||||
|
self.min_tokens.pop(i2, None)
|
||||||
if t2 is not None:
|
if t2 is not None:
|
||||||
self.min_tokens[i1] = t2
|
self.min_tokens[i1] = t2
|
||||||
|
else:
|
||||||
|
self.min_tokens.pop(i1, None)
|
||||||
|
|
||||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||||
self.logit_bias[i1], self.logit_bias[i2] =\
|
self.logit_bias[i1], self.logit_bias[i2] =\
|
||||||
self.logit_bias[i2], self.logit_bias[i1]
|
self.logit_bias[i2], self.logit_bias[i1]
|
||||||
|
|
||||||
|
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||||
|
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
||||||
|
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
||||||
|
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
||||||
|
self.allowed_token_ids_mask_cpu_tensor[i1]
|
||||||
self.block_table.swap_row(i1, i2)
|
self.block_table.swap_row(i1, i2)
|
||||||
|
|
||||||
def condense(self, empty_req_indices: list[int]) -> None:
|
def condense(self, empty_req_indices: list[int]) -> None:
|
||||||
|
@ -456,8 +456,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Some attention backends (namely MLA) may want to separate requests
|
# Some attention backends (namely MLA) may want to separate requests
|
||||||
# based on if the attention computation will be compute-bound or
|
# based on if the attention computation will be compute-bound or
|
||||||
# memory-bound. This gives them a hook to do that.
|
# memory-bound. This gives them a hook to do that.
|
||||||
self.attn_metadata_builder.reorder_batch(self.input_batch,
|
modified_batch = self.attn_metadata_builder.reorder_batch(
|
||||||
scheduler_output)
|
self.input_batch, scheduler_output)
|
||||||
|
if modified_batch:
|
||||||
|
self.input_batch.refresh_sampling_metadata()
|
||||||
|
|
||||||
# OPTIMIZATION: Start copying the block table first.
|
# OPTIMIZATION: Start copying the block table first.
|
||||||
# This way, we can overlap the copy with the following CPU operations.
|
# This way, we can overlap the copy with the following CPU operations.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user