Change scheduler & input tensor shape (#1381)

This commit is contained in:
Woosuk Kwon 2023-10-16 17:48:42 -07:00 committed by GitHub
parent 651c614aa4
commit c1376e0f82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 180 additions and 178 deletions

View File

@ -13,8 +13,8 @@ __device__ __forceinline__ T silu(const T& x) {
template<typename scalar_t> template<typename scalar_t>
__global__ void silu_and_mul_kernel( __global__ void silu_and_mul_kernel(
scalar_t* __restrict__ out, // [num_tokens, d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [num_tokens, 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
@ -27,11 +27,11 @@ __global__ void silu_and_mul_kernel(
} // namespace vllm } // namespace vllm
void silu_and_mul( void silu_and_mul(
torch::Tensor& out, // [num_tokens, d] torch::Tensor& out, // [..., d]
torch::Tensor& input) // [num_tokens, 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
int num_tokens = input.size(0); int num_tokens = input.numel() / input.size(-1);
int d = input.size(1) / 2; int d = input.size(-1) / 2;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(d, 1024)); dim3 block(std::min(d, 1024));
@ -52,8 +52,8 @@ namespace vllm {
// Element-wise activation kernel template. // Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel( __global__ void activation_kernel(
scalar_t* __restrict__ out, // [num_tokens, d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [num_tokens, d] const scalar_t* __restrict__ input, // [..., d]
const int d) { const int d) {
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
@ -66,8 +66,8 @@ __global__ void activation_kernel(
// Launch element-wise activation kernel. // Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int num_tokens = input.size(0); \ int d = input.size(-1); \
int d = input.size(1); \ int num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
@ -100,15 +100,15 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
} // namespace vllm } // namespace vllm
void gelu_new( void gelu_new(
torch::Tensor& out, // [num_tokens, d] torch::Tensor& out, // [..., d]
torch::Tensor& input) // [num_tokens, d] torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
} }
void gelu_fast( void gelu_fast(
torch::Tensor& out, // [num_tokens, d] torch::Tensor& out, // [..., d]
torch::Tensor& input) // [num_tokens, d] torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
} }

View File

@ -154,6 +154,11 @@ __global__ void reshape_and_cache_kernel(
const int x) { const int x) {
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
const int slot_idx = slot_mapping[token_idx]; const int slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
const int block_idx = slot_idx / block_size; const int block_idx = slot_idx / block_size;
const int block_offset = slot_idx % block_size; const int block_offset = slot_idx % block_size;
@ -176,8 +181,8 @@ __global__ void reshape_and_cache_kernel(
+ head_idx * head_size * block_size + head_idx * head_size * block_size
+ head_offset * block_size + head_offset * block_size
+ block_offset; + block_offset;
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]); key_cache[tgt_key_idx] = key[src_key_idx];
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); value_cache[tgt_value_idx] = value[src_value_idx];
} }
} }

View File

