[Experimental] Prefix Caching Support (#1669)

Co-authored-by: DouHappy <2278958187@qq.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
shiyi.c_98 2024-01-17 16:32:10 -08:00 committed by GitHub
parent 14cc317ba4
commit d10f8e1d43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1356 additions and 71 deletions

View File

@ -31,6 +31,10 @@ steps:
- pytest -v -s models --forked
soft_fail: true
- label: Prefix Caching Test
commands:
- pytest -v -s prefix_caching
- label: Samplers Test
command: pytest -v -s samplers --forked

View File

@ -0,0 +1,51 @@
from vllm import LLM, SamplingParams
prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: ")
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0)
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
generating_prompts = [prefix + prompt for prompt in prompts]
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(generating_prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print("-" * 80)
# -1 since the last token can change when concatenating prompts.
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
# Generate with prefix
outputs = llm.generate(generating_prompts, sampling_params,
prefix_pos=[prefix_pos] * len(generating_prompts))
# Print the outputs. You should see the same outputs as before
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -0,0 +1,168 @@
import random
import pytest
import time
import torch
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
context_attention_fwd)
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
NUM_HEADS = [12]
HEAD_SIZES = [128]
DTYPES = [torch.float16]
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
head_size: int,
dtype: torch.dtype,
) -> None:
random.seed(0)
torch.manual_seed(0)
MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
max_block_per_request = 64
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
num_tokens = sum(subquery_lens)
query = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')
query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')
kv = torch.empty(sum(seq_lens),
2,
num_heads,
head_size,
dtype=dtype,
device='cuda')
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)
k_cache = torch.zeros(cache_size,
block_size,
num_heads,
head_size,
dtype=dtype,
device='cuda')
v_cache = torch.zeros(cache_size,
block_size,
num_heads,
head_size,
dtype=dtype,
device='cuda')
k = torch.zeros(sum(subquery_lens),
num_heads,
head_size,
dtype=dtype,
device='cuda')
v = torch.zeros(sum(subquery_lens),
num_heads,
head_size,
dtype=dtype,
device='cuda')
values = torch.arange(0, cache_size, dtype=torch.long, device='cuda')
values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view(
BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda')
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda')
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1],
dtype=torch.long,
device='cuda'),
dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
dtype=torch.long,
device='cuda'),
dim=0)
for i in range(BS):
for j in range(subquery_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
b_ctx_len[i] + j])
cur_ctx = 0
block_id = 0
while cur_ctx < b_ctx_len[i]:
start_loc = b_seq_start_loc[i] + cur_ctx
if cur_ctx + block_size > b_ctx_len[i]:
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
else:
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc])
v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc])
cur_ctx += block_size
block_id += 1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8,
8).permute(0, 2, 3, 1, 4).contiguous()
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_heads,
head_size).permute(0, 2, 3, 1).contiguous()
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
torch.cuda.synchronize()
start_time = time.time()
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))
attn_op = xops.fmha.cutlass.FwOp()
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
subquery_lens, seq_lens)
output_ref = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
start_time = time.time()
output_ref = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.squeeze(0)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)

View File

