[1/n][Chunked Prefill] Refactor input query shapes (#3236)
This commit is contained in:
parent
426ec4ec67
commit
6e435de766
@ -47,7 +47,7 @@ steps:
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
- label: Samplers Test
|
||||
command: pytest -v -s samplers --forked
|
||||
command: pytest -v -s samplers
|
||||
|
||||
- label: Worker Test
|
||||
command: pytest -v -s worker
|
||||
@ -56,7 +56,7 @@ steps:
|
||||
command: pytest -v -s spec_decode
|
||||
|
||||
- label: LoRA Test %N
|
||||
command: pytest -v -s lora --forked --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
|
||||
- label: Metrics Test
|
||||
|
@ -13,6 +13,7 @@ MODELS = [
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
@ -20,12 +21,13 @@ def test_models(
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
) -> None:
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
|
@ -10,7 +10,7 @@ from .utils import create_dummy_prompt
|
||||
|
||||
def test_scheduler_add_seq_group():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(100, 64, 1, 256)
|
||||
scheduler_config = SchedulerConfig(100, 64, 1)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 4
|
||||
cache_config.num_gpu_blocks = 4
|
||||
@ -26,7 +26,7 @@ def test_scheduler_add_seq_group():
|
||||
|
||||
def test_scheduler_abort_seq_group():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(100, 64, 1, 256)
|
||||
scheduler_config = SchedulerConfig(100, 64, 1)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 4
|
||||
cache_config.num_gpu_blocks = 4
|
||||
@ -50,7 +50,7 @@ def test_scheduler_schedule_simple():
|
||||
block_size = 4
|
||||
num_seq_group = 4
|
||||
max_model_len = 16
|
||||
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len, 256)
|
||||
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
@ -64,10 +64,10 @@ def test_scheduler_schedule_simple():
|
||||
running.append(seq_group)
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
num_tokens = block_size * num_seq_group
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert set(out.scheduled_seq_groups) == set(running)
|
||||
assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs(
|
||||
)[0].get_len()
|
||||
assert out.num_batched_tokens == num_tokens
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == num_seq_group
|
||||
@ -84,7 +84,7 @@ def test_scheduler_schedule_simple():
|
||||
def test_scheduler_schedule_preempt_abort():
|
||||
block_size = 4
|
||||
max_model_len = 16
|
||||
scheduler_config = SchedulerConfig(64, 2, max_model_len, 256)
|
||||
scheduler_config = SchedulerConfig(64, 2, max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 2
|
||||
cache_config.num_gpu_blocks = 2
|
||||
@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort():
|
||||
# Schedule seq groups prompts.
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert out.scheduled_seq_groups == [seq_group_a, seq_group_b]
|
||||
assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2
|
||||
assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == 2
|
||||
@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort():
|
||||
scheduler.abort_seq_group("1")
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert out.scheduled_seq_groups == [seq_group_b]
|
||||
assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len()
|
||||
assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == 1
|
||||
@ -136,7 +136,7 @@ def test_scheduler_max_seqs():
|
||||
num_seq_group = 4
|
||||
max_seq_group = 2
|
||||
max_model_len = 16
|
||||
scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len, 256)
|
||||
scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
|
@ -25,7 +25,7 @@ def test_worker_apply_lora(sql_lora_files):
|
||||
revision=None,
|
||||
),
|
||||
parallel_config=ParallelConfig(1, 1, False),
|
||||
scheduler_config=SchedulerConfig(32, 32, 32, 256),
|
||||
scheduler_config=SchedulerConfig(32, 32, 32),
|
||||
device_config=DeviceConfig("cuda"),
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
|
@ -92,8 +92,8 @@ def test_same_output_for_single_step():
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
multi_step_worker.model_runner = worker.model_runner
|
||||
multi_step_worker.cache_engine = worker.cache_engine
|
||||
# multi_step_worker.model_runner = worker.model_runner
|
||||
# multi_step_worker.cache_engine = worker.cache_engine
|
||||
|
||||
num_steps = 1
|
||||
|
||||
|
@ -1,14 +1,132 @@
|
||||
import random
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT
|
||||
|
||||
|
||||
def get_aligned_size(batch_size: int, alignment: int):
|
||||
return ((batch_size + alignment - 1) // alignment * alignment)
|
||||
|
||||
|
||||
def test_prepare_prompt():
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
batch_size = random.randint(1, 256)
|
||||
prompt_lens = []
|
||||
seq_group_metadata_list = []
|
||||
block_tables = {0: [1]}
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = list(range(prompt_len))
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData(seq_data)},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables=block_tables,
|
||||
))
|
||||
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
for prompt_len in prompt_lens:
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += prompt_len
|
||||
(input_tokens, input_positions, input_metadata, return_prompt_lens, _, _,
|
||||
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert return_prompt_lens == prompt_lens
|
||||
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert input_metadata.is_prompt is True
|
||||
assert torch.allclose(input_metadata.prompt_lens_tensor,
|
||||
torch.tensor(prompt_lens, device=device))
|
||||
assert input_metadata.prompt_lens == prompt_lens
|
||||
assert input_metadata.num_prompt_tokens == sum(prompt_lens)
|
||||
assert input_metadata.num_generation_tokens == 0
|
||||
assert input_metadata.max_seq_len == max(prompt_lens)
|
||||
|
||||
# Test subquery start locs.
|
||||
start_idx = 0
|
||||
start_loc = [start_idx]
|
||||
for prompt_len in prompt_lens:
|
||||
start_idx += prompt_len
|
||||
start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
input_metadata.subquery_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
# Test seq start locs. Note that for normal prefill it is
|
||||
# equivalent to subquery_start_loc.
|
||||
start_idx = 0
|
||||
seq_start_loc = [start_idx]
|
||||
for prompt_len in prompt_lens:
|
||||
start_idx += prompt_len
|
||||
seq_start_loc.append(start_idx)
|
||||
|
||||
assert torch.allclose(
|
||||
input_metadata.seq_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
assert input_metadata.max_context_len is None
|
||||
assert torch.allclose(
|
||||
input_metadata.context_lens,
|
||||
torch.zeros(input_metadata.context_lens.shape[0],
|
||||
dtype=torch.int,
|
||||
device=device))
|
||||
|
||||
expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device)
|
||||
assert torch.allclose(input_metadata.block_tables, expected)
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert input_metadata.use_cuda_graph is False
|
||||
assert input_metadata.kv_cache_dtype == "auto"
|
||||
|
||||
assert input_tokens.shape == (sum(prompt_lens), )
|
||||
assert input_positions.shape == (sum(prompt_lens), )
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
assert input_tokens.shape == (sum(prompt_lens), )
|
||||
assert input_positions.shape == (sum(prompt_lens), )
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_prepare_decode_cuda_graph():
|
||||
model_config = ModelConfig(
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
enforce_eager=False,
|
||||
)
|
||||
model_runner = ModelRunner(model_config, None, None, None, None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
batch_size = random.randint(1, 256)
|
||||
prompt_lens = []
|
||||
seq_group_metadata_list = []
|
||||
@ -20,29 +138,56 @@ def test_prepare_prompt():
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
is_prompt=False,
|
||||
seq_data={0: SequenceData(seq_data)},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
input_tokens, input_positions, input_metadata, _, _, _ = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert input_metadata.is_prompt is False
|
||||
assert input_metadata.prompt_lens is None
|
||||
assert input_metadata.num_prompt_tokens == 0
|
||||
assert input_metadata.num_generation_tokens == (get_aligned_size(
|
||||
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT))
|
||||
assert input_metadata.max_seq_len is None
|
||||
assert input_metadata.subquery_start_loc is None
|
||||
assert input_metadata.seq_start_loc is None
|
||||
assert input_metadata.max_context_len == max(prompt_lens)
|
||||
assert torch.allclose(
|
||||
input_metadata.context_lens[:len(prompt_lens)],
|
||||
torch.tensor(prompt_lens, dtype=torch.int, device=device))
|
||||
|
||||
# block table's first index corresponds to each batch, meaning in
|
||||
# decoding it is each token.
|
||||
assert input_metadata.block_tables.shape[0] == len(input_tokens)
|
||||
# Block table's second dim correspondsd to each token's block number.
|
||||
# It is padded up to
|
||||
assert input_metadata.block_tables.shape[1] == (
|
||||
model_runner.get_max_block_per_batch())
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert input_metadata.use_cuda_graph is True
|
||||
assert input_metadata.kv_cache_dtype == "auto"
|
||||
|
||||
assert input_tokens.shape == (get_aligned_size(
|
||||
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
|
||||
assert input_positions.shape == (get_aligned_size(
|
||||
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
|
||||
# Verify Sampling
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
max_seq_len = max(prompt_lens)
|
||||
for prompt_len in prompt_lens:
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += max_seq_len
|
||||
input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = (
|
||||
model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert return_prompt_lens == prompt_lens
|
||||
expected_selected_token_indices.append(selected_token_start_idx)
|
||||
selected_token_start_idx += 1
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
assert input_tokens.shape == (batch_size, max_seq_len)
|
||||
assert input_positions.shape == (batch_size, max_seq_len)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
|
@ -535,7 +535,6 @@ class SchedulerConfig:
|
||||
iteration.
|
||||
max_model_len: Maximum length of a sequence (including prompt
|
||||
and generated text).
|
||||
max_paddings: Maximum number of paddings to be added to a batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -543,7 +542,6 @@ class SchedulerConfig:
|
||||
max_num_batched_tokens: Optional[int],
|
||||
max_num_seqs: int,
|
||||
max_model_len: int,
|
||||
max_paddings: int,
|
||||
) -> None:
|
||||
if max_num_batched_tokens is not None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
@ -553,7 +551,6 @@ class SchedulerConfig:
|
||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_paddings = max_paddings
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
|
@ -173,12 +173,12 @@ class Scheduler:
|
||||
curr_loras = set(
|
||||
seq_group.lora_int_id
|
||||
for seq_group in self.running) if self.lora_enabled else None
|
||||
seq_lens: List[int] = []
|
||||
|
||||
# Optimization: We do not sort the waiting queue since the preempted
|
||||
# sequence groups are added to the front and the new sequence groups
|
||||
# are added to the back.
|
||||
leftover_waiting_sequences = deque()
|
||||
num_batched_tokens = 0
|
||||
while self.waiting:
|
||||
seq_group = self.waiting[0]
|
||||
waiting_seqs = seq_group.get_seqs(
|
||||
@ -223,8 +223,7 @@ class Scheduler:
|
||||
continue
|
||||
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
new_seq_lens = seq_lens + [num_prompt_tokens]
|
||||
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
|
||||
num_batched_tokens += num_prompt_tokens
|
||||
if (num_batched_tokens >
|
||||
self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
@ -236,11 +235,6 @@ class Scheduler:
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
|
||||
num_paddings = num_batched_tokens - sum(new_seq_lens)
|
||||
if num_paddings > self.scheduler_config.max_paddings:
|
||||
break
|
||||
seq_lens = new_seq_lens
|
||||
|
||||
if lora_int_id > 0:
|
||||
curr_loras.add(lora_int_id)
|
||||
self.waiting.popleft()
|
||||
@ -255,8 +249,7 @@ class Scheduler:
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
scheduled_seq_groups=scheduled,
|
||||
prompt_run=True,
|
||||
num_batched_tokens=len(seq_lens) *
|
||||
max(seq_lens) if seq_lens else 0,
|
||||
num_batched_tokens=num_batched_tokens,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
|
@ -31,7 +31,6 @@ class EngineArgs:
|
||||
gpu_memory_utilization: float = 0.90
|
||||
max_num_batched_tokens: Optional[int] = None
|
||||
max_num_seqs: int = 256
|
||||
max_paddings: int = 256
|
||||
max_logprobs: int = 5 # OpenAI default value
|
||||
disable_log_stats: bool = False
|
||||
revision: Optional[str] = None
|
||||
@ -213,10 +212,6 @@ class EngineArgs:
|
||||
type=int,
|
||||
default=EngineArgs.max_num_seqs,
|
||||
help='maximum number of sequences per iteration')
|
||||
parser.add_argument('--max-paddings',
|
||||
type=int,
|
||||
default=EngineArgs.max_paddings,
|
||||
help='maximum number of paddings in a batch')
|
||||
parser.add_argument(
|
||||
'--max-logprobs',
|
||||
type=int,
|
||||
@ -347,8 +342,7 @@ class EngineArgs:
|
||||
), self.ray_workers_use_nsight)
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs,
|
||||
model_config.max_model_len,
|
||||
self.max_paddings)
|
||||
model_config.max_model_len)
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
max_loras=self.max_loras,
|
||||
|
@ -561,7 +561,6 @@ class LLMEngine:
|
||||
# Log stats.
|
||||
if self.log_stats:
|
||||
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
||||
|
||||
return request_outputs
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
|
@ -1,36 +1,92 @@
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional, Any, Dict
|
||||
from typing import Optional, List, Any, Dict
|
||||
|
||||
import torch
|
||||
from xformers.ops.fmha.attn_bias import AttentionBias
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputMetadata:
|
||||
"""Metadata for input sequences. Used in PagedAttention.
|
||||
|
||||
Args:
|
||||
prompt_lens: Lengths of prompts.
|
||||
slot_mapping: The address to write the new KV to of each token.
|
||||
max_context_len: The maximum context length.
|
||||
context_lens: the length of attention context for each sequence.
|
||||
block_tables: The block tables. (Seq id -> list of physical block)
|
||||
kv_cache_dtype: Data type to store kv cache.
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
# (batch_size,). The prompt length per sequence. None if it is a decoding.
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# The number of prompt tokens. Doesn't include padding.
|
||||
num_prompt_tokens: int
|
||||
# The number of generation tokens. Doesn't include padding.
|
||||
num_generation_tokens: int
|
||||
"""
|
||||
Definition of context_len, subquery_len, and seqlen.
|
||||
|---------- N-1 iteration --------|
|
||||
|---------------- N iteration ---------------------|
|
||||
|- tokenA -|......................|-- newTokens ---|
|
||||
|---------- context_len ----------|
|
||||
|-------------------- seqlen ----------------------|
|
||||
|- subquery_len -|
|
||||
|
||||
WARNING: context_len has different definition depending on if it is
|
||||
prefill vs decoding. When it is prefill, it doesn't include new
|
||||
tokens. When it is for decoding, it includes a new token.
|
||||
"""
|
||||
|
||||
is_prompt: bool
|
||||
slot_mapping: torch.Tensor
|
||||
prompt_lens: Optional[torch.Tensor]
|
||||
max_seq_len: Optional[int]
|
||||
start_loc: Optional[torch.Tensor]
|
||||
# Maximum subquery length in the batch.
|
||||
max_subquery_len: Optional[int]
|
||||
# Maximum context length in the batch.
|
||||
max_context_len: Optional[int]
|
||||
# FIXME: It is for flash attn.
|
||||
# Maximum sequence length in the batch.
|
||||
max_seq_len: Optional[int]
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
subquery_start_loc: Optional[torch.Tensor]
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,). The length of context (tokens stored in KV cache) per
|
||||
# sequence. WARNING: When it is a prefill request, it doesn't include new
|
||||
# tokens. When it is for decoding, it includes a new token.
|
||||
context_lens: Optional[torch.Tensor]
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
use_cuda_graph: bool
|
||||
kv_cache_dtype: str
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias = None
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||
|
||||
# Cuda graph is only used for decoding now.
|
||||
if self.use_cuda_graph:
|
||||
assert self.num_prompt_tokens == 0
|
||||
|
||||
def asdict_zerocopy(self) -> Dict[str, Any]:
|
||||
"""Similar to dataclasses.asdict, but avoids deepcopying."""
|
||||
|
@ -20,8 +20,8 @@ class SiluAndMul(nn.Module):
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -17,11 +17,12 @@ class Attention(nn.Module):
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Return the output tensor.
|
||||
3. Output the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Attention layer with Flash and PagedAttention."""
|
||||
from typing import List, Optional
|
||||
|
||||
from flash_attn import flash_attn_func
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
@ -10,6 +10,21 @@ from vllm.model_executor.layers.attention.ops.paged_attn import (
|
||||
|
||||
|
||||
class FlashAttentionBackend:
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -52,18 +67,18 @@ class FlashAttentionBackend:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for the inputs.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
@ -82,13 +97,16 @@ class FlashAttentionBackend:
|
||||
if (key_cache is None or value_cache is None
|
||||
or input_metadata.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
query = query.unflatten(0, (batch_size, seq_len))
|
||||
key = key.unflatten(0, (batch_size, seq_len))
|
||||
value = value.unflatten(0, (batch_size, seq_len))
|
||||
output = flash_attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=input_metadata.seq_start_loc,
|
||||
cu_seqlens_k=input_metadata.seq_start_loc,
|
||||
max_seqlen_q=input_metadata.max_seq_len,
|
||||
max_seqlen_k=input_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
@ -118,4 +136,4 @@ class FlashAttentionBackend:
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, seq_len, hidden_size)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
@ -14,6 +14,21 @@ from vllm.utils import is_hip
|
||||
|
||||
|
||||
class XFormersBackend:
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens --------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -55,19 +70,18 @@ class XFormersBackend:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for the inputs.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
@ -82,9 +96,10 @@ class XFormersBackend:
|
||||
|
||||
if input_metadata.is_prompt:
|
||||
# Prompt run.
|
||||
# key_cache and value_cache are None when it is a profiling run.
|
||||
# block tables are empty if the prompt has never been computed.
|
||||
if (key_cache is None or value_cache is None
|
||||
or input_metadata.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
@ -103,61 +118,33 @@ class XFormersBackend:
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
|
||||
# Set attention bias if not provided. This typically happens at
|
||||
# the very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
if input_metadata.attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
[seq_len] * batch_size)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
else:
|
||||
input_metadata.attn_bias = _make_alibi_bias(
|
||||
self.alibi_slopes, self.num_kv_heads, batch_size,
|
||||
seq_len, query.dtype)
|
||||
|
||||
if self.use_ref_attention:
|
||||
output = _ref_masked_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.num_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.scale,
|
||||
)
|
||||
print("ref attention used.")
|
||||
output = torch.empty_like(query)
|
||||
start = 0
|
||||
for _, prompt_len in enumerate(input_metadata.prompt_lens):
|
||||
end = start + prompt_len
|
||||
out = _ref_masked_attention(
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
value[None, start:end],
|
||||
self.num_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.scale,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out)
|
||||
start += prompt_len
|
||||
|
||||
# Using view got RuntimeError: view size is not compatible
|
||||
# with input tensor's size and stride (at least one
|
||||
# dimension spans across two contiguous subspaces).
|
||||
# Use reshape instead.
|
||||
return output.reshape(batch_size, seq_len, hidden_size)
|
||||
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
else:
|
||||
query = query.unflatten(0, (batch_size, seq_len))
|
||||
key = key.unflatten(0, (batch_size, seq_len))
|
||||
value = value.unflatten(0, (batch_size, seq_len))
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
|
||||
(is_hip()) else None,
|
||||
)
|
||||
output = out.view_as(query)
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
output = self._run_memory_efficient_xformer_forward(
|
||||
query, key, value, input_metadata)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output = PagedAttentionImpl.forward_prefix(
|
||||
@ -182,41 +169,117 @@ class XFormersBackend:
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, seq_len, hidden_size)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def _run_memory_efficient_xformer_forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Attention for 1D query of multiple prompts. Multiple prompt
|
||||
tokens are flattened in to `query` input.
|
||||
|
||||
Args:
|
||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
"""
|
||||
# Set attention bias if not provided. This typically happens at
|
||||
# the very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
if input_metadata.attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
input_metadata.prompt_lens)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
input_metadata.attn_bias = [attn_bias]
|
||||
else:
|
||||
input_metadata.attn_bias = _make_alibi_bias(
|
||||
self.alibi_slopes, self.num_kv_heads, query.dtype,
|
||||
input_metadata)
|
||||
|
||||
op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if (
|
||||
is_hip()) else None
|
||||
# No alibi slopes.
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=input_metadata.attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=op)
|
||||
|
||||
return out.view_as(query)
|
||||
|
||||
# Attention with alibi slopes.
|
||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
output = torch.empty_like(query)
|
||||
start = 0
|
||||
for i, prompt_len in enumerate(input_metadata.prompt_lens):
|
||||
end = start + prompt_len
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
value[None, start:end],
|
||||
attn_bias=input_metadata.attn_bias[i],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=op)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out.squeeze(0))
|
||||
start += prompt_len
|
||||
return output
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
input_metadata: InputMetadata,
|
||||
) -> LowerTriangularMaskWithTensorBias:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
attn_biases = []
|
||||
for prompt_len in input_metadata.prompt_lens:
|
||||
bias = torch.arange(prompt_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
# Calculate a matrix where each element represents ith element- jth
|
||||
# element.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
# When using custom attention bias, xformers requires the bias to
|
||||
# be sliced from a tensor whose length is a multiple of 8.
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = torch.empty(
|
||||
batch_size,
|
||||
num_heads,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if num_heads != num_kv_heads:
|
||||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||
return attn_bias
|
||||
padded_len = (prompt_len + 7) // 8 * 8
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = torch.empty(
|
||||
1, # batch size
|
||||
num_heads,
|
||||
prompt_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :prompt_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if num_heads != num_kv_heads:
|
||||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
||||
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _check_use_ref_attention() -> bool:
|
||||
@ -239,7 +302,6 @@ def _ref_masked_attention(
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
seq_len, _, _ = query.shape
|
||||
attn_mask = torch.triu(torch.ones(seq_len,
|
||||
seq_len,
|
||||
|
@ -128,11 +128,12 @@ class PagedAttentionImpl:
|
||||
output,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.block_tables, # [BS, max_block_per_request]
|
||||
input_metadata.start_loc,
|
||||
input_metadata.prompt_lens,
|
||||
input_metadata.block_tables,
|
||||
# subquery_start_loc is (batch_size + 1,)
|
||||
input_metadata.subquery_start_loc[:-1],
|
||||
input_metadata.prompt_lens_tensor,
|
||||
input_metadata.context_lens,
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.max_subquery_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
return output
|
||||
|
@ -128,7 +128,6 @@ def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
return hidden_states.index_select(0,
|
||||
sampling_metadata.selected_token_indices)
|
||||
|
||||
|
@ -28,9 +28,12 @@ logger = init_logger(__name__)
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
_PAD_SLOT_ID = -1
|
||||
LORA_WARMUP_RANK = 8
|
||||
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
||||
_BATCH_SIZE_ALIGNMENT = 8
|
||||
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
||||
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
|
||||
]
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
@ -107,8 +110,7 @@ class ModelRunner:
|
||||
), "Model does not have embedding_padding_modules"
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens +
|
||||
self.scheduler_config.max_paddings, self.vocab_size,
|
||||
self.scheduler_config.max_num_batched_tokens, self.vocab_size,
|
||||
self.lora_config, self.device, self.model.embedding_modules,
|
||||
self.model.embedding_padding_modules)
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
@ -116,10 +118,13 @@ class ModelRunner:
|
||||
def set_block_size(self, block_size: int) -> None:
|
||||
self.block_size = block_size
|
||||
|
||||
max_num_blocks = (self.max_context_len_to_capture + block_size -
|
||||
1) // block_size
|
||||
self.graph_block_tables = np.zeros(
|
||||
(max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
|
||||
dtype=np.int32)
|
||||
|
||||
def get_max_block_per_batch(self) -> int:
|
||||
block_size = self.block_size
|
||||
return (self.max_context_len_to_capture + block_size - 1) // block_size
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
@ -127,9 +132,9 @@ class ModelRunner:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
|
||||
List[int], List[int], Set[LoRARequest]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
lora_index_mapping: List[int] = []
|
||||
lora_prompt_mapping: List[int] = []
|
||||
lora_requests: Set[LoRARequest] = set()
|
||||
@ -158,16 +163,18 @@ class ModelRunner:
|
||||
computed_len = len(computed_block_nums) * self.block_size
|
||||
prompt_tokens = prompt_tokens[computed_len:]
|
||||
prefix_block_tables.append(computed_block_nums)
|
||||
context_len = computed_len
|
||||
else:
|
||||
prefix_block_tables.append([])
|
||||
context_len = 0
|
||||
# actual prompt lens
|
||||
context_lens.append(computed_len)
|
||||
context_lens.append(context_len)
|
||||
subquery_lens.append(prompt_len - computed_len)
|
||||
|
||||
input_tokens.append(prompt_tokens)
|
||||
input_tokens.extend(prompt_tokens)
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.append(
|
||||
input_positions.extend(
|
||||
list(range(computed_len, computed_len + len(prompt_tokens))))
|
||||
|
||||
lora_id = seq_group_metadata.lora_int_id
|
||||
@ -175,7 +182,7 @@ class ModelRunner:
|
||||
if lora_id > 0:
|
||||
lora_requests.add(seq_group_metadata.lora_request)
|
||||
|
||||
lora_index_mapping.append([lora_id] * (prompt_len - computed_len))
|
||||
lora_index_mapping += [lora_id] * (prompt_len - computed_len)
|
||||
lora_prompt_mapping.extend(
|
||||
[lora_id] *
|
||||
(prompt_len - computed_len
|
||||
@ -184,11 +191,10 @@ class ModelRunner:
|
||||
if seq_group_metadata.block_tables is None:
|
||||
# During memory profiling, the block tables are not initialized
|
||||
# yet. In this case, we just use a dummy slot mapping.
|
||||
slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
|
||||
slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
|
||||
continue
|
||||
|
||||
# Compute the slot mapping.
|
||||
slot_mapping.append([])
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
# where start_idx is max(0, prompt_len - sliding_window).
|
||||
@ -203,35 +209,30 @@ class ModelRunner:
|
||||
start_idx = max(0, prompt_len - self.sliding_window)
|
||||
for i in range(computed_len, prompt_len):
|
||||
if i < start_idx:
|
||||
slot_mapping[-1].append(_PAD_SLOT_ID)
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
continue
|
||||
|
||||
block_number = block_table[i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping[-1].append(slot)
|
||||
slot_mapping.append(slot)
|
||||
|
||||
max_subquery_len = max(subquery_lens)
|
||||
max_seq_len = max(prompt_lens)
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
assert max_subquery_len > 0
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
lora_index_mapping = lora_index_mapping
|
||||
|
||||
max_prompt_len = max(subquery_lens)
|
||||
assert max_prompt_len > 0
|
||||
input_tokens = _make_tensor_with_pad(input_tokens,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = _make_tensor_with_pad(input_positions,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
||||
max_prompt_len,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
lora_index_mapping = [
|
||||
_pad_to_max(mapping, max_prompt_len, pad=0)
|
||||
for mapping in lora_index_mapping
|
||||
]
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
@ -244,22 +245,45 @@ class ModelRunner:
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
start_loc_tensor = torch.arange(0,
|
||||
len(prompt_lens) * max_prompt_len,
|
||||
max_prompt_len,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
# Query length can be shorter than key (i.e., prompt) when prefill
|
||||
# is chunked or prefix cached.
|
||||
subquery_lens_tensor = torch.tensor(subquery_lens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
prompt_lens_tensor = torch.tensor(prompt_lens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
torch.cumsum(subquery_lens_tensor,
|
||||
dim=0,
|
||||
dtype=subquery_start_loc.dtype,
|
||||
out=subquery_start_loc[1:])
|
||||
|
||||
torch.cumsum(prompt_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
is_prompt=True,
|
||||
slot_mapping=slot_mapping,
|
||||
prompt_lens=prompt_lens_tensor,
|
||||
max_seq_len=max_prompt_len,
|
||||
start_loc=start_loc_tensor,
|
||||
prompt_lens=prompt_lens,
|
||||
prompt_lens_tensor=prompt_lens_tensor,
|
||||
num_prompt_tokens=num_prompt_tokens,
|
||||
num_generation_tokens=0,
|
||||
max_subquery_len=max_subquery_len,
|
||||
max_context_len=None,
|
||||
max_seq_len=max_seq_len,
|
||||
subquery_start_loc=subquery_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
@ -275,9 +299,9 @@ class ModelRunner:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
|
||||
Set[LoRARequest]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
context_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
lora_index_mapping: List[int] = []
|
||||
@ -296,11 +320,11 @@ class ModelRunner:
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
input_tokens.append(generation_token)
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append([position])
|
||||
input_positions.append(position)
|
||||
|
||||
context_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
@ -310,8 +334,8 @@ class ModelRunner:
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot])
|
||||
lora_index_mapping.append([lora_id])
|
||||
slot_mapping.append(slot)
|
||||
lora_index_mapping.append(lora_id)
|
||||
lora_prompt_mapping.append(lora_id)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
@ -320,6 +344,9 @@ class ModelRunner:
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
# vLLM uses cuda graph only for decoding requests.
|
||||
# See `capture_model` API for more details.
|
||||
# For decoding requests, batch_size == input_tokens.
|
||||
batch_size = len(input_tokens)
|
||||
max_context_len = max(context_lens)
|
||||
use_captured_graph = (
|
||||
@ -327,38 +354,37 @@ class ModelRunner:
|
||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
and max_context_len <= self.max_context_len_to_capture)
|
||||
if use_captured_graph:
|
||||
# Pad the input tokens, positions, and slot mapping to match the
|
||||
# batch size of the captured graph.
|
||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||
assert graph_batch_size >= batch_size
|
||||
for _ in range(graph_batch_size - batch_size):
|
||||
input_tokens.append([])
|
||||
input_positions.append([])
|
||||
slot_mapping.append([])
|
||||
input_tokens.append(0)
|
||||
input_positions.append(0)
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
context_lens.append(1)
|
||||
block_tables.append([])
|
||||
lora_index_mapping.append(0)
|
||||
batch_size = graph_batch_size
|
||||
|
||||
input_tokens = _make_tensor_with_pad(input_tokens,
|
||||
max_len=1,
|
||||
pad=0,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = _make_tensor_with_pad(input_positions,
|
||||
max_len=1,
|
||||
pad=0,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
||||
max_len=1,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
if use_captured_graph:
|
||||
# When using cuda-graph all these tensors should be
|
||||
# padded.
|
||||
assert context_lens.shape[0] == input_tokens.shape[0]
|
||||
assert context_lens.shape[0] == input_positions.shape[0]
|
||||
assert context_lens.shape[0] == slot_mapping.shape[0]
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
input_block_tables = self.graph_block_tables[:batch_size]
|
||||
@ -377,17 +403,18 @@ class ModelRunner:
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
lora_index_mapping = [
|
||||
_pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
|
||||
]
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping,
|
||||
prompt_lens=None,
|
||||
max_seq_len=None,
|
||||
start_loc=None,
|
||||
prompt_lens_tensor=None,
|
||||
num_prompt_tokens=0,
|
||||
num_generation_tokens=len(input_tokens),
|
||||
max_subquery_len=None,
|
||||
max_context_len=max_context_len,
|
||||
max_seq_len=None,
|
||||
subquery_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
@ -411,7 +438,6 @@ class ModelRunner:
|
||||
categorized_sampled_token_indices_start_idx = 0
|
||||
pin_memory = not self.in_wsl and not self.device_config.is_neuron
|
||||
|
||||
max_subquery_len = max(subquery_lens) if subquery_lens else 1
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
@ -439,7 +465,7 @@ class ModelRunner:
|
||||
selected_token_start_idx + subquery_len - 1))
|
||||
selected_token_indices.append(selected_token_start_idx +
|
||||
subquery_len - 1)
|
||||
selected_token_start_idx += max_subquery_len
|
||||
selected_token_start_idx += subquery_len
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seq_group_metadata.state.generator = torch.Generator(
|
||||
@ -521,11 +547,8 @@ class ModelRunner:
|
||||
subquery_lens)
|
||||
|
||||
if self.lora_config:
|
||||
flat_lora_index_mapping = [
|
||||
item for sublist in lora_index_mapping for item in sublist
|
||||
]
|
||||
lora_mapping = LoRAMapping(
|
||||
flat_lora_index_mapping,
|
||||
lora_index_mapping,
|
||||
lora_prompt_mapping,
|
||||
)
|
||||
else:
|
||||
@ -679,6 +702,18 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, kv_caches: List[KVCache]) -> None:
|
||||
"""Cuda graph capture a model.
|
||||
|
||||
Note that CUDA graph's performance gain is negligible if number
|
||||
of batched tokens are larger than 200. And since CUDA graph
|
||||
requires fixed sized tensors, supporting large/variable batch
|
||||
size requires high GPU memory overhead. Thus, vLLM only captures
|
||||
decoding requests. Mixed batch (chunked prefill + decoding) or
|
||||
prefill requests are not captured.
|
||||
|
||||
Since it is used for decoding-only, it assumes there's only 1 token
|
||||
per sequence in the batch.
|
||||
"""
|
||||
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
|
||||
# deleted before the CUDA graphs.
|
||||
self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
|
||||
@ -697,10 +732,9 @@ class ModelRunner:
|
||||
|
||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
|
||||
input_positions = torch.zeros(max_batch_size, 1,
|
||||
dtype=torch.long).cuda()
|
||||
slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
|
||||
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
|
||||
slot_mapping.fill_(_PAD_SLOT_ID)
|
||||
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
@ -726,9 +760,14 @@ class ModelRunner:
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
prompt_lens=None,
|
||||
max_seq_len=None,
|
||||
start_loc=None,
|
||||
prompt_lens_tensor=None,
|
||||
num_prompt_tokens=0,
|
||||
num_generation_tokens=batch_size,
|
||||
max_subquery_len=None,
|
||||
max_context_len=self.max_context_len_to_capture,
|
||||
max_seq_len=None,
|
||||
subquery_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens=context_lens[:batch_size],
|
||||
block_tables=block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
@ -845,7 +884,6 @@ class CUDAGraphRunner:
|
||||
non_blocking=True)
|
||||
self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
|
||||
non_blocking=True)
|
||||
|
||||
# Run the graph.
|
||||
self.graph.replay()
|
||||
|
||||
@ -877,17 +915,28 @@ def _make_tensor_with_pad(
|
||||
dtype: torch.dtype,
|
||||
device: Optional[Union[str, torch.device]],
|
||||
) -> torch.Tensor:
|
||||
"""Make a padded tensor of a 2D inputs.
|
||||
|
||||
The padding is applied to the end of each inner list until it reaches
|
||||
`max_len`.
|
||||
"""
|
||||
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
||||
return torch.tensor(padded_x, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def _get_graph_batch_size(batch_size: int) -> int:
|
||||
"""Returns the padded batch size given actual batch size.
|
||||
|
||||
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
|
||||
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
|
||||
"""
|
||||
if batch_size <= 2:
|
||||
return batch_size
|
||||
elif batch_size <= 4:
|
||||
return 4
|
||||
else:
|
||||
return (batch_size + 7) // 8 * 8
|
||||
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
||||
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
||||
|
||||
|
||||
def _async_h2d(
|
||||
|
Loading…
x
Reference in New Issue
Block a user