[1/n][Chunked Prefill] Refactor input query shapes (#3236)

This commit is contained in:
SangBin Cho 2024-03-21 06:46:05 +09:00 committed by GitHub
parent 426ec4ec67
commit 6e435de766
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 579 additions and 263 deletions

View File

@ -47,7 +47,7 @@ steps:
- pytest -v -s prefix_caching
- label: Samplers Test
command: pytest -v -s samplers --forked
command: pytest -v -s samplers
- label: Worker Test
command: pytest -v -s worker
@ -56,7 +56,7 @@ steps:
command: pytest -v -s spec_decode
- label: LoRA Test %N
command: pytest -v -s lora --forked --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4
- label: Metrics Test

View File

@ -13,6 +13,7 @@ MODELS = [
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
def test_models(
hf_runner,
vllm_runner,
@ -20,12 +21,13 @@ def test_models(
model: str,
dtype: str,
max_tokens: int,
enforce_eager: bool,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

View File

@ -10,7 +10,7 @@ from .utils import create_dummy_prompt
def test_scheduler_add_seq_group():
block_size = 4
scheduler_config = SchedulerConfig(100, 64, 1, 256)
scheduler_config = SchedulerConfig(100, 64, 1)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 4
cache_config.num_gpu_blocks = 4
@ -26,7 +26,7 @@ def test_scheduler_add_seq_group():
def test_scheduler_abort_seq_group():
block_size = 4
scheduler_config = SchedulerConfig(100, 64, 1, 256)
scheduler_config = SchedulerConfig(100, 64, 1)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 4
cache_config.num_gpu_blocks = 4
@ -50,7 +50,7 @@ def test_scheduler_schedule_simple():
block_size = 4
num_seq_group = 4
max_model_len = 16
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len, 256)
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
@ -64,10 +64,10 @@ def test_scheduler_schedule_simple():
running.append(seq_group)
# Schedule seq groups prompts.
num_tokens = block_size * num_seq_group
seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(running)
assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs(
)[0].get_len()
assert out.num_batched_tokens == num_tokens
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == num_seq_group
@ -84,7 +84,7 @@ def test_scheduler_schedule_simple():
def test_scheduler_schedule_preempt_abort():
block_size = 4
max_model_len = 16
scheduler_config = SchedulerConfig(64, 2, max_model_len, 256)
scheduler_config = SchedulerConfig(64, 2, max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 2
cache_config.num_gpu_blocks = 2
@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort():
# Schedule seq groups prompts.
seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_a, seq_group_b]
assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2
assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == 2
@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort():
scheduler.abort_seq_group("1")
seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_b]
assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len()
assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out)
assert len(seq_group_meta) == 1
@ -136,7 +136,7 @@ def test_scheduler_max_seqs():
num_seq_group = 4
max_seq_group = 2
max_model_len = 16
scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len, 256)
scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8

View File

@ -25,7 +25,7 @@ def test_worker_apply_lora(sql_lora_files):
revision=None,
),
parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32, 256),
scheduler_config=SchedulerConfig(32, 32, 32),
device_config=DeviceConfig("cuda"),
local_rank=0,
rank=0,

View File

@ -92,8 +92,8 @@ def test_same_output_for_single_step():
num_gpu_blocks,
seed,
)
multi_step_worker.model_runner = worker.model_runner
multi_step_worker.cache_engine = worker.cache_engine
# multi_step_worker.model_runner = worker.model_runner
# multi_step_worker.cache_engine = worker.cache_engine
num_steps = 1

View File