@ -0,0 +1,41 @@
"""Compare the with and without prefix caching.
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import pytest
from vllm import LLM, SamplingParams
prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: ")
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_tokens", [16])
def test_prefix_caching(
example_prompts,
model: str,
max_tokens: int,
):
llm = LLM(model=model)
# -1 since the last token can change when concatenating prompts.
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
prompts = [prefix + prompt for prompt in example_prompts]
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs_without_prefix = llm.generate(prompts, sampling_params)
outputs_with_prefix = llm.generate(prompts,
sampling_params,
prefix_pos=[prefix_pos] * len(prompts))
for output_without_prefix, output_with_prefix in zip(
outputs_without_prefix, outputs_with_prefix):
assert (output_without_prefix.outputs[0].token_ids ==
output_with_prefix.outputs[0].token_ids)
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1

View File

@ -66,7 +66,8 @@ def test_sampler_all_greedy(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@ -105,7 +106,8 @@ def test_sampler_all_random(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@ -140,7 +142,8 @@ def test_sampler_all_beam(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@ -193,7 +196,8 @@ def test_sampler_mixed(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@ -234,7 +238,8 @@ def test_sampler_logits_processors(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@ -288,7 +293,8 @@ def test_sampler_top_k_top_p(seed: int):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
sample_probs = None

View File

@ -33,11 +33,12 @@ def test_prepare_prompt():
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
input_tokens, input_positions, _, return_prompt_lens = (
input_tokens, input_positions, _, return_prompt_lens, _ = (
model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens=prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len)
assert input_positions.shape == (batch_size, max_seq_len)
torch.testing.assert_close(input_tokens, input_positions)

View File

@ -66,3 +66,7 @@ class PhysicalTokenBlock:
return (f'PhysicalTokenBlock(device={self.device}, '
f'block_number={self.block_number}, '
f'ref_count={self.ref_count})')
# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]

View File

@ -2,13 +2,10 @@
import enum
from typing import Dict, List, Optional, Set, Tuple
from vllm.block import PhysicalTokenBlock
from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]
class BlockAllocator:
"""Manages free physical token blocks for a device.
@ -105,6 +102,10 @@ class BlockSpaceManager:
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks)
if seq_group.prefix is not None and seq_group.prefix.allocated:
num_required_blocks -= seq_group.prefix.get_num_blocks()
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
@ -125,8 +126,21 @@ class BlockSpaceManager:
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks)
block_table: BlockTable = []
for logical_idx in range(len(seq.logical_token_blocks)):
prefix_block_table: BlockTable = []
num_prefix_blocks = 0
prefix = seq_group.prefix
if prefix is not None and prefix.allocated:
# Prefix has already been allocated. Use the existing block table.
num_prompt_blocks -= prefix.get_num_blocks()
for block in prefix.block_table:
block.ref_count += seq_group.num_seqs()
block_table.append(block)
for logical_idx in range(num_prompt_blocks):
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window]
@ -136,6 +150,15 @@ class BlockSpaceManager:
block.ref_count = seq_group.num_seqs()
block_table.append(block)
if prefix is not None and not prefix.allocated:
# Allocate blocks for the prefix, we will compute the prefix's
# KV cache in this run.
num_prefix_blocks = prefix.get_num_blocks()
prefix_block_table = block_table[:num_prefix_blocks]
for block in prefix_block_table:
block.ref_count += 1
prefix.set_block_table(prefix_block_table)
# Assign the block table for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()
@ -210,10 +233,18 @@ class BlockSpaceManager:
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block.
if seq_group.prefix is not None:
# make sure to swap in the prefix first
assert seq_group.prefix.allocated and seq_group.prefix.computed
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
if seq_group.prefix is not None:
for block in seq_group.prefix.block_table:
new_block_table.append(block)
block.ref_count += 1
for cpu_block in block_table:
if cpu_block in mapping:
@ -245,6 +276,12 @@ class BlockSpaceManager:
block_table = self.block_tables[seq.seq_id]
for gpu_block in block_table:
if (seq_group.prefix is not None
and gpu_block in seq_group.prefix.block_table):
# NOTE: We do not swap out the prefix blocks for now.
self.gpu_allocator.free(gpu_block)
continue
if gpu_block in mapping:
cpu_block = mapping[gpu_block]
cpu_block.ref_count += 1

View File

@ -9,6 +9,7 @@ from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
from vllm.prefix import PrefixPool
logger = init_logger(__name__)
@ -76,6 +77,9 @@ class Scheduler:
num_cpu_blocks=self.cache_config.num_cpu_blocks,
sliding_window=self.cache_config.sliding_window)
# Create the prefix pool to cache the prefixes.
self.prefix_pool = PrefixPool(self.cache_config.block_size)
# Sequence groups in the WAITING state.
self.waiting: Deque[SequenceGroup] = deque()
# Sequence groups in the RUNNING state.
@ -316,6 +320,7 @@ class Scheduler:
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
prefix=seq_group.prefix,
)
seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs

View File

@ -371,6 +371,7 @@ class AsyncLLMEngine:
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
prefix_pos: Optional[int] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
@ -383,6 +384,7 @@ class AsyncLLMEngine:
max_log_len]
logger.info(f"Received request {request_id}: "
f"prompt: {shortened_prompt!r}, "
f"prefix_pos: {prefix_pos},"
f"sampling params: {sampling_params}, "
f"prompt token ids: {shortened_token_ids}.")
@ -401,7 +403,8 @@ class AsyncLLMEngine:
prompt=prompt,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
arrival_time=arrival_time,
prefix_pos=prefix_pos)
return stream
@ -410,7 +413,8 @@ class AsyncLLMEngine:
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None
prompt_token_ids: Optional[List[int]] = None,
prefix_pos: Optional[int] = None,
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
@ -425,6 +429,11 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Yields:
The output `RequestOutput` objects from the LLMEngine for the
@ -482,7 +491,8 @@ class AsyncLLMEngine:
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
arrival_time=arrival_time,
prefix_pos=prefix_pos)
async for request_output in stream:
yield request_output