@ -9,8 +9,8 @@ namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
template<typename scalar_t> template<typename scalar_t>
__global__ void rms_norm_kernel( __global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [num_tokens, hidden_size] scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [num_tokens, hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon,
const int num_tokens, const int num_tokens,
@ -37,12 +37,12 @@ __global__ void rms_norm_kernel(
} // namespace vllm } // namespace vllm
void rms_norm( void rms_norm(
torch::Tensor& out, // [num_tokens, hidden_size] torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [num_tokens, hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
float epsilon) { float epsilon) {
int num_tokens = input.size(0); int hidden_size = input.size(-1);
int hidden_size = input.size(1); int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));

View File

@ -37,9 +37,9 @@ inline __device__ void apply_rotary_embedding(
template<typename scalar_t, bool IS_NEOX> template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel( __global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int rot_dim,
const int query_stride, const int query_stride,
@ -78,18 +78,18 @@ __global__ void rotary_embedding_kernel(
} // namespace vllm } // namespace vllm
void rotary_embedding( void rotary_embedding(
torch::Tensor& positions, // [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int head_size, int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox) {
int num_tokens = query.size(0); int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size; int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(1) / head_size; int num_kv_heads = key.size(-1) / head_size;
int query_stride = query.stride(0); int query_stride = query.stride(-2);
int key_stride = key.stride(0); int key_stride = key.stride(-2);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));

View File

@ -268,6 +268,7 @@ class SchedulerConfig:
iteration. iteration.
max_model_len: Maximum length of a sequence (including prompt max_model_len: Maximum length of a sequence (including prompt
and generated text). and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
""" """
def __init__( def __init__(
@ -275,6 +276,7 @@ class SchedulerConfig:
max_num_batched_tokens: Optional[int], max_num_batched_tokens: Optional[int],
max_num_seqs: int, max_num_seqs: int,
max_model_len: int, max_model_len: int,
max_paddings: int,
) -> None: ) -> None:
if max_num_batched_tokens is not None: if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
@ -284,6 +286,7 @@ class SchedulerConfig:
self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.max_paddings = max_paddings
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:

View File

@ -131,7 +131,8 @@ class Scheduler:
# requests in the generation phase. # requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs() num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running) for seq_group in self.running)
num_batched_tokens = 0 seq_lens: List[int] = []
# Optimization: We do not sort the waiting queue since the preempted # Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups # sequence groups are added to the front and the new sequence groups
# are added to the back. # are added to the back.
@ -157,7 +158,9 @@ class Scheduler:
break break
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens > new_seq_lens = seq_lens + [num_prompt_tokens]
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
if (num_batched_tokens >
self.scheduler_config.max_num_batched_tokens): self.scheduler_config.max_num_batched_tokens):
break break
@ -168,10 +171,14 @@ class Scheduler:
self.scheduler_config.max_num_seqs): self.scheduler_config.max_num_seqs):
break break
num_paddings = num_batched_tokens - sum(new_seq_lens)
if num_paddings > self.scheduler_config.max_paddings:
break
seq_lens = new_seq_lens
seq_group = self.waiting.pop(0) seq_group = self.waiting.pop(0)
self._allocate(seq_group) self._allocate(seq_group)
self.running.append(seq_group) self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
scheduled.append(seq_group) scheduled.append(seq_group)
@ -179,7 +186,7 @@ class Scheduler:
scheduler_outputs = SchedulerOutputs( scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled, scheduled_seq_groups=scheduled,
prompt_run=True, prompt_run=True,
num_batched_tokens=num_batched_tokens, num_batched_tokens=len(seq_lens) * max(seq_lens),
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,

View File

@ -27,6 +27,7 @@ class EngineArgs:
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256 max_num_seqs: int = 256
max_paddings: int = 256
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
@ -156,6 +157,10 @@ class EngineArgs:
type=int, type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration') help='maximum number of sequences per iteration')
parser.add_argument('--max-paddings',
type=int,
default=EngineArgs.max_paddings,
help='maximum number of paddings in a batch')
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
help='disable logging statistics') help='disable logging statistics')
@ -193,7 +198,8 @@ class EngineArgs:
self.worker_use_ray) self.worker_use_ray)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, self.max_num_seqs,
model_config.max_model_len) model_config.max_model_len,
self.max_paddings)
return model_config, cache_config, parallel_config, scheduler_config return model_config, cache_config, parallel_config, scheduler_config

View File

