[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)

This commit is contained in:
SangBin Cho 2024-05-04 02:20:12 +09:00 committed by GitHub
parent 2d7bce9cd5
commit 3521ba4f25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 554 additions and 525 deletions

View File

@ -16,7 +16,7 @@ PARTITION_SIZE = 512
def main(
version: str,
num_seqs: int,
context_len: int,
seq_len: int,
num_query_heads: int,
num_kv_heads: int,
head_size: int,
@ -48,12 +48,12 @@ def main(
dtype=torch.float,
device=device)
context_lens = [context_len for _ in range(num_seqs)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
seq_lens = [seq_len for _ in range(num_seqs)]
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_seqs):
block_table = [
@ -77,8 +77,7 @@ def main(
# Prepare for the paged attention kernel.
output = torch.empty_like(query)
if version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
PARTITION_SIZE)
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size),
dtype=output.dtype,
@ -110,9 +109,9 @@ def main(
num_kv_heads,
scale,
block_tables,
context_lens,
seq_lens,
block_size,
max_context_len,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
@ -129,9 +128,9 @@ def main(
num_kv_heads,
scale,
block_tables,
context_lens,
seq_lens,
block_size,
max_context_len,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
@ -166,7 +165,7 @@ if __name__ == '__main__':
choices=["v1", "v2"],
default="v2")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--context-len", type=int, default=4096)
parser.add_argument("--seq_len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size",
@ -199,7 +198,7 @@ if __name__ == '__main__':
main(
version=args.version,
num_seqs=args.batch_size,
context_len=args.context_len,
seq_len=args.seq_len,
num_query_heads=args.num_query_heads,
num_kv_heads=args.num_kv_heads,
head_size=args.head_size,

View File

@ -104,7 +104,7 @@ __device__ void paged_attention_kernel(
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
@ -115,23 +115,23 @@ __device__ void paged_attention_kernel(
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int context_len = context_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block.
return;
}
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
@ -245,12 +245,12 @@ __device__ void paged_attention_kernel(
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= context_len;
const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
@ -364,14 +364,14 @@ __device__ void paged_attention_kernel(
} else {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
}
if (block_idx == num_context_blocks - 1) {
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
accs[i] += dot(logits_vec, v_vec);
@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel(
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel(
const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
}
@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel(
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel(
const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride, kv_scale);
}
@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel(
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel(
num_kv_heads, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
@ -639,8 +639,8 @@ void paged_attention_v1_launcher(
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0);
@ -664,11 +664,11 @@ void paged_attention_v1_launcher(
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(float);
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
@ -715,8 +715,8 @@ void paged_attention_v1_launcher(
num_kv_heads, \
scale, \
block_tables, \
context_lens, \
max_context_len, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);
@ -746,9 +746,9 @@ void paged_attention_v1(
int num_kv_heads, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
torch::Tensor& seq_lens, // [num_seqs]
int block_size,
int max_context_len,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale) {
@ -790,7 +790,7 @@ void paged_attention_v1(
num_kv_heads, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
@ -803,7 +803,7 @@ void paged_attention_v1(
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
context_lens_ptr, \
seq_lens_ptr, \
max_num_partitions);
template<
@ -824,8 +824,8 @@ void paged_attention_v2_launcher(
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0);
@ -852,10 +852,10 @@ void paged_attention_v2_launcher(
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
@ -909,8 +909,8 @@ void paged_attention_v2_launcher(
num_kv_heads, \
scale, \
block_tables, \
context_lens, \
max_context_len, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);
@ -943,9 +943,9 @@ void paged_attention_v2(
int num_kv_heads, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
torch::Tensor& seq_lens, // [num_seqs]
int block_size,
int max_context_len,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale) {

View File

@ -70,11 +70,11 @@ template <typename T>
FORCE_INLINE std::pair<T, T>
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
const float alibi_slope, const int start_index,
const int context_len) {
data[0] += alibi_slope * (start_index - context_len + 1);
const int seq_len) {
data[0] += alibi_slope * (start_index - seq_len + 1);
T max = data[0];
for (int i = 1; i < size; ++i) {
T qk = data[i] + alibi_slope * (start_index + i - context_len + 1);
T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
data[i] = qk;
max = max >= qk ? max : qk;
}
@ -225,7 +225,7 @@ struct paged_attention_v1_impl {
const int num_kv_heads, const float scale,
const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs]
const int *__restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
@ -235,32 +235,32 @@ struct paged_attention_v1_impl {
static_assert(BLOCK_SIZE == 16);
int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE;
int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0;
TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0);
int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
const int parallel_work_item_num = omp_get_max_threads();
size_t logits_bytes =
parallel_work_item_num * max_context_len_padded * sizeof(float);
parallel_work_item_num * max_seq_len_padded * sizeof(float);
float *logits = (float *)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_context_len_padded]
// [parallel_work_item_num, max_seq_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int context_len = context_lens[seq_idx];
int seq_len = seq_lens[seq_idx];
const int *seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num =
context_len - (block_num - 1) * BLOCK_SIZE;
seq_len - (block_num - 1) * BLOCK_SIZE;
float *__restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_context_len_padded;
logits + omp_get_thread_num() * max_seq_len_padded;
// Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
@ -278,11 +278,11 @@ struct paged_attention_v1_impl {
// Compute softmax
if (alibi_slopes) {
reduceSoftmaxAlibi(thread_block_logits, context_len,
reduceSoftmaxAlibi(thread_block_logits, seq_len,
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
context_len);
seq_len);
} else {
reduceSoftmax(thread_block_logits, context_len,
reduceSoftmax(thread_block_logits, seq_len,
block_num * BLOCK_SIZE);
}
@ -340,7 +340,7 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads);
@ -348,8 +348,8 @@ template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher(
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) {
torch::Tensor &block_tables, torch::Tensor &seq_lens,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher(
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>();
int *seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) {
case 64:
@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher(
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
context_lens, max_context_len, alibi_slopes);
seq_lens, max_seq_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch::Tensor &key_cache, torch::Tensor &value_cache,
int num_kv_heads, float scale,
torch::Tensor &block_tables,
torch::Tensor &context_lens, int block_size,
int max_context_len,
torch::Tensor &seq_lens, int block_size,
int max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
@ -448,7 +448,7 @@ struct paged_attention_v2_impl {
const int num_kv_heads, const float scale,
const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs]
const int *__restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
@ -465,22 +465,22 @@ struct paged_attention_v2_impl {
for (int partition_idx = 0; partition_idx < max_num_partitions;
++partition_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int context_len = context_lens[seq_idx];
const int seq_len = seq_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= context_len)
if (start_token_idx >= seq_len)
continue;
const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
const bool no_reduce = (partition_num == 1);
const int context_token_num =
(std::min(context_len, start_token_idx + PARTITION_SIZE) -
const int token_num =
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
start_token_idx);
const int block_num =
(context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num =
context_token_num - (block_num - 1) * BLOCK_SIZE;
token_num - (block_num - 1) * BLOCK_SIZE;
const int *seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE;
@ -507,10 +507,10 @@ struct paged_attention_v2_impl {
std::pair<float, float> max_and_sum;
if (alibi_slopes) {
max_and_sum = reduceSoftmaxAlibi(
logits, context_token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, context_len);
logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, seq_len);
} else {
max_and_sum = reduceSoftmax(logits, context_token_num,
max_and_sum = reduceSoftmax(logits, token_num,
block_num * BLOCK_SIZE);
}
@ -583,9 +583,9 @@ struct paged_attention_v2_impl {
#pragma omp parallel for collapse(2) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int context_len = context_lens[seq_idx];
const int seq_len = seq_lens[seq_idx];
const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1)
continue;
@ -612,9 +612,9 @@ struct paged_attention_v2_impl {
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
const int context_len = context_lens[seq_idx];
const int seq_len = seq_lens[seq_idx];
const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1)
continue;
@ -649,7 +649,7 @@ struct paged_attention_v2_impl {
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions);
@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher(
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) {
torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher(
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>();
int *seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) {
case 64:
@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher(
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, block_size, \
max_context_len, alibi_slopes);
num_kv_heads, scale, block_tables, seq_lens, block_size, \
max_seq_len, alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads,
float scale, torch::Tensor &block_tables,
torch::Tensor &context_lens, int block_size,
int max_context_len,
torch::Tensor &seq_lens, int block_size,
int max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);

View File

@ -10,9 +10,9 @@ void paged_attention_v1(
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
torch::Tensor& seq_lens,
int block_size,
int max_context_len,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale);
@ -28,9 +28,9 @@ void paged_attention_v2(
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
torch::Tensor& seq_lens,
int block_size,
int max_context_len,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale);

View File

@ -61,7 +61,7 @@ def ref_single_query_cached_kv_attention(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
alibi_slopes: Optional[torch.Tensor],
) -> None:
@ -72,15 +72,15 @@ def ref_single_query_cached_kv_attention(
num_seqs = query.shape[0]
block_tables = block_tables.cpu().tolist()
context_lens = context_lens.cpu().tolist()
seq_lens = seq_lens.cpu().tolist()
for i in range(num_seqs):
q = query[i].unsqueeze(0)
block_table = block_tables[i]
context_len = int(context_lens[i])
seq_len = int(seq_lens[i])
keys = []
values = []
for j in range(context_len):
for j in range(seq_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
@ -100,8 +100,8 @@ def ref_single_query_cached_kv_attention(
alibi_bias = None
if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len).int()
alibi_bias = (position_ids - context_len + 1).float()
position_ids = torch.arange(seq_len).int()
alibi_bias = (position_ids - seq_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
@ -149,13 +149,13 @@ def test_paged_attention(
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int)
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_seqs):
block_table = [
@ -186,16 +186,15 @@ def test_paged_attention(
num_kv_heads,
scale,
block_tables,
context_lens,
seq_lens,
block_size,
max_context_len,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
elif version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
PARTITION_SIZE)
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
@ -218,9 +217,9 @@ def test_paged_attention(
num_kv_heads,
scale,
block_tables,
context_lens,
seq_lens,
block_size,
max_context_len,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
@ -255,7 +254,7 @@ def test_paged_attention(
key_cache,
value_cache,
block_tables,
context_lens,
seq_lens,
scale,
alibi_slopes,
)

View File

@ -51,12 +51,12 @@ def test_contexted_kv_attention(
cache_size = 640
block_size = 32
max_block_per_request = 64
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv
num_tokens = sum(subquery_lens)
num_tokens = sum(query_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
@ -75,15 +75,15 @@ def test_contexted_kv_attention(
num_kv_heads,
head_size,
dtype=dtype)
k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view(
BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1],
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
dtype=torch.long),
dim=0)
max_input_len = MAX_SEQ_LEN
@ -92,7 +92,7 @@ def test_contexted_kv_attention(
dtype=torch.long),
dim=0)
for i in range(BS):
for j in range(subquery_lens[i]):
for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
@ -178,7 +178,7 @@ def test_contexted_kv_attention(
value = value.unsqueeze(0)
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
subquery_lens, seq_lens)
query_lens, seq_lens)
if sliding_window > 0:
attn_bias = attn_bias.make_local_attention_from_bottomright(
sliding_window)

View File

@ -58,7 +58,7 @@ def _do_sample(
device: str,
):
seq_group_metadata_list = []
prompt_lens = []
seq_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
@ -68,12 +68,12 @@ def _do_sample(
sampling_params=sampling_params,
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
@ -421,7 +421,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
"Invalid test case, need seq_group_metadata_list"
batch_size = 0
prompt_lens = []
seq_lens = []
sampling_params_per_row = []
for sgm in seq_group_metadata_list:
sampling_params = sgm.sampling_params
@ -431,7 +431,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
# a prompt seq_group has only one sequence
seq_data = next(iter(sgm.seq_data.values()))
prompt_len = seq_data.get_prompt_len()
prompt_lens.append(prompt_len)
seq_lens.append(prompt_len)
if sgm.sampling_params.prompt_logprobs:
# with prompt_logprobs each token in the prompt has a row in
@ -451,8 +451,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens=prompt_lens if prompt_lens else None,
subquery_lens=prompt_lens if prompt_lens else None,
seq_lens=seq_lens if seq_lens else None,
query_lens=seq_lens if seq_lens else None,
device=device,
pin_memory=model_runner.pin_memory)
# the logits tensor is modified in-place by the sampler
@ -497,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str):
seq_group_metadata_list = []
expected_tokens: List[Optional[List[int]]] = []
prompt_lens = []
seq_lens = []
for i in range(batch_size):
expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3)
@ -532,13 +532,13 @@ def test_sampler_mixed(seed: int, device: str):
sampling_params=sampling_params,
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
def test_sampling(model_runner: ModelRunner):
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
sampler_output = sampler(logits=fake_logits,
@ -575,7 +575,7 @@ def test_sampler_mixed(seed: int, device: str):
# Shuffle the batch and resample
target_index = list(range(batch_size))
for list_to_shuffle in (target_index, seq_group_metadata_list,
expected_tokens, prompt_lens):
expected_tokens, seq_lens):
random.Random(seed).shuffle(list_to_shuffle)
target_index = torch.tensor(target_index)
input_tensor.data = input_tensor.index_select(0, target_index)
@ -620,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
assert len(warpers) == 2 # top_p and top_k
seq_group_metadata_list = []
prompt_lens = []
seq_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
@ -634,12 +634,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)

View File

@ -45,7 +45,7 @@ class AsyncLLM:
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
@ -66,7 +66,7 @@ class AsyncLLM:
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
engine_use_ray=True,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,

View File

@ -34,7 +34,7 @@ def test_assert_enough_kv_space(num_steps: int):
list(range(block_size * 2)),
]
final_seq_lens = [
final_prompt_lens = [
len(prompt + output) + num_steps
for prompt, output in zip(prompts, prev_output_tokens)
]
@ -43,7 +43,7 @@ def test_assert_enough_kv_space(num_steps: int):
prompts,
num_gpu_blocks,
block_size,
final_seq_lens,
final_prompt_lens,
continuations=prev_output_tokens)
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
@ -103,17 +103,21 @@ def test_same_output_for_single_step():
[6, 7, 8, 9, 10],
]
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
multi_step_execute_model_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
single_step_execute_model_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
@ -181,7 +185,7 @@ def test_same_output_for_multi_step():
random.randint(0, 1000) for _ in range(random.randint(10, 20))
] for _ in range(10)]
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
@ -195,7 +199,7 @@ def test_same_output_for_multi_step():
num_gpu_blocks,
block_size,
continuations=continuations,
final_seq_lens=final_seq_lens), )
final_prompt_lens=final_prompt_lens), )
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
@ -217,7 +221,7 @@ def test_same_output_for_multi_step():
num_gpu_blocks,
block_size,
continuations=continuations,
final_seq_lens=final_seq_lens))
final_prompt_lens=final_prompt_lens))
single_step_output.extend(
worker.execute_model(**execute_model_data.to_dict(), ))

View File

@ -43,11 +43,13 @@ def test_ngram_algo_correctness_for_single_no_match():
]
proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
@ -110,11 +112,13 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
]
proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
@ -180,11 +184,13 @@ def test_ngram_algo_correctness_for_batches_match_all():
]
proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),

View File

@ -144,7 +144,7 @@ def create_seq_group_metadata_from_prompts(
prompts: List[List[int]],
num_gpu_blocks: int,
block_size: int,
final_seq_lens: List[int],
final_prompt_lens: List[int],
continuations: Optional[List[List[int]]] = None,
seq_ids: Optional[List[int]] = None,
) -> List[SequenceGroupMetadata]:
@ -162,7 +162,7 @@ def create_seq_group_metadata_from_prompts(
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(final_len, block_size))
]
for i, final_len in enumerate(final_seq_lens)
for i, final_len in enumerate(final_prompt_lens)
}
return [
@ -251,13 +251,13 @@ def create_batch(batch_size,
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_seq_lens = [
final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
execute_model_data = create_execute_model_data(
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
block_size, final_seq_lens,
block_size, final_prompt_lens,
prev_output_tokens, seq_ids), )
return execute_model_data, prompts, prev_output_tokens

View File

@ -70,7 +70,7 @@ def test_logits_processors(seed: int, device: str):
return logits
seq_group_metadata_list = []
prompt_lens = []
seq_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
@ -81,12 +81,12 @@ def test_logits_processors(seed: int, device: str):
logits_processors=[pick_ith]),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
seq_lens,
query_lens=seq_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
logits_processor_output = logits_processor(

View File

@ -23,14 +23,14 @@ def test_prepare_prompt(batch_size):
lora_config=None)
model_runner.set_block_size(16)
prompt_lens = []
seq_lens = []
seq_group_metadata_list = []
block_tables = {0: [1]}
for i in range(batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = SequenceData(list(range(prompt_len)))
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
@ -43,29 +43,29 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices = []
selected_token_start_idx = 0
for prompt_len in prompt_lens:
for seq_len in seq_lens:
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += prompt_len
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _,
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
seq_len - 1)
selected_token_start_idx += seq_len
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.is_prompt is True
assert torch.allclose(attn_metadata.prompt_lens_tensor,
torch.tensor(prompt_lens, device=device))
assert attn_metadata.prompt_lens == prompt_lens
assert attn_metadata.max_prompt_len == max(prompt_lens)
assert torch.allclose(
attn_metadata.seq_lens_tensor,
torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_seq_len == max(seq_lens)
# Test subquery start locs.
start_idx = 0
start_loc = [start_idx]
for prompt_len in prompt_lens:
start_idx += prompt_len
for seq_len in seq_lens:
start_idx += seq_len
start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.subquery_start_loc,
@ -75,17 +75,16 @@ def test_prepare_prompt(batch_size):
# equivalent to subquery_start_loc.
start_idx = 0
seq_start_loc = [start_idx]
for prompt_len in prompt_lens:
start_idx += prompt_len
for seq_len in seq_lens:
start_idx += seq_len
seq_start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
assert attn_metadata.max_context_len is None
assert torch.allclose(
attn_metadata.context_lens,
torch.zeros(attn_metadata.context_lens.shape[0],
attn_metadata.context_lens_tensor,
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
dtype=torch.int,
device=device))
@ -96,18 +95,18 @@ def test_prepare_prompt(batch_size):
# Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is False
assert len(input_tokens) == sum(prompt_lens)
assert len(input_positions) == sum(prompt_lens)
assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(seq_lens)
torch.testing.assert_close(input_tokens, input_positions)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
seq_lens,
query_lens=seq_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
assert len(input_tokens) == sum(prompt_lens)
assert len(input_positions) == sum(prompt_lens)
assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(seq_lens)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
@ -146,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size):
lora_config=None)
model_runner.set_block_size(16)
prompt_lens = []
seq_lens = []
seq_group_metadata_list = []
for i in range(batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len))
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = list(range(seq_len))
seq_data = SequenceData(seq_data)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
@ -172,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.is_prompt is False
assert attn_metadata.prompt_lens is None
assert attn_metadata.max_prompt_len is None
assert attn_metadata.seq_lens is None
assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None
assert attn_metadata.max_context_len == max(prompt_lens)
assert attn_metadata.max_seq_len == max(seq_lens)
assert torch.allclose(
attn_metadata.context_lens[:len(prompt_lens)],
torch.tensor(prompt_lens, dtype=torch.int, device=device))
attn_metadata.seq_lens_tensor[:len(seq_lens)],
torch.tensor(seq_lens, dtype=torch.int, device=device))
# block table's first index corresponds to each batch, meaning in
# decoding it is each token.
@ -198,13 +196,13 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify Sampling
expected_selected_token_indices = []
selected_token_start_idx = 0
for prompt_len in prompt_lens:
for seq_len in seq_lens:
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens,
seq_lens,
query_lens=seq_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
actual = sampling_metadata.selected_token_indices
@ -241,14 +239,13 @@ def test_empty_seq_group():
assert attn_metadata is None
assert len(slot_mapping) == 0
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _,
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
assert len(return_prompt_lens) == 0
assert len(return_seq_lens) == 0
@pytest.fixture
@ -288,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
model_runner.set_block_size(16)
# Add prefill requests.
prompt_lens = []
seq_lens = []
seq_group_metadata_list = []
prefill_metadata_list = []
decode_metadata_list = []
@ -297,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
decode_batch_size = batch_size - prefill_batch_size
for i in range(prefill_batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = SequenceData(list(range(prompt_len)))
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
@ -314,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# Add decode requests
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(prompt_len))
seq_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(seq_len))
seq_data = SequenceData(prompt_toks)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
@ -343,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
else:
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
decode_batch_size)
assert attn_metadata.num_prefill_tokens == sum(prompt_lens)
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.