@ -1,14 +1,132 @@
import random
import torch
from vllm.config import ModelConfig
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT
def get_aligned_size(batch_size: int, alignment: int):
return ((batch_size + alignment - 1) // alignment * alignment)
def test_prepare_prompt():
model_runner = ModelRunner(None, None, None, None, None)
model_runner.set_block_size(16)
batch_size = random.randint(1, 256)
prompt_lens = []
seq_group_metadata_list = []
block_tables = {0: [1]}
for i in range(batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len))
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
))
expected_selected_token_indices = []
selected_token_start_idx = 0
for prompt_len in prompt_lens:
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += prompt_len
(input_tokens, input_positions, input_metadata, return_prompt_lens, _, _,
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
# Verify input metadata is correct for prompts.
device = model_runner.device
assert input_metadata.is_prompt is True
assert torch.allclose(input_metadata.prompt_lens_tensor,
torch.tensor(prompt_lens, device=device))
assert input_metadata.prompt_lens == prompt_lens
assert input_metadata.num_prompt_tokens == sum(prompt_lens)
assert input_metadata.num_generation_tokens == 0
assert input_metadata.max_seq_len == max(prompt_lens)
# Test subquery start locs.
start_idx = 0
start_loc = [start_idx]
for prompt_len in prompt_lens:
start_idx += prompt_len
start_loc.append(start_idx)
assert torch.allclose(
input_metadata.subquery_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
# Test seq start locs. Note that for normal prefill it is
# equivalent to subquery_start_loc.
start_idx = 0
seq_start_loc = [start_idx]
for prompt_len in prompt_lens:
start_idx += prompt_len
seq_start_loc.append(start_idx)
assert torch.allclose(
input_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
assert input_metadata.max_context_len is None
assert torch.allclose(
input_metadata.context_lens,
torch.zeros(input_metadata.context_lens.shape[0],
dtype=torch.int,
device=device))
expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device)
assert torch.allclose(input_metadata.block_tables, expected)
# Cuda graph should not be used for prerill.
assert input_metadata.use_cuda_graph is False
assert input_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), )
torch.testing.assert_close(input_tokens, input_positions)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), )
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(input_tokens, input_positions)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
def test_prepare_decode_cuda_graph():
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
enforce_eager=False,
)
model_runner = ModelRunner(model_config, None, None, None, None)
model_runner.set_block_size(16)
batch_size = random.randint(1, 256)
prompt_lens = []
seq_group_metadata_list = []
@ -20,29 +138,56 @@ def test_prepare_prompt():
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
is_prompt=False,
seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
))
input_tokens, input_positions, input_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert input_metadata.is_prompt is False
assert input_metadata.prompt_lens is None
assert input_metadata.num_prompt_tokens == 0
assert input_metadata.num_generation_tokens == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT))
assert input_metadata.max_seq_len is None
assert input_metadata.subquery_start_loc is None
assert input_metadata.seq_start_loc is None
assert input_metadata.max_context_len == max(prompt_lens)
assert torch.allclose(
input_metadata.context_lens[:len(prompt_lens)],
torch.tensor(prompt_lens, dtype=torch.int, device=device))
# block table's first index corresponds to each batch, meaning in
# decoding it is each token.
assert input_metadata.block_tables.shape[0] == len(input_tokens)
# Block table's second dim correspondsd to each token's block number.
# It is padded up to
assert input_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill.
assert input_metadata.use_cuda_graph is True
assert input_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
assert input_positions.shape == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
torch.testing.assert_close(input_tokens, input_positions)
# Verify Sampling
expected_selected_token_indices = []
selected_token_start_idx = 0
max_seq_len = max(prompt_lens)
for prompt_len in prompt_lens:
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = (
model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len)
assert input_positions.shape == (batch_size, max_seq_len)
torch.testing.assert_close(input_tokens, input_positions)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,

View File