@ -39,11 +39,12 @@ class InputMetadata:
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.block_tables = block_tables self.block_tables = block_tables
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.to_cache = None self.to_cache = None
if sliding_window is not None: if sliding_window is not None:
# We need to keep the positions of sliding windows within # We need to keep the positions of sliding windows within
# the key / value tables, this is helpful to know which # the key / value tables, this is helpful to know which
# elements we need to cache and where # elements we need to cache.
to_cache, start_idx = [], 0 to_cache, start_idx = [], 0
for prompt_len in self.prompt_lens: for prompt_len in self.prompt_lens:
to_cache.extend( to_cache.extend(
@ -51,16 +52,15 @@ class InputMetadata:
start_idx + max(0, prompt_len - sliding_window), start_idx + max(0, prompt_len - sliding_window),
start_idx + prompt_len, start_idx + prompt_len,
)) ))
start_idx += prompt_len start_idx += self.max_prompt_len
to_cache.extend(range(start_idx, slot_mapping.shape[0])) to_cache.extend(range(start_idx, slot_mapping.shape[0]))
self.to_cache = torch.tensor(to_cache, self.to_cache = torch.tensor(to_cache,
dtype=torch.int32, dtype=torch.int32,
device=self.slot_mapping.device) device=self.slot_mapping.device)
self.num_prompts = len(prompt_lens) self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens) self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
self.num_generation_tokens = context_lens.shape[0] self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0: if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1] self.max_num_blocks_per_seq = block_tables.shape[1]
else: else:
@ -69,12 +69,11 @@ class InputMetadata:
assert context_lens.shape[0] == self.num_generation_tokens assert context_lens.shape[0] == self.num_generation_tokens
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
self.attn_bias: List[AttentionBias] = [] self.attn_bias: Optional[AttentionBias] = None
def __repr__(self) -> str: def __repr__(self) -> str:
# Print only useful metadata. # Print only useful metadata.
return (f'InputMetadata(' return (f'InputMetadata('
f'num_valid_tokens={self.num_valid_tokens}, '
f'num_prompt_tokens={self.num_prompt_tokens}, ' f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, ' f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, ' f'prompt_lens={self.prompt_lens}, '

View File

@ -8,17 +8,17 @@ from vllm import activation_ops
class SiluAndMul(nn.Module): class SiluAndMul(nn.Module):
"""An activation function for SwiGLU. """An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes: Shapes:
x: (num_tokens, 2 * d) x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (num_tokens, d) return: (batch_size, seq_len, d) or (num_tokens, d)
""" """
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0] d = x.shape[-1] // 2
d = x.shape[1] // 2 output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x) activation_ops.silu_and_mul(out, x)
return out return out
@ -26,9 +26,7 @@ class SiluAndMul(nn.Module):
class NewGELU(nn.Module): class NewGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0] out = torch.empty_like(x)
d = x.shape[1]
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.gelu_new(out, x) activation_ops.gelu_new(out, x)
return out return out
@ -36,9 +34,7 @@ class NewGELU(nn.Module):
class FastGELU(nn.Module): class FastGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0] out = torch.empty_like(x)
d = x.shape[1]
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.gelu_fast(out, x) activation_ops.gelu_fast(out, x)
return out return out

View File

