[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
This commit is contained in:
parent
2d7bce9cd5
commit
3521ba4f25
@ -16,7 +16,7 @@ PARTITION_SIZE = 512
|
||||
def main(
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
context_len: int,
|
||||
seq_len: int,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
@ -48,12 +48,12 @@ def main(
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
|
||||
context_lens = [context_len for _ in range(num_seqs)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
|
||||
seq_lens = [seq_len for _ in range(num_seqs)]
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
@ -77,8 +77,7 @@ def main(
|
||||
# Prepare for the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v2":
|
||||
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
||||
PARTITION_SIZE)
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
@ -110,9 +109,9 @@ def main(
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
@ -129,9 +128,9 @@ def main(
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
@ -166,7 +165,7 @@ if __name__ == '__main__':
|
||||
choices=["v1", "v2"],
|
||||
default="v2")
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--context-len", type=int, default=4096)
|
||||
parser.add_argument("--seq_len", type=int, default=4096)
|
||||
parser.add_argument("--num-query-heads", type=int, default=64)
|
||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
@ -199,7 +198,7 @@ if __name__ == '__main__':
|
||||
main(
|
||||
version=args.version,
|
||||
num_seqs=args.batch_size,
|
||||
context_len=args.context_len,
|
||||
seq_len=args.seq_len,
|
||||
num_query_heads=args.num_query_heads,
|
||||
num_kv_heads=args.num_kv_heads,
|
||||
head_size=args.head_size,
|
||||
|
@ -104,7 +104,7 @@ __device__ void paged_attention_kernel(
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
@ -115,23 +115,23 @@ __device__ void paged_attention_kernel(
|
||||
const int partition_idx = blockIdx.z;
|
||||
const int max_num_partitions = gridDim.z;
|
||||
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
|
||||
const int context_len = context_lens[seq_idx];
|
||||
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
|
||||
// No work to do. Terminate the thread block.
|
||||
return;
|
||||
}
|
||||
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
|
||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
|
||||
|
||||
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
||||
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
||||
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
|
||||
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
|
||||
const int num_blocks = end_block_idx - start_block_idx;
|
||||
|
||||
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
||||
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
||||
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
|
||||
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
|
||||
const int num_tokens = end_token_idx - start_token_idx;
|
||||
|
||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
@ -245,12 +245,12 @@ __device__ void paged_attention_kernel(
|
||||
// This includes a reduction across the threads in the same thread group.
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
// Add the ALiBi bias if slopes are given.
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
|
||||
|
||||
if (thread_group_offset == 0) {
|
||||
// Store the partial reductions to shared memory.
|
||||
// NOTE(woosuk): It is required to zero out the masked logits.
|
||||
const bool mask = token_idx >= context_len;
|
||||
const bool mask = token_idx >= seq_len;
|
||||
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
||||
// Update the max value.
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
@ -364,14 +364,14 @@ __device__ void paged_attention_kernel(
|
||||
} else {
|
||||
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||
}
|
||||
if (block_idx == num_context_blocks - 1) {
|
||||
if (block_idx == num_seq_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||
// we should explicitly zero out the values since they may contain NaNs.
|
||||
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < V_VEC_SIZE; j++) {
|
||||
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
||||
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
|
||||
}
|
||||
}
|
||||
accs[i] += dot(logits_vec, v_vec);
|
||||
@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel(
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel(
|
||||
const float kv_scale) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
|
||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
||||
}
|
||||
|
||||
@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel(
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel(
|
||||
const float kv_scale) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||
q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
||||
}
|
||||
|
||||
@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_partitions) {
|
||||
const int num_heads = gridDim.x;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
||||
if (num_partitions == 1) {
|
||||
// No need to reduce. Only copy tmp_out to out.
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
seq_lens_ptr, \
|
||||
max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, \
|
||||
q_stride, \
|
||||
@ -639,8 +639,8 @@ void paged_attention_v1_launcher(
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int max_context_len,
|
||||
torch::Tensor& seq_lens,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
float kv_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
@ -664,11 +664,11 @@ void paged_attention_v1_launcher(
|
||||
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int logits_size = padded_max_context_len * sizeof(float);
|
||||
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int logits_size = padded_max_seq_len * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||
// Keep that in sync with the logic here!
|
||||
@ -715,8 +715,8 @@ void paged_attention_v1_launcher(
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len, \
|
||||
seq_lens, \
|
||||
max_seq_len, \
|
||||
alibi_slopes, \
|
||||
kv_scale);
|
||||
|
||||
@ -746,9 +746,9 @@ void paged_attention_v1(
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale) {
|
||||
@ -790,7 +790,7 @@ void paged_attention_v1(
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
seq_lens_ptr, \
|
||||
max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, \
|
||||
q_stride, \
|
||||
@ -803,7 +803,7 @@ void paged_attention_v1(
|
||||
exp_sums_ptr, \
|
||||
max_logits_ptr, \
|
||||
tmp_out_ptr, \
|
||||
context_lens_ptr, \
|
||||
seq_lens_ptr, \
|
||||
max_num_partitions);
|
||||
|
||||
template<
|
||||
@ -824,8 +824,8 @@ void paged_attention_v2_launcher(
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int max_context_len,
|
||||
torch::Tensor& seq_lens,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
float kv_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
@ -852,10 +852,10 @@ void paged_attention_v2_launcher(
|
||||
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
|
||||
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
||||
int logits_size = PARTITION_SIZE * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
|
||||
@ -909,8 +909,8 @@ void paged_attention_v2_launcher(
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len, \
|
||||
seq_lens, \
|
||||
max_seq_len, \
|
||||
alibi_slopes, \
|
||||
kv_scale);
|
||||
|
||||
@ -943,9 +943,9 @@ void paged_attention_v2(
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale) {
|
||||
|
@ -70,11 +70,11 @@ template <typename T>
|
||||
FORCE_INLINE std::pair<T, T>
|
||||
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
|
||||
const float alibi_slope, const int start_index,
|
||||
const int context_len) {
|
||||
data[0] += alibi_slope * (start_index - context_len + 1);
|
||||
const int seq_len) {
|
||||
data[0] += alibi_slope * (start_index - seq_len + 1);
|
||||
T max = data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
T qk = data[i] + alibi_slope * (start_index + i - context_len + 1);
|
||||
T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
|
||||
data[i] = qk;
|
||||
max = max >= qk ? max : qk;
|
||||
}
|
||||
@ -225,7 +225,7 @@ struct paged_attention_v1_impl {
|
||||
const int num_kv_heads, const float scale,
|
||||
const int
|
||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int *__restrict__ context_lens, // [num_seqs]
|
||||
const int *__restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float *__restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
@ -235,32 +235,32 @@ struct paged_attention_v1_impl {
|
||||
|
||||
static_assert(BLOCK_SIZE == 16);
|
||||
|
||||
int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE;
|
||||
int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0;
|
||||
TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0);
|
||||
int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
|
||||
int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
|
||||
TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
|
||||
|
||||
const int parallel_work_item_num = omp_get_max_threads();
|
||||
|
||||
size_t logits_bytes =
|
||||
parallel_work_item_num * max_context_len_padded * sizeof(float);
|
||||
parallel_work_item_num * max_seq_len_padded * sizeof(float);
|
||||
float *logits = (float *)std::aligned_alloc(
|
||||
64, logits_bytes); // Cacheline alignment for each context token.
|
||||
// [parallel_work_item_num, max_context_len_padded]
|
||||
// [parallel_work_item_num, max_seq_len_padded]
|
||||
|
||||
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
int context_len = context_lens[seq_idx];
|
||||
int seq_len = seq_lens[seq_idx];
|
||||
const int *seq_block_table =
|
||||
block_tables + max_num_blocks_per_seq * seq_idx;
|
||||
const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const scalar_t *__restrict__ q_vec_ptr =
|
||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
const int last_block_token_num =
|
||||
context_len - (block_num - 1) * BLOCK_SIZE;
|
||||
seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||
float *__restrict__ thread_block_logits =
|
||||
logits + omp_get_thread_num() * max_context_len_padded;
|
||||
logits + omp_get_thread_num() * max_seq_len_padded;
|
||||
|
||||
// Compute logits
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
@ -278,11 +278,11 @@ struct paged_attention_v1_impl {
|
||||
|
||||
// Compute softmax
|
||||
if (alibi_slopes) {
|
||||
reduceSoftmaxAlibi(thread_block_logits, context_len,
|
||||
reduceSoftmaxAlibi(thread_block_logits, seq_len,
|
||||
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
||||
context_len);
|
||||
seq_len);
|
||||
} else {
|
||||
reduceSoftmax(thread_block_logits, context_len,
|
||||
reduceSoftmax(thread_block_logits, seq_len,
|
||||
block_num * BLOCK_SIZE);
|
||||
}
|
||||
|
||||
@ -340,7 +340,7 @@ struct paged_attention_v1_impl {
|
||||
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
|
||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
||||
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
|
||||
num_heads);
|
||||
|
||||
@ -348,8 +348,8 @@ template <typename T, int BLOCK_SIZE>
|
||||
void paged_attention_v1_impl_launcher(
|
||||
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
|
||||
torch::Tensor &value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor &block_tables, torch::Tensor &context_lens,
|
||||
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
||||
torch::Tensor &block_tables, torch::Tensor &seq_lens,
|
||||
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher(
|
||||
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
|
||||
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
|
||||
int *block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int *context_lens_ptr = context_lens.data_ptr<int>();
|
||||
int *seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
switch (head_size) {
|
||||
case 64:
|
||||
@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher(
|
||||
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||
context_lens, max_context_len, alibi_slopes);
|
||||
seq_lens, max_seq_len, alibi_slopes);
|
||||
|
||||
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
|
||||
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
||||
int num_kv_heads, float scale,
|
||||
torch::Tensor &block_tables,
|
||||
torch::Tensor &context_lens, int block_size,
|
||||
int max_context_len,
|
||||
torch::Tensor &seq_lens, int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
@ -448,7 +448,7 @@ struct paged_attention_v2_impl {
|
||||
const int num_kv_heads, const float scale,
|
||||
const int
|
||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int *__restrict__ context_lens, // [num_seqs]
|
||||
const int *__restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float *__restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
@ -465,22 +465,22 @@ struct paged_attention_v2_impl {
|
||||
for (int partition_idx = 0; partition_idx < max_num_partitions;
|
||||
++partition_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
||||
|
||||
if (start_token_idx >= context_len)
|
||||
if (start_token_idx >= seq_len)
|
||||
continue;
|
||||
|
||||
const int partition_num =
|
||||
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
const bool no_reduce = (partition_num == 1);
|
||||
const int context_token_num =
|
||||
(std::min(context_len, start_token_idx + PARTITION_SIZE) -
|
||||
const int token_num =
|
||||
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
|
||||
start_token_idx);
|
||||
const int block_num =
|
||||
(context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int last_block_token_num =
|
||||
context_token_num - (block_num - 1) * BLOCK_SIZE;
|
||||
token_num - (block_num - 1) * BLOCK_SIZE;
|
||||
const int *seq_block_table = block_tables +
|
||||
max_num_blocks_per_seq * seq_idx +
|
||||
start_token_idx / BLOCK_SIZE;
|
||||
@ -507,10 +507,10 @@ struct paged_attention_v2_impl {
|
||||
std::pair<float, float> max_and_sum;
|
||||
if (alibi_slopes) {
|
||||
max_and_sum = reduceSoftmaxAlibi(
|
||||
logits, context_token_num, block_num * BLOCK_SIZE,
|
||||
alibi_slopes[head_idx], start_token_idx, context_len);
|
||||
logits, token_num, block_num * BLOCK_SIZE,
|
||||
alibi_slopes[head_idx], start_token_idx, seq_len);
|
||||
} else {
|
||||
max_and_sum = reduceSoftmax(logits, context_token_num,
|
||||
max_and_sum = reduceSoftmax(logits, token_num,
|
||||
block_num * BLOCK_SIZE);
|
||||
}
|
||||
|
||||
@ -583,9 +583,9 @@ struct paged_attention_v2_impl {
|
||||
#pragma omp parallel for collapse(2) schedule(static, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int partition_num =
|
||||
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
if (partition_num == 1)
|
||||
continue;
|
||||
@ -612,9 +612,9 @@ struct paged_attention_v2_impl {
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int partition_num =
|
||||
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
if (partition_num == 1)
|
||||
continue;
|
||||
@ -649,7 +649,7 @@ struct paged_attention_v2_impl {
|
||||
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
|
||||
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||
context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
|
||||
max_num_partitions);
|
||||
|
||||
@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher(
|
||||
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
|
||||
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
|
||||
torch::Tensor &value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size,
|
||||
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
||||
torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size,
|
||||
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher(
|
||||
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
|
||||
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
|
||||
int *block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int *context_lens_ptr = context_lens.data_ptr<int>();
|
||||
int *seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
switch (head_size) {
|
||||
case 64:
|
||||
@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher(
|
||||
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, context_lens, block_size, \
|
||||
max_context_len, alibi_slopes);
|
||||
num_kv_heads, scale, block_tables, seq_lens, block_size, \
|
||||
max_seq_len, alibi_slopes);
|
||||
|
||||
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
|
||||
torch::Tensor &query, torch::Tensor &key_cache,
|
||||
torch::Tensor &value_cache, int num_kv_heads,
|
||||
float scale, torch::Tensor &block_tables,
|
||||
torch::Tensor &context_lens, int block_size,
|
||||
int max_context_len,
|
||||
torch::Tensor &seq_lens, int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
|
@ -10,9 +10,9 @@ void paged_attention_v1(
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
torch::Tensor& seq_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale);
|
||||
@ -28,9 +28,9 @@ void paged_attention_v2(
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
torch::Tensor& seq_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale);
|
||||
|
@ -61,7 +61,7 @@ def ref_single_query_cached_kv_attention(
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
@ -72,15 +72,15 @@ def ref_single_query_cached_kv_attention(
|
||||
num_seqs = query.shape[0]
|
||||
|
||||
block_tables = block_tables.cpu().tolist()
|
||||
context_lens = context_lens.cpu().tolist()
|
||||
seq_lens = seq_lens.cpu().tolist()
|
||||
for i in range(num_seqs):
|
||||
q = query[i].unsqueeze(0)
|
||||
block_table = block_tables[i]
|
||||
context_len = int(context_lens[i])
|
||||
seq_len = int(seq_lens[i])
|
||||
|
||||
keys = []
|
||||
values = []
|
||||
for j in range(context_len):
|
||||
for j in range(seq_len):
|
||||
block_number = int(block_table[j // block_size])
|
||||
block_offset = j % block_size
|
||||
|
||||
@ -100,8 +100,8 @@ def ref_single_query_cached_kv_attention(
|
||||
alibi_bias = None
|
||||
if alibi_slopes is not None:
|
||||
# Create the ALiBi bias used in the paged attention kernel.
|
||||
position_ids = torch.arange(context_len).int()
|
||||
alibi_bias = (position_ids - context_len + 1).float()
|
||||
position_ids = torch.arange(seq_len).int()
|
||||
alibi_bias = (position_ids - seq_len + 1).float()
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||
1, 1, -1)
|
||||
|
||||
@ -149,13 +149,13 @@ def test_paged_attention(
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
context_lens[-1] = MAX_SEQ_LEN
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int)
|
||||
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
seq_lens[-1] = MAX_SEQ_LEN
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
@ -186,16 +186,15 @@ def test_paged_attention(
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
)
|
||||
elif version == "v2":
|
||||
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
||||
PARTITION_SIZE)
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
assert PARTITION_SIZE % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
tmp_output = torch.empty(
|
||||
@ -218,9 +217,9 @@ def test_paged_attention(
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
@ -255,7 +254,7 @@ def test_paged_attention(
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
seq_lens,
|
||||
scale,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
@ -51,12 +51,12 @@ def test_contexted_kv_attention(
|
||||
cache_size = 640
|
||||
block_size = 32
|
||||
max_block_per_request = 64
|
||||
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
|
||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
|
||||
num_tokens = sum(subquery_lens)
|
||||
num_tokens = sum(query_lens)
|
||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-1e-3, 1e-3)
|
||||
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
@ -75,15 +75,15 @@ def test_contexted_kv_attention(
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[:BS * max_block_per_request].view(
|
||||
BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1],
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
@ -92,7 +92,7 @@ def test_contexted_kv_attention(
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
for i in range(BS):
|
||||
for j in range(subquery_lens[i]):
|
||||
for j in range(query_lens[i]):
|
||||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
||||
j])
|
||||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
||||
@ -178,7 +178,7 @@ def test_contexted_kv_attention(
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||
subquery_lens, seq_lens)
|
||||
query_lens, seq_lens)
|
||||
if sliding_window > 0:
|
||||
attn_bias = attn_bias.make_local_attention_from_bottomright(
|
||||
sliding_window)
|
||||
|
@ -58,7 +58,7 @@ def _do_sample(
|
||||
device: str,
|
||||
):
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
@ -68,12 +68,12 @@ def _do_sample(
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
|
||||
@ -421,7 +421,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
"Invalid test case, need seq_group_metadata_list"
|
||||
|
||||
batch_size = 0
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
sampling_params_per_row = []
|
||||
for sgm in seq_group_metadata_list:
|
||||
sampling_params = sgm.sampling_params
|
||||
@ -431,7 +431,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
# a prompt seq_group has only one sequence
|
||||
seq_data = next(iter(sgm.seq_data.values()))
|
||||
prompt_len = seq_data.get_prompt_len()
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_lens.append(prompt_len)
|
||||
|
||||
if sgm.sampling_params.prompt_logprobs:
|
||||
# with prompt_logprobs each token in the prompt has a row in
|
||||
@ -451,8 +451,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens=prompt_lens if prompt_lens else None,
|
||||
subquery_lens=prompt_lens if prompt_lens else None,
|
||||
seq_lens=seq_lens if seq_lens else None,
|
||||
query_lens=seq_lens if seq_lens else None,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
# the logits tensor is modified in-place by the sampler
|
||||
@ -497,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
|
||||
seq_group_metadata_list = []
|
||||
expected_tokens: List[Optional[List[int]]] = []
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
expected: Optional[List[int]] = None
|
||||
sampling_type = random.randint(0, 3)
|
||||
@ -532,13 +532,13 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
def test_sampling(model_runner: ModelRunner):
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
@ -575,7 +575,7 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
# Shuffle the batch and resample
|
||||
target_index = list(range(batch_size))
|
||||
for list_to_shuffle in (target_index, seq_group_metadata_list,
|
||||
expected_tokens, prompt_lens):
|
||||
expected_tokens, seq_lens):
|
||||
random.Random(seed).shuffle(list_to_shuffle)
|
||||
target_index = torch.tensor(target_index)
|
||||
input_tensor.data = input_tensor.index_select(0, target_index)
|
||||
@ -620,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
assert len(warpers) == 2 # top_p and top_k
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
@ -634,12 +634,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
|
||||
|
@ -45,7 +45,7 @@ class AsyncLLM:
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: int = 8192,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -66,7 +66,7 @@ class AsyncLLM:
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
engine_use_ray=True,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
**kwargs,
|
||||
|
@ -34,7 +34,7 @@ def test_assert_enough_kv_space(num_steps: int):
|
||||
list(range(block_size * 2)),
|
||||
]
|
||||
|
||||
final_seq_lens = [
|
||||
final_prompt_lens = [
|
||||
len(prompt + output) + num_steps
|
||||
for prompt, output in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
@ -43,7 +43,7 @@ def test_assert_enough_kv_space(num_steps: int):
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_seq_lens,
|
||||
final_prompt_lens,
|
||||
continuations=prev_output_tokens)
|
||||
|
||||
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
|
||||
@ -103,17 +103,21 @@ def test_same_output_for_single_step():
|
||||
[6, 7, 8, 9, 10],
|
||||
]
|
||||
|
||||
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
multi_step_execute_model_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
single_step_execute_model_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
@ -181,7 +185,7 @@ def test_same_output_for_multi_step():
|
||||
random.randint(0, 1000) for _ in range(random.randint(10, 20))
|
||||
] for _ in range(10)]
|
||||
|
||||
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
@ -195,7 +199,7 @@ def test_same_output_for_multi_step():
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_seq_lens=final_seq_lens), )
|
||||
final_prompt_lens=final_prompt_lens), )
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
@ -217,7 +221,7 @@ def test_same_output_for_multi_step():
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_seq_lens=final_seq_lens))
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
single_step_output.extend(
|
||||
worker.execute_model(**execute_model_data.to_dict(), ))
|
||||
|
@ -43,11 +43,13 @@ def test_ngram_algo_correctness_for_single_no_match():
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
@ -110,11 +112,13 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
@ -180,11 +184,13 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens))
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
|
@ -144,7 +144,7 @@ def create_seq_group_metadata_from_prompts(
|
||||
prompts: List[List[int]],
|
||||
num_gpu_blocks: int,
|
||||
block_size: int,
|
||||
final_seq_lens: List[int],
|
||||
final_prompt_lens: List[int],
|
||||
continuations: Optional[List[List[int]]] = None,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
@ -162,7 +162,7 @@ def create_seq_group_metadata_from_prompts(
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(final_len, block_size))
|
||||
]
|
||||
for i, final_len in enumerate(final_seq_lens)
|
||||
for i, final_len in enumerate(final_prompt_lens)
|
||||
}
|
||||
|
||||
return [
|
||||
@ -251,13 +251,13 @@ def create_batch(batch_size,
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_seq_lens = [
|
||||
final_prompt_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
|
||||
block_size, final_seq_lens,
|
||||
block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids), )
|
||||
return execute_model_data, prompts, prev_output_tokens
|
||||
|
@ -70,7 +70,7 @@ def test_logits_processors(seed: int, device: str):
|
||||
return logits
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
@ -81,12 +81,12 @@ def test_logits_processors(seed: int, device: str):
|
||||
logits_processors=[pick_ith]),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
logits_processor_output = logits_processor(
|
||||
|
@ -23,14 +23,14 @@ def test_prepare_prompt(batch_size):
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
seq_group_metadata_list = []
|
||||
block_tables = {0: [1]}
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = SequenceData(list(range(prompt_len)))
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(list(range(seq_len)))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@ -43,29 +43,29 @@ def test_prepare_prompt(batch_size):
|
||||
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
for prompt_len in prompt_lens:
|
||||
for seq_len in seq_lens:
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += prompt_len
|
||||
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
|
||||
_, _,
|
||||
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert return_prompt_lens == prompt_lens
|
||||
seq_len - 1)
|
||||
selected_token_start_idx += seq_len
|
||||
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
|
||||
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert return_seq_lens == seq_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.is_prompt is True
|
||||
assert torch.allclose(attn_metadata.prompt_lens_tensor,
|
||||
torch.tensor(prompt_lens, device=device))
|
||||
assert attn_metadata.prompt_lens == prompt_lens
|
||||
assert attn_metadata.max_prompt_len == max(prompt_lens)
|
||||
assert torch.allclose(
|
||||
attn_metadata.seq_lens_tensor,
|
||||
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.seq_lens == seq_lens
|
||||
assert attn_metadata.max_seq_len == max(seq_lens)
|
||||
|
||||
# Test subquery start locs.
|
||||
start_idx = 0
|
||||
start_loc = [start_idx]
|
||||
for prompt_len in prompt_lens:
|
||||
start_idx += prompt_len
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
attn_metadata.subquery_start_loc,
|
||||
@ -75,17 +75,16 @@ def test_prepare_prompt(batch_size):
|
||||
# equivalent to subquery_start_loc.
|
||||
start_idx = 0
|
||||
seq_start_loc = [start_idx]
|
||||
for prompt_len in prompt_lens:
|
||||
start_idx += prompt_len
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
seq_start_loc.append(start_idx)
|
||||
|
||||
assert torch.allclose(
|
||||
attn_metadata.seq_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
assert attn_metadata.max_context_len is None
|
||||
assert torch.allclose(
|
||||
attn_metadata.context_lens,
|
||||
torch.zeros(attn_metadata.context_lens.shape[0],
|
||||
attn_metadata.context_lens_tensor,
|
||||
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
|
||||
dtype=torch.int,
|
||||
device=device))
|
||||
|
||||
@ -96,18 +95,18 @@ def test_prepare_prompt(batch_size):
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert attn_metadata.use_cuda_graph is False
|
||||
|
||||
assert len(input_tokens) == sum(prompt_lens)
|
||||
assert len(input_positions) == sum(prompt_lens)
|
||||
assert len(input_tokens) == sum(seq_lens)
|
||||
assert len(input_positions) == sum(seq_lens)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
assert len(input_tokens) == sum(prompt_lens)
|
||||
assert len(input_positions) == sum(prompt_lens)
|
||||
assert len(input_tokens) == sum(seq_lens)
|
||||
assert len(input_positions) == sum(seq_lens)
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
@ -146,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = list(range(prompt_len))
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = list(range(seq_len))
|
||||
seq_data = SequenceData(seq_data)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
@ -172,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.is_prompt is False
|
||||
assert attn_metadata.prompt_lens is None
|
||||
assert attn_metadata.max_prompt_len is None
|
||||
assert attn_metadata.seq_lens is None
|
||||
assert attn_metadata.subquery_start_loc is None
|
||||
assert attn_metadata.seq_start_loc is None
|
||||
assert attn_metadata.max_context_len == max(prompt_lens)
|
||||
assert attn_metadata.max_seq_len == max(seq_lens)
|
||||
assert torch.allclose(
|
||||
attn_metadata.context_lens[:len(prompt_lens)],
|
||||
torch.tensor(prompt_lens, dtype=torch.int, device=device))
|
||||
attn_metadata.seq_lens_tensor[:len(seq_lens)],
|
||||
torch.tensor(seq_lens, dtype=torch.int, device=device))
|
||||
|
||||
# block table's first index corresponds to each batch, meaning in
|
||||
# decoding it is each token.
|
||||
@ -198,13 +196,13 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# Verify Sampling
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
for prompt_len in prompt_lens:
|
||||
for seq_len in seq_lens:
|
||||
expected_selected_token_indices.append(selected_token_start_idx)
|
||||
selected_token_start_idx += 1
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
@ -241,14 +239,13 @@ def test_empty_seq_group():
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
|
||||
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
|
||||
_, _,
|
||||
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
|
||||
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
assert len(return_prompt_lens) == 0
|
||||
assert len(return_seq_lens) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -288,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
# Add prefill requests.
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
seq_group_metadata_list = []
|
||||
prefill_metadata_list = []
|
||||
decode_metadata_list = []
|
||||
@ -297,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
decode_batch_size = batch_size - prefill_batch_size
|
||||
for i in range(prefill_batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = SequenceData(list(range(prompt_len)))
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(list(range(seq_len)))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@ -314,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
# Add decode requests
|
||||
for i in range(prefill_batch_size, batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_toks = list(range(prompt_len))
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_toks = list(range(seq_len))
|
||||
seq_data = SequenceData(prompt_toks)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
@ -343,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
else:
|
||||
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
|
||||
decode_batch_size)
|
||||
assert attn_metadata.num_prefill_tokens == sum(prompt_lens)
|
||||
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
|
||||
|
||||
# Verify attn metadata is consistent. We don't need to test individual
|
||||
# values here because they are tested above.
|
||||
|
@ -39,17 +39,17 @@ def paged_attention_v1(
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
block_size: int,
|
||||
max_context_len: int,
|
||||
max_seq_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
kv_scale: float,
|
||||
) -> None:
|
||||
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
|
||||
num_kv_heads, scale, block_tables,
|
||||
context_lens, block_size, max_context_len,
|
||||
alibi_slopes, kv_cache_dtype, kv_scale)
|
||||
num_kv_heads, scale, block_tables, seq_lens,
|
||||
block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, kv_scale)
|
||||
|
||||
|
||||
def paged_attention_v2(
|
||||
@ -63,17 +63,17 @@ def paged_attention_v2(
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
block_size: int,
|
||||
max_context_len: int,
|
||||
max_seq_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
kv_scale: float,
|
||||
) -> None:
|
||||
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
|
||||
key_cache, value_cache, num_kv_heads, scale,
|
||||
block_tables, context_lens, block_size,
|
||||
max_context_len, alibi_slopes, kv_cache_dtype,
|
||||
block_tables, seq_lens, block_size,
|
||||
max_seq_len, alibi_slopes, kv_cache_dtype,
|
||||
kv_scale)
|
||||
|
||||
|
||||
|
@ -66,27 +66,24 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (batch_size,). The prompt length per sequence. None if it is a decoding.
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seqlen ----------------------|
|
||||
# |- subquery_len -|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# WARNING(sang): context_len has different definition depending on if it is
|
||||
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
|
||||
# When it is for decoding, it includes a new token.
|
||||
|
||||
# Maximum subquery length in the batch.
|
||||
max_subquery_len: Optional[int]
|
||||
# Maximum prompt length in the batch.
|
||||
max_prompt_len: Optional[int]
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int]
|
||||
# Maximum sequence length in the batch.
|
||||
max_seq_len: Optional[int]
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
@ -223,8 +223,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prompt_len,
|
||||
max_seqlen_k=prefill_meta.max_prompt_len,
|
||||
max_seqlen_q=prefill_meta.max_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
@ -245,9 +245,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.prompt_lens_tensor,
|
||||
prefill_meta.context_lens,
|
||||
prefill_meta.max_subquery_len,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.context_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window[0],
|
||||
)
|
||||
@ -258,8 +258,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.context_lens,
|
||||
decode_meta.max_context_len,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_seq_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
|
@ -64,27 +64,24 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (batch_size,). The prompt length per sequence. None if it is a decoding.
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seqlen ----------------------|
|
||||
# |- subquery_len -|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# WARNING(sang): context_len has different definition depending on if it is
|
||||
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
|
||||
# When it is for decoding, it includes a new token.
|
||||
|
||||
# Maximum subquery length in the batch.
|
||||
max_subquery_len: Optional[int]
|
||||
# Maximum prompt length in the batch.
|
||||
max_prompt_len: Optional[int]
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int]
|
||||
# Maximum sequence length in the batch.
|
||||
max_seq_len: Optional[int]
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
@ -247,7 +247,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
assert prefill_meta.prompt_lens is not None
|
||||
assert prefill_meta.seq_lens is not None
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
# triton attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
@ -260,8 +260,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
None,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prompt_len,
|
||||
prefill_meta.max_prompt_len,
|
||||
prefill_meta.max_seq_len,
|
||||
prefill_meta.max_seq_len,
|
||||
True,
|
||||
self.scale,
|
||||
)
|
||||
@ -274,7 +274,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta.prompt_lens,
|
||||
prefill_meta.seq_lens,
|
||||
self.scale,
|
||||
)
|
||||
else:
|
||||
@ -284,8 +284,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prompt_len,
|
||||
max_seqlen_k=prefill_meta.max_prompt_len,
|
||||
max_seqlen_q=prefill_meta.max_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
@ -303,9 +303,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.prompt_lens_tensor,
|
||||
prefill_meta.context_lens,
|
||||
prefill_meta.max_subquery_len,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.context_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window[0],
|
||||
)
|
||||
@ -317,8 +317,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.context_lens,
|
||||
decode_meta.max_context_len,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_seq_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
@ -334,13 +334,13 @@ def _naive_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
prompt_lens: List[int],
|
||||
seq_lens: List[int],
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
start = 0
|
||||
for _, prompt_len in enumerate(prompt_lens):
|
||||
end = start + prompt_len
|
||||
for _, seq_len in enumerate(seq_lens):
|
||||
end = start + seq_len
|
||||
out = _naive_masked_attention(
|
||||
query[start:end],
|
||||
key[start:end],
|
||||
@ -349,7 +349,7 @@ def _naive_attention(
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out)
|
||||
start += prompt_len
|
||||
start += seq_len
|
||||
|
||||
return output
|
||||
|
||||
|
@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
slot_mapping: torch.Tensor
|
||||
prompt_lens: Optional[List[int]]
|
||||
seq_lens: Optional[List[int]]
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
@ -136,7 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
kv_scale)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
assert attn_metadata.prompt_lens is not None
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
@ -147,13 +147,13 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
if self.alibi_slopes is not None:
|
||||
att_masks = _make_alibi_bias(
|
||||
self.alibi_slopes, query.dtype,
|
||||
attn_metadata.prompt_lens) # type: ignore
|
||||
attn_metadata.seq_lens) # type: ignore
|
||||
elif self.sliding_window is not None:
|
||||
att_masks = _make_sliding_window_bias(
|
||||
attn_metadata.prompt_lens, self.sliding_window,
|
||||
attn_metadata.seq_lens, self.sliding_window,
|
||||
query.dtype) # type: ignore
|
||||
else:
|
||||
att_masks = [None] * len(attn_metadata.prompt_lens)
|
||||
att_masks = [None] * len(attn_metadata.seq_lens)
|
||||
attn_metadata.attn_bias = att_masks
|
||||
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
@ -164,9 +164,9 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
output = torch.empty(
|
||||
(num_tokens, self.num_heads, self.head_size),
|
||||
dtype=query.dtype)
|
||||
for prompt_len, mask in zip(attn_metadata.prompt_lens,
|
||||
for seq_len, mask in zip(attn_metadata.seq_lens,
|
||||
attn_metadata.attn_bias):
|
||||
end = start + prompt_len
|
||||
end = start + seq_len
|
||||
sub_out = scaled_dot_product_attention(
|
||||
query[:, start:end, :],
|
||||
key[:, start:end, :],
|
||||
@ -189,8 +189,8 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_context_len,
|
||||
attn_metadata.seq_lens_tensor,
|
||||
attn_metadata.max_seq_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
@ -205,13 +205,13 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
prompt_lens: List[int],
|
||||
seq_lens: List[int],
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases = []
|
||||
for prompt_len in prompt_lens:
|
||||
bias = torch.arange(prompt_len, dtype=dtype)
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
@ -221,7 +221,7 @@ def _make_alibi_bias(
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
inf_mask = torch.empty(
|
||||
(1, prompt_len, prompt_len),
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
@ -229,14 +229,14 @@ def _make_alibi_bias(
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
prompt_lens: List[int],
|
||||
seq_lens: List[int],
|
||||
window_size: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases = []
|
||||
for prompt_len in prompt_lens:
|
||||
for seq_len in seq_lens:
|
||||
tensor = torch.full(
|
||||
(1, prompt_len, prompt_len),
|
||||
(1, seq_len, seq_len),
|
||||
dtype=dtype,
|
||||
fill_value=1,
|
||||
)
|
||||
|
@ -66,28 +66,24 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (batch_size,). The prompt length per sequence. None if it is a decoding.
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seqlen ----------------------|
|
||||
# |- subquery_len -|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# WARNING(sang): context_len has different definition depending on if it is
|
||||
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
|
||||
# When it is for decoding, it includes a new token.
|
||||
|
||||
# Maximum subquery length in the batch.
|
||||
max_subquery_len: Optional[int]
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int]
|
||||
# FIXME: It is for flash attn.
|
||||
# Maximum prompt length in the batch.
|
||||
max_prompt_len: Optional[int]
|
||||
# Maximum sequence length in the batch.
|
||||
max_seq_len: Optional[int]
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
@ -97,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
@ -242,9 +241,9 @@ class XFormersImpl(AttentionImpl):
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.prompt_lens_tensor,
|
||||
prefill_meta.context_lens,
|
||||
prefill_meta.max_subquery_len,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.context_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window,
|
||||
)
|
||||
@ -257,8 +256,8 @@ class XFormersImpl(AttentionImpl):
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.context_lens,
|
||||
decode_meta.max_context_len,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_seq_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
@ -289,7 +288,7 @@ class XFormersImpl(AttentionImpl):
|
||||
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
"""
|
||||
assert attn_metadata.prompt_lens is not None
|
||||
assert attn_metadata.seq_lens is not None
|
||||
original_query = query
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# GQA/MQA requires the shape [B, M, G, H, K].
|
||||
@ -310,7 +309,7 @@ class XFormersImpl(AttentionImpl):
|
||||
if attn_metadata.attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
attn_metadata.prompt_lens)
|
||||
attn_metadata.seq_lens)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
@ -318,7 +317,7 @@ class XFormersImpl(AttentionImpl):
|
||||
else:
|
||||
attn_metadata.attn_bias = _make_alibi_bias(
|
||||
self.alibi_slopes, self.num_kv_heads, query.dtype,
|
||||
attn_metadata.prompt_lens)
|
||||
attn_metadata.seq_lens)
|
||||
|
||||
# No alibi slopes.
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
@ -343,8 +342,8 @@ class XFormersImpl(AttentionImpl):
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
output = torch.empty_like(original_query)
|
||||
start = 0
|
||||
for i, prompt_len in enumerate(attn_metadata.prompt_lens):
|
||||
end = start + prompt_len
|
||||
for i, seq_len in enumerate(attn_metadata.seq_lens):
|
||||
end = start + seq_len
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
@ -354,7 +353,7 @@ class XFormersImpl(AttentionImpl):
|
||||
scale=self.scale)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out.view_as(original_query[start:end]))
|
||||
start += prompt_len
|
||||
start += seq_len
|
||||
return output
|
||||
|
||||
|
||||
@ -362,13 +361,13 @@ def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
dtype: torch.dtype,
|
||||
prompt_lens: List[int],
|
||||
seq_lens: List[int],
|
||||
) -> LowerTriangularMaskWithTensorBias:
|
||||
attn_biases = []
|
||||
for prompt_len in prompt_lens:
|
||||
bias = torch.arange(prompt_len, dtype=dtype)
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
@ -376,16 +375,16 @@ def _make_alibi_bias(
|
||||
# element.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
padded_len = (prompt_len + 7) // 8 * 8
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = torch.empty(
|
||||
1, # batch size
|
||||
num_heads,
|
||||
prompt_len,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :prompt_len].copy_(bias)
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if num_heads != num_kv_heads:
|
||||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
||||
|
@ -13,12 +13,11 @@ _PARTITION_SIZE = 512
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
# (batch_size,). The length of context (tokens stored in KV cache) per
|
||||
# sequence. WARNING: When it is a prefill request, it doesn't include new
|
||||
# tokens. When it is for decoding, it includes a new token.
|
||||
context_lens: Optional[torch.Tensor]
|
||||
# Maximum context length in the batch.
|
||||
max_context_len: Optional[int]
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length in the batch.
|
||||
max_seq_len: Optional[int]
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
@ -85,8 +84,8 @@ class PagedAttention:
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
@ -97,7 +96,7 @@ class PagedAttention:
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) //
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
@ -106,7 +105,7 @@ class PagedAttention:
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = (max_context_len <= 8192
|
||||
use_v1 = (max_seq_len <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
@ -118,9 +117,9 @@ class PagedAttention:
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
@ -150,9 +149,9 @@ class PagedAttention:
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
@ -168,9 +167,9 @@ class PagedAttention:
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
subquery_start_loc: torch.Tensor,
|
||||
prompt_lens_tensor: torch.Tensor,
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_subquery_len: int,
|
||||
max_query_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
@ -185,9 +184,9 @@ class PagedAttention:
|
||||
block_tables,
|
||||
# subquery_start_loc is (batch_size + 1,)
|
||||
subquery_start_loc[:-1],
|
||||
prompt_lens_tensor,
|
||||
seq_lens_tensor,
|
||||
context_lens,
|
||||
max_subquery_len,
|
||||
max_query_len,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
)
|
||||
|
@ -63,7 +63,10 @@ class ModelConfig:
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode.
|
||||
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
|
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode
|
||||
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
||||
detokenizer.
|
||||
"""
|
||||
@ -84,6 +87,7 @@ class ModelConfig:
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 5,
|
||||
skip_tokenizer_init: bool = False,
|
||||
) -> None:
|
||||
@ -99,6 +103,11 @@ class ModelConfig:
|
||||
self.quantization_param_path = quantization_param_path
|
||||
self.enforce_eager = enforce_eager
|
||||
self.max_context_len_to_capture = max_context_len_to_capture
|
||||
if self.max_context_len_to_capture is not None:
|
||||
raise ValueError("`max_context_len_to_capture` is deprecated. "
|
||||
"Use `max_seq_len_to_capture` instead.")
|
||||
self.max_seq_len_to_capture = (max_seq_len_to_capture
|
||||
or max_context_len_to_capture)
|
||||
self.max_logprobs = max_logprobs
|
||||
self.skip_tokenizer_init = skip_tokenizer_init
|
||||
|
||||
@ -190,9 +199,9 @@ class ModelConfig:
|
||||
"non-quantized models.", self.quantization)
|
||||
|
||||
def _verify_cuda_graph(self) -> None:
|
||||
if self.max_context_len_to_capture is None:
|
||||
self.max_context_len_to_capture = self.max_model_len
|
||||
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
|
||||
if self.max_seq_len_to_capture is None:
|
||||
self.max_seq_len_to_capture = self.max_model_len
|
||||
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
||||
self.max_model_len)
|
||||
|
||||
def verify_with_parallel_config(
|
||||
@ -772,8 +781,8 @@ class SpeculativeConfig:
|
||||
max_model_len=None,
|
||||
quantization=draft_quantization,
|
||||
enforce_eager=target_model_config.enforce_eager,
|
||||
max_context_len_to_capture=target_model_config.
|
||||
max_context_len_to_capture,
|
||||
max_seq_len_to_capture=target_model_config.
|
||||
max_seq_len_to_capture,
|
||||
max_logprobs=target_model_config.max_logprobs,
|
||||
)
|
||||
|
||||
|
@ -44,7 +44,8 @@ class EngineArgs:
|
||||
tokenizer_revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
enforce_eager: bool = False
|
||||
max_context_len_to_capture: int = 8192
|
||||
max_context_len_to_capture: Optional[int] = None
|
||||
max_seq_len_to_capture: int = 8192
|
||||
disable_custom_all_reduce: bool = False
|
||||
tokenizer_pool_size: int = 0
|
||||
tokenizer_pool_type: str = "ray"
|
||||
@ -322,6 +323,14 @@ class EngineArgs:
|
||||
default=EngineArgs.max_context_len_to_capture,
|
||||
help='Maximum context length covered by CUDA '
|
||||
'graphs. When a sequence has context length '
|
||||
'larger than this, we fall back to eager mode. '
|
||||
'(DEPRECATED. Use --max-seq_len-to-capture instead'
|
||||
')')
|
||||
parser.add_argument('--max-seq_len-to-capture',
|
||||
type=int,
|
||||
default=EngineArgs.max_seq_len_to_capture,
|
||||
help='Maximum sequence length covered by CUDA '
|
||||
'graphs. When a sequence has context length '
|
||||
'larger than this, we fall back to eager mode.')
|
||||
parser.add_argument('--disable-custom-all-reduce',
|
||||
action='store_true',
|
||||
@ -492,7 +501,8 @@ class EngineArgs:
|
||||
self.code_revision, self.tokenizer_revision, self.max_model_len,
|
||||
self.quantization, self.quantization_param_path,
|
||||
self.enforce_eager, self.max_context_len_to_capture,
|
||||
self.max_logprobs, self.skip_tokenizer_init)
|
||||
self.max_seq_len_to_capture, self.max_logprobs,
|
||||
self.skip_tokenizer_init)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space, self.kv_cache_dtype,
|
||||
|
@ -69,6 +69,9 @@ class LLM:
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
|
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode.
|
||||
disable_custom_all_reduce: See ParallelConfig
|
||||
@ -90,7 +93,8 @@ class LLM:
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: int = 8192,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -112,6 +116,7 @@ class LLM:
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -1033,8 +1033,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
|
||||
assert seq_group.is_prompt, (
|
||||
"Caller should ensure the sequence group is in a prefill stage.")
|
||||
seq_ids = seq_group.seq_ids
|
||||
subquery_len = seq_group.subquery_len
|
||||
assert subquery_len is not None
|
||||
query_len = seq_group.query_len
|
||||
assert query_len is not None
|
||||
# prompt has only 1 seq id.
|
||||
assert len(seq_ids) == 1
|
||||
seq_data = seq_group.seq_data[seq_ids[0]]
|
||||
@ -1042,7 +1042,7 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
|
||||
prompt_tokens = seq_data.prompt_token_ids
|
||||
# +1 because we are looking for a next prompt token.
|
||||
next_token_index_start = computed_len + 1
|
||||
next_token_index_end = min(computed_len + subquery_len + 1,
|
||||
next_token_index_end = min(computed_len + query_len + 1,
|
||||
len(prompt_tokens))
|
||||
next_prompt_tokens = prompt_tokens[
|
||||
next_token_index_start:next_token_index_end]
|
||||
|
@ -16,17 +16,26 @@ _SEED_0_REPLACEMENT = 3403598558
|
||||
|
||||
@dataclass
|
||||
class SequenceGroupToSample:
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Sequence ids for the sequence group in a previous step.
|
||||
seq_ids: List[int]
|
||||
sampling_params: SamplingParams
|
||||
# seq_id -> sequence data.
|
||||
seq_data: Dict[int, SequenceData]
|
||||
# The length of the prompt of the sequence group. None if it is in a decode
|
||||
# The length of the sequence (all tokens seen in the past + new token to
|
||||
# compute attention) of the sequence group. None if it is in a decode
|
||||
# stage.
|
||||
prompt_len: Optional[int]
|
||||
# The length of the query tokens to compute in the current step. None if it
|
||||
# is in a decode stage. The length of subquery_len <= prompt_len.
|
||||
subquery_len: Optional[int]
|
||||
seq_len: Optional[int]
|
||||
# The length of new query tokens to compute in the current step. None if it
|
||||
# is in a decode stage. The length of query_len <= seq_len if chunked
|
||||
# prefill is enabled.
|
||||
query_len: Optional[int]
|
||||
# A random number generator for sampling.
|
||||
generator: Optional[torch.Generator]
|
||||
# True if the sequence group is in prefill stage. False if it is in a
|
||||
@ -46,8 +55,8 @@ class SequenceGroupToSample:
|
||||
if len(self.prompt_logprob_indices) > 0:
|
||||
assert self.sampling_params.prompt_logprobs is not None
|
||||
if self.is_prompt:
|
||||
assert self.prompt_len is not None
|
||||
assert self.subquery_len is not None
|
||||
assert self.seq_len is not None
|
||||
assert self.query_len is not None
|
||||
|
||||
|
||||
class SamplingMetadata:
|
||||
@ -94,8 +103,8 @@ class SamplingMetadata:
|
||||
@staticmethod
|
||||
def prepare(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
subquery_lens: Optional[List[int]],
|
||||
seq_lens: List[int],
|
||||
query_lens: Optional[List[int]],
|
||||
device: str,
|
||||
pin_memory: bool,
|
||||
) -> "SamplingMetadata":
|
||||
@ -104,8 +113,8 @@ class SamplingMetadata:
|
||||
selected_token_indices,
|
||||
categorized_sample_indices,
|
||||
num_prompts,
|
||||
) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens,
|
||||
subquery_lens, device)
|
||||
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
|
||||
device)
|
||||
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
target_device=device,
|
||||
@ -137,8 +146,8 @@ class SamplingMetadata:
|
||||
|
||||
def _prepare_seq_groups(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
subquery_lens: Optional[List[int]],
|
||||
seq_lens: List[int],
|
||||
query_lens: Optional[List[int]],
|
||||
device: str,
|
||||
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
|
||||
SamplingType, List[Tuple[int, int]]], int]:
|
||||
@ -146,9 +155,9 @@ def _prepare_seq_groups(
|
||||
|
||||
Args:
|
||||
seq_group_metadata_list: A list of sequence group to batch.
|
||||
prompt_lens: A list of prompt lens per sequence group.
|
||||
seq_lens: A list of sequence lens per sequence group.
|
||||
Index of prompt len should match with seq_group_metadata_list.
|
||||
subquery_lens: A list of query lengths. Prompt lens include the length
|
||||
query_lens: A list of query lengths. Prompt lens include the length
|
||||
of entire prompt tokens, and it could be shorter.
|
||||
device: A device to use for random number generator,
|
||||
`SequenceGroupToSample.generator`.
|
||||
@ -189,8 +198,8 @@ def _prepare_seq_groups(
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
generator: Optional[torch.Generator] = None
|
||||
# If the current seq group is in decode stage, it is None.
|
||||
prompt_len: Optional[int] = None
|
||||
subquery_len: Optional[int] = None
|
||||
seq_len: Optional[int] = None
|
||||
query_len: Optional[int] = None
|
||||
prompt_logprob_indices: List[int] = []
|
||||
sample_indices: List[int] = []
|
||||
do_sample = seq_group_metadata.do_sample
|
||||
@ -203,12 +212,12 @@ def _prepare_seq_groups(
|
||||
num_prompts += 1
|
||||
num_prefill_sample = len(seq_ids)
|
||||
assert num_prefill_sample == 1
|
||||
assert subquery_lens is not None and prompt_lens is not None
|
||||
subquery_len, prompt_len = subquery_lens[i], prompt_lens[i]
|
||||
assert query_lens is not None and seq_lens is not None
|
||||
query_len, seq_len = query_lens[i], seq_lens[i]
|
||||
# If we need sampling, exclude num_prefill_sample tokens from
|
||||
# prompt logprob.
|
||||
prompt_logprob_len = (subquery_len - num_prefill_sample
|
||||
if do_sample else subquery_len)
|
||||
prompt_logprob_len = (query_len - num_prefill_sample
|
||||
if do_sample else query_len)
|
||||
sample_len = num_prefill_sample if do_sample else 0
|
||||
else:
|
||||
# Decode
|
||||
@ -267,8 +276,8 @@ def _prepare_seq_groups(
|
||||
seq_ids=seq_ids,
|
||||
sampling_params=sampling_params,
|
||||
seq_data=seq_group_metadata.seq_data,
|
||||
prompt_len=prompt_len,
|
||||
subquery_len=subquery_len,
|
||||
seq_len=seq_len,
|
||||
query_len=query_len,
|
||||
generator=generator,
|
||||
is_prompt=is_prompt,
|
||||
prompt_logprob_indices=list(prompt_logprob_indices),
|
||||
@ -367,8 +376,8 @@ class SamplingTensors:
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# For tokens in the prompt that we only need to get
|
||||
# their logprobs
|
||||
subquery_len = seq_group.subquery_len
|
||||
assert subquery_len is not None
|
||||
query_len = seq_group.query_len
|
||||
assert query_len is not None
|
||||
prefill_len = len(seq_group.prompt_logprob_indices)
|
||||
temperatures += [temperature] * prefill_len
|
||||
top_ps += [top_p] * prefill_len
|
||||
@ -397,8 +406,8 @@ class SamplingTensors:
|
||||
|
||||
if is_prompt:
|
||||
prompt_best_of.append(sampling_params.best_of)
|
||||
subquery_len = seq_group.subquery_len
|
||||
assert subquery_len is not None
|
||||
query_len = seq_group.query_len
|
||||
assert query_len is not None
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
|
@ -80,7 +80,7 @@ class CPUModelRunner:
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
prompt_lens: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_input_list: List[torch.Tensor] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
@ -92,15 +92,15 @@ class CPUModelRunner:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
prompt_len = len(prompt_tokens)
|
||||
seq_len = len(prompt_tokens)
|
||||
|
||||
prompt_lens.append(prompt_len) # Prompt token num
|
||||
seq_lens.append(seq_len) # Prompt token num
|
||||
input_tokens.extend(prompt_tokens) # Token ids
|
||||
|
||||
# Token position ids
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.extend(list(range(computed_len, prompt_len)))
|
||||
input_positions.extend(list(range(computed_len, seq_len)))
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
multi_modal_input_list.append(
|
||||
@ -109,15 +109,15 @@ class CPUModelRunner:
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
# where start_idx is max(0, prompt_len - sliding_window).
|
||||
# where start_idx is max(0, seq_len - sliding_window).
|
||||
# For example, if the prompt len is 10, sliding window is 8, and
|
||||
# block size is 4, the first two tokens are masked and the slot
|
||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
start_idx = max(0, prompt_len - self.sliding_window)
|
||||
start_idx = max(0, seq_len - self.sliding_window)
|
||||
|
||||
for i in range(computed_len, prompt_len):
|
||||
for i in range(computed_len, seq_len):
|
||||
if i < start_idx:
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
continue
|
||||
@ -151,19 +151,19 @@ class CPUModelRunner:
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
prompt_lens=prompt_lens,
|
||||
num_prefills=len(prompt_lens),
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=None,
|
||||
max_seq_len=None,
|
||||
num_prefills=len(seq_lens),
|
||||
num_prefill_tokens=num_prompt_tokens,
|
||||
num_decode_tokens=0,
|
||||
prefill_metadata=None,
|
||||
decode_metadata=None,
|
||||
max_context_len=None,
|
||||
context_lens=None,
|
||||
block_tables=torch.tensor([]),
|
||||
slot_mapping=slot_mapping,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
return (input_tokens, input_positions, attn_metadata, prompt_lens,
|
||||
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_input)
|
||||
|
||||
def _prepare_decode(
|
||||
@ -174,7 +174,7 @@ class CPUModelRunner:
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
context_lens: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
@ -192,9 +192,9 @@ class CPUModelRunner:
|
||||
position = seq_len - 1
|
||||
input_positions.append(position)
|
||||
|
||||
context_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
context_lens.append(context_len)
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_number = block_table[position // self.block_size]
|
||||
@ -208,7 +208,7 @@ class CPUModelRunner:
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
max_context_len = max(context_lens)
|
||||
max_seq_len = max(seq_lens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
@ -219,7 +219,7 @@ class CPUModelRunner:
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
context_lens = torch.tensor(context_lens,
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
@ -236,14 +236,14 @@ class CPUModelRunner:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping,
|
||||
prompt_lens=None,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_seq_len=max_seq_len,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=len(input_tokens),
|
||||
max_context_len=max_context_len,
|
||||
num_prefills=0,
|
||||
prefill_metadata=None,
|
||||
decode_metadata=None,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
@ -265,20 +265,20 @@ class CPUModelRunner:
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata, prompt_lens,
|
||||
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_input
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
# subquery_lens is not needed if chunked prefill is not
|
||||
seq_lens,
|
||||
# query_lens is not needed if chunked prefill is not
|
||||
# supported. Since CPU worker doesn't support chunked prefill
|
||||
# just use prompt_lens instead.
|
||||
prompt_lens,
|
||||
# just use seq_lens instead.
|
||||
seq_lens,
|
||||
self.device,
|
||||
pin_memory=False)
|
||||
# Broadcast the metadata.
|
||||
@ -300,7 +300,7 @@ class CPUModelRunner:
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
seq_data=None,
|
||||
prompt_lens=None,
|
||||
seq_lens=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
generators=None,
|
||||
|
@ -42,8 +42,8 @@ class PreparePromptMetadata(NamedTuple):
|
||||
input_tokens: List[int]
|
||||
input_positions: List[int]
|
||||
attn_metadata: Optional[AttentionMetadataPerStage]
|
||||
prompt_lens: List[int]
|
||||
subquery_lens: List[int]
|
||||
seq_lens: List[int]
|
||||
query_lens: List[int]
|
||||
lora_index_mapping: List[int]
|
||||
lora_prompt_mapping: List[int]
|
||||
lora_requests: Set[LoRARequest]
|
||||
@ -56,8 +56,8 @@ class PreparePromptMetadata(NamedTuple):
|
||||
input_tokens=[],
|
||||
input_positions=[],
|
||||
attn_metadata=None,
|
||||
prompt_lens=[],
|
||||
subquery_lens=[],
|
||||
seq_lens=[],
|
||||
query_lens=[],
|
||||
lora_index_mapping=[],
|
||||
lora_prompt_mapping=[],
|
||||
lora_requests=set(),
|
||||
@ -134,8 +134,7 @@ class ModelRunner:
|
||||
self.graph_memory_pool: Optional[Tuple[
|
||||
int, int]] = None # Set during graph capture.
|
||||
|
||||
self.max_context_len_to_capture = (
|
||||
self.model_config.max_context_len_to_capture
|
||||
self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture
|
||||
if self.model_config is not None else 0)
|
||||
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
@ -149,7 +148,7 @@ class ModelRunner:
|
||||
self.model: torch.nn.Module # Set after load_model
|
||||
self.block_size: int # Set after initial profiling.
|
||||
# When using CUDA graph, the input block tables must be padded to
|
||||
# max_context_len_to_capture. However, creating the block table in
|
||||
# max_seq_len_to_capture. However, creating the block table in
|
||||
# Python can be expensive. To optimize this, we cache the block table
|
||||
# in numpy and only copy the actual input content at every iteration.
|
||||
# The shape of the cached block table will be
|
||||
@ -218,7 +217,7 @@ class ModelRunner:
|
||||
|
||||
def get_max_block_per_batch(self) -> int:
|
||||
block_size = self.block_size
|
||||
return (self.max_context_len_to_capture + block_size - 1) // block_size
|
||||
return (self.max_seq_len_to_capture + block_size - 1) // block_size
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
@ -231,9 +230,9 @@ class ModelRunner:
|
||||
lora_prompt_mapping: List[int] = []
|
||||
lora_requests: Set[LoRARequest] = set()
|
||||
|
||||
prompt_lens: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
context_lens: List[int] = []
|
||||
subquery_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
prefix_block_tables: List[List[int]] = []
|
||||
multi_modal_input_list: List[torch.Tensor] = []
|
||||
|
||||
@ -257,21 +256,19 @@ class ModelRunner:
|
||||
|
||||
token_chunk_size = seq_group_metadata.token_chunk_size
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
# We should use get_len here because in case of preemption
|
||||
# it contains output tokens.
|
||||
prefill_end = min(seq_data.get_len(),
|
||||
computed_len + token_chunk_size)
|
||||
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
|
||||
prompt_len = prefill_end
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
|
||||
prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
if computed_block_nums is not None and len(
|
||||
computed_block_nums) > 0 and self.sliding_window is None:
|
||||
# Prefix is not supported with sliding_window
|
||||
computed_len = len(computed_block_nums) * self.block_size
|
||||
prompt_tokens = prompt_tokens[computed_len:]
|
||||
context_len = len(computed_block_nums) * self.block_size
|
||||
prompt_tokens = prompt_tokens[context_len:]
|
||||
prefix_block_tables.append(computed_block_nums)
|
||||
elif self.scheduler_config.chunked_prefill_enabled:
|
||||
if seq_group_metadata.block_tables is not None:
|
||||
@ -285,25 +282,25 @@ class ModelRunner:
|
||||
prefix_block_tables.append([])
|
||||
# Right now, prefill start is always 0. However, this
|
||||
# assumption can be changed once chunked prefill is introduced.
|
||||
assert computed_len == 0
|
||||
assert context_len == 0
|
||||
|
||||
# actual prompt lens
|
||||
context_lens.append(computed_len)
|
||||
subquery_lens.append(prompt_len - computed_len)
|
||||
context_lens.append(context_len)
|
||||
query_lens.append(seq_len - context_len)
|
||||
|
||||
input_tokens.extend(prompt_tokens)
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.extend(list(range(computed_len, prefill_end)))
|
||||
input_positions.extend(list(range(context_len, seq_len)))
|
||||
lora_id = seq_group_metadata.lora_int_id
|
||||
|
||||
if lora_id > 0:
|
||||
lora_requests.add(seq_group_metadata.lora_request)
|
||||
|
||||
lora_index_mapping += [lora_id] * (prompt_len - computed_len)
|
||||
lora_index_mapping += [lora_id] * (seq_len - context_len)
|
||||
lora_prompt_mapping.extend(
|
||||
[lora_id] *
|
||||
(prompt_len - computed_len
|
||||
(seq_len - context_len
|
||||
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
@ -313,24 +310,24 @@ class ModelRunner:
|
||||
if seq_group_metadata.block_tables is None:
|
||||
# During memory profiling, the block tables are not initialized
|
||||
# yet. In this case, we just use a dummy slot mapping.
|
||||
slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
|
||||
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||
continue
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
# where start_idx is max(0, prompt_len - sliding_window).
|
||||
# where start_idx is max(0, seq_len - sliding_window).
|
||||
# For example, if the prompt len is 10, sliding window is 8, and
|
||||
# block size is 4, the first two tokens are masked and the slot
|
||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
assert computed_len == 0, (
|
||||
assert context_len == 0, (
|
||||
"Prefix caching is currently not supported with "
|
||||
"sliding window attention")
|
||||
start_idx = max(0, prompt_len - self.sliding_window)
|
||||
start_idx = max(0, seq_len - self.sliding_window)
|
||||
|
||||
for i in range(computed_len, prefill_end):
|
||||
for i in range(context_len, seq_len):
|
||||
if i < start_idx:
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
continue
|
||||
@ -340,9 +337,9 @@ class ModelRunner:
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
max_subquery_len = max(subquery_lens)
|
||||
max_prompt_len = max(prompt_lens)
|
||||
assert max_subquery_len > 0
|
||||
max_query_len = max(query_lens)
|
||||
max_seq_len = max(seq_lens)
|
||||
assert max_query_len > 0
|
||||
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
@ -369,40 +366,39 @@ class ModelRunner:
|
||||
|
||||
# Query length can be shorter than key (i.e., prompt) when prefill
|
||||
# is chunked or prefix cached.
|
||||
subquery_lens_tensor = torch.tensor(subquery_lens,
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
|
||||
subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
prompt_lens_tensor = torch.tensor(prompt_lens,
|
||||
dtype=torch.long,
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
torch.cumsum(subquery_lens_tensor,
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=subquery_start_loc.dtype,
|
||||
out=subquery_start_loc[1:])
|
||||
|
||||
torch.cumsum(prompt_lens_tensor,
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
prompt_lens=prompt_lens,
|
||||
prompt_lens_tensor=prompt_lens_tensor,
|
||||
max_subquery_len=max_subquery_len,
|
||||
max_context_len=None,
|
||||
max_prompt_len=max_prompt_len,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
subquery_start_loc=subquery_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens=context_lens_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
@ -411,8 +407,8 @@ class ModelRunner:
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
prompt_lens=prompt_lens,
|
||||
subquery_lens=subquery_lens,
|
||||
seq_lens=seq_lens,
|
||||
query_lens=query_lens,
|
||||
lora_index_mapping=lora_index_mapping,
|
||||
lora_prompt_mapping=lora_prompt_mapping,
|
||||
lora_requests=lora_requests,
|
||||
@ -427,7 +423,7 @@ class ModelRunner:
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
context_lens: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
lora_index_mapping: List[int] = []
|
||||
lora_prompt_mapping: List[int] = []
|
||||
@ -455,9 +451,9 @@ class ModelRunner:
|
||||
position = seq_len - 1
|
||||
input_positions.append(position)
|
||||
|
||||
context_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
context_lens.append(context_len)
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_number = block_table[position // self.block_size]
|
||||
@ -477,11 +473,10 @@ class ModelRunner:
|
||||
# See `capture_model` API for more details.
|
||||
# For decoding requests, batch_size == input_tokens.
|
||||
batch_size = len(input_tokens)
|
||||
max_context_len = max(context_lens)
|
||||
use_captured_graph = (
|
||||
not self.model_config.enforce_eager
|
||||
max_seq_len = max(seq_lens)
|
||||
use_captured_graph = (not self.model_config.enforce_eager
|
||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
and max_context_len <= self.max_context_len_to_capture)
|
||||
and max_seq_len <= self.max_seq_len_to_capture)
|
||||
if use_captured_graph:
|
||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||
assert graph_batch_size >= batch_size
|
||||
@ -489,21 +484,21 @@ class ModelRunner:
|
||||
input_tokens.append(0)
|
||||
input_positions.append(0)
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
context_lens.append(1)
|
||||
seq_lens.append(1)
|
||||
block_tables.append([])
|
||||
lora_index_mapping.append(0)
|
||||
batch_size = graph_batch_size
|
||||
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
if use_captured_graph:
|
||||
# When using cuda-graph all these tensors should be
|
||||
# padded.
|
||||
assert context_lens_tensor.shape[0] == len(input_tokens)
|
||||
assert context_lens_tensor.shape[0] == len(input_positions)
|
||||
assert context_lens_tensor.shape[0] == len(slot_mapping)
|
||||
assert seq_lens_tensor.shape[0] == len(input_tokens)
|
||||
assert seq_lens_tensor.shape[0] == len(input_positions)
|
||||
assert seq_lens_tensor.shape[0] == len(slot_mapping)
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
@ -525,14 +520,13 @@ class ModelRunner:
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
prompt_lens=None,
|
||||
prompt_lens_tensor=None,
|
||||
max_subquery_len=None,
|
||||
max_context_len=max_context_len,
|
||||
max_prompt_len=None,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=None,
|
||||
max_seq_len=max_seq_len,
|
||||
subquery_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens=context_lens_tensor,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
@ -565,8 +559,8 @@ class ModelRunner:
|
||||
input_tokens,
|
||||
input_positions,
|
||||
prefill_attn_metadata,
|
||||
prompt_lens,
|
||||
subquery_lens,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
lora_index_mapping,
|
||||
lora_prompt_mapping,
|
||||
lora_requests,
|
||||
@ -583,13 +577,13 @@ class ModelRunner:
|
||||
decode_slot_mapping,
|
||||
) = self._prepare_decode(decode_reqs)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, prompt_lens, subquery_lens,
|
||||
self.device, self.pin_memory)
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.pin_memory)
|
||||
|
||||
if not self.scheduler_config.chunked_prefill_enabled:
|
||||
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
||||
|
||||
num_prefills = len(prompt_lens)
|
||||
num_prefills = len(seq_lens)
|
||||
num_prefill_tokens = len(input_tokens)
|
||||
num_decode_tokens = len(decode_input_tokens)
|
||||
|
||||
@ -886,7 +880,7 @@ class ModelRunner:
|
||||
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
|
||||
slot_mapping.fill_(_PAD_SLOT_ID)
|
||||
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
||||
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
|
||||
graph_batch_size = _get_graph_batch_size(
|
||||
@ -908,14 +902,13 @@ class ModelRunner:
|
||||
# Create dummy attn_metadata.
|
||||
decode_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
prompt_lens=None,
|
||||
prompt_lens_tensor=None,
|
||||
max_subquery_len=None,
|
||||
max_context_len=self.max_context_len_to_capture,
|
||||
max_prompt_len=None,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens[:batch_size],
|
||||
max_query_len=None,
|
||||
max_seq_len=self.max_seq_len_to_capture,
|
||||
subquery_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens=context_lens[:batch_size],
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
@ -1025,7 +1018,7 @@ class CUDAGraphRunner:
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"context_lens": attn_metadata.decode_metadata.context_lens,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
self.output_buffers = {"hidden_states": hidden_states}
|
||||
@ -1047,8 +1040,8 @@ class CUDAGraphRunner:
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
|
||||
non_blocking=True)
|
||||
self.input_buffers["context_lens"].copy_(
|
||||
attn_metadata.decode_metadata.context_lens, non_blocking=True)
|
||||
self.input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
self.input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
# Run the graph.
|
||||
|
@ -52,7 +52,7 @@ class NeuronModelRunner:
|
||||
input_positions: List[List[int]] = []
|
||||
input_block_ids: List[int] = []
|
||||
|
||||
prompt_lens: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
@ -61,26 +61,26 @@ class NeuronModelRunner:
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_len = len(prompt_tokens)
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
input_tokens.append(prompt_tokens)
|
||||
input_positions.append(list(range(prompt_len)))
|
||||
input_positions.append(list(range(seq_len)))
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
assert len(block_table) == 1
|
||||
input_block_ids.append(block_table[0])
|
||||
|
||||
max_prompt_len = max(prompt_lens)
|
||||
assert max_prompt_len > 0
|
||||
max_seq_len = max(seq_lens)
|
||||
assert max_seq_len > 0
|
||||
input_tokens = make_tensor_with_pad(input_tokens,
|
||||
max_prompt_len,
|
||||
max_seq_len,
|
||||
pad=0,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = make_tensor_with_pad(input_positions,
|
||||
max_prompt_len,
|
||||
max_seq_len,
|
||||
pad=0,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
@ -88,7 +88,7 @@ class NeuronModelRunner:
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
return input_tokens, input_positions, input_block_ids, prompt_lens
|
||||
return input_tokens, input_positions, input_block_ids, seq_lens
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
@ -149,18 +149,18 @@ class NeuronModelRunner:
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, input_block_ids,
|
||||
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
|
||||
seq_lens) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
|
||||
prompt_lens = []
|
||||
seq_lens = []
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
# subquery_lens is not needed if chunked prefill is not
|
||||
seq_lens,
|
||||
# query_lens is not needed if chunked prefill is not
|
||||
# supported. Since neuron worker doesn't support chunked prefill
|
||||
# just use prompt_lens instead.
|
||||
prompt_lens,
|
||||
# just use seq_lens instead.
|
||||
seq_lens,
|
||||
self.device,
|
||||
self.pin_memory)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user