View File

@ -337,6 +337,7 @@ class LLMEngine:
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
prefix_pos: Optional[int] = None,
) -> None:
"""Add a request to the engine's request pool.
@ -353,6 +354,11 @@ class LLMEngine:
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Details:
- Set arrival_time to the current time if it is None.
@ -389,9 +395,13 @@ class LLMEngine:
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
# Check whether the input specifies prefix
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time)
arrival_time, prefix)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
@ -662,6 +672,12 @@ class LLMEngine:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
# Update prefix state, now all the uncomputed prefixes are computed.
for seq_group in scheduled_seq_groups:
if (seq_group.prefix is not None and seq_group.prefix.allocated
and not seq_group.prefix.computed):
seq_group.prefix.computed = True
if self.log_stats:
# Log the system stats.
self._log_system_stats(scheduler_outputs.prompt_run,

View File

@ -33,11 +33,15 @@ async def generate(request: Request) -> Response:
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
prefix_pos = request_dict.pop("prefix_pos", None)
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id)
results_generator = engine.generate(prompt,
sampling_params,
request_id,
prefix_pos=prefix_pos)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:

View File

@ -120,6 +120,7 @@ class LLM:
prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
prefix_pos: Optional[Union[int, List[int]]] = None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
@ -134,6 +135,11 @@ class LLM:
None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
use_tqdm: Whether to use tqdm to display the progress bar.
Returns:
@ -159,9 +165,10 @@ class LLM:
prompt_token_ids)
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]
self._add_request(prompt, sampling_params, token_ids)
self._add_request(prompt, sampling_params, token_ids, prefix_pos_i)
return self._run_engine(use_tqdm)
def _add_request(
@ -169,10 +176,14 @@ class LLM:
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
prefix_pos: Optional[int] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id, prompt, sampling_params,
prompt_token_ids)
self.llm_engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids,
prefix_pos=prefix_pos)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.

View File

@ -18,12 +18,18 @@ class InputMetadata:
self,
is_prompt: bool,
slot_mapping: torch.Tensor,
prompt_lens: Optional[torch.Tensor],
max_seq_len: Optional[int],
start_loc: Optional[torch.Tensor],
max_context_len: Optional[int],
context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor],
use_cuda_graph: bool,
) -> None:
self.is_prompt = is_prompt
self.prompt_lens = prompt_lens
self.max_seq_len = max_seq_len
self.start_loc = start_loc
self.max_context_len = max_context_len
self.slot_mapping = slot_mapping
self.context_lens = context_lens

View File

@ -10,6 +10,8 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from vllm._C import ops
from vllm._C import cache_ops
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
context_attention_fwd)
from vllm.utils import is_hip
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
@ -115,45 +117,65 @@ class PagedAttention(nn.Module):
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# normal attention
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
# Set attention bias if not provided. This typically happens at the
# very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
# TODO(woosuk): Too many view operations. Let's try to reduce them
# in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)
else:
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
# prefix-enabled attention
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.context_lens,
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)
else:
# Decoding run.
output = _paged_attention(

View File

@ -0,0 +1,728 @@
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
import torch
import triton
import triton.language as tl
if triton.__version__ >= "2.1.0":
@triton.jit
def _fwd_kernel(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(
Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# # update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@triton.jit
def _fwd_kernel_flash_attn_v2(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(
Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# acc /= l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@triton.jit
def _fwd_kernel_alibi(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
Alibi_slopes,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# attn_bias[]
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(
Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = 0
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v, allow_tf32=False)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
# init alibi
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = cur_batch_ctx_len
# # init debuger
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, allow_tf32=False)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v, allow_tf32=False)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@torch.inference_mode()
def context_attention_fwd(q,
k,
v,
o,
k_cache,
v_cache,
b_loc,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
alibi_slopes=None):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
num_warps = 8 if Lk <= 64 else 8
if alibi_slopes is not None:
_fwd_kernel_alibi[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
alibi_slopes,
v_cache.shape[3],
8,
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
_fwd_kernel[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
v_cache.shape[3],
8,
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return

87
vllm/prefix.py Normal file
View File

@ -0,0 +1,87 @@
from typing import Dict, List, Sequence, Tuple, Optional
from vllm.block import BlockTable
class Prefix:
"""Data and states associated with a prefix of prompt tokens for multiple
sequence groups.
NOTE: This feature is experimental and may be replaced with automatic
prefix caching in the future.
Args:
prefix_id: The id of the prefix in the prefix pool.
token_ids: The token ids of the prefix.
block_size: The block size of the executed model.
"""
def __init__(
self,
token_ids: Sequence[int],
block_size: int,
) -> None:
self.token_ids = tuple(token_ids)
self.block_size = block_size
self.length = len(token_ids)
self.hash = hash(token_ids)
assert self.length % block_size == 0
self.block_table: Optional[BlockTable] = None
self.computed = False
@property
def allocated(self) -> bool:
return self.block_table is not None
def get_num_blocks(self) -> int:
return self.length // self.block_size
def get_block_numbers(self) -> List[int]:
return [block.block_number for block in self.block_table]
def get_length(self) -> int:
return self.length
def __hash__(self) -> int:
return self.hash
def set_block_table(self, block_table: BlockTable) -> None:
self.block_table = block_table.copy()
class PrefixPool:
"""Manages all the prompt prefixes.
NOTE: This feature is experimental and may be replaced with automatic
prefix caching in the future.
Args:
block_size: The block size of the executed model.
Attributes:
prefixes: A list of all the prefixes.
block_size: The block size of the executed model.
"""
def __init__(
self,
block_size: int,
) -> None:
# TODO(zhuohan): Add a capacity limit to the prefix pool.
self.prefixes: Dict[int, Prefix] = {}
self.block_size = block_size
def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]:
new_length = len(token_ids) // self.block_size * self.block_size
return tuple(token_ids[:new_length])
def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]:
token_ids = self._truncate_token_ids(token_ids)
if len(token_ids) == 0:
# Prefix is empty.
return None
prefix = Prefix(token_ids, self.block_size)
prefix_hash = hash(prefix)
if prefix_hash not in self.prefixes:
self.prefixes[prefix_hash] = prefix
return self.prefixes[prefix_hash]

View File

@ -4,6 +4,7 @@ import enum
from typing import Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock
from vllm.prefix import Prefix
from vllm.sampling_params import SamplingParams
PromptLogprobs = List[Optional[Dict[int, float]]]
@ -236,11 +237,13 @@ class SequenceGroup:
seqs: List[Sequence],
sampling_params: SamplingParams,
arrival_time: float,
prefix: Optional[Prefix] = None,
) -> None:
self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params
self.arrival_time = arrival_time
self.prefix: Optional[Prefix] = prefix
self.prompt_logprobs: Optional[PromptLogprobs] = None
@property
@ -327,7 +330,6 @@ class SequenceGroup:
class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`.
Args:
request_id: The ID of the request.
is_prompt: Whether the request is at prompt stage.
@ -335,6 +337,7 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
prefix: The prefix of the prompt of the sequence group.
"""
def __init__(
@ -344,12 +347,14 @@ class SequenceGroupMetadata:
seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]],
prefix: Optional[Prefix] = None,
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
self.seq_data = seq_data
self.sampling_params = sampling_params
self.block_tables = block_tables
self.prefix = prefix
class SequenceOutput:

View File

@ -74,13 +74,17 @@ class ModelRunner:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]:
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int],
List[int]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
prompt_lens: List[int] = []
context_lens: List[int] = []
subquery_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
@ -91,11 +95,23 @@ class ModelRunner:
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
prefix_len = 0
prefix = seq_group_metadata.prefix
if prefix is not None and prefix.computed:
prefix_len = prefix.get_length()
prompt_tokens = prompt_tokens[prefix_len:]
prefix_block_tables.append(prefix.get_block_numbers())
else:
prefix_block_tables.append([])
# actual prompt lens
context_lens.append(prefix_len)
subquery_lens.append(prompt_len - prefix_len)
input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.append(list(range(prompt_len)))
input_positions.append(
list(range(prefix_len, prefix_len + len(prompt_tokens))))
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
@ -113,8 +129,11 @@ class ModelRunner:
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
assert prefix_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window)
for i in range(prompt_len):
for i in range(prefix_len, prompt_len):
if i < start_idx:
slot_mapping[-1].append(_PAD_SLOT_ID)
continue
@ -124,7 +143,7 @@ class ModelRunner:
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
max_prompt_len = max(prompt_lens)
max_prompt_len = max(subquery_lens)
input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
@ -137,16 +156,39 @@ class ModelRunner:
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long)
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device='cuda')
# Prepare prefix block tables
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = _make_tensor_with_pad(
prefix_block_tables,
max_len=max_prompt_block_table_len,
pad=0,
dtype=torch.int,
)
start_loc_tensor = torch.arange(0,
len(prompt_lens) * max_prompt_len,
max_prompt_len,
dtype=torch.long,
device='cuda')
prompt_lens_tensor = torch.tensor(prompt_lens,
dtype=torch.long,
device='cuda')
input_metadata = InputMetadata(
is_prompt=True,
slot_mapping=slot_mapping,
prompt_lens=prompt_lens_tensor,
max_seq_len=max_prompt_len,
start_loc=start_loc_tensor,
max_context_len=None,
context_lens=None,
block_tables=None,
context_lens=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
)
return input_tokens, input_positions, input_metadata, prompt_lens
return (input_tokens, input_positions, input_metadata, prompt_lens,
subquery_lens)
def _prepare_decode(
self,
@ -248,6 +290,9 @@ class ModelRunner:
input_metadata = InputMetadata(
is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None,
max_seq_len=None,
start_loc=None,
max_context_len=max_context_len,
context_lens=context_lens,
block_tables=block_tables,
@ -259,6 +304,7 @@ class ModelRunner:
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
@ -266,7 +312,7 @@ class ModelRunner:
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
max_prompt_len = max(prompt_lens) if prompt_lens else 1
max_subquery_len = max(subquery_lens) if subquery_lens else 1
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
@ -274,10 +320,11 @@ class ModelRunner:
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
prompt_len = prompt_lens[i]
assert subquery_lens is not None
subquery_len = subquery_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
@ -287,10 +334,10 @@ class ModelRunner:
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_prompt_len
subquery_len - 1)
selected_token_start_idx += max_subquery_len
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
@ -335,14 +382,16 @@ class ModelRunner:
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_metadata,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
(input_tokens, input_positions, input_metadata, prompt_lens,
subquery_lens) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions, input_metadata
) = self._prepare_decode(seq_group_metadata_list)
subquery_lens = None
prompt_lens = []
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens)
prompt_lens,
subquery_lens)
def get_size_or_none(x: Optional[torch.Tensor]):
return x.size() if x is not None else None
@ -359,6 +408,12 @@ class ModelRunner:
input_metadata.is_prompt,
"slot_mapping_size":
get_size_or_none(input_metadata.slot_mapping),
"prompt_lens_size":
get_size_or_none(input_metadata.prompt_lens),
"max_seq_len":
input_metadata.max_seq_len,
"start_loc_size":
get_size_or_none(input_metadata.start_loc),
"max_context_len":
input_metadata.max_context_len,
"context_lens_size":
@ -376,6 +431,10 @@ class ModelRunner:
broadcast(input_positions, src=0)
if input_metadata.slot_mapping is not None:
broadcast(input_metadata.slot_mapping, src=0)
if input_metadata.prompt_lens is not None:
broadcast(input_metadata.prompt_lens, src=0)
if input_metadata.start_loc is not None:
broadcast(input_metadata.start_loc, src=0)
if input_metadata.context_lens is not None:
broadcast(input_metadata.context_lens, src=0)
if input_metadata.block_tables is not None:
@ -400,6 +459,20 @@ class ModelRunner:
broadcast(slot_mapping, src=0)
else:
slot_mapping = None
if py_data["prompt_lens_size"] is not None:
prompt_lens = torch.empty(*py_data["prompt_lens_size"],
dtype=torch.long,
device="cuda")
broadcast(prompt_lens, src=0)
else:
prompt_lens = None
if py_data["start_loc_size"] is not None:
start_loc = torch.empty(*py_data["start_loc_size"],
dtype=torch.long,
device="cuda")
broadcast(start_loc, src=0)
else:
start_loc = None
if py_data["context_lens_size"] is not None:
context_lens = torch.empty(*py_data["context_lens_size"],
dtype=torch.int,
@ -422,6 +495,9 @@ class ModelRunner:
input_metadata = InputMetadata(
is_prompt=py_data["is_prompt"],
slot_mapping=slot_mapping,
prompt_lens=prompt_lens,
max_seq_len=py_data["max_seq_len"],
start_loc=start_loc,
max_context_len=py_data["max_context_len"],
context_lens=context_lens,
block_tables=block_tables,
@ -534,6 +610,9 @@ class ModelRunner:
input_metadata = InputMetadata(
is_prompt=False,
slot_mapping=slot_mapping[:batch_size],
prompt_lens=None,
max_seq_len=None,
start_loc=None,
max_context_len=self.max_context_len_to_capture,
context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size],