[Experimental] Prefix Caching Support (#1669)
Co-authored-by: DouHappy <2278958187@qq.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
14cc317ba4
commit
d10f8e1d43
@ -31,6 +31,10 @@ steps:
|
||||
- pytest -v -s models --forked
|
||||
soft_fail: true
|
||||
|
||||
- label: Prefix Caching Test
|
||||
commands:
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
- label: Samplers Test
|
||||
command: pytest -v -s samplers --forked
|
||||
|
||||
|
51
examples/offline_inference_with_prefix.py
Normal file
51
examples/offline_inference_with_prefix.py
Normal file
@ -0,0 +1,51 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
prefix = (
|
||||
"You are an expert school principal, skilled in effectively managing "
|
||||
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
||||
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
|
||||
"community, joyful discovery, and life-long learning. The candidate is "
|
||||
"coming in for a first-round panel interview for a 8th grade Math "
|
||||
"teaching role. They have 5 years of previous teaching experience "
|
||||
"as an assistant teacher at a co-ed, public school with experience "
|
||||
"in middle school math teaching. Based on these information, fulfill "
|
||||
"the following paragraph: ")
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.0)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
|
||||
generating_prompts = [prefix + prompt for prompt in prompts]
|
||||
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(generating_prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
print("-" * 80)
|
||||
|
||||
# -1 since the last token can change when concatenating prompts.
|
||||
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
|
||||
|
||||
# Generate with prefix
|
||||
outputs = llm.generate(generating_prompts, sampling_params,
|
||||
prefix_pos=[prefix_pos] * len(generating_prompts))
|
||||
|
||||
# Print the outputs. You should see the same outputs as before
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
168
tests/kernels/test_prefix_prefill.py
Normal file
168
tests/kernels/test_prefix_prefill.py
Normal file
@ -0,0 +1,168 @@
|
||||
import random
|
||||
import pytest
|
||||
import time
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
|
||||
context_attention_fwd)
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||
|
||||
NUM_HEADS = [12]
|
||||
HEAD_SIZES = [128]
|
||||
DTYPES = [torch.float16]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
MAX_SEQ_LEN = 1024
|
||||
MAX_CTX_LEN = 1024
|
||||
BS = 10
|
||||
cache_size = 640
|
||||
block_size = 32
|
||||
max_block_per_request = 64
|
||||
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
|
||||
|
||||
num_tokens = sum(subquery_lens)
|
||||
query = torch.empty(num_tokens,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
query.uniform_(-1e-3, 1e-3)
|
||||
output = torch.empty(num_tokens,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
|
||||
kv = torch.empty(sum(seq_lens),
|
||||
2,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
kv.uniform_(-1e-3, 1e-3)
|
||||
key, value = kv.unbind(dim=1)
|
||||
|
||||
k_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
v_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
k = torch.zeros(sum(subquery_lens),
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
v = torch.zeros(sum(subquery_lens),
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
values = torch.arange(0, cache_size, dtype=torch.long, device='cuda')
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[:BS * max_block_per_request].view(
|
||||
BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda')
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda')
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1],
|
||||
dtype=torch.long,
|
||||
device='cuda'),
|
||||
dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
||||
dtype=torch.long,
|
||||
device='cuda'),
|
||||
dim=0)
|
||||
for i in range(BS):
|
||||
for j in range(subquery_lens[i]):
|
||||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
||||
j])
|
||||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
||||
b_ctx_len[i] + j])
|
||||
cur_ctx = 0
|
||||
block_id = 0
|
||||
while cur_ctx < b_ctx_len[i]:
|
||||
start_loc = b_seq_start_loc[i] + cur_ctx
|
||||
if cur_ctx + block_size > b_ctx_len[i]:
|
||||
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
|
||||
else:
|
||||
end_loc = start_loc + block_size
|
||||
start_slot = block_table[i, block_id] * block_size
|
||||
end_slot = start_slot + end_loc - start_loc
|
||||
k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
|
||||
key[start_loc:end_loc])
|
||||
v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
|
||||
value[start_loc:end_loc])
|
||||
cur_ctx += block_size
|
||||
block_id += 1
|
||||
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
||||
k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8,
|
||||
8).permute(0, 2, 3, 1, 4).contiguous()
|
||||
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = v_cache.view(-1, block_size, num_heads,
|
||||
head_size).permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
|
||||
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
|
||||
b_start_loc, b_seq_len, b_ctx_len, max_input_len)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
attn_op = xops.fmha.cutlass.FwOp()
|
||||
|
||||
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||
subquery_lens, seq_lens)
|
||||
output_ref = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
op=attn_op,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
output_ref = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
op=attn_op,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
output_ref = output_ref.squeeze(0)
|
||||
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
41
tests/prefix_caching/test_prefix_caching.py
Normal file
41
tests/prefix_caching/test_prefix_caching.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Compare the with and without prefix caching.
|
||||
|
||||
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
prefix = (
|
||||
"You are an expert school principal, skilled in effectively managing "
|
||||
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
||||
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
|
||||
"community, joyful discovery, and life-long learning. The candidate is "
|
||||
"coming in for a first-round panel interview for a 8th grade Math "
|
||||
"teaching role. They have 5 years of previous teaching experience "
|
||||
"as an assistant teacher at a co-ed, public school with experience "
|
||||
"in middle school math teaching. Based on these information, fulfill "
|
||||
"the following paragraph: ")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("max_tokens", [16])
|
||||
def test_prefix_caching(
|
||||
example_prompts,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
):
|
||||
llm = LLM(model=model)
|
||||
# -1 since the last token can change when concatenating prompts.
|
||||
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
|
||||
prompts = [prefix + prompt for prompt in example_prompts]
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs_without_prefix = llm.generate(prompts, sampling_params)
|
||||
outputs_with_prefix = llm.generate(prompts,
|
||||
sampling_params,
|
||||
prefix_pos=[prefix_pos] * len(prompts))
|
||||
for output_without_prefix, output_with_prefix in zip(
|
||||
outputs_without_prefix, outputs_with_prefix):
|
||||
assert (output_without_prefix.outputs[0].token_ids ==
|
||||
output_with_prefix.outputs[0].token_ids)
|
||||
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1
|
@ -66,7 +66,8 @@ def test_sampler_all_greedy(seed: int):
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
@ -105,7 +106,8 @@ def test_sampler_all_random(seed: int):
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
@ -140,7 +142,8 @@ def test_sampler_all_beam(seed: int):
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
@ -193,7 +196,8 @@ def test_sampler_mixed(seed: int):
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
@ -234,7 +238,8 @@ def test_sampler_logits_processors(seed: int):
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
@ -288,7 +293,8 @@ def test_sampler_top_k_top_p(seed: int):
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
|
||||
sample_probs = None
|
||||
|
||||
|
@ -33,11 +33,12 @@ def test_prepare_prompt():
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += max_seq_len
|
||||
input_tokens, input_positions, _, return_prompt_lens = (
|
||||
input_tokens, input_positions, _, return_prompt_lens, _ = (
|
||||
model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert return_prompt_lens == prompt_lens
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
assert input_tokens.shape == (batch_size, max_seq_len)
|
||||
assert input_positions.shape == (batch_size, max_seq_len)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
|
@ -66,3 +66,7 @@ class PhysicalTokenBlock:
|
||||
return (f'PhysicalTokenBlock(device={self.device}, '
|
||||
f'block_number={self.block_number}, '
|
||||
f'ref_count={self.ref_count})')
|
||||
|
||||
|
||||
# Mapping: logical block number -> physical block.
|
||||
BlockTable = List[PhysicalTokenBlock]
|
||||
|
@ -2,13 +2,10 @@
|
||||
import enum
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from vllm.block import PhysicalTokenBlock
|
||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
# Mapping: logical block number -> physical block.
|
||||
BlockTable = List[PhysicalTokenBlock]
|
||||
|
||||
|
||||
class BlockAllocator:
|
||||
"""Manages free physical token blocks for a device.
|
||||
@ -105,6 +102,10 @@ class BlockSpaceManager:
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
num_required_blocks = len(seq.logical_token_blocks)
|
||||
|
||||
if seq_group.prefix is not None and seq_group.prefix.allocated:
|
||||
num_required_blocks -= seq_group.prefix.get_num_blocks()
|
||||
|
||||
if self.block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.block_sliding_window)
|
||||
@ -125,8 +126,21 @@ class BlockSpaceManager:
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
num_prompt_blocks = len(seq.logical_token_blocks)
|
||||
|
||||
block_table: BlockTable = []
|
||||
for logical_idx in range(len(seq.logical_token_blocks)):
|
||||
prefix_block_table: BlockTable = []
|
||||
num_prefix_blocks = 0
|
||||
|
||||
prefix = seq_group.prefix
|
||||
if prefix is not None and prefix.allocated:
|
||||
# Prefix has already been allocated. Use the existing block table.
|
||||
num_prompt_blocks -= prefix.get_num_blocks()
|
||||
for block in prefix.block_table:
|
||||
block.ref_count += seq_group.num_seqs()
|
||||
block_table.append(block)
|
||||
|
||||
for logical_idx in range(num_prompt_blocks):
|
||||
if (self.block_sliding_window is not None
|
||||
and logical_idx >= self.block_sliding_window):
|
||||
block = block_table[logical_idx % self.block_sliding_window]
|
||||
@ -136,6 +150,15 @@ class BlockSpaceManager:
|
||||
block.ref_count = seq_group.num_seqs()
|
||||
block_table.append(block)
|
||||
|
||||
if prefix is not None and not prefix.allocated:
|
||||
# Allocate blocks for the prefix, we will compute the prefix's
|
||||
# KV cache in this run.
|
||||
num_prefix_blocks = prefix.get_num_blocks()
|
||||
prefix_block_table = block_table[:num_prefix_blocks]
|
||||
for block in prefix_block_table:
|
||||
block.ref_count += 1
|
||||
prefix.set_block_table(prefix_block_table)
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
self.block_tables[seq.seq_id] = block_table.copy()
|
||||
@ -210,10 +233,18 @@ class BlockSpaceManager:
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
# CPU block -> GPU block.
|
||||
if seq_group.prefix is not None:
|
||||
# make sure to swap in the prefix first
|
||||
assert seq_group.prefix.allocated and seq_group.prefix.computed
|
||||
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
new_block_table: BlockTable = []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
if seq_group.prefix is not None:
|
||||
for block in seq_group.prefix.block_table:
|
||||
new_block_table.append(block)
|
||||
block.ref_count += 1
|
||||
|
||||
for cpu_block in block_table:
|
||||
if cpu_block in mapping:
|
||||
@ -245,6 +276,12 @@ class BlockSpaceManager:
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
for gpu_block in block_table:
|
||||
if (seq_group.prefix is not None
|
||||
and gpu_block in seq_group.prefix.block_table):
|
||||
# NOTE: We do not swap out the prefix blocks for now.
|
||||
self.gpu_allocator.free(gpu_block)
|
||||
continue
|
||||
|
||||
if gpu_block in mapping:
|
||||
cpu_block = mapping[gpu_block]
|
||||
cpu_block.ref_count += 1
|
||||
|
@ -9,6 +9,7 @@ from vllm.core.policy import PolicyFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceStatus)
|
||||
from vllm.prefix import PrefixPool
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -76,6 +77,9 @@ class Scheduler:
|
||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||
sliding_window=self.cache_config.sliding_window)
|
||||
|
||||
# Create the prefix pool to cache the prefixes.
|
||||
self.prefix_pool = PrefixPool(self.cache_config.block_size)
|
||||
|
||||
# Sequence groups in the WAITING state.
|
||||
self.waiting: Deque[SequenceGroup] = deque()
|
||||
# Sequence groups in the RUNNING state.
|
||||
@ -316,6 +320,7 @@ class Scheduler:
|
||||
seq_data=seq_data,
|
||||
sampling_params=seq_group.sampling_params,
|
||||
block_tables=block_tables,
|
||||
prefix=seq_group.prefix,
|
||||
)
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
return seq_group_metadata_list, scheduler_outputs
|
||||
|
@ -371,6 +371,7 @@ class AsyncLLMEngine:
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
shortened_prompt = prompt
|
||||
@ -383,6 +384,7 @@ class AsyncLLMEngine:
|
||||
max_log_len]
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {shortened_prompt!r}, "
|
||||
f"prefix_pos: {prefix_pos},"
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {shortened_token_ids}.")
|
||||
|
||||
@ -401,7 +403,8 @@ class AsyncLLMEngine:
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
arrival_time=arrival_time,
|
||||
prefix_pos=prefix_pos)
|
||||
|
||||
return stream
|
||||
|
||||
@ -410,7 +413,8 @@ class AsyncLLMEngine:
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
@ -425,6 +429,11 @@ class AsyncLLMEngine:
|
||||
request_id: The unique id of the request.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
prefix_pos: If not None, we use the given position as the prefix
|
||||
position for each prompt. We will cache the prefix's KV
|
||||
cache and reuse it for the next request with the same prefix.
|
||||
This is an experimental feature, and may be replaced with
|
||||
automatic prefix caching in the future.
|
||||
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMEngine for the
|
||||
@ -482,7 +491,8 @@ class AsyncLLMEngine:
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
arrival_time=arrival_time,
|
||||
prefix_pos=prefix_pos)
|
||||
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
|
@ -337,6 +337,7 @@ class LLMEngine:
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Add a request to the engine's request pool.
|
||||
|
||||
@ -353,6 +354,11 @@ class LLMEngine:
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
the current monotonic time.
|
||||
prefix_pos: If not None, we use the given position as the prefix
|
||||
position for each prompt. We will cache the prefix's KV
|
||||
cache and reuse it for the next request with the same prefix.
|
||||
This is an experimental feature, and may be replaced with
|
||||
automatic prefix caching in the future.
|
||||
|
||||
Details:
|
||||
- Set arrival_time to the current time if it is None.
|
||||
@ -389,9 +395,13 @@ class LLMEngine:
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||
|
||||
# Check whether the input specifies prefix
|
||||
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
|
||||
prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
arrival_time)
|
||||
arrival_time, prefix)
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
@ -662,6 +672,12 @@ class LLMEngine:
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
|
||||
# Update prefix state, now all the uncomputed prefixes are computed.
|
||||
for seq_group in scheduled_seq_groups:
|
||||
if (seq_group.prefix is not None and seq_group.prefix.allocated
|
||||
and not seq_group.prefix.computed):
|
||||
seq_group.prefix.computed = True
|
||||
|
||||
if self.log_stats:
|
||||
# Log the system stats.
|
||||
self._log_system_stats(scheduler_outputs.prompt_run,
|
||||
|
@ -33,11 +33,15 @@ async def generate(request: Request) -> Response:
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
prefix_pos = request_dict.pop("prefix_pos", None)
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
results_generator = engine.generate(prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
prefix_pos=prefix_pos)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
|
@ -120,6 +120,7 @@ class LLM:
|
||||
prompts: Optional[Union[str, List[str]]] = None,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
prefix_pos: Optional[Union[int, List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
) -> List[RequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
@ -134,6 +135,11 @@ class LLM:
|
||||
None, we use the default sampling parameters.
|
||||
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
prefix_pos: If not None, we use the given position as the prefix
|
||||
position for each prompt. We will cache the prefix's KV
|
||||
cache and reuse it for the next request with the same prefix.
|
||||
This is an experimental feature, and may be replaced with
|
||||
automatic prefix caching in the future.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
|
||||
Returns:
|
||||
@ -159,9 +165,10 @@ class LLM:
|
||||
prompt_token_ids)
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
|
||||
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
||||
i]
|
||||
self._add_request(prompt, sampling_params, token_ids)
|
||||
self._add_request(prompt, sampling_params, token_ids, prefix_pos_i)
|
||||
return self._run_engine(use_tqdm)
|
||||
|
||||
def _add_request(
|
||||
@ -169,10 +176,14 @@ class LLM:
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_engine.add_request(request_id, prompt, sampling_params,
|
||||
prompt_token_ids)
|
||||
self.llm_engine.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids,
|
||||
prefix_pos=prefix_pos)
|
||||
|
||||
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
# Initialize tqdm.
|
||||
|
@ -18,12 +18,18 @@ class InputMetadata:
|
||||
self,
|
||||
is_prompt: bool,
|
||||
slot_mapping: torch.Tensor,
|
||||
prompt_lens: Optional[torch.Tensor],
|
||||
max_seq_len: Optional[int],
|
||||
start_loc: Optional[torch.Tensor],
|
||||
max_context_len: Optional[int],
|
||||
context_lens: Optional[torch.Tensor],
|
||||
block_tables: Optional[torch.Tensor],
|
||||
use_cuda_graph: bool,
|
||||
) -> None:
|
||||
self.is_prompt = is_prompt
|
||||
self.prompt_lens = prompt_lens
|
||||
self.max_seq_len = max_seq_len
|
||||
self.start_loc = start_loc
|
||||
self.max_context_len = max_context_len
|
||||
self.slot_mapping = slot_mapping
|
||||
self.context_lens = context_lens
|
||||
|
@ -10,6 +10,8 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||
from vllm._C import ops
|
||||
from vllm._C import cache_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
|
||||
context_attention_fwd)
|
||||
from vllm.utils import is_hip
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
@ -115,45 +117,65 @@ class PagedAttention(nn.Module):
|
||||
self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
# normal attention
|
||||
if (key_cache is None or value_cache is None
|
||||
or input_metadata.block_tables.numel() == 0):
|
||||
# Set attention bias if not provided. This typically happens at
|
||||
# the very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
if input_metadata.attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
[seq_len] * batch_size)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
else:
|
||||
input_metadata.attn_bias = _make_alibi_bias(
|
||||
self.alibi_slopes, self.num_kv_heads, batch_size,
|
||||
seq_len, query.dtype)
|
||||
|
||||
# Set attention bias if not provided. This typically happens at the
|
||||
# very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
if input_metadata.attn_bias is None:
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
[seq_len] * batch_size)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
else:
|
||||
input_metadata.attn_bias = _make_alibi_bias(
|
||||
self.alibi_slopes, self.num_kv_heads, batch_size,
|
||||
seq_len, query.dtype)
|
||||
query = query.unflatten(0, (batch_size, seq_len))
|
||||
key = key.unflatten(0, (batch_size, seq_len))
|
||||
value = value.unflatten(0, (batch_size, seq_len))
|
||||
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce them
|
||||
# in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
|
||||
(is_hip()) else None,
|
||||
)
|
||||
output = out.view_as(query)
|
||||
else:
|
||||
query = query.unflatten(0, (batch_size, seq_len))
|
||||
key = key.unflatten(0, (batch_size, seq_len))
|
||||
value = value.unflatten(0, (batch_size, seq_len))
|
||||
# prefix-enabled attention
|
||||
output = torch.empty_like(query)
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.block_tables, # [BS, max_block_per_request]
|
||||
input_metadata.start_loc,
|
||||
input_metadata.prompt_lens,
|
||||
input_metadata.context_lens,
|
||||
input_metadata.max_seq_len,
|
||||
getattr(self, "alibi_slopes", None),
|
||||
)
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
|
||||
(is_hip()) else None,
|
||||
)
|
||||
output = out.view_as(query)
|
||||
else:
|
||||
# Decoding run.
|
||||
output = _paged_attention(
|
||||
|
728
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
Normal file
728
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
Normal file
@ -0,0 +1,728 @@
|
||||
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
|
||||
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
if triton.__version__ >= "2.1.0":
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
B_Ctxlen,
|
||||
block_size,
|
||||
x,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
q = tl.load(
|
||||
Q + off_q,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0)
|
||||
off_k = (bn[None, :] * stride_k_cache_bs +
|
||||
cur_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) *
|
||||
stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (
|
||||
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k = tl.load(K_cache + off_k,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(V_cache + off_v,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# # update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
block_mask = tl.where(
|
||||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) <
|
||||
cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) <
|
||||
cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_flash_attn_v2(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
B_Ctxlen,
|
||||
block_size,
|
||||
x,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
q = tl.load(
|
||||
Q + off_q,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0)
|
||||
off_k = (bn[None, :] * stride_k_cache_bs +
|
||||
cur_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) *
|
||||
stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (
|
||||
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k = tl.load(K_cache + off_k,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(V_cache + off_v,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
block_mask = tl.where(
|
||||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) <
|
||||
cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) <
|
||||
cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
# acc /= l_i[:, None]
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
||||
return
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_alibi(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
K_cache,
|
||||
V_cache,
|
||||
B_Loc,
|
||||
sm_scale,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
B_Ctxlen,
|
||||
Alibi_slopes,
|
||||
block_size,
|
||||
x,
|
||||
Out,
|
||||
stride_b_loc_b,
|
||||
stride_b_loc_s,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kbs,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vbs,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
stride_k_cache_bs,
|
||||
stride_k_cache_h,
|
||||
stride_k_cache_d,
|
||||
stride_k_cache_bl,
|
||||
stride_k_cache_x,
|
||||
stride_v_cache_bs,
|
||||
stride_v_cache_h,
|
||||
stride_v_cache_d,
|
||||
stride_v_cache_bl,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
# attn_bias[]
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
start_m = tl.program_id(2)
|
||||
|
||||
# cur_batch_seq_len: the length of prompts
|
||||
# cur_batch_ctx_len: the length of prefix
|
||||
# cur_batch_in_all_start_index: the start id of the dim=0
|
||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||
|
||||
block_start_loc = BLOCK_M * start_m
|
||||
|
||||
# initialize offsets
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_q = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
|
||||
cur_head * stride_qh + offs_d[None, :] * stride_qd)
|
||||
|
||||
q = tl.load(
|
||||
Q + off_q,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
# # initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(
|
||||
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = 0
|
||||
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
|
||||
((start_n + offs_n) // block_size) * stride_b_loc_s,
|
||||
mask=(start_n + offs_n) < cur_batch_ctx_len,
|
||||
other=0)
|
||||
off_k = (bn[None, :] * stride_k_cache_bs +
|
||||
cur_head * stride_k_cache_h +
|
||||
(offs_d[:, None] // x) * stride_k_cache_d +
|
||||
((start_n + offs_n[None, :]) % block_size) *
|
||||
stride_k_cache_bl +
|
||||
(offs_d[:, None] % x) * stride_k_cache_x)
|
||||
off_v = (
|
||||
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
|
||||
offs_d[None, :] * stride_v_cache_d +
|
||||
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
|
||||
k = tl.load(K_cache + off_k,
|
||||
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
|
||||
float("-inf"))
|
||||
qk *= sm_scale
|
||||
|
||||
# load alibi
|
||||
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||
alibi_start_q[:, None]) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
|
||||
alibi, float("-inf"))
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(V_cache + off_v,
|
||||
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v, allow_tf32=False)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
|
||||
offs_d[:, None] * stride_kd)
|
||||
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
|
||||
offs_d[None, :] * stride_vd)
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
block_mask = tl.where(
|
||||
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
|
||||
|
||||
# init alibi
|
||||
alibi_slope = tl.load(Alibi_slopes + cur_head)
|
||||
alibi_start_q = tl.arange(
|
||||
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
|
||||
alibi_start_k = cur_batch_ctx_len
|
||||
# # init debuger
|
||||
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
|
||||
# offset_db_k = tl.arange(0, BLOCK_N)
|
||||
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
|
||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||
mask=(start_n + offs_n[None, :]) <
|
||||
cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, allow_tf32=False)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
|
||||
float("-inf"))
|
||||
|
||||
# load alibi
|
||||
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
|
||||
alibi_start_q[:, None]) * alibi_slope
|
||||
alibi = tl.where(
|
||||
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
|
||||
alibi, float("-inf"))
|
||||
qk += alibi
|
||||
alibi_start_k += BLOCK_N
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
p = tl.math.exp(qk - m_i_new[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
|
||||
alpha = tl.math.exp(m_i - m_i_new)
|
||||
l_i_new = alpha * l_i + l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
# scale acc
|
||||
acc_scale = alpha
|
||||
# acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs +
|
||||
(cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||
mask=(start_n + offs_n[:, None]) <
|
||||
cur_batch_seq_len - cur_batch_ctx_len,
|
||||
other=0.0)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v, allow_tf32=False)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
|
||||
# initialize pointers to output
|
||||
off_o = (
|
||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
|
||||
cur_head * stride_oh + offs_d[None, :] * stride_od)
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
|
||||
return
|
||||
|
||||
@torch.inference_mode()
|
||||
def context_attention_fwd(q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
alibi_slopes=None):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
|
||||
sm_scale = 1.0 / (Lq**0.5)
|
||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
|
||||
|
||||
num_warps = 8 if Lk <= 64 else 8
|
||||
if alibi_slopes is not None:
|
||||
_fwd_kernel_alibi[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
alibi_slopes,
|
||||
v_cache.shape[3],
|
||||
8,
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4
|
||||
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(
|
||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
b_loc,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
v_cache.shape[3],
|
||||
8,
|
||||
o,
|
||||
b_loc.stride(0),
|
||||
b_loc.stride(1),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
k_cache.stride(
|
||||
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
v_cache.stride(0),
|
||||
v_cache.stride(1),
|
||||
v_cache.stride(2),
|
||||
v_cache.stride(
|
||||
3), #[num_blocks, num_kv_heads, head_size, block_size]
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return
|
87
vllm/prefix.py
Normal file
87
vllm/prefix.py
Normal file
@ -0,0 +1,87 @@
|
||||
from typing import Dict, List, Sequence, Tuple, Optional
|
||||
|
||||
from vllm.block import BlockTable
|
||||
|
||||
|
||||
class Prefix:
|
||||
"""Data and states associated with a prefix of prompt tokens for multiple
|
||||
sequence groups.
|
||||
|
||||
NOTE: This feature is experimental and may be replaced with automatic
|
||||
prefix caching in the future.
|
||||
|
||||
Args:
|
||||
prefix_id: The id of the prefix in the prefix pool.
|
||||
token_ids: The token ids of the prefix.
|
||||
block_size: The block size of the executed model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_ids: Sequence[int],
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.token_ids = tuple(token_ids)
|
||||
self.block_size = block_size
|
||||
self.length = len(token_ids)
|
||||
self.hash = hash(token_ids)
|
||||
assert self.length % block_size == 0
|
||||
self.block_table: Optional[BlockTable] = None
|
||||
self.computed = False
|
||||
|
||||
@property
|
||||
def allocated(self) -> bool:
|
||||
return self.block_table is not None
|
||||
|
||||
def get_num_blocks(self) -> int:
|
||||
return self.length // self.block_size
|
||||
|
||||
def get_block_numbers(self) -> List[int]:
|
||||
return [block.block_number for block in self.block_table]
|
||||
|
||||
def get_length(self) -> int:
|
||||
return self.length
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.hash
|
||||
|
||||
def set_block_table(self, block_table: BlockTable) -> None:
|
||||
self.block_table = block_table.copy()
|
||||
|
||||
|
||||
class PrefixPool:
|
||||
"""Manages all the prompt prefixes.
|
||||
|
||||
NOTE: This feature is experimental and may be replaced with automatic
|
||||
prefix caching in the future.
|
||||
|
||||
Args:
|
||||
block_size: The block size of the executed model.
|
||||
|
||||
Attributes:
|
||||
prefixes: A list of all the prefixes.
|
||||
block_size: The block size of the executed model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
) -> None:
|
||||
# TODO(zhuohan): Add a capacity limit to the prefix pool.
|
||||
self.prefixes: Dict[int, Prefix] = {}
|
||||
self.block_size = block_size
|
||||
|
||||
def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]:
|
||||
new_length = len(token_ids) // self.block_size * self.block_size
|
||||
return tuple(token_ids[:new_length])
|
||||
|
||||
def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]:
|
||||
token_ids = self._truncate_token_ids(token_ids)
|
||||
if len(token_ids) == 0:
|
||||
# Prefix is empty.
|
||||
return None
|
||||
prefix = Prefix(token_ids, self.block_size)
|
||||
prefix_hash = hash(prefix)
|
||||
if prefix_hash not in self.prefixes:
|
||||
self.prefixes[prefix_hash] = prefix
|
||||
return self.prefixes[prefix_hash]
|
@ -4,6 +4,7 @@ import enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from vllm.block import LogicalTokenBlock
|
||||
from vllm.prefix import Prefix
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
PromptLogprobs = List[Optional[Dict[int, float]]]
|
||||
@ -236,11 +237,13 @@ class SequenceGroup:
|
||||
seqs: List[Sequence],
|
||||
sampling_params: SamplingParams,
|
||||
arrival_time: float,
|
||||
prefix: Optional[Prefix] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
self.sampling_params = sampling_params
|
||||
self.arrival_time = arrival_time
|
||||
self.prefix: Optional[Prefix] = prefix
|
||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
|
||||
@property
|
||||
@ -327,7 +330,6 @@ class SequenceGroup:
|
||||
class SequenceGroupMetadata:
|
||||
"""Metadata for a sequence group. Used to create `InputMetadata`.
|
||||
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
is_prompt: Whether the request is at prompt stage.
|
||||
@ -335,6 +337,7 @@ class SequenceGroupMetadata:
|
||||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
block_tables: The block tables. (Seq id -> list of physical block
|
||||
numbers)
|
||||
prefix: The prefix of the prompt of the sequence group.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -344,12 +347,14 @@ class SequenceGroupMetadata:
|
||||
seq_data: Dict[int, SequenceData],
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]],
|
||||
prefix: Optional[Prefix] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.is_prompt = is_prompt
|
||||
self.seq_data = seq_data
|
||||
self.sampling_params = sampling_params
|
||||
self.block_tables = block_tables
|
||||
self.prefix = prefix
|
||||
|
||||
|
||||
class SequenceOutput:
|
||||
|
@ -74,13 +74,17 @@ class ModelRunner:
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int],
|
||||
List[int]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
|
||||
prompt_lens: List[int] = []
|
||||
context_lens: List[int] = []
|
||||
subquery_lens: List[int] = []
|
||||
prefix_block_tables: List[List[int]] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
@ -91,11 +95,23 @@ class ModelRunner:
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
prefix_len = 0
|
||||
prefix = seq_group_metadata.prefix
|
||||
if prefix is not None and prefix.computed:
|
||||
prefix_len = prefix.get_length()
|
||||
prompt_tokens = prompt_tokens[prefix_len:]
|
||||
prefix_block_tables.append(prefix.get_block_numbers())
|
||||
else:
|
||||
prefix_block_tables.append([])
|
||||
# actual prompt lens
|
||||
context_lens.append(prefix_len)
|
||||
subquery_lens.append(prompt_len - prefix_len)
|
||||
|
||||
input_tokens.append(prompt_tokens)
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.append(list(range(prompt_len)))
|
||||
input_positions.append(
|
||||
list(range(prefix_len, prefix_len + len(prompt_tokens))))
|
||||
|
||||
if seq_group_metadata.block_tables is None:
|
||||
# During memory profiling, the block tables are not initialized
|
||||
@ -113,8 +129,11 @@ class ModelRunner:
|
||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
assert prefix_len == 0, (
|
||||
"Prefix caching is currently not supported with "
|
||||
"sliding window attention")
|
||||
start_idx = max(0, prompt_len - self.sliding_window)
|
||||
for i in range(prompt_len):
|
||||
for i in range(prefix_len, prompt_len):
|
||||
if i < start_idx:
|
||||
slot_mapping[-1].append(_PAD_SLOT_ID)
|
||||
continue
|
||||
@ -124,7 +143,7 @@ class ModelRunner:
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping[-1].append(slot)
|
||||
|
||||
max_prompt_len = max(prompt_lens)
|
||||
max_prompt_len = max(subquery_lens)
|
||||
input_tokens = _make_tensor_with_pad(input_tokens,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
@ -137,16 +156,39 @@ class ModelRunner:
|
||||
max_prompt_len,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=torch.long)
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device='cuda')
|
||||
# Prepare prefix block tables
|
||||
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
|
||||
block_tables = _make_tensor_with_pad(
|
||||
prefix_block_tables,
|
||||
max_len=max_prompt_block_table_len,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
)
|
||||
start_loc_tensor = torch.arange(0,
|
||||
len(prompt_lens) * max_prompt_len,
|
||||
max_prompt_len,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
prompt_lens_tensor = torch.tensor(prompt_lens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
is_prompt=True,
|
||||
slot_mapping=slot_mapping,
|
||||
prompt_lens=prompt_lens_tensor,
|
||||
max_seq_len=max_prompt_len,
|
||||
start_loc=start_loc_tensor,
|
||||
max_context_len=None,
|
||||
context_lens=None,
|
||||
block_tables=None,
|
||||
context_lens=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return input_tokens, input_positions, input_metadata, prompt_lens
|
||||
return (input_tokens, input_positions, input_metadata, prompt_lens,
|
||||
subquery_lens)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
@ -248,6 +290,9 @@ class ModelRunner:
|
||||
input_metadata = InputMetadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping,
|
||||
prompt_lens=None,
|
||||
max_seq_len=None,
|
||||
start_loc=None,
|
||||
max_context_len=max_context_len,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
@ -259,6 +304,7 @@ class ModelRunner:
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
subquery_lens: Optional[List[int]],
|
||||
) -> SamplingMetadata:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
selected_token_indices: List[int] = []
|
||||
@ -266,7 +312,7 @@ class ModelRunner:
|
||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
|
||||
max_prompt_len = max(prompt_lens) if prompt_lens else 1
|
||||
max_subquery_len = max(subquery_lens) if subquery_lens else 1
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
@ -274,10 +320,11 @@ class ModelRunner:
|
||||
|
||||
if seq_group_metadata.is_prompt:
|
||||
assert len(seq_ids) == 1
|
||||
prompt_len = prompt_lens[i]
|
||||
assert subquery_lens is not None
|
||||
subquery_len = subquery_lens[i]
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
# NOTE: prompt token positions do not need sample, skip
|
||||
categorized_sample_indices_start_idx += prompt_len - 1
|
||||
categorized_sample_indices_start_idx += subquery_len - 1
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].append(
|
||||
@ -287,10 +334,10 @@ class ModelRunner:
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
range(selected_token_start_idx,
|
||||
selected_token_start_idx + prompt_len - 1))
|
||||
selected_token_start_idx + subquery_len - 1))
|
||||
selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += max_prompt_len
|
||||
subquery_len - 1)
|
||||
selected_token_start_idx += max_subquery_len
|
||||
else:
|
||||
num_seqs = len(seq_ids)
|
||||
selected_token_indices.extend(
|
||||
@ -335,14 +382,16 @@ class ModelRunner:
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, input_metadata,
|
||||
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
|
||||
(input_tokens, input_positions, input_metadata, prompt_lens,
|
||||
subquery_lens) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions, input_metadata
|
||||
) = self._prepare_decode(seq_group_metadata_list)
|
||||
subquery_lens = None
|
||||
prompt_lens = []
|
||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
prompt_lens,
|
||||
subquery_lens)
|
||||
|
||||
def get_size_or_none(x: Optional[torch.Tensor]):
|
||||
return x.size() if x is not None else None
|
||||
@ -359,6 +408,12 @@ class ModelRunner:
|
||||
input_metadata.is_prompt,
|
||||
"slot_mapping_size":
|
||||
get_size_or_none(input_metadata.slot_mapping),
|
||||
"prompt_lens_size":
|
||||
get_size_or_none(input_metadata.prompt_lens),
|
||||
"max_seq_len":
|
||||
input_metadata.max_seq_len,
|
||||
"start_loc_size":
|
||||
get_size_or_none(input_metadata.start_loc),
|
||||
"max_context_len":
|
||||
input_metadata.max_context_len,
|
||||
"context_lens_size":
|
||||
@ -376,6 +431,10 @@ class ModelRunner:
|
||||
broadcast(input_positions, src=0)
|
||||
if input_metadata.slot_mapping is not None:
|
||||
broadcast(input_metadata.slot_mapping, src=0)
|
||||
if input_metadata.prompt_lens is not None:
|
||||
broadcast(input_metadata.prompt_lens, src=0)
|
||||
if input_metadata.start_loc is not None:
|
||||
broadcast(input_metadata.start_loc, src=0)
|
||||
if input_metadata.context_lens is not None:
|
||||
broadcast(input_metadata.context_lens, src=0)
|
||||
if input_metadata.block_tables is not None:
|
||||
@ -400,6 +459,20 @@ class ModelRunner:
|
||||
broadcast(slot_mapping, src=0)
|
||||
else:
|
||||
slot_mapping = None
|
||||
if py_data["prompt_lens_size"] is not None:
|
||||
prompt_lens = torch.empty(*py_data["prompt_lens_size"],
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
broadcast(prompt_lens, src=0)
|
||||
else:
|
||||
prompt_lens = None
|
||||
if py_data["start_loc_size"] is not None:
|
||||
start_loc = torch.empty(*py_data["start_loc_size"],
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
broadcast(start_loc, src=0)
|
||||
else:
|
||||
start_loc = None
|
||||
if py_data["context_lens_size"] is not None:
|
||||
context_lens = torch.empty(*py_data["context_lens_size"],
|
||||
dtype=torch.int,
|
||||
@ -422,6 +495,9 @@ class ModelRunner:
|
||||
input_metadata = InputMetadata(
|
||||
is_prompt=py_data["is_prompt"],
|
||||
slot_mapping=slot_mapping,
|
||||
prompt_lens=prompt_lens,
|
||||
max_seq_len=py_data["max_seq_len"],
|
||||
start_loc=start_loc,
|
||||
max_context_len=py_data["max_context_len"],
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
@ -534,6 +610,9 @@ class ModelRunner:
|
||||
input_metadata = InputMetadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
prompt_lens=None,
|
||||
max_seq_len=None,
|
||||
start_loc=None,
|
||||
max_context_len=self.max_context_len_to_capture,
|
||||
context_lens=context_lens[:batch_size],
|
||||
block_tables=block_tables[:batch_size],
|
||||
|
Loading…
x
Reference in New Issue
Block a user