View File

@ -39,17 +39,17 @@ def paged_attention_v1(
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_context_len: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables,
context_lens, block_size, max_context_len,
alibi_slopes, kv_cache_dtype, kv_scale)
num_kv_heads, scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, kv_scale)
def paged_attention_v2(
@ -63,17 +63,17 @@ def paged_attention_v2(
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_context_len: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes, kv_cache_dtype,
block_tables, seq_lens, block_size,
max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale)

View File

@ -66,27 +66,24 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |- subquery_len -|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# Maximum query length in the batch.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
@ -223,8 +223,8 @@ class FlashAttentionImpl(AttentionImpl):
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prompt_len,
max_seqlen_k=prefill_meta.max_prompt_len,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
@ -245,9 +245,9 @@ class FlashAttentionImpl(AttentionImpl):
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor,
prefill_meta.context_lens,
prefill_meta.max_subquery_len,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
)
@ -258,8 +258,8 @@ class FlashAttentionImpl(AttentionImpl):
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.context_lens,
decode_meta.max_context_len,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,

View File

@ -64,27 +64,24 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |- subquery_len -|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# Maximum query length in the batch.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
class ROCmFlashAttentionImpl(AttentionImpl):
@ -247,7 +247,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.prompt_lens is not None
assert prefill_meta.seq_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
@ -260,8 +260,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_prompt_len,
prefill_meta.max_prompt_len,
prefill_meta.max_seq_len,
prefill_meta.max_seq_len,
True,
self.scale,
)
@ -274,7 +274,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query,
key,
value,
prefill_meta.prompt_lens,
prefill_meta.seq_lens,
self.scale,
)
else:
@ -284,8 +284,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prompt_len,
max_seqlen_k=prefill_meta.max_prompt_len,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
softmax_scale=self.scale,
causal=True,
)
@ -303,9 +303,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor,
prefill_meta.context_lens,
prefill_meta.max_subquery_len,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
)
@ -317,8 +317,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.context_lens,
decode_meta.max_context_len,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
@ -334,13 +334,13 @@ def _naive_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
prompt_lens: List[int],
seq_lens: List[int],
scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
start = 0
for _, prompt_len in enumerate(prompt_lens):
end = start + prompt_len
for _, seq_len in enumerate(seq_lens):
end = start + seq_len
out = _naive_masked_attention(
query[start:end],
key[start:end],
@ -349,7 +349,7 @@ def _naive_attention(
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
start += prompt_len
start += seq_len
return output

View File

@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
prompt_lens: Optional[List[int]]
seq_lens: Optional[List[int]]
def __post_init__(self):
# Set during the execution of the first attention op.
@ -136,7 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
kv_scale)
if attn_metadata.is_prompt:
assert attn_metadata.prompt_lens is not None
assert attn_metadata.seq_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
@ -147,13 +147,13 @@ class TorchSDPABackendImpl(AttentionImpl):
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.prompt_lens) # type: ignore
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.prompt_lens, self.sliding_window,
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.prompt_lens)
att_masks = [None] * len(attn_metadata.seq_lens)
attn_metadata.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2)
@ -164,9 +164,9 @@ class TorchSDPABackendImpl(AttentionImpl):
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for prompt_len, mask in zip(attn_metadata.prompt_lens,
attn_metadata.attn_bias):
end = start + prompt_len
for seq_len, mask in zip(attn_metadata.seq_lens,
attn_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[:, start:end, :],
key[:, start:end, :],
@ -189,8 +189,8 @@ class TorchSDPABackendImpl(AttentionImpl):
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
attn_metadata.seq_lens_tensor,
attn_metadata.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
@ -205,13 +205,13 @@ class TorchSDPABackendImpl(AttentionImpl):
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
prompt_lens: List[int],
seq_lens: List[int],
) -> List[torch.Tensor]:
attn_biases = []
for prompt_len in prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype)
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
@ -221,7 +221,7 @@ def _make_alibi_bias(
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty(
(1, prompt_len, prompt_len),
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype))
@ -229,14 +229,14 @@ def _make_alibi_bias(
def _make_sliding_window_bias(
prompt_lens: List[int],
seq_lens: List[int],
window_size: Optional[int],
dtype: torch.dtype,
) -> List[torch.Tensor]:
attn_biases = []
for prompt_len in prompt_lens:
for seq_len in seq_lens:
tensor = torch.full(
(1, prompt_len, prompt_len),
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)

View File

@ -66,28 +66,24 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seqlen ----------------------|
# |- subquery_len -|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# WARNING(sang): context_len has different definition depending on if it is
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
# When it is for decoding, it includes a new token.
# Maximum subquery length in the batch.
max_subquery_len: Optional[int]
# Maximum query length in the batch.
max_query_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum prompt length in the batch.
max_prompt_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
@ -97,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
@ -242,9 +241,9 @@ class XFormersImpl(AttentionImpl):
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor,
prefill_meta.context_lens,
prefill_meta.max_subquery_len,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window,
)
@ -257,8 +256,8 @@ class XFormersImpl(AttentionImpl):
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.context_lens,
decode_meta.max_context_len,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
@ -289,7 +288,7 @@ class XFormersImpl(AttentionImpl):
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
assert attn_metadata.prompt_lens is not None
assert attn_metadata.seq_lens is not None
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
@ -310,7 +309,7 @@ class XFormersImpl(AttentionImpl):
if attn_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.prompt_lens)
attn_metadata.seq_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
@ -318,7 +317,7 @@ class XFormersImpl(AttentionImpl):
else:
attn_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, query.dtype,
attn_metadata.prompt_lens)
attn_metadata.seq_lens)
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
@ -343,8 +342,8 @@ class XFormersImpl(AttentionImpl):
# one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(original_query)
start = 0
for i, prompt_len in enumerate(attn_metadata.prompt_lens):
end = start + prompt_len
for i, seq_len in enumerate(attn_metadata.seq_lens):
end = start + seq_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
@ -354,7 +353,7 @@ class XFormersImpl(AttentionImpl):
scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.view_as(original_query[start:end]))
start += prompt_len
start += seq_len
return output
@ -362,13 +361,13 @@ def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
prompt_lens: List[int],
seq_lens: List[int],
) -> LowerTriangularMaskWithTensorBias:
attn_biases = []
for prompt_len in prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype)
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
@ -376,16 +375,16 @@ def _make_alibi_bias(
# element.
bias = bias[None, :] - bias[:, None]
padded_len = (prompt_len + 7) // 8 * 8
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
num_heads,
prompt_len,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias)
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))