@ -535,7 +535,6 @@ class SchedulerConfig:
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
"""
def __init__(
@ -543,7 +542,6 @@ class SchedulerConfig:
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
max_paddings: int,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
@ -553,7 +551,6 @@ class SchedulerConfig:
self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.max_paddings = max_paddings
self._verify_args()
def _verify_args(self) -> None:

View File

@ -173,12 +173,12 @@ class Scheduler:
curr_loras = set(
seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None
seq_lens: List[int] = []
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
leftover_waiting_sequences = deque()
num_batched_tokens = 0
while self.waiting:
seq_group = self.waiting[0]
waiting_seqs = seq_group.get_seqs(
@ -223,8 +223,7 @@ class Scheduler:
continue
# If the number of batched tokens exceeds the limit, stop.
new_seq_lens = seq_lens + [num_prompt_tokens]
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
num_batched_tokens += num_prompt_tokens
if (num_batched_tokens >
self.scheduler_config.max_num_batched_tokens):
break
@ -236,11 +235,6 @@ class Scheduler:
self.scheduler_config.max_num_seqs):
break
num_paddings = num_batched_tokens - sum(new_seq_lens)
if num_paddings > self.scheduler_config.max_paddings:
break
seq_lens = new_seq_lens
if lora_int_id > 0:
curr_loras.add(lora_int_id)
self.waiting.popleft()
@ -255,8 +249,7 @@ class Scheduler:
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
prompt_run=True,
num_batched_tokens=len(seq_lens) *
max(seq_lens) if seq_lens else 0,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,

View File

@ -31,7 +31,6 @@ class EngineArgs:
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_paddings: int = 256
max_logprobs: int = 5 # OpenAI default value
disable_log_stats: bool = False
revision: Optional[str] = None
@ -213,10 +212,6 @@ class EngineArgs:
type=int,
default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--max-paddings',
type=int,
default=EngineArgs.max_paddings,
help='maximum number of paddings in a batch')
parser.add_argument(
'--max-logprobs',
type=int,
@ -347,8 +342,7 @@ class EngineArgs:
), self.ray_workers_use_nsight)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
self.max_paddings)
model_config.max_model_len)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,

View File

@ -561,7 +561,6 @@ class LLMEngine:
# Log stats.
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs))
return request_outputs
def step(self) -> List[RequestOutput]:

View File

@ -1,36 +1,92 @@
from dataclasses import dataclass, fields
from typing import Optional, Any, Dict
from typing import Optional, List, Any, Dict
import torch
from xformers.ops.fmha.attn_bias import AttentionBias
@dataclass
class InputMetadata:
"""Metadata for input sequences. Used in PagedAttention.
Args:
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
max_context_len: The maximum context length.
context_lens: the length of attention context for each sequence.
block_tables: The block tables. (Seq id -> list of physical block)
kv_cache_dtype: Data type to store kv cache.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
"""
Definition of context_len, subquery_len, and seqlen.
|---------- N-1 iteration --------|
|---------------- N iteration ---------------------|
|- tokenA -|......................|-- newTokens ---|
|---------- context_len ----------|
|-------------------- seqlen ----------------------|
|- subquery_len -|
WARNING: context_len has different definition depending on if it is
prefill vs decoding. When it is prefill, it doesn't include new
tokens. When it is for decoding, it includes a new token.
"""
is_prompt: bool
slot_mapping: torch.Tensor
prompt_lens: Optional[torch.Tensor]
max_seq_len: Optional[int]
start_loc: Optional[torch.Tensor]
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum context length in the batch.
max_context_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
use_cuda_graph: bool
kv_cache_dtype: str
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias = None
self.attn_bias: Optional[List[AttentionBias]] = None
# Cuda graph is only used for decoding now.
if self.use_cuda_graph:
assert self.num_prompt_tokens == 0
def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""

View File

@ -20,8 +20,8 @@ class SiluAndMul(nn.Module):
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def _forward(self, x: torch.Tensor) -> torch.Tensor:

View File

