From c1376e0f825e88e32b5aca85c676fe547bcb03c9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 16 Oct 2023 17:48:42 -0700 Subject: [PATCH] Change scheduler & input tensor shape (#1381) --- csrc/activation_kernels.cu | 28 +-- csrc/cache_kernels.cu | 9 +- csrc/layernorm_kernels.cu | 12 +- csrc/pos_encoding_kernels.cu | 22 +-- vllm/config.py | 3 + vllm/core/scheduler.py | 15 +- vllm/engine/arg_utils.py | 8 +- vllm/model_executor/input_metadata.py | 11 +- vllm/model_executor/layers/activation.py | 20 +-- vllm/model_executor/layers/attention.py | 164 ++++++++---------- .../layers/quantized_linear/awq.py | 4 +- vllm/model_executor/layers/sampler.py | 3 +- vllm/worker/worker.py | 59 ++++--- 13 files changed, 180 insertions(+), 178 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index c6ae5db8..581525e9 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -13,8 +13,8 @@ __device__ __forceinline__ T silu(const T& x) { template __global__ void silu_and_mul_kernel( - scalar_t* __restrict__ out, // [num_tokens, d] - const scalar_t* __restrict__ input, // [num_tokens, 2, d] + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { @@ -27,11 +27,11 @@ __global__ void silu_and_mul_kernel( } // namespace vllm void silu_and_mul( - torch::Tensor& out, // [num_tokens, d] - torch::Tensor& input) // [num_tokens, 2 * d] + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { - int num_tokens = input.size(0); - int d = input.size(1) / 2; + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; dim3 grid(num_tokens); dim3 block(std::min(d, 1024)); @@ -52,8 +52,8 @@ namespace vllm { // Element-wise activation kernel template. template __global__ void activation_kernel( - scalar_t* __restrict__ out, // [num_tokens, d] - const scalar_t* __restrict__ input, // [num_tokens, d] + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., d] const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { @@ -66,8 +66,8 @@ __global__ void activation_kernel( // Launch element-wise activation kernel. #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - int num_tokens = input.size(0); \ - int d = input.size(1); \ + int d = input.size(-1); \ + int num_tokens = input.numel() / d; \ dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ @@ -100,15 +100,15 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) { } // namespace vllm void gelu_new( - torch::Tensor& out, // [num_tokens, d] - torch::Tensor& input) // [num_tokens, d] + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } void gelu_fast( - torch::Tensor& out, // [num_tokens, d] - torch::Tensor& input) // [num_tokens, d] + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index ddad2b5a..4c806828 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -154,6 +154,11 @@ __global__ void reshape_and_cache_kernel( const int x) { const int token_idx = blockIdx.x; const int slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + const int block_idx = slot_idx / block_size; const int block_offset = slot_idx % block_size; @@ -176,8 +181,8 @@ __global__ void reshape_and_cache_kernel( + head_idx * head_size * block_size + head_offset * block_size + block_offset; - key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]); - value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); + key_cache[tgt_key_idx] = key[src_key_idx]; + value_cache[tgt_value_idx] = value[src_value_idx]; } } diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index f932b9e2..fe07c272 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -9,8 +9,8 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. template __global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [num_tokens, hidden_size] - const scalar_t* __restrict__ input, // [num_tokens, hidden_size] + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, @@ -37,12 +37,12 @@ __global__ void rms_norm_kernel( } // namespace vllm void rms_norm( - torch::Tensor& out, // [num_tokens, hidden_size] - torch::Tensor& input, // [num_tokens, hidden_size] + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] float epsilon) { - int num_tokens = input.size(0); - int hidden_size = input.size(1); + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index b4351ee0..41001ba6 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -37,9 +37,9 @@ inline __device__ void apply_rotary_embedding( template __global__ void rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [num_tokens] - scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const int rot_dim, const int query_stride, @@ -78,18 +78,18 @@ __global__ void rotary_embedding_kernel( } // namespace vllm void rotary_embedding( - torch::Tensor& positions, // [num_tokens] - torch::Tensor& query, // [num_tokens, num_heads * head_size] - torch::Tensor& key, // [num_tokens, num_kv_heads * head_size] + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] int head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { - int num_tokens = query.size(0); + int num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); - int num_heads = query.size(1) / head_size; - int num_kv_heads = key.size(1) / head_size; - int query_stride = query.stride(0); - int key_stride = key.stride(0); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int query_stride = query.stride(-2); + int key_stride = key.stride(-2); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); diff --git a/vllm/config.py b/vllm/config.py index 90ffe822..d45bb885 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -268,6 +268,7 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). + max_paddings: Maximum number of paddings to be added to a batch. """ def __init__( @@ -275,6 +276,7 @@ class SchedulerConfig: max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, + max_paddings: int, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -284,6 +286,7 @@ class SchedulerConfig: self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len + self.max_paddings = max_paddings self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 8f381add..516f23d2 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -131,7 +131,8 @@ class Scheduler: # requests in the generation phase. num_curr_seqs = sum(seq_group.get_max_num_running_seqs() for seq_group in self.running) - num_batched_tokens = 0 + seq_lens: List[int] = [] + # Optimization: We do not sort the waiting queue since the preempted # sequence groups are added to the front and the new sequence groups # are added to the back. @@ -157,7 +158,9 @@ class Scheduler: break # If the number of batched tokens exceeds the limit, stop. - if (num_batched_tokens + num_prompt_tokens > + new_seq_lens = seq_lens + [num_prompt_tokens] + num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) + if (num_batched_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -168,10 +171,14 @@ class Scheduler: self.scheduler_config.max_num_seqs): break + num_paddings = num_batched_tokens - sum(new_seq_lens) + if num_paddings > self.scheduler_config.max_paddings: + break + seq_lens = new_seq_lens + seq_group = self.waiting.pop(0) self._allocate(seq_group) self.running.append(seq_group) - num_batched_tokens += num_prompt_tokens num_curr_seqs += num_new_seqs scheduled.append(seq_group) @@ -179,7 +186,7 @@ class Scheduler: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, prompt_run=True, - num_batched_tokens=num_batched_tokens, + num_batched_tokens=len(seq_lens) * max(seq_lens), blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 88f19bbc..51a8161b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -27,6 +27,7 @@ class EngineArgs: gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 + max_paddings: int = 256 disable_log_stats: bool = False revision: Optional[str] = None tokenizer_revision: Optional[str] = None @@ -156,6 +157,10 @@ class EngineArgs: type=int, default=EngineArgs.max_num_seqs, help='maximum number of sequences per iteration') + parser.add_argument('--max-paddings', + type=int, + default=EngineArgs.max_paddings, + help='maximum number of paddings in a batch') parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') @@ -193,7 +198,8 @@ class EngineArgs: self.worker_use_ray) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, - model_config.max_model_len) + model_config.max_model_len, + self.max_paddings) return model_config, cache_config, parallel_config, scheduler_config diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index a0a62034..bc5c2482 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -39,11 +39,12 @@ class InputMetadata: self.max_context_len = max_context_len self.block_tables = block_tables + self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 self.to_cache = None if sliding_window is not None: # We need to keep the positions of sliding windows within # the key / value tables, this is helpful to know which - # elements we need to cache and where + # elements we need to cache. to_cache, start_idx = [], 0 for prompt_len in self.prompt_lens: to_cache.extend( @@ -51,16 +52,15 @@ class InputMetadata: start_idx + max(0, prompt_len - sliding_window), start_idx + prompt_len, )) - start_idx += prompt_len + start_idx += self.max_prompt_len to_cache.extend(range(start_idx, slot_mapping.shape[0])) self.to_cache = torch.tensor(to_cache, dtype=torch.int32, device=self.slot_mapping.device) self.num_prompts = len(prompt_lens) - self.num_prompt_tokens = sum(prompt_lens) + self.num_prompt_tokens = self.num_prompts * self.max_prompt_len self.num_generation_tokens = context_lens.shape[0] - self.num_valid_tokens = slot_mapping.shape[0] if block_tables.numel() > 0: self.max_num_blocks_per_seq = block_tables.shape[1] else: @@ -69,12 +69,11 @@ class InputMetadata: assert context_lens.shape[0] == self.num_generation_tokens # Set during the execution of the first attention op. - self.attn_bias: List[AttentionBias] = [] + self.attn_bias: Optional[AttentionBias] = None def __repr__(self) -> str: # Print only useful metadata. return (f'InputMetadata(' - f'num_valid_tokens={self.num_valid_tokens}, ' f'num_prompt_tokens={self.num_prompt_tokens}, ' f'num_prompts={self.num_prompts}, ' f'prompt_lens={self.prompt_lens}, ' diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 9222fe27..109451d4 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -8,17 +8,17 @@ from vllm import activation_ops class SiluAndMul(nn.Module): """An activation function for SwiGLU. - The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2. + The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. Shapes: - x: (num_tokens, 2 * d) - return: (num_tokens, d) + x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) + return: (batch_size, seq_len, d) or (num_tokens, d) """ def forward(self, x: torch.Tensor) -> torch.Tensor: - num_tokens = x.shape[0] - d = x.shape[1] // 2 - out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) activation_ops.silu_and_mul(out, x) return out @@ -26,9 +26,7 @@ class SiluAndMul(nn.Module): class NewGELU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: - num_tokens = x.shape[0] - d = x.shape[1] - out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) + out = torch.empty_like(x) activation_ops.gelu_new(out, x) return out @@ -36,9 +34,7 @@ class NewGELU(nn.Module): class FastGELU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: - num_tokens = x.shape[0] - d = x.shape[1] - out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) + out = torch.empty_like(x) activation_ops.gelu_fast(out, x) return out diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 084f4d98..58f868d4 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -23,25 +23,9 @@ class PagedAttention(nn.Module): # pylint: disable=line-too-long """GPT-style multi-head PagedAttention. - This class takes flattened 1D query, key, and value tensors as input. The - input 1D tensors can either contain prompt tokens or generation tokens, in - addition to paddings. - - If the input tensors contain prompt tokens, the layout is as follows: - - |<---------------------- num_valid_tokens ---------------------->| - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->| - - Otherwise, the layout is as follows: - - |<------------------ num_valid_tokens ------------------->| - |<------- num_generation_tokens (M) ------->| - |<--generation_0-->|...|<--generation_M-1-->|<--padding-->| - - The prompts might have different lengths, while the generation tokens always - have length 1. The paddings are appended to make the input length a multiple - of 8, which is desirable for Tensor Cores. + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens, in addition to + paddings. The class does the following: 1. Perform multi_query_kv_attention for the prompts. This operation does @@ -53,7 +37,7 @@ class PagedAttention(nn.Module): 4. Perform single_query_cached_kv_attention for the generation tokens. This operation reads the previous key and value tensors from the KV cache. - 5. Output a flattened 1D tensor. + 5. Return the output tensor. """ def __init__(self, @@ -85,14 +69,15 @@ class PagedAttention(nn.Module): dtype: torch.dtype, ) -> None: del dtype # Unused. - if input_metadata.attn_bias: + if input_metadata.attn_bias is not None: # Already set by a previous layer. return - prompt_lens = input_metadata.prompt_lens + prompt_lens = [input_metadata.max_prompt_len + ] * input_metadata.num_prompts attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention(self.sliding_window) - input_metadata.attn_bias.append(attn_bias) + input_metadata.attn_bias = attn_bias def multi_query_kv_attention( self, @@ -111,7 +96,6 @@ class PagedAttention(nn.Module): value: shape = [num_prompt_tokens, num_kv_heads, head_size] input_metadata: metadata for paged attention. """ - if self.num_kv_heads != self.num_heads: # Project the key and value tensors to the desired number of heads. key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) @@ -124,7 +108,7 @@ class PagedAttention(nn.Module): query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0), - attn_bias=input_metadata.attn_bias[0], + attn_bias=input_metadata.attn_bias, p=0.0, scale=self.scale, ) @@ -232,12 +216,12 @@ class PagedAttention(nn.Module): """PagedAttention forward pass. NOTE: The query, key, and value tensors must be sliced from a qkv - tensor of shape [num_tokens, 3 * num_heads * head_size]. + tensor of shape [batch_size, seq_len, 3 * num_heads * head_size]. Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, @@ -246,9 +230,9 @@ class PagedAttention(nn.Module): cache_event: event to wait for the cache operations to finish. Returns: - shape = [num_tokens, num_heads * head_size] + shape = [batch_size, seq_len, num_heads * head_size] """ - + batch_size, seq_len, _ = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -264,10 +248,10 @@ class PagedAttention(nn.Module): assert input_metadata.num_generation_tokens == 0 self.set_attn_bias(input_metadata, dtype=query.dtype) self.multi_query_kv_attention( - output[:num_prompt_tokens], - query[:num_prompt_tokens], - key[:num_prompt_tokens], - value[:num_prompt_tokens], + output, + query, + key, + value, input_metadata, ) @@ -278,13 +262,10 @@ class PagedAttention(nn.Module): # Reshape the keys and values and store them in the cache. # When key_cache and value_cache are not provided, the new key # and value vectors will not be cached. - num_valid_tokens = input_metadata.num_valid_tokens - if (num_valid_tokens > 0 and key_cache is not None - and value_cache is not None): - # The stride is 3 because the key and value are sliced from qkv. - key_to_cache = key[:num_valid_tokens] - value_to_cache = value[:num_valid_tokens] - slot_mapping = input_metadata.slot_mapping + if key_cache is not None and value_cache is not None: + key_to_cache = key + value_to_cache = value + slot_mapping = input_metadata.slot_mapping.view(-1) if input_metadata.to_cache is not None: key_to_cache = key_to_cache[input_metadata.to_cache] value_to_cache = value_to_cache[input_metadata.to_cache] @@ -305,14 +286,14 @@ class PagedAttention(nn.Module): "key_cache and value_cache must be provided when " "generating tokens.") # Compute the attention op for generation tokens. - self.single_query_cached_kv_attention( - output[num_prompt_tokens:num_valid_tokens], - query[num_prompt_tokens:num_valid_tokens], key_cache, - value_cache, input_metadata, self.get_alibi_slopes()) + self.single_query_cached_kv_attention(output, query, key_cache, + value_cache, input_metadata, + self.get_alibi_slopes()) # Reshape the output tensor. # NOTE(woosuk): The output tensor may include paddings. - return output.view(-1, self.num_heads * self.head_size) + return output.view(batch_size, seq_len, + self.num_heads * self.head_size) class PagedAttentionWithRoPE(PagedAttention): @@ -368,10 +349,10 @@ class PagedAttentionWithRoPE(PagedAttention): """ PagedAttention forward pass with rotary embedding. Args: - positions: shape = [num_tokens] - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] + positions: shape = [batch_size, seq_len] + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, @@ -380,7 +361,7 @@ class PagedAttentionWithRoPE(PagedAttention): cache_event: event to wait for the cache operations to finish. Returns: - shape = [num_tokens, num_heads * head_size] + shape = [batch_size, seq_len, num_heads * head_size] """ # Apply rotary embedding to the query and key before passing them @@ -414,34 +395,34 @@ class PagedAttentionWithALiBi(PagedAttention): def set_attn_bias(self, input_metadata: InputMetadata, dtype: torch.dtype) -> None: - if input_metadata.attn_bias: + if input_metadata.attn_bias is not None: # Already set by a previous layer. return - # Generates ALiBi mask for each prompt. - for prompt_len in input_metadata.prompt_lens: - bias = torch.arange(prompt_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - bias = bias.to(self.alibi_slopes.device) + # Generates ALiBi mask based on the max prompt length. + max_prompt_len = input_metadata.max_prompt_len + bias = torch.arange(max_prompt_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + bias = bias.to(self.alibi_slopes.device) - # When using custom attention bias, xformers requires the bias to - # be sliced from a tensor whose length is a multiple of 8. - padded_len = (prompt_len + 7) // 8 * 8 - bias = torch.empty( - 1, # batch_size - self.num_heads, - prompt_len, - padded_len, - device=self.alibi_slopes.device, - dtype=dtype, - )[:, :, :, :prompt_len].copy_(bias) - bias.mul_(self.alibi_slopes[:, None, None]) - attn_bias = LowerTriangularMaskWithTensorBias(bias) - input_metadata.attn_bias.append(attn_bias) + # When using custom attention bias, xformers requires the bias to + # be sliced from a tensor whose length is a multiple of 8. + padded_len = (max_prompt_len + 7) // 8 * 8 + bias = torch.empty( + input_metadata.num_prompts, + self.num_heads, + max_prompt_len, + padded_len, + device=self.alibi_slopes.device, + dtype=dtype, + )[:, :, :, :max_prompt_len].copy_(bias) + bias.mul_(self.alibi_slopes[:, None, None]) + attn_bias = LowerTriangularMaskWithTensorBias(bias) + input_metadata.attn_bias = attn_bias def multi_query_kv_attention( self, @@ -466,24 +447,19 @@ class PagedAttentionWithALiBi(PagedAttention): value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=1) + batch_size = input_metadata.num_prompts + seq_len = input_metadata.max_prompt_len - # FIXME(woosuk): Because xformers does not support dynamic sequence - # lengths with custom attention bias, we process each prompt one by - # one. This is inefficient, especially when we have many short prompts. - start = 0 - for i, prompt_len in enumerate(input_metadata.prompt_lens): - end = start + prompt_len - out = xops.memory_efficient_attention_forward( - query[None, start:end], - key[None, start:end], - value[None, start:end], - attn_bias=input_metadata.attn_bias[i], - p=0.0, - scale=self.scale, - ) - # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out.squeeze(0)) - start += prompt_len + out = xops.memory_efficient_attention_forward( + query.view(batch_size, seq_len, self.num_heads, self.head_size), + key.view(batch_size, seq_len, self.num_heads, self.head_size), + value.view(batch_size, seq_len, self.num_heads, self.head_size), + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + ) + # TODO(woosuk): Unnecessary copy. Optimize. + output.copy_(out.view(-1, self.num_heads, self.head_size)) return output def get_alibi_slopes(self) -> Optional[torch.Tensor]: diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index 2c2d0f8c..0d7d0f91 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -50,7 +50,7 @@ class AWQColumnParallelLinear(ColumnParallelLinear): bias: Optional[torch.Tensor], ) -> torch.Tensor: pack_factor = self.quant_config.pack_factor - out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor) + out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, self.qzeros, pack_factor) @@ -95,7 +95,7 @@ class AWQRowParallelLinear(RowParallelLinear): def apply_weights(self, x: torch.Tensor) -> torch.Tensor: pack_factor = self.quant_config.pack_factor - out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor) + out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, self.qzeros, pack_factor) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2c8652ee..a12c82a2 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -119,7 +119,7 @@ def _prune_hidden_states( selected_token_indices.extend( range(start_idx, start_idx + prompt_len - 1)) selected_token_indices.append(start_idx + prompt_len - 1) - start_idx += prompt_len + start_idx += input_metadata.max_prompt_len else: num_seqs = len(seq_ids) selected_token_indices.extend( @@ -129,6 +129,7 @@ def _prune_hidden_states( selected_token_indices = torch.tensor(selected_token_indices, dtype=torch.long, device=hidden_states.device) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) return hidden_states.index_select(0, selected_token_indices) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 5b0a60db..1b1f116a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -158,9 +158,9 @@ class Worker: seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: seq_groups: List[Tuple[List[int], SamplingParams]] = [] - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] # Add prompt tokens. prompt_lens: List[int] = [] @@ -180,24 +180,25 @@ class Worker: prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) - input_tokens.extend(prompt_tokens) + input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(range(len(prompt_tokens))) + input_positions.append(list(range(prompt_len))) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. - slot_mapping.extend([0] * prompt_len) + slot_mapping.append([0] * prompt_len) continue # Compute the slot mapping. + slot_mapping.append([]) block_table = seq_group_metadata.block_tables[seq_id] for i in range(prompt_len): block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) + slot_mapping[-1].append(slot) # Add generation tokens. max_context_len = 0 @@ -215,13 +216,13 @@ class Worker: for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) + input_tokens.append([generation_token]) context_len = seq_data.get_len() position = context_len - 1 if self.sliding_window is not None: context_len = min(context_len, self.sliding_window) - input_positions.append(position) + input_positions.append([position]) block_table = seq_group_metadata.block_tables[seq_id] @@ -233,7 +234,7 @@ class Worker: block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) + slot_mapping.append([slot]) if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window // @@ -241,28 +242,36 @@ class Worker: block_table = block_table[-sliding_window_blocks:] generation_block_tables.append(block_table) - # Optimization: Pad the input length to be a multiple of 8. - # This is required for utilizing the Tensor Cores in NVIDIA GPUs. - input_tokens = _pad_to_alignment(input_tokens, multiple_of=8) - input_positions = _pad_to_alignment(input_positions, multiple_of=8) + max_seq_len = max(prompt_lens) if prompt_lens else 1 + padded_input_tokens = [ + _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens + ] + padded_input_positions = [ + _pad_to_max(positions, max_seq_len, pad=0) + for positions in input_positions + ] + padded_slot_mapping = [ + _pad_to_max(mapping, max_seq_len, pad=-1) + for mapping in slot_mapping + ] + padded_block_tables = [ + _pad_to_max(block_table, max_num_blocks_per_seq, pad=0) + for block_table in generation_block_tables + ] # Convert to tensors. - tokens_tensor = torch.tensor(input_tokens, + tokens_tensor = torch.tensor(padded_input_tokens, dtype=torch.long, device="cuda") - positions_tensor = torch.tensor(input_positions, + positions_tensor = torch.tensor(padded_input_positions, dtype=torch.long, device="cuda") - slot_mapping_tensor = torch.tensor(slot_mapping, + slot_mapping_tensor = torch.tensor(padded_slot_mapping, dtype=torch.int, device="cuda") context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device="cuda") - padded_block_tables = [ - _pad_to_max(block_table, max_num_blocks_per_seq) - for block_table in generation_block_tables - ] block_tables_tensor = torch.tensor(padded_block_tables, dtype=torch.int, device="cuda") @@ -361,12 +370,12 @@ def _init_distributed_environment( parallel_config.pipeline_parallel_size) -def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]: - return x + [0] * ((-len(x)) % multiple_of) +def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]: + return x + [pad] * ((-len(x)) % multiple_of) -def _pad_to_max(x: List[int], max_len: int) -> List[int]: - return x + [0] * (max_len - len(x)) +def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: + return x + [pad] * (max_len - len(x)) def _check_if_can_support_max_seq_len(max_seq_len: int,