@ -23,25 +23,9 @@ class PagedAttention(nn.Module):
# pylint: disable=line-too-long # pylint: disable=line-too-long
"""GPT-style multi-head PagedAttention. """GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The This class takes query, key, and value tensors as input. The input tensors
input 1D tensors can either contain prompt tokens or generation tokens, in can either contain prompt tokens or generation tokens, in addition to
addition to paddings. paddings.
If the input tensors contain prompt tokens, the layout is as follows:
|<---------------------- num_valid_tokens ---------------------->|
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
Otherwise, the layout is as follows:
|<------------------ num_valid_tokens ------------------->|
|<------- num_generation_tokens (M) ------->|
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple
of 8, which is desirable for Tensor Cores.
The class does the following: The class does the following:
1. Perform multi_query_kv_attention for the prompts. This operation does 1. Perform multi_query_kv_attention for the prompts. This operation does
@ -53,7 +37,7 @@ class PagedAttention(nn.Module):
4. Perform single_query_cached_kv_attention for the generation tokens. 4. Perform single_query_cached_kv_attention for the generation tokens.
This operation reads the previous key and value tensors from the KV This operation reads the previous key and value tensors from the KV
cache. cache.
5. Output a flattened 1D tensor. 5. Return the output tensor.
""" """
def __init__(self, def __init__(self,
@ -85,14 +69,15 @@ class PagedAttention(nn.Module):
dtype: torch.dtype, dtype: torch.dtype,
) -> None: ) -> None:
del dtype # Unused. del dtype # Unused.
if input_metadata.attn_bias: if input_metadata.attn_bias is not None:
# Already set by a previous layer. # Already set by a previous layer.
return return
prompt_lens = input_metadata.prompt_lens prompt_lens = [input_metadata.max_prompt_len
] * input_metadata.num_prompts
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens) attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
if self.sliding_window is not None: if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(self.sliding_window) attn_bias = attn_bias.make_local_attention(self.sliding_window)
input_metadata.attn_bias.append(attn_bias) input_metadata.attn_bias = attn_bias
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
@ -111,7 +96,6 @@ class PagedAttention(nn.Module):
value: shape = [num_prompt_tokens, num_kv_heads, head_size] value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention. input_metadata: metadata for paged attention.
""" """
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads. # Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
@ -124,7 +108,7 @@ class PagedAttention(nn.Module):
query.unsqueeze(0), query.unsqueeze(0),
key.unsqueeze(0), key.unsqueeze(0),
value.unsqueeze(0), value.unsqueeze(0),
attn_bias=input_metadata.attn_bias[0], attn_bias=input_metadata.attn_bias,
p=0.0, p=0.0,
scale=self.scale, scale=self.scale,
) )
@ -232,12 +216,12 @@ class PagedAttention(nn.Module):
"""PagedAttention forward pass. """PagedAttention forward pass.
NOTE: The query, key, and value tensors must be sliced from a qkv NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [num_tokens, 3 * num_heads * head_size]. tensor of shape [batch_size, seq_len, 3 * num_heads * head_size].
Args: Args:
query: shape = [num_tokens, num_heads * head_size] query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [batch_size, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x, key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x] block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size, value_cache: shape = [num_blocks, num_kv_heads, head_size,
@ -246,9 +230,9 @@ class PagedAttention(nn.Module):
cache_event: event to wait for the cache operations to finish. cache_event: event to wait for the cache operations to finish.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
batch_size, seq_len, _ = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
@ -264,10 +248,10 @@ class PagedAttention(nn.Module):
assert input_metadata.num_generation_tokens == 0 assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata, dtype=query.dtype) self.set_attn_bias(input_metadata, dtype=query.dtype)
self.multi_query_kv_attention( self.multi_query_kv_attention(
output[:num_prompt_tokens], output,
query[:num_prompt_tokens], query,
key[:num_prompt_tokens], key,
value[:num_prompt_tokens], value,
input_metadata, input_metadata,
) )
@ -278,13 +262,10 @@ class PagedAttention(nn.Module):
# Reshape the keys and values and store them in the cache. # Reshape the keys and values and store them in the cache.
# When key_cache and value_cache are not provided, the new key # When key_cache and value_cache are not provided, the new key
# and value vectors will not be cached. # and value vectors will not be cached.
num_valid_tokens = input_metadata.num_valid_tokens if key_cache is not None and value_cache is not None:
if (num_valid_tokens > 0 and key_cache is not None key_to_cache = key
and value_cache is not None): value_to_cache = value
# The stride is 3 because the key and value are sliced from qkv. slot_mapping = input_metadata.slot_mapping.view(-1)
key_to_cache = key[:num_valid_tokens]
value_to_cache = value[:num_valid_tokens]
slot_mapping = input_metadata.slot_mapping
if input_metadata.to_cache is not None: if input_metadata.to_cache is not None:
key_to_cache = key_to_cache[input_metadata.to_cache] key_to_cache = key_to_cache[input_metadata.to_cache]
value_to_cache = value_to_cache[input_metadata.to_cache] value_to_cache = value_to_cache[input_metadata.to_cache]
@ -305,14 +286,14 @@ class PagedAttention(nn.Module):
"key_cache and value_cache must be provided when " "key_cache and value_cache must be provided when "
"generating tokens.") "generating tokens.")
# Compute the attention op for generation tokens. # Compute the attention op for generation tokens.
self.single_query_cached_kv_attention( self.single_query_cached_kv_attention(output, query, key_cache,
output[num_prompt_tokens:num_valid_tokens], value_cache, input_metadata,
query[num_prompt_tokens:num_valid_tokens], key_cache, self.get_alibi_slopes())
value_cache, input_metadata, self.get_alibi_slopes())
# Reshape the output tensor. # Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings. # NOTE(woosuk): The output tensor may include paddings.
return output.view(-1, self.num_heads * self.head_size) return output.view(batch_size, seq_len,
self.num_heads * self.head_size)
class PagedAttentionWithRoPE(PagedAttention): class PagedAttentionWithRoPE(PagedAttention):
@ -368,10 +349,10 @@ class PagedAttentionWithRoPE(PagedAttention):
""" PagedAttention forward pass with rotary embedding. """ PagedAttention forward pass with rotary embedding.
Args: Args:
positions: shape = [num_tokens] positions: shape = [batch_size, seq_len]
query: shape = [num_tokens, num_heads * head_size] query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x, key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x] block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size, value_cache: shape = [num_blocks, num_kv_heads, head_size,
@ -380,7 +361,7 @@ class PagedAttentionWithRoPE(PagedAttention):
cache_event: event to wait for the cache operations to finish. cache_event: event to wait for the cache operations to finish.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
# Apply rotary embedding to the query and key before passing them # Apply rotary embedding to the query and key before passing them
@ -414,34 +395,34 @@ class PagedAttentionWithALiBi(PagedAttention):
def set_attn_bias(self, input_metadata: InputMetadata, def set_attn_bias(self, input_metadata: InputMetadata,
dtype: torch.dtype) -> None: dtype: torch.dtype) -> None:
if input_metadata.attn_bias: if input_metadata.attn_bias is not None:
# Already set by a previous layer. # Already set by a previous layer.
return return
# Generates ALiBi mask for each prompt. # Generates ALiBi mask based on the max prompt length.
for prompt_len in input_metadata.prompt_lens: max_prompt_len = input_metadata.max_prompt_len
bias = torch.arange(prompt_len, dtype=dtype) bias = torch.arange(max_prompt_len, dtype=dtype)
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)` # `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but # here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi # the bias below more accurately follows the original ALiBi
# paper. # paper.
bias = bias[None, :] - bias[:, None] bias = bias[None, :] - bias[:, None]
bias = bias.to(self.alibi_slopes.device) bias = bias.to(self.alibi_slopes.device)
# When using custom attention bias, xformers requires the bias to # When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8. # be sliced from a tensor whose length is a multiple of 8.
padded_len = (prompt_len + 7) // 8 * 8 padded_len = (max_prompt_len + 7) // 8 * 8
bias = torch.empty( bias = torch.empty(
1, # batch_size input_metadata.num_prompts,
self.num_heads, self.num_heads,
prompt_len, max_prompt_len,
padded_len, padded_len,
device=self.alibi_slopes.device, device=self.alibi_slopes.device,
dtype=dtype, dtype=dtype,
)[:, :, :, :prompt_len].copy_(bias) )[:, :, :, :max_prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None]) bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias) attn_bias = LowerTriangularMaskWithTensorBias(bias)
input_metadata.attn_bias.append(attn_bias) input_metadata.attn_bias = attn_bias
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
@ -466,24 +447,19 @@ class PagedAttentionWithALiBi(PagedAttention):
value = torch.repeat_interleave(value, value = torch.repeat_interleave(value,
self.num_queries_per_kv, self.num_queries_per_kv,
dim=1) dim=1)
batch_size = input_metadata.num_prompts
seq_len = input_metadata.max_prompt_len
# FIXME(woosuk): Because xformers does not support dynamic sequence out = xops.memory_efficient_attention_forward(
# lengths with custom attention bias, we process each prompt one by query.view(batch_size, seq_len, self.num_heads, self.head_size),
# one. This is inefficient, especially when we have many short prompts. key.view(batch_size, seq_len, self.num_heads, self.head_size),
start = 0 value.view(batch_size, seq_len, self.num_heads, self.head_size),
for i, prompt_len in enumerate(input_metadata.prompt_lens): attn_bias=input_metadata.attn_bias,
end = start + prompt_len p=0.0,
out = xops.memory_efficient_attention_forward( scale=self.scale,
query[None, start:end], )
key[None, start:end], # TODO(woosuk): Unnecessary copy. Optimize.
value[None, start:end], output.copy_(out.view(-1, self.num_heads, self.head_size))
attn_bias=input_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))
start += prompt_len
return output return output
def get_alibi_slopes(self) -> Optional[torch.Tensor]: def get_alibi_slopes(self) -> Optional[torch.Tensor]:

View File

@ -50,7 +50,7 @@ class AWQColumnParallelLinear(ColumnParallelLinear):
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor) out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor) self.qzeros, pack_factor)
@ -95,7 +95,7 @@ class AWQRowParallelLinear(RowParallelLinear):
def apply_weights(self, x: torch.Tensor) -> torch.Tensor: def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor) out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor) self.qzeros, pack_factor)

View File

@ -119,7 +119,7 @@ def _prune_hidden_states(
selected_token_indices.extend( selected_token_indices.extend(
range(start_idx, start_idx + prompt_len - 1)) range(start_idx, start_idx + prompt_len - 1))
selected_token_indices.append(start_idx + prompt_len - 1) selected_token_indices.append(start_idx + prompt_len - 1)
start_idx += prompt_len start_idx += input_metadata.max_prompt_len
else: else:
num_seqs = len(seq_ids) num_seqs = len(seq_ids)
selected_token_indices.extend( selected_token_indices.extend(
@ -129,6 +129,7 @@ def _prune_hidden_states(
selected_token_indices = torch.tensor(selected_token_indices, selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long, dtype=torch.long,
device=hidden_states.device) device=hidden_states.device)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0, selected_token_indices) return hidden_states.index_select(0, selected_token_indices)

View File

@ -158,9 +158,9 @@ class Worker:
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
seq_groups: List[Tuple[List[int], SamplingParams]] = [] seq_groups: List[Tuple[List[int], SamplingParams]] = []
input_tokens: List[int] = [] input_tokens: List[List[int]] = []
input_positions: List[int] = [] input_positions: List[List[int]] = []
slot_mapping: List[int] = [] slot_mapping: List[List[int]] = []
# Add prompt tokens. # Add prompt tokens.
prompt_lens: List[int] = [] prompt_lens: List[int] = []
@ -180,24 +180,25 @@ class Worker:
prompt_len = len(prompt_tokens) prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens) input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(range(len(prompt_tokens))) input_positions.append(list(range(prompt_len)))
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping. # yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([0] * prompt_len) slot_mapping.append([0] * prompt_len)
continue continue
# Compute the slot mapping. # Compute the slot mapping.
slot_mapping.append([])
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
for i in range(prompt_len): for i in range(prompt_len):
block_number = block_table[i // self.block_size] block_number = block_table[i // self.block_size]
block_offset = i % self.block_size block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping[-1].append(slot)
# Add generation tokens. # Add generation tokens.
max_context_len = 0 max_context_len = 0
@ -215,13 +216,13 @@ class Worker:
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id() generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token) input_tokens.append([generation_token])
context_len = seq_data.get_len() context_len = seq_data.get_len()
position = context_len - 1 position = context_len - 1
if self.sliding_window is not None: if self.sliding_window is not None:
context_len = min(context_len, self.sliding_window) context_len = min(context_len, self.sliding_window)
input_positions.append(position) input_positions.append([position])
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
@ -233,7 +234,7 @@ class Worker:
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
block_offset = position % self.block_size block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append([slot])
if self.sliding_window is not None: if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window // sliding_window_blocks = (self.sliding_window //
@ -241,28 +242,36 @@ class Worker:
block_table = block_table[-sliding_window_blocks:] block_table = block_table[-sliding_window_blocks:]
generation_block_tables.append(block_table) generation_block_tables.append(block_table)
# Optimization: Pad the input length to be a multiple of 8. max_seq_len = max(prompt_lens) if prompt_lens else 1
# This is required for utilizing the Tensor Cores in NVIDIA GPUs. padded_input_tokens = [
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8) _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
input_positions = _pad_to_alignment(input_positions, multiple_of=8) ]
padded_input_positions = [
_pad_to_max(positions, max_seq_len, pad=0)
for positions in input_positions
]
padded_slot_mapping = [
_pad_to_max(mapping, max_seq_len, pad=-1)
for mapping in slot_mapping
]
padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq, pad=0)
for block_table in generation_block_tables
]
# Convert to tensors. # Convert to tensors.
tokens_tensor = torch.tensor(input_tokens, tokens_tensor = torch.tensor(padded_input_tokens,
dtype=torch.long, dtype=torch.long,
device="cuda") device="cuda")
positions_tensor = torch.tensor(input_positions, positions_tensor = torch.tensor(padded_input_positions,
dtype=torch.long, dtype=torch.long,
device="cuda") device="cuda")
slot_mapping_tensor = torch.tensor(slot_mapping, slot_mapping_tensor = torch.tensor(padded_slot_mapping,
dtype=torch.int, dtype=torch.int,
device="cuda") device="cuda")
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device="cuda") device="cuda")
padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables
]
block_tables_tensor = torch.tensor(padded_block_tables, block_tables_tensor = torch.tensor(padded_block_tables,
dtype=torch.int, dtype=torch.int,
device="cuda") device="cuda")
@ -361,12 +370,12 @@ def _init_distributed_environment(
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]: def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]:
return x + [0] * ((-len(x)) % multiple_of) return x + [pad] * ((-len(x)) % multiple_of)
def _pad_to_max(x: List[int], max_len: int) -> List[int]: def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
return x + [0] * (max_len - len(x)) return x + [pad] * (max_len - len(x))
def _check_if_can_support_max_seq_len(max_seq_len: int, def _check_if_can_support_max_seq_len(max_seq_len: int,