@ -17,11 +17,12 @@ class Attention(nn.Module):
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
3. Output the output tensor.
"""
def __init__(

View File

@ -1,7 +1,7 @@
"""Attention layer with Flash and PagedAttention."""
from typing import List, Optional
from flash_attn import flash_attn_func
from flash_attn import flash_attn_varlen_func
import torch
from vllm.model_executor.input_metadata import InputMetadata
@ -10,6 +10,21 @@ from vllm.model_executor.layers.attention.ops.paged_attn import (
class FlashAttentionBackend:
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def __init__(
self,
@ -52,18 +67,18 @@ class FlashAttentionBackend:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
batch_size, seq_len, hidden_size = query.shape
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
@ -82,13 +97,16 @@ class FlashAttentionBackend:
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# normal attention
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
output = flash_attn_func(
query,
key,
value,
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=input_metadata.seq_start_loc,
cu_seqlens_k=input_metadata.seq_start_loc,
max_seqlen_q=input_metadata.max_seq_len,
max_seqlen_k=input_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
@ -118,4 +136,4 @@ class FlashAttentionBackend:
)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
return output.view(num_tokens, hidden_size)

View File

@ -14,6 +14,21 @@ from vllm.utils import is_hip
class XFormersBackend:
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens --------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def __init__(
self,
@ -55,19 +70,18 @@ class XFormersBackend:
"""Forward pass with xFormers and PagedAttention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
num_tokens, hidden_size = query.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
@ -82,9 +96,10 @@ class XFormersBackend:
if input_metadata.is_prompt:
# Prompt run.
# key_cache and value_cache are None when it is a profiling run.
# block tables are empty if the prompt has never been computed.
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# normal attention
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
@ -103,61 +118,33 @@ class XFormersBackend:
self.num_queries_per_kv,
value.shape[-1])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
if self.use_ref_attention:
output = _ref_masked_attention(
query,
key,
value,
self.num_heads,
self.num_kv_heads,
self.head_size,
self.scale,
)
print("ref attention used.")
output = torch.empty_like(query)
start = 0
for _, prompt_len in enumerate(input_metadata.prompt_lens):
end = start + prompt_len
out = _ref_masked_attention(
query[None, start:end],
key[None, start:end],
value[None, start:end],
self.num_heads,
self.num_kv_heads,
self.head_size,
self.scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
start += prompt_len
# Using view got RuntimeError: view size is not compatible
# with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# Use reshape instead.
return output.reshape(batch_size, seq_len, hidden_size)
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
else:
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)
return output.reshape(num_tokens, hidden_size)
output = self._run_memory_efficient_xformer_forward(
query, key, value, input_metadata)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
@ -182,41 +169,117 @@ class XFormersBackend:
)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
return output.view(-1, self.num_heads * self.head_size)
def _run_memory_efficient_xformer_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
input_metadata.prompt_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = [attn_bias]
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, query.dtype,
input_metadata)
op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if (
is_hip()) else None
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias[0],
p=0.0,
scale=self.scale,
op=op)
return out.view_as(query)
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(query)
start = 0
for i, prompt_len in enumerate(input_metadata.prompt_lens):
end = start + prompt_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=input_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
op=op)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))
start += prompt_len
return output
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
batch_size: int,
seq_len: int,
dtype: torch.dtype,
input_metadata: InputMetadata,
) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
attn_biases = []
for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias = bias[None, :] - bias[:, None]
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
batch_size,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_bias = LowerTriangularMaskWithTensorBias(bias)
return attn_bias
padded_len = (prompt_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
num_heads,
prompt_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
return attn_biases
def _check_use_ref_attention() -> bool:
@ -239,7 +302,6 @@ def _ref_masked_attention(
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
seq_len, _, _ = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,

View File

@ -128,11 +128,12 @@ class PagedAttentionImpl:
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.block_tables,
# subquery_start_loc is (batch_size + 1,)
input_metadata.subquery_start_loc[:-1],
input_metadata.prompt_lens_tensor,
input_metadata.context_lens,
input_metadata.max_seq_len,
input_metadata.max_subquery_len,
alibi_slopes,
)
return output

View File

@ -128,7 +128,6 @@ def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)

View File

@ -28,9 +28,12 @@ logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor]
_PAD_SLOT_ID = -1
LORA_WARMUP_RANK = 8
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
class ModelRunner:
@ -107,8 +110,7 @@ class ModelRunner:
), "Model does not have embedding_padding_modules"
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens +
self.scheduler_config.max_paddings, self.vocab_size,
self.scheduler_config.max_num_batched_tokens, self.vocab_size,
self.lora_config, self.device, self.model.embedding_modules,
self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model)
@ -116,10 +118,13 @@ class ModelRunner:
def set_block_size(self, block_size: int) -> None:
self.block_size = block_size
max_num_blocks = (self.max_context_len_to_capture + block_size -
1) // block_size
self.graph_block_tables = np.zeros(
(max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
dtype=np.int32)
def get_max_block_per_batch(self) -> int:
block_size = self.block_size
return (self.max_context_len_to_capture + block_size - 1) // block_size
def _prepare_prompt(
self,
@ -127,9 +132,9 @@ class ModelRunner:
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
List[int], List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
@ -158,16 +163,18 @@ class ModelRunner:
computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:]
prefix_block_tables.append(computed_block_nums)
context_len = computed_len
else:
prefix_block_tables.append([])
context_len = 0
# actual prompt lens
context_lens.append(computed_len)
context_lens.append(context_len)
subquery_lens.append(prompt_len - computed_len)
input_tokens.append(prompt_tokens)
input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.append(
input_positions.extend(
list(range(computed_len, computed_len + len(prompt_tokens))))
lora_id = seq_group_metadata.lora_int_id
@ -175,7 +182,7 @@ class ModelRunner:
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping.append([lora_id] * (prompt_len - computed_len))
lora_index_mapping += [lora_id] * (prompt_len - computed_len)
lora_prompt_mapping.extend(
[lora_id] *
(prompt_len - computed_len
@ -184,11 +191,10 @@ class ModelRunner:
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
continue
# Compute the slot mapping.
slot_mapping.append([])
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
@ -203,35 +209,30 @@ class ModelRunner:
start_idx = max(0, prompt_len - self.sliding_window)
for i in range(computed_len, prompt_len):
if i < start_idx:
slot_mapping[-1].append(_PAD_SLOT_ID)
slot_mapping.append(_PAD_SLOT_ID)
continue
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
slot_mapping.append(slot)
max_subquery_len = max(subquery_lens)
max_seq_len = max(prompt_lens)
num_prompt_tokens = len(input_tokens)
assert max_subquery_len > 0
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
lora_index_mapping = lora_index_mapping
max_prompt_len = max(subquery_lens)
assert max_prompt_len > 0
input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
dtype=torch.long,
device=self.device)
input_positions = _make_tensor_with_pad(input_positions,
max_prompt_len,
pad=0,
dtype=torch.long,
device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long,
device=self.device)
lora_index_mapping = [
_pad_to_max(mapping, max_prompt_len, pad=0)
for mapping in lora_index_mapping
]
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
@ -244,22 +245,45 @@ class ModelRunner:
dtype=torch.int,
device=self.device,
)
start_loc_tensor = torch.arange(0,
len(prompt_lens) * max_prompt_len,
max_prompt_len,
dtype=torch.long,
device=self.device)
# Query length can be shorter than key (i.e., prompt) when prefill
# is chunked or prefix cached.
subquery_lens_tensor = torch.tensor(subquery_lens,
dtype=torch.long,
device=self.device)
subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
prompt_lens_tensor = torch.tensor(prompt_lens,
dtype=torch.long,
device=self.device)
seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(subquery_lens_tensor,
dim=0,
dtype=subquery_start_loc.dtype,
out=subquery_start_loc[1:])
torch.cumsum(prompt_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
input_metadata = InputMetadata(
is_prompt=True,
slot_mapping=slot_mapping,
prompt_lens=prompt_lens_tensor,
max_seq_len=max_prompt_len,
start_loc=start_loc_tensor,
prompt_lens=prompt_lens,
prompt_lens_tensor=prompt_lens_tensor,
num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=0,
max_subquery_len=max_subquery_len,
max_context_len=None,
max_seq_len=max_seq_len,
subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc,
context_lens=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
@ -275,9 +299,9 @@ class ModelRunner:
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
context_lens: List[int] = []
block_tables: List[List[int]] = []
lora_index_mapping: List[int] = []
@ -296,11 +320,11 @@ class ModelRunner:
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
input_tokens.append(generation_token)
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
@ -310,8 +334,8 @@ class ModelRunner:
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
lora_index_mapping.append([lora_id])
slot_mapping.append(slot)
lora_index_mapping.append(lora_id)
lora_prompt_mapping.append(lora_id)
if self.sliding_window is not None:
@ -320,6 +344,9 @@ class ModelRunner:
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
# vLLM uses cuda graph only for decoding requests.
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
batch_size = len(input_tokens)
max_context_len = max(context_lens)
use_captured_graph = (
@ -327,38 +354,37 @@ class ModelRunner:
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_context_len <= self.max_context_len_to_capture)
if use_captured_graph:
# Pad the input tokens, positions, and slot mapping to match the
# batch size of the captured graph.
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
for _ in range(graph_batch_size - batch_size):
input_tokens.append([])
input_positions.append([])
slot_mapping.append([])
input_tokens.append(0)
input_positions.append(0)
slot_mapping.append(_PAD_SLOT_ID)
context_lens.append(1)
block_tables.append([])
lora_index_mapping.append(0)
batch_size = graph_batch_size
input_tokens = _make_tensor_with_pad(input_tokens,
max_len=1,
pad=0,
dtype=torch.long,
device=self.device)
input_positions = _make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
dtype=torch.long,
device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1,
pad=_PAD_SLOT_ID,
dtype=torch.long,
device=self.device)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
if use_captured_graph:
# When using cuda-graph all these tensors should be
# padded.
assert context_lens.shape[0] == input_tokens.shape[0]
assert context_lens.shape[0] == input_positions.shape[0]
assert context_lens.shape[0] == slot_mapping.shape[0]
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.graph_block_tables[:batch_size]
@ -377,17 +403,18 @@ class ModelRunner:
device=self.device,
)
lora_index_mapping = [
_pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
]
input_metadata = InputMetadata(
is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None,
max_seq_len=None,
start_loc=None,
prompt_lens_tensor=None,
num_prompt_tokens=0,
num_generation_tokens=len(input_tokens),
max_subquery_len=None,
max_context_len=max_context_len,
max_seq_len=None,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
@ -411,7 +438,6 @@ class ModelRunner:
categorized_sampled_token_indices_start_idx = 0
pin_memory = not self.in_wsl and not self.device_config.is_neuron
max_subquery_len = max(subquery_lens) if subquery_lens else 1
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
@ -439,7 +465,7 @@ class ModelRunner:
selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx +
subquery_len - 1)
selected_token_start_idx += max_subquery_len
selected_token_start_idx += subquery_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
@ -521,11 +547,8 @@ class ModelRunner:
subquery_lens)
if self.lora_config:
flat_lora_index_mapping = [
item for sublist in lora_index_mapping for item in sublist
]
lora_mapping = LoRAMapping(
flat_lora_index_mapping,
lora_index_mapping,
lora_prompt_mapping,
)
else:
@ -679,6 +702,18 @@ class ModelRunner:
@torch.inference_mode()
def capture_model(self, kv_caches: List[KVCache]) -> None:
"""Cuda graph capture a model.
Note that CUDA graph's performance gain is negligible if number
of batched tokens are larger than 200. And since CUDA graph
requires fixed sized tensors, supporting large/variable batch
size requires high GPU memory overhead. Thus, vLLM only captures
decoding requests. Mixed batch (chunked prefill + decoding) or
prefill requests are not captured.
Since it is used for decoding-only, it assumes there's only 1 token
per sequence in the batch.
"""
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# deleted before the CUDA graphs.
self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
@ -697,10 +732,9 @@ class ModelRunner:
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, 1,
dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
slot_mapping.fill_(_PAD_SLOT_ID)
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
@ -726,9 +760,14 @@ class ModelRunner:
is_prompt=False,
slot_mapping=slot_mapping[:batch_size],
prompt_lens=None,
max_seq_len=None,
start_loc=None,
prompt_lens_tensor=None,
num_prompt_tokens=0,
num_generation_tokens=batch_size,
max_subquery_len=None,
max_context_len=self.max_context_len_to_capture,
max_seq_len=None,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
@ -845,7 +884,6 @@ class CUDAGraphRunner:
non_blocking=True)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
non_blocking=True)
# Run the graph.
self.graph.replay()
@ -877,17 +915,28 @@ def _make_tensor_with_pad(
dtype: torch.dtype,
device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
"""Make a padded tensor of a 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device=device)
def _get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return (batch_size + 7) // 8 * 8
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
def _async_h2d(