View File

@ -13,12 +13,11 @@ _PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens: Optional[torch.Tensor]
# Maximum context length in the batch.
max_context_len: Optional[int]
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
@ -85,8 +84,8 @@ class PagedAttention:
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
@ -97,7 +96,7 @@ class PagedAttention:
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) //
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
@ -106,7 +105,7 @@ class PagedAttention:
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_context_len <= 8192
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1.
@ -118,9 +117,9 @@ class PagedAttention:
num_kv_heads,
scale,
block_tables,
context_lens,
seq_lens,
block_size,
max_context_len,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
@ -150,9 +149,9 @@ class PagedAttention:
num_kv_heads,
scale,
block_tables,
context_lens,
seq_lens,
block_size,
max_context_len,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
@ -168,9 +167,9 @@ class PagedAttention:
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_subquery_len: int,
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
) -> torch.Tensor:
@ -185,9 +184,9 @@ class PagedAttention:
block_tables,
# subquery_start_loc is (batch_size + 1,)
subquery_start_loc[:-1],
prompt_lens_tensor,
seq_lens_tensor,
context_lens,
max_subquery_len,
max_query_len,
alibi_slopes,
sliding_window,
)

View File

@ -63,7 +63,10 @@ class ModelConfig:
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
"""
@ -84,6 +87,7 @@ class ModelConfig:
quantization_param_path: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
skip_tokenizer_init: bool = False,
) -> None:
@ -99,6 +103,11 @@ class ModelConfig:
self.quantization_param_path = quantization_param_path
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
if self.max_context_len_to_capture is not None:
raise ValueError("`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead.")
self.max_seq_len_to_capture = (max_seq_len_to_capture
or max_context_len_to_capture)
self.max_logprobs = max_logprobs
self.skip_tokenizer_init = skip_tokenizer_init
@ -190,10 +199,10 @@ class ModelConfig:
"non-quantized models.", self.quantization)
def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None:
self.max_context_len_to_capture = self.max_model_len
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
self.max_model_len)
if self.max_seq_len_to_capture is None:
self.max_seq_len_to_capture = self.max_model_len
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)
def verify_with_parallel_config(
self,
@ -772,8 +781,8 @@ class SpeculativeConfig:
max_model_len=None,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_context_len_to_capture=target_model_config.
max_context_len_to_capture,
max_seq_len_to_capture=target_model_config.
max_seq_len_to_capture,
max_logprobs=target_model_config.max_logprobs,
)

View File

@ -44,7 +44,8 @@ class EngineArgs:
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0
tokenizer_pool_type: str = "ray"
@ -322,6 +323,14 @@ class EngineArgs:
default=EngineArgs.max_context_len_to_capture,
help='Maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq_len-to-capture instead'
')')
parser.add_argument('--max-seq_len-to-capture',
type=int,
default=EngineArgs.max_seq_len_to_capture,
help='Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.')
parser.add_argument('--disable-custom-all-reduce',
action='store_true',
@ -492,7 +501,8 @@ class EngineArgs:
self.code_revision, self.tokenizer_revision, self.max_model_len,
self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture,
self.max_logprobs, self.skip_tokenizer_init)
self.max_seq_len_to_capture, self.max_logprobs,
self.skip_tokenizer_init)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,

View File

@ -69,6 +69,9 @@ class LLM:
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
disable_custom_all_reduce: See ParallelConfig
@ -90,7 +93,8 @@ class LLM:
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
@ -112,6 +116,7 @@ class LLM:
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,
)

View File

@ -1033,8 +1033,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
assert seq_group.is_prompt, (
"Caller should ensure the sequence group is in a prefill stage.")
seq_ids = seq_group.seq_ids
subquery_len = seq_group.subquery_len
assert subquery_len is not None
query_len = seq_group.query_len
assert query_len is not None
# prompt has only 1 seq id.
assert len(seq_ids) == 1
seq_data = seq_group.seq_data[seq_ids[0]]
@ -1042,7 +1042,7 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
prompt_tokens = seq_data.prompt_token_ids
# +1 because we are looking for a next prompt token.
next_token_index_start = computed_len + 1
next_token_index_end = min(computed_len + subquery_len + 1,
next_token_index_end = min(computed_len + query_len + 1,
len(prompt_tokens))
next_prompt_tokens = prompt_tokens[
next_token_index_start:next_token_index_end]

View File

@ -16,17 +16,26 @@ _SEED_0_REPLACEMENT = 3403598558
@dataclass
class SequenceGroupToSample:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Sequence ids for the sequence group in a previous step.
seq_ids: List[int]
sampling_params: SamplingParams
# seq_id -> sequence data.
seq_data: Dict[int, SequenceData]
# The length of the prompt of the sequence group. None if it is in a decode
# The length of the sequence (all tokens seen in the past + new token to
# compute attention) of the sequence group. None if it is in a decode
# stage.
prompt_len: Optional[int]
# The length of the query tokens to compute in the current step. None if it
# is in a decode stage. The length of subquery_len <= prompt_len.
subquery_len: Optional[int]
seq_len: Optional[int]
# The length of new query tokens to compute in the current step. None if it
# is in a decode stage. The length of query_len <= seq_len if chunked
# prefill is enabled.
query_len: Optional[int]
# A random number generator for sampling.
generator: Optional[torch.Generator]
# True if the sequence group is in prefill stage. False if it is in a
@ -46,8 +55,8 @@ class SequenceGroupToSample:
if len(self.prompt_logprob_indices) > 0:
assert self.sampling_params.prompt_logprobs is not None
if self.is_prompt:
assert self.prompt_len is not None
assert self.subquery_len is not None
assert self.seq_len is not None
assert self.query_len is not None
class SamplingMetadata:
@ -94,8 +103,8 @@ class SamplingMetadata:
@staticmethod
def prepare(
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
seq_lens: List[int],
query_lens: Optional[List[int]],
device: str,
pin_memory: bool,
) -> "SamplingMetadata":
@ -104,8 +113,8 @@ class SamplingMetadata:
selected_token_indices,
categorized_sample_indices,
num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens,
subquery_lens, device)
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=device,
@ -137,8 +146,8 @@ class SamplingMetadata:
def _prepare_seq_groups(
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
seq_lens: List[int],
query_lens: Optional[List[int]],
device: str,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]:
@ -146,9 +155,9 @@ def _prepare_seq_groups(
Args:
seq_group_metadata_list: A list of sequence group to batch.
prompt_lens: A list of prompt lens per sequence group.
seq_lens: A list of sequence lens per sequence group.
Index of prompt len should match with seq_group_metadata_list.
subquery_lens: A list of query lengths. Prompt lens include the length
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
`SequenceGroupToSample.generator`.
@ -189,8 +198,8 @@ def _prepare_seq_groups(
is_prompt = seq_group_metadata.is_prompt
generator: Optional[torch.Generator] = None
# If the current seq group is in decode stage, it is None.
prompt_len: Optional[int] = None
subquery_len: Optional[int] = None
seq_len: Optional[int] = None
query_len: Optional[int] = None
prompt_logprob_indices: List[int] = []
sample_indices: List[int] = []
do_sample = seq_group_metadata.do_sample
@ -203,12 +212,12 @@ def _prepare_seq_groups(
num_prompts += 1
num_prefill_sample = len(seq_ids)
assert num_prefill_sample == 1
assert subquery_lens is not None and prompt_lens is not None
subquery_len, prompt_len = subquery_lens[i], prompt_lens[i]
assert query_lens is not None and seq_lens is not None
query_len, seq_len = query_lens[i], seq_lens[i]
# If we need sampling, exclude num_prefill_sample tokens from
# prompt logprob.
prompt_logprob_len = (subquery_len - num_prefill_sample
if do_sample else subquery_len)
prompt_logprob_len = (query_len - num_prefill_sample
if do_sample else query_len)
sample_len = num_prefill_sample if do_sample else 0
else:
# Decode
@ -267,8 +276,8 @@ def _prepare_seq_groups(
seq_ids=seq_ids,
sampling_params=sampling_params,
seq_data=seq_group_metadata.seq_data,
prompt_len=prompt_len,
subquery_len=subquery_len,
seq_len=seq_len,
query_len=query_len,
generator=generator,
is_prompt=is_prompt,
prompt_logprob_indices=list(prompt_logprob_indices),
@ -367,8 +376,8 @@ class SamplingTensors:
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get
# their logprobs
subquery_len = seq_group.subquery_len
assert subquery_len is not None
query_len = seq_group.query_len
assert query_len is not None
prefill_len = len(seq_group.prompt_logprob_indices)
temperatures += [temperature] * prefill_len
top_ps += [top_p] * prefill_len
@ -397,8 +406,8 @@ class SamplingTensors:
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
subquery_len = seq_group.subquery_len
assert subquery_len is not None
query_len = seq_group.query_len
assert query_len is not None
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]

View File

@ -80,7 +80,7 @@ class CPUModelRunner:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
prompt_lens: List[int] = []
seq_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list:
@ -92,15 +92,15 @@ class CPUModelRunner:
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
computed_len = seq_data.get_num_computed_tokens()
prompt_len = len(prompt_tokens)
seq_len = len(prompt_tokens)
prompt_lens.append(prompt_len) # Prompt token num
seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prompt_len)))
input_positions.extend(list(range(computed_len, seq_len)))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
@ -109,15 +109,15 @@ class CPUModelRunner:
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, prompt_len - self.sliding_window)
start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, prompt_len):
for i in range(computed_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
@ -151,19 +151,19 @@ class CPUModelRunner:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
prompt_lens=prompt_lens,
num_prefills=len(prompt_lens),
seq_lens=seq_lens,
seq_lens_tensor=None,
max_seq_len=None,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
prefill_metadata=None,
decode_metadata=None,
max_context_len=None,
context_lens=None,
block_tables=torch.tensor([]),
slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input)
def _prepare_decode(
@ -174,7 +174,7 @@ class CPUModelRunner:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
context_lens: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
@ -192,9 +192,9 @@ class CPUModelRunner:
position = seq_len - 1
input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min(
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
context_lens.append(context_len)
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
@ -208,7 +208,7 @@ class CPUModelRunner:
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
max_context_len = max(context_lens)
max_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
@ -219,9 +219,9 @@ class CPUModelRunner:
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
max_block_table_len = max(
len(block_table) for block_table in block_tables)
@ -236,14 +236,14 @@ class CPUModelRunner:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_seq_len=max_seq_len,
num_prefill_tokens=0,
num_decode_tokens=len(input_tokens),
max_context_len=max_context_len,
num_prefills=0,
prefill_metadata=None,
decode_metadata=None,
context_lens=context_lens,
block_tables=block_tables,
kv_cache_dtype=self.kv_cache_dtype,
)
@ -265,20 +265,20 @@ class CPUModelRunner:
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, prompt_lens,
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
# subquery_lens is not needed if chunked prefill is not
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use prompt_lens instead.
prompt_lens,
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
# Broadcast the metadata.
@ -300,7 +300,7 @@ class CPUModelRunner:
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
prompt_lens=None,
seq_lens=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
generators=None,

View File

@ -42,8 +42,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadataPerStage]
prompt_lens: List[int]
subquery_lens: List[int]
seq_lens: List[int]
query_lens: List[int]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
@ -56,8 +56,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens=[],
input_positions=[],
attn_metadata=None,
prompt_lens=[],
subquery_lens=[],
seq_lens=[],
query_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
@ -134,9 +134,8 @@ class ModelRunner:
self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture.
self.max_context_len_to_capture = (
self.model_config.max_context_len_to_capture
if self.model_config is not None else 0)
self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture
if self.model_config is not None else 0)
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype
@ -149,7 +148,7 @@ class ModelRunner:
self.model: torch.nn.Module # Set after load_model
self.block_size: int # Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
@ -218,7 +217,7 @@ class ModelRunner:
def get_max_block_per_batch(self) -> int:
block_size = self.block_size
return (self.max_context_len_to_capture + block_size - 1) // block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size
def _prepare_prompt(
self,
@ -231,9 +230,9 @@ class ModelRunner:
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
prompt_lens: List[int] = []
seq_lens: List[int] = []
context_lens: List[int] = []
subquery_lens: List[int] = []
query_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = []
@ -257,21 +256,19 @@ class ModelRunner:
token_chunk_size = seq_group_metadata.token_chunk_size
seq_data = seq_group_metadata.seq_data[seq_id]
computed_len = seq_data.get_num_computed_tokens()
context_len = seq_data.get_num_computed_tokens()
# We should use get_len here because in case of preemption
# it contains output tokens.
prefill_end = min(seq_data.get_len(),
computed_len + token_chunk_size)
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
prompt_len = prefill_end
prompt_lens.append(prompt_len)
seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
seq_lens.append(seq_len)
# NOTE: This only works for oooooooxxx style attention.
if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window
computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:]
context_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None:
@ -285,25 +282,25 @@ class ModelRunner:
prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert computed_len == 0
assert context_len == 0
# actual prompt lens
context_lens.append(computed_len)
subquery_lens.append(prompt_len - computed_len)
context_lens.append(context_len)
query_lens.append(seq_len - context_len)
input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prefill_end)))
input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * (prompt_len - computed_len)
lora_index_mapping += [lora_id] * (seq_len - context_len)
lora_prompt_mapping.extend(
[lora_id] *
(prompt_len - computed_len
(seq_len - context_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.multi_modal_data:
@ -313,24 +310,24 @@ class ModelRunner:
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
assert computed_len == 0, (
assert context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window)
start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, prefill_end):
for i in range(context_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
@ -340,9 +337,9 @@ class ModelRunner:
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
max_subquery_len = max(subquery_lens)
max_prompt_len = max(prompt_lens)
assert max_subquery_len > 0
max_query_len = max(query_lens)
max_seq_len = max(seq_lens)
assert max_query_len > 0
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
@ -369,40 +366,39 @@ class ModelRunner:
# Query length can be shorter than key (i.e., prompt) when prefill
# is chunked or prefix cached.
subquery_lens_tensor = torch.tensor(subquery_lens,
dtype=torch.long,
device=self.device)
subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
prompt_lens_tensor = torch.tensor(prompt_lens,
dtype=torch.long,
device=self.device)
seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(subquery_lens_tensor,
torch.cumsum(query_lens_tensor,
dim=0,
dtype=subquery_start_loc.dtype,
out=subquery_start_loc[1:])
torch.cumsum(prompt_lens_tensor,
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
prompt_lens=prompt_lens,
prompt_lens_tensor=prompt_lens_tensor,
max_subquery_len=max_subquery_len,
max_context_len=None,
max_prompt_len=max_prompt_len,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc,
context_lens=context_lens_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
)
@ -411,8 +407,8 @@ class ModelRunner:
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
prompt_lens=prompt_lens,
subquery_lens=subquery_lens,
seq_lens=seq_lens,
query_lens=query_lens,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
@ -427,7 +423,7 @@ class ModelRunner:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
context_lens: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
@ -455,9 +451,9 @@ class ModelRunner:
position = seq_len - 1
input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min(
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
context_lens.append(context_len)
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
@ -477,11 +473,10 @@ class ModelRunner:
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
batch_size = len(input_tokens)
max_context_len = max(context_lens)
use_captured_graph = (
not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_context_len <= self.max_context_len_to_capture)
max_seq_len = max(seq_lens)
use_captured_graph = (not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_seq_len <= self.max_seq_len_to_capture)
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
@ -489,21 +484,21 @@ class ModelRunner:
input_tokens.append(0)
input_positions.append(0)
slot_mapping.append(_PAD_SLOT_ID)
context_lens.append(1)
seq_lens.append(1)
block_tables.append([])
lora_index_mapping.append(0)
batch_size = graph_batch_size
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
if use_captured_graph:
# When using cuda-graph all these tensors should be
# padded.
assert context_lens_tensor.shape[0] == len(input_tokens)
assert context_lens_tensor.shape[0] == len(input_positions)
assert context_lens_tensor.shape[0] == len(slot_mapping)
assert seq_lens_tensor.shape[0] == len(input_tokens)
assert seq_lens_tensor.shape[0] == len(input_positions)
assert seq_lens_tensor.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
@ -525,14 +520,13 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
prompt_lens=None,
prompt_lens_tensor=None,
max_subquery_len=None,
max_context_len=max_context_len,
max_prompt_len=None,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_query_len=None,
max_seq_len=max_seq_len,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens_tensor,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
@ -565,8 +559,8 @@ class ModelRunner:
input_tokens,
input_positions,
prefill_attn_metadata,
prompt_lens,
subquery_lens,
seq_lens,
query_lens,
lora_index_mapping,
lora_prompt_mapping,
lora_requests,
@ -583,13 +577,13 @@ class ModelRunner:
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, prompt_lens, subquery_lens,
self.device, self.pin_memory)
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.pin_memory)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens)
num_prefills = len(seq_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
@ -886,7 +880,7 @@ class ModelRunner:
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
slot_mapping.fill_(_PAD_SLOT_ID)
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
graph_batch_size = _get_graph_batch_size(
@ -908,14 +902,13 @@ class ModelRunner:
# Create dummy attn_metadata.
decode_metadata = self.attn_backend.make_metadata(
is_prompt=False,
prompt_lens=None,
prompt_lens_tensor=None,
max_subquery_len=None,
max_context_len=self.max_context_len_to_capture,
max_prompt_len=None,
seq_lens=None,
seq_lens_tensor=seq_lens[:batch_size],
max_query_len=None,
max_seq_len=self.max_seq_len_to_capture,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens[:batch_size],
context_lens_tensor=None,
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
@ -1025,7 +1018,7 @@ class CUDAGraphRunner:
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.decode_metadata.context_lens,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
@ -1047,8 +1040,8 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True)
self.input_buffers["context_lens"].copy_(
attn_metadata.decode_metadata.context_lens, non_blocking=True)
self.input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph.

View File

@ -52,7 +52,7 @@ class NeuronModelRunner:
input_positions: List[List[int]] = []
input_block_ids: List[int] = []
prompt_lens: List[int] = []
seq_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
@ -61,26 +61,26 @@ class NeuronModelRunner:
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
seq_len = len(prompt_tokens)
seq_lens.append(seq_len)
input_tokens.append(prompt_tokens)
input_positions.append(list(range(prompt_len)))
input_positions.append(list(range(seq_len)))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1
input_block_ids.append(block_table[0])
max_prompt_len = max(prompt_lens)
assert max_prompt_len > 0
max_seq_len = max(seq_lens)
assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens,
max_prompt_len,
max_seq_len,
pad=0,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_prompt_len,
max_seq_len,
pad=0,
dtype=torch.long,
device=self.device)
@ -88,7 +88,7 @@ class NeuronModelRunner:
dtype=torch.long,
device=self.device)
return input_tokens, input_positions, input_block_ids, prompt_lens
return input_tokens, input_positions, input_block_ids, seq_lens
def _prepare_decode(
self,
@ -149,18 +149,18 @@ class NeuronModelRunner:
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
seq_lens) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens,
# subquery_lens is not needed if chunked prefill is not
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use prompt_lens instead.
prompt_lens,
# just use seq_lens instead.
seq_lens,
self.device,
self.pin_memory)