Change scheduler & input tensor shape (#1381)
This commit is contained in:
parent
651c614aa4
commit
c1376e0f82
@ -13,8 +13,8 @@ __device__ __forceinline__ T silu(const T& x) {
|
|||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
__global__ void silu_and_mul_kernel(
|
__global__ void silu_and_mul_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_tokens, d]
|
scalar_t* __restrict__ out, // [..., d]
|
||||||
const scalar_t* __restrict__ input, // [num_tokens, 2, d]
|
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||||
const int d) {
|
const int d) {
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
@ -27,11 +27,11 @@ __global__ void silu_and_mul_kernel(
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void silu_and_mul(
|
void silu_and_mul(
|
||||||
torch::Tensor& out, // [num_tokens, d]
|
torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input) // [num_tokens, 2 * d]
|
torch::Tensor& input) // [..., 2 * d]
|
||||||
{
|
{
|
||||||
int num_tokens = input.size(0);
|
int num_tokens = input.numel() / input.size(-1);
|
||||||
int d = input.size(1) / 2;
|
int d = input.size(-1) / 2;
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(d, 1024));
|
dim3 block(std::min(d, 1024));
|
||||||
@ -52,8 +52,8 @@ namespace vllm {
|
|||||||
// Element-wise activation kernel template.
|
// Element-wise activation kernel template.
|
||||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||||
__global__ void activation_kernel(
|
__global__ void activation_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_tokens, d]
|
scalar_t* __restrict__ out, // [..., d]
|
||||||
const scalar_t* __restrict__ input, // [num_tokens, d]
|
const scalar_t* __restrict__ input, // [..., d]
|
||||||
const int d) {
|
const int d) {
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
@ -66,8 +66,8 @@ __global__ void activation_kernel(
|
|||||||
|
|
||||||
// Launch element-wise activation kernel.
|
// Launch element-wise activation kernel.
|
||||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||||
int num_tokens = input.size(0); \
|
int d = input.size(-1); \
|
||||||
int d = input.size(1); \
|
int num_tokens = input.numel() / d; \
|
||||||
dim3 grid(num_tokens); \
|
dim3 grid(num_tokens); \
|
||||||
dim3 block(std::min(d, 1024)); \
|
dim3 block(std::min(d, 1024)); \
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
@ -100,15 +100,15 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void gelu_new(
|
void gelu_new(
|
||||||
torch::Tensor& out, // [num_tokens, d]
|
torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input) // [num_tokens, d]
|
torch::Tensor& input) // [..., d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gelu_fast(
|
void gelu_fast(
|
||||||
torch::Tensor& out, // [num_tokens, d]
|
torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input) // [num_tokens, d]
|
torch::Tensor& input) // [..., d]
|
||||||
{
|
{
|
||||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||||
}
|
}
|
||||||
|
@ -154,6 +154,11 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
const int x) {
|
const int x) {
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
const int slot_idx = slot_mapping[token_idx];
|
const int slot_idx = slot_mapping[token_idx];
|
||||||
|
if (slot_idx < 0) {
|
||||||
|
// Padding token that should be ignored.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const int block_idx = slot_idx / block_size;
|
const int block_idx = slot_idx / block_size;
|
||||||
const int block_offset = slot_idx % block_size;
|
const int block_offset = slot_idx % block_size;
|
||||||
|
|
||||||
@ -176,8 +181,8 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
+ head_idx * head_size * block_size
|
+ head_idx * head_size * block_size
|
||||||
+ head_offset * block_size
|
+ head_offset * block_size
|
||||||
+ block_offset;
|
+ block_offset;
|
||||||
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
|
key_cache[tgt_key_idx] = key[src_key_idx];
|
||||||
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
|
value_cache[tgt_value_idx] = value[src_value_idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,8 +9,8 @@ namespace vllm {
|
|||||||
// TODO(woosuk): Further optimize this kernel.
|
// TODO(woosuk): Further optimize this kernel.
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
__global__ void rms_norm_kernel(
|
__global__ void rms_norm_kernel(
|
||||||
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
|
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon,
|
const float epsilon,
|
||||||
const int num_tokens,
|
const int num_tokens,
|
||||||
@ -37,12 +37,12 @@ __global__ void rms_norm_kernel(
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rms_norm(
|
void rms_norm(
|
||||||
torch::Tensor& out, // [num_tokens, hidden_size]
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
torch::Tensor& input, // [num_tokens, hidden_size]
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
float epsilon) {
|
float epsilon) {
|
||||||
int num_tokens = input.size(0);
|
int hidden_size = input.size(-1);
|
||||||
int hidden_size = input.size(1);
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(hidden_size, 1024));
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
|
@ -37,9 +37,9 @@ inline __device__ void apply_rotary_embedding(
|
|||||||
|
|
||||||
template<typename scalar_t, bool IS_NEOX>
|
template<typename scalar_t, bool IS_NEOX>
|
||||||
__global__ void rotary_embedding_kernel(
|
__global__ void rotary_embedding_kernel(
|
||||||
const int64_t* __restrict__ positions, // [num_tokens]
|
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||||
const int rot_dim,
|
const int rot_dim,
|
||||||
const int query_stride,
|
const int query_stride,
|
||||||
@ -78,18 +78,18 @@ __global__ void rotary_embedding_kernel(
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rotary_embedding(
|
void rotary_embedding(
|
||||||
torch::Tensor& positions, // [num_tokens]
|
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
||||||
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
|
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
||||||
int head_size,
|
int head_size,
|
||||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
bool is_neox) {
|
bool is_neox) {
|
||||||
int num_tokens = query.size(0);
|
int num_tokens = query.numel() / query.size(-1);
|
||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(1) / head_size;
|
int num_heads = query.size(-1) / head_size;
|
||||||
int num_kv_heads = key.size(1) / head_size;
|
int num_kv_heads = key.size(-1) / head_size;
|
||||||
int query_stride = query.stride(0);
|
int query_stride = query.stride(-2);
|
||||||
int key_stride = key.stride(0);
|
int key_stride = key.stride(-2);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||||
|
@ -268,6 +268,7 @@ class SchedulerConfig:
|
|||||||
iteration.
|
iteration.
|
||||||
max_model_len: Maximum length of a sequence (including prompt
|
max_model_len: Maximum length of a sequence (including prompt
|
||||||
and generated text).
|
and generated text).
|
||||||
|
max_paddings: Maximum number of paddings to be added to a batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -275,6 +276,7 @@ class SchedulerConfig:
|
|||||||
max_num_batched_tokens: Optional[int],
|
max_num_batched_tokens: Optional[int],
|
||||||
max_num_seqs: int,
|
max_num_seqs: int,
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
|
max_paddings: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
if max_num_batched_tokens is not None:
|
if max_num_batched_tokens is not None:
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
@ -284,6 +286,7 @@ class SchedulerConfig:
|
|||||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||||
self.max_num_seqs = max_num_seqs
|
self.max_num_seqs = max_num_seqs
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
|
self.max_paddings = max_paddings
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
|
@ -131,7 +131,8 @@ class Scheduler:
|
|||||||
# requests in the generation phase.
|
# requests in the generation phase.
|
||||||
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||||
for seq_group in self.running)
|
for seq_group in self.running)
|
||||||
num_batched_tokens = 0
|
seq_lens: List[int] = []
|
||||||
|
|
||||||
# Optimization: We do not sort the waiting queue since the preempted
|
# Optimization: We do not sort the waiting queue since the preempted
|
||||||
# sequence groups are added to the front and the new sequence groups
|
# sequence groups are added to the front and the new sequence groups
|
||||||
# are added to the back.
|
# are added to the back.
|
||||||
@ -157,7 +158,9 @@ class Scheduler:
|
|||||||
break
|
break
|
||||||
|
|
||||||
# If the number of batched tokens exceeds the limit, stop.
|
# If the number of batched tokens exceeds the limit, stop.
|
||||||
if (num_batched_tokens + num_prompt_tokens >
|
new_seq_lens = seq_lens + [num_prompt_tokens]
|
||||||
|
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
|
||||||
|
if (num_batched_tokens >
|
||||||
self.scheduler_config.max_num_batched_tokens):
|
self.scheduler_config.max_num_batched_tokens):
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -168,10 +171,14 @@ class Scheduler:
|
|||||||
self.scheduler_config.max_num_seqs):
|
self.scheduler_config.max_num_seqs):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
num_paddings = num_batched_tokens - sum(new_seq_lens)
|
||||||
|
if num_paddings > self.scheduler_config.max_paddings:
|
||||||
|
break
|
||||||
|
seq_lens = new_seq_lens
|
||||||
|
|
||||||
seq_group = self.waiting.pop(0)
|
seq_group = self.waiting.pop(0)
|
||||||
self._allocate(seq_group)
|
self._allocate(seq_group)
|
||||||
self.running.append(seq_group)
|
self.running.append(seq_group)
|
||||||
num_batched_tokens += num_prompt_tokens
|
|
||||||
num_curr_seqs += num_new_seqs
|
num_curr_seqs += num_new_seqs
|
||||||
scheduled.append(seq_group)
|
scheduled.append(seq_group)
|
||||||
|
|
||||||
@ -179,7 +186,7 @@ class Scheduler:
|
|||||||
scheduler_outputs = SchedulerOutputs(
|
scheduler_outputs = SchedulerOutputs(
|
||||||
scheduled_seq_groups=scheduled,
|
scheduled_seq_groups=scheduled,
|
||||||
prompt_run=True,
|
prompt_run=True,
|
||||||
num_batched_tokens=num_batched_tokens,
|
num_batched_tokens=len(seq_lens) * max(seq_lens),
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
blocks_to_copy=blocks_to_copy,
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
@ -27,6 +27,7 @@ class EngineArgs:
|
|||||||
gpu_memory_utilization: float = 0.90
|
gpu_memory_utilization: float = 0.90
|
||||||
max_num_batched_tokens: Optional[int] = None
|
max_num_batched_tokens: Optional[int] = None
|
||||||
max_num_seqs: int = 256
|
max_num_seqs: int = 256
|
||||||
|
max_paddings: int = 256
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
tokenizer_revision: Optional[str] = None
|
tokenizer_revision: Optional[str] = None
|
||||||
@ -156,6 +157,10 @@ class EngineArgs:
|
|||||||
type=int,
|
type=int,
|
||||||
default=EngineArgs.max_num_seqs,
|
default=EngineArgs.max_num_seqs,
|
||||||
help='maximum number of sequences per iteration')
|
help='maximum number of sequences per iteration')
|
||||||
|
parser.add_argument('--max-paddings',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.max_paddings,
|
||||||
|
help='maximum number of paddings in a batch')
|
||||||
parser.add_argument('--disable-log-stats',
|
parser.add_argument('--disable-log-stats',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='disable logging statistics')
|
help='disable logging statistics')
|
||||||
@ -193,7 +198,8 @@ class EngineArgs:
|
|||||||
self.worker_use_ray)
|
self.worker_use_ray)
|
||||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||||
self.max_num_seqs,
|
self.max_num_seqs,
|
||||||
model_config.max_model_len)
|
model_config.max_model_len,
|
||||||
|
self.max_paddings)
|
||||||
return model_config, cache_config, parallel_config, scheduler_config
|
return model_config, cache_config, parallel_config, scheduler_config
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,11 +39,12 @@ class InputMetadata:
|
|||||||
self.max_context_len = max_context_len
|
self.max_context_len = max_context_len
|
||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
|
|
||||||
|
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
||||||
self.to_cache = None
|
self.to_cache = None
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
# We need to keep the positions of sliding windows within
|
# We need to keep the positions of sliding windows within
|
||||||
# the key / value tables, this is helpful to know which
|
# the key / value tables, this is helpful to know which
|
||||||
# elements we need to cache and where
|
# elements we need to cache.
|
||||||
to_cache, start_idx = [], 0
|
to_cache, start_idx = [], 0
|
||||||
for prompt_len in self.prompt_lens:
|
for prompt_len in self.prompt_lens:
|
||||||
to_cache.extend(
|
to_cache.extend(
|
||||||
@ -51,16 +52,15 @@ class InputMetadata:
|
|||||||
start_idx + max(0, prompt_len - sliding_window),
|
start_idx + max(0, prompt_len - sliding_window),
|
||||||
start_idx + prompt_len,
|
start_idx + prompt_len,
|
||||||
))
|
))
|
||||||
start_idx += prompt_len
|
start_idx += self.max_prompt_len
|
||||||
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
|
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
|
||||||
self.to_cache = torch.tensor(to_cache,
|
self.to_cache = torch.tensor(to_cache,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.slot_mapping.device)
|
device=self.slot_mapping.device)
|
||||||
|
|
||||||
self.num_prompts = len(prompt_lens)
|
self.num_prompts = len(prompt_lens)
|
||||||
self.num_prompt_tokens = sum(prompt_lens)
|
self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
|
||||||
self.num_generation_tokens = context_lens.shape[0]
|
self.num_generation_tokens = context_lens.shape[0]
|
||||||
self.num_valid_tokens = slot_mapping.shape[0]
|
|
||||||
if block_tables.numel() > 0:
|
if block_tables.numel() > 0:
|
||||||
self.max_num_blocks_per_seq = block_tables.shape[1]
|
self.max_num_blocks_per_seq = block_tables.shape[1]
|
||||||
else:
|
else:
|
||||||
@ -69,12 +69,11 @@ class InputMetadata:
|
|||||||
assert context_lens.shape[0] == self.num_generation_tokens
|
assert context_lens.shape[0] == self.num_generation_tokens
|
||||||
|
|
||||||
# Set during the execution of the first attention op.
|
# Set during the execution of the first attention op.
|
||||||
self.attn_bias: List[AttentionBias] = []
|
self.attn_bias: Optional[AttentionBias] = None
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
# Print only useful metadata.
|
# Print only useful metadata.
|
||||||
return (f'InputMetadata('
|
return (f'InputMetadata('
|
||||||
f'num_valid_tokens={self.num_valid_tokens}, '
|
|
||||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||||
f'num_prompts={self.num_prompts}, '
|
f'num_prompts={self.num_prompts}, '
|
||||||
f'prompt_lens={self.prompt_lens}, '
|
f'prompt_lens={self.prompt_lens}, '
|
||||||
|
@ -8,17 +8,17 @@ from vllm import activation_ops
|
|||||||
class SiluAndMul(nn.Module):
|
class SiluAndMul(nn.Module):
|
||||||
"""An activation function for SwiGLU.
|
"""An activation function for SwiGLU.
|
||||||
|
|
||||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
|
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
x: (num_tokens, 2 * d)
|
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
||||||
return: (num_tokens, d)
|
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens = x.shape[0]
|
d = x.shape[-1] // 2
|
||||||
d = x.shape[1] // 2
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
activation_ops.silu_and_mul(out, x)
|
activation_ops.silu_and_mul(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -26,9 +26,7 @@ class SiluAndMul(nn.Module):
|
|||||||
class NewGELU(nn.Module):
|
class NewGELU(nn.Module):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens = x.shape[0]
|
out = torch.empty_like(x)
|
||||||
d = x.shape[1]
|
|
||||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
|
||||||
activation_ops.gelu_new(out, x)
|
activation_ops.gelu_new(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -36,9 +34,7 @@ class NewGELU(nn.Module):
|
|||||||
class FastGELU(nn.Module):
|
class FastGELU(nn.Module):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens = x.shape[0]
|
out = torch.empty_like(x)
|
||||||
d = x.shape[1]
|
|
||||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
|
||||||
activation_ops.gelu_fast(out, x)
|
activation_ops.gelu_fast(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -23,25 +23,9 @@ class PagedAttention(nn.Module):
|
|||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
"""GPT-style multi-head PagedAttention.
|
"""GPT-style multi-head PagedAttention.
|
||||||
|
|
||||||
This class takes flattened 1D query, key, and value tensors as input. The
|
This class takes query, key, and value tensors as input. The input tensors
|
||||||
input 1D tensors can either contain prompt tokens or generation tokens, in
|
can either contain prompt tokens or generation tokens, in addition to
|
||||||
addition to paddings.
|
paddings.
|
||||||
|
|
||||||
If the input tensors contain prompt tokens, the layout is as follows:
|
|
||||||
|
|
||||||
|<---------------------- num_valid_tokens ---------------------->|
|
|
||||||
|<--------------- num_prompt_tokens -------------->|
|
|
||||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
|
|
||||||
|
|
||||||
Otherwise, the layout is as follows:
|
|
||||||
|
|
||||||
|<------------------ num_valid_tokens ------------------->|
|
|
||||||
|<------- num_generation_tokens (M) ------->|
|
|
||||||
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
|
||||||
|
|
||||||
The prompts might have different lengths, while the generation tokens always
|
|
||||||
have length 1. The paddings are appended to make the input length a multiple
|
|
||||||
of 8, which is desirable for Tensor Cores.
|
|
||||||
|
|
||||||
The class does the following:
|
The class does the following:
|
||||||
1. Perform multi_query_kv_attention for the prompts. This operation does
|
1. Perform multi_query_kv_attention for the prompts. This operation does
|
||||||
@ -53,7 +37,7 @@ class PagedAttention(nn.Module):
|
|||||||
4. Perform single_query_cached_kv_attention for the generation tokens.
|
4. Perform single_query_cached_kv_attention for the generation tokens.
|
||||||
This operation reads the previous key and value tensors from the KV
|
This operation reads the previous key and value tensors from the KV
|
||||||
cache.
|
cache.
|
||||||
5. Output a flattened 1D tensor.
|
5. Return the output tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -85,14 +69,15 @@ class PagedAttention(nn.Module):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
del dtype # Unused.
|
del dtype # Unused.
|
||||||
if input_metadata.attn_bias:
|
if input_metadata.attn_bias is not None:
|
||||||
# Already set by a previous layer.
|
# Already set by a previous layer.
|
||||||
return
|
return
|
||||||
prompt_lens = input_metadata.prompt_lens
|
prompt_lens = [input_metadata.max_prompt_len
|
||||||
|
] * input_metadata.num_prompts
|
||||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
||||||
if self.sliding_window is not None:
|
if self.sliding_window is not None:
|
||||||
attn_bias = attn_bias.make_local_attention(self.sliding_window)
|
attn_bias = attn_bias.make_local_attention(self.sliding_window)
|
||||||
input_metadata.attn_bias.append(attn_bias)
|
input_metadata.attn_bias = attn_bias
|
||||||
|
|
||||||
def multi_query_kv_attention(
|
def multi_query_kv_attention(
|
||||||
self,
|
self,
|
||||||
@ -111,7 +96,6 @@ class PagedAttention(nn.Module):
|
|||||||
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||||
input_metadata: metadata for paged attention.
|
input_metadata: metadata for paged attention.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.num_kv_heads != self.num_heads:
|
if self.num_kv_heads != self.num_heads:
|
||||||
# Project the key and value tensors to the desired number of heads.
|
# Project the key and value tensors to the desired number of heads.
|
||||||
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
|
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
|
||||||
@ -124,7 +108,7 @@ class PagedAttention(nn.Module):
|
|||||||
query.unsqueeze(0),
|
query.unsqueeze(0),
|
||||||
key.unsqueeze(0),
|
key.unsqueeze(0),
|
||||||
value.unsqueeze(0),
|
value.unsqueeze(0),
|
||||||
attn_bias=input_metadata.attn_bias[0],
|
attn_bias=input_metadata.attn_bias,
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
)
|
)
|
||||||
@ -232,12 +216,12 @@ class PagedAttention(nn.Module):
|
|||||||
"""PagedAttention forward pass.
|
"""PagedAttention forward pass.
|
||||||
|
|
||||||
NOTE: The query, key, and value tensors must be sliced from a qkv
|
NOTE: The query, key, and value tensors must be sliced from a qkv
|
||||||
tensor of shape [num_tokens, 3 * num_heads * head_size].
|
tensor of shape [batch_size, seq_len, 3 * num_heads * head_size].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: shape = [num_tokens, num_heads * head_size]
|
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [batch_size, num_kv_heads * head_size]
|
||||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
block_size, x]
|
block_size, x]
|
||||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||||
@ -246,9 +230,9 @@ class PagedAttention(nn.Module):
|
|||||||
cache_event: event to wait for the cache operations to finish.
|
cache_event: event to wait for the cache operations to finish.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
|
batch_size, seq_len, _ = query.shape
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
@ -264,10 +248,10 @@ class PagedAttention(nn.Module):
|
|||||||
assert input_metadata.num_generation_tokens == 0
|
assert input_metadata.num_generation_tokens == 0
|
||||||
self.set_attn_bias(input_metadata, dtype=query.dtype)
|
self.set_attn_bias(input_metadata, dtype=query.dtype)
|
||||||
self.multi_query_kv_attention(
|
self.multi_query_kv_attention(
|
||||||
output[:num_prompt_tokens],
|
output,
|
||||||
query[:num_prompt_tokens],
|
query,
|
||||||
key[:num_prompt_tokens],
|
key,
|
||||||
value[:num_prompt_tokens],
|
value,
|
||||||
input_metadata,
|
input_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -278,13 +262,10 @@ class PagedAttention(nn.Module):
|
|||||||
# Reshape the keys and values and store them in the cache.
|
# Reshape the keys and values and store them in the cache.
|
||||||
# When key_cache and value_cache are not provided, the new key
|
# When key_cache and value_cache are not provided, the new key
|
||||||
# and value vectors will not be cached.
|
# and value vectors will not be cached.
|
||||||
num_valid_tokens = input_metadata.num_valid_tokens
|
if key_cache is not None and value_cache is not None:
|
||||||
if (num_valid_tokens > 0 and key_cache is not None
|
key_to_cache = key
|
||||||
and value_cache is not None):
|
value_to_cache = value
|
||||||
# The stride is 3 because the key and value are sliced from qkv.
|
slot_mapping = input_metadata.slot_mapping.view(-1)
|
||||||
key_to_cache = key[:num_valid_tokens]
|
|
||||||
value_to_cache = value[:num_valid_tokens]
|
|
||||||
slot_mapping = input_metadata.slot_mapping
|
|
||||||
if input_metadata.to_cache is not None:
|
if input_metadata.to_cache is not None:
|
||||||
key_to_cache = key_to_cache[input_metadata.to_cache]
|
key_to_cache = key_to_cache[input_metadata.to_cache]
|
||||||
value_to_cache = value_to_cache[input_metadata.to_cache]
|
value_to_cache = value_to_cache[input_metadata.to_cache]
|
||||||
@ -305,14 +286,14 @@ class PagedAttention(nn.Module):
|
|||||||
"key_cache and value_cache must be provided when "
|
"key_cache and value_cache must be provided when "
|
||||||
"generating tokens.")
|
"generating tokens.")
|
||||||
# Compute the attention op for generation tokens.
|
# Compute the attention op for generation tokens.
|
||||||
self.single_query_cached_kv_attention(
|
self.single_query_cached_kv_attention(output, query, key_cache,
|
||||||
output[num_prompt_tokens:num_valid_tokens],
|
value_cache, input_metadata,
|
||||||
query[num_prompt_tokens:num_valid_tokens], key_cache,
|
self.get_alibi_slopes())
|
||||||
value_cache, input_metadata, self.get_alibi_slopes())
|
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
# NOTE(woosuk): The output tensor may include paddings.
|
# NOTE(woosuk): The output tensor may include paddings.
|
||||||
return output.view(-1, self.num_heads * self.head_size)
|
return output.view(batch_size, seq_len,
|
||||||
|
self.num_heads * self.head_size)
|
||||||
|
|
||||||
|
|
||||||
class PagedAttentionWithRoPE(PagedAttention):
|
class PagedAttentionWithRoPE(PagedAttention):
|
||||||
@ -368,10 +349,10 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
""" PagedAttention forward pass with rotary embedding.
|
""" PagedAttention forward pass with rotary embedding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
positions: shape = [num_tokens]
|
positions: shape = [batch_size, seq_len]
|
||||||
query: shape = [num_tokens, num_heads * head_size]
|
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
block_size, x]
|
block_size, x]
|
||||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||||
@ -380,7 +361,7 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
cache_event: event to wait for the cache operations to finish.
|
cache_event: event to wait for the cache operations to finish.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Apply rotary embedding to the query and key before passing them
|
# Apply rotary embedding to the query and key before passing them
|
||||||
@ -414,12 +395,12 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
|
|
||||||
def set_attn_bias(self, input_metadata: InputMetadata,
|
def set_attn_bias(self, input_metadata: InputMetadata,
|
||||||
dtype: torch.dtype) -> None:
|
dtype: torch.dtype) -> None:
|
||||||
if input_metadata.attn_bias:
|
if input_metadata.attn_bias is not None:
|
||||||
# Already set by a previous layer.
|
# Already set by a previous layer.
|
||||||
return
|
return
|
||||||
# Generates ALiBi mask for each prompt.
|
# Generates ALiBi mask based on the max prompt length.
|
||||||
for prompt_len in input_metadata.prompt_lens:
|
max_prompt_len = input_metadata.max_prompt_len
|
||||||
bias = torch.arange(prompt_len, dtype=dtype)
|
bias = torch.arange(max_prompt_len, dtype=dtype)
|
||||||
# NOTE(zhuohan): HF uses
|
# NOTE(zhuohan): HF uses
|
||||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||||
# here. We find that both biases give the same results, but
|
# here. We find that both biases give the same results, but
|
||||||
@ -430,18 +411,18 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
|
|
||||||
# When using custom attention bias, xformers requires the bias to
|
# When using custom attention bias, xformers requires the bias to
|
||||||
# be sliced from a tensor whose length is a multiple of 8.
|
# be sliced from a tensor whose length is a multiple of 8.
|
||||||
padded_len = (prompt_len + 7) // 8 * 8
|
padded_len = (max_prompt_len + 7) // 8 * 8
|
||||||
bias = torch.empty(
|
bias = torch.empty(
|
||||||
1, # batch_size
|
input_metadata.num_prompts,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
prompt_len,
|
max_prompt_len,
|
||||||
padded_len,
|
padded_len,
|
||||||
device=self.alibi_slopes.device,
|
device=self.alibi_slopes.device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)[:, :, :, :prompt_len].copy_(bias)
|
)[:, :, :, :max_prompt_len].copy_(bias)
|
||||||
bias.mul_(self.alibi_slopes[:, None, None])
|
bias.mul_(self.alibi_slopes[:, None, None])
|
||||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||||
input_metadata.attn_bias.append(attn_bias)
|
input_metadata.attn_bias = attn_bias
|
||||||
|
|
||||||
def multi_query_kv_attention(
|
def multi_query_kv_attention(
|
||||||
self,
|
self,
|
||||||
@ -466,24 +447,19 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
value = torch.repeat_interleave(value,
|
value = torch.repeat_interleave(value,
|
||||||
self.num_queries_per_kv,
|
self.num_queries_per_kv,
|
||||||
dim=1)
|
dim=1)
|
||||||
|
batch_size = input_metadata.num_prompts
|
||||||
|
seq_len = input_metadata.max_prompt_len
|
||||||
|
|
||||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
|
||||||
# lengths with custom attention bias, we process each prompt one by
|
|
||||||
# one. This is inefficient, especially when we have many short prompts.
|
|
||||||
start = 0
|
|
||||||
for i, prompt_len in enumerate(input_metadata.prompt_lens):
|
|
||||||
end = start + prompt_len
|
|
||||||
out = xops.memory_efficient_attention_forward(
|
out = xops.memory_efficient_attention_forward(
|
||||||
query[None, start:end],
|
query.view(batch_size, seq_len, self.num_heads, self.head_size),
|
||||||
key[None, start:end],
|
key.view(batch_size, seq_len, self.num_heads, self.head_size),
|
||||||
value[None, start:end],
|
value.view(batch_size, seq_len, self.num_heads, self.head_size),
|
||||||
attn_bias=input_metadata.attn_bias[i],
|
attn_bias=input_metadata.attn_bias,
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||||
output[start:end].copy_(out.squeeze(0))
|
output.copy_(out.view(-1, self.num_heads, self.head_size))
|
||||||
start += prompt_len
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_alibi_slopes(self) -> Optional[torch.Tensor]:
|
def get_alibi_slopes(self) -> Optional[torch.Tensor]:
|
||||||
|
@ -50,7 +50,7 @@ class AWQColumnParallelLinear(ColumnParallelLinear):
|
|||||||
bias: Optional[torch.Tensor],
|
bias: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
pack_factor = self.quant_config.pack_factor
|
pack_factor = self.quant_config.pack_factor
|
||||||
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
|
out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
|
||||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||||
self.qzeros, pack_factor)
|
self.qzeros, pack_factor)
|
||||||
@ -95,7 +95,7 @@ class AWQRowParallelLinear(RowParallelLinear):
|
|||||||
|
|
||||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
pack_factor = self.quant_config.pack_factor
|
pack_factor = self.quant_config.pack_factor
|
||||||
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
|
out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
|
||||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||||
self.qzeros, pack_factor)
|
self.qzeros, pack_factor)
|
||||||
|
@ -119,7 +119,7 @@ def _prune_hidden_states(
|
|||||||
selected_token_indices.extend(
|
selected_token_indices.extend(
|
||||||
range(start_idx, start_idx + prompt_len - 1))
|
range(start_idx, start_idx + prompt_len - 1))
|
||||||
selected_token_indices.append(start_idx + prompt_len - 1)
|
selected_token_indices.append(start_idx + prompt_len - 1)
|
||||||
start_idx += prompt_len
|
start_idx += input_metadata.max_prompt_len
|
||||||
else:
|
else:
|
||||||
num_seqs = len(seq_ids)
|
num_seqs = len(seq_ids)
|
||||||
selected_token_indices.extend(
|
selected_token_indices.extend(
|
||||||
@ -129,6 +129,7 @@ def _prune_hidden_states(
|
|||||||
selected_token_indices = torch.tensor(selected_token_indices,
|
selected_token_indices = torch.tensor(selected_token_indices,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
return hidden_states.index_select(0, selected_token_indices)
|
return hidden_states.index_select(0, selected_token_indices)
|
||||||
|
|
||||||
|
|
||||||
|
@ -158,9 +158,9 @@ class Worker:
|
|||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
|
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
|
||||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||||
input_tokens: List[int] = []
|
input_tokens: List[List[int]] = []
|
||||||
input_positions: List[int] = []
|
input_positions: List[List[int]] = []
|
||||||
slot_mapping: List[int] = []
|
slot_mapping: List[List[int]] = []
|
||||||
|
|
||||||
# Add prompt tokens.
|
# Add prompt tokens.
|
||||||
prompt_lens: List[int] = []
|
prompt_lens: List[int] = []
|
||||||
@ -180,24 +180,25 @@ class Worker:
|
|||||||
prompt_len = len(prompt_tokens)
|
prompt_len = len(prompt_tokens)
|
||||||
prompt_lens.append(prompt_len)
|
prompt_lens.append(prompt_len)
|
||||||
|
|
||||||
input_tokens.extend(prompt_tokens)
|
input_tokens.append(prompt_tokens)
|
||||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||||
# is always the first token in the sequence.
|
# is always the first token in the sequence.
|
||||||
input_positions.extend(range(len(prompt_tokens)))
|
input_positions.append(list(range(prompt_len)))
|
||||||
|
|
||||||
if seq_group_metadata.block_tables is None:
|
if seq_group_metadata.block_tables is None:
|
||||||
# During memory profiling, the block tables are not initialized
|
# During memory profiling, the block tables are not initialized
|
||||||
# yet. In this case, we just use a dummy slot mapping.
|
# yet. In this case, we just use a dummy slot mapping.
|
||||||
slot_mapping.extend([0] * prompt_len)
|
slot_mapping.append([0] * prompt_len)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Compute the slot mapping.
|
# Compute the slot mapping.
|
||||||
|
slot_mapping.append([])
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
for i in range(prompt_len):
|
for i in range(prompt_len):
|
||||||
block_number = block_table[i // self.block_size]
|
block_number = block_table[i // self.block_size]
|
||||||
block_offset = i % self.block_size
|
block_offset = i % self.block_size
|
||||||
slot = block_number * self.block_size + block_offset
|
slot = block_number * self.block_size + block_offset
|
||||||
slot_mapping.append(slot)
|
slot_mapping[-1].append(slot)
|
||||||
|
|
||||||
# Add generation tokens.
|
# Add generation tokens.
|
||||||
max_context_len = 0
|
max_context_len = 0
|
||||||
@ -215,13 +216,13 @@ class Worker:
|
|||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
generation_token = seq_data.get_last_token_id()
|
generation_token = seq_data.get_last_token_id()
|
||||||
input_tokens.append(generation_token)
|
input_tokens.append([generation_token])
|
||||||
|
|
||||||
context_len = seq_data.get_len()
|
context_len = seq_data.get_len()
|
||||||
position = context_len - 1
|
position = context_len - 1
|
||||||
if self.sliding_window is not None:
|
if self.sliding_window is not None:
|
||||||
context_len = min(context_len, self.sliding_window)
|
context_len = min(context_len, self.sliding_window)
|
||||||
input_positions.append(position)
|
input_positions.append([position])
|
||||||
|
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
|
||||||
@ -233,7 +234,7 @@ class Worker:
|
|||||||
block_number = block_table[position // self.block_size]
|
block_number = block_table[position // self.block_size]
|
||||||
block_offset = position % self.block_size
|
block_offset = position % self.block_size
|
||||||
slot = block_number * self.block_size + block_offset
|
slot = block_number * self.block_size + block_offset
|
||||||
slot_mapping.append(slot)
|
slot_mapping.append([slot])
|
||||||
|
|
||||||
if self.sliding_window is not None:
|
if self.sliding_window is not None:
|
||||||
sliding_window_blocks = (self.sliding_window //
|
sliding_window_blocks = (self.sliding_window //
|
||||||
@ -241,28 +242,36 @@ class Worker:
|
|||||||
block_table = block_table[-sliding_window_blocks:]
|
block_table = block_table[-sliding_window_blocks:]
|
||||||
generation_block_tables.append(block_table)
|
generation_block_tables.append(block_table)
|
||||||
|
|
||||||
# Optimization: Pad the input length to be a multiple of 8.
|
max_seq_len = max(prompt_lens) if prompt_lens else 1
|
||||||
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
|
padded_input_tokens = [
|
||||||
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
|
_pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
|
||||||
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
|
]
|
||||||
|
padded_input_positions = [
|
||||||
|
_pad_to_max(positions, max_seq_len, pad=0)
|
||||||
|
for positions in input_positions
|
||||||
|
]
|
||||||
|
padded_slot_mapping = [
|
||||||
|
_pad_to_max(mapping, max_seq_len, pad=-1)
|
||||||
|
for mapping in slot_mapping
|
||||||
|
]
|
||||||
|
padded_block_tables = [
|
||||||
|
_pad_to_max(block_table, max_num_blocks_per_seq, pad=0)
|
||||||
|
for block_table in generation_block_tables
|
||||||
|
]
|
||||||
|
|
||||||
# Convert to tensors.
|
# Convert to tensors.
|
||||||
tokens_tensor = torch.tensor(input_tokens,
|
tokens_tensor = torch.tensor(padded_input_tokens,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
positions_tensor = torch.tensor(input_positions,
|
positions_tensor = torch.tensor(padded_input_positions,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
slot_mapping_tensor = torch.tensor(slot_mapping,
|
slot_mapping_tensor = torch.tensor(padded_slot_mapping,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
context_lens_tensor = torch.tensor(context_lens,
|
context_lens_tensor = torch.tensor(context_lens,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
padded_block_tables = [
|
|
||||||
_pad_to_max(block_table, max_num_blocks_per_seq)
|
|
||||||
for block_table in generation_block_tables
|
|
||||||
]
|
|
||||||
block_tables_tensor = torch.tensor(padded_block_tables,
|
block_tables_tensor = torch.tensor(padded_block_tables,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
@ -361,12 +370,12 @@ def _init_distributed_environment(
|
|||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
|
||||||
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]:
|
||||||
return x + [0] * ((-len(x)) % multiple_of)
|
return x + [pad] * ((-len(x)) % multiple_of)
|
||||||
|
|
||||||
|
|
||||||
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
|
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||||
return x + [0] * (max_len - len(x))
|
return x + [pad] * (max_len - len(x))
|
||||||
|
|
||||||
|
|
||||||
def _check_if_can_support_max_seq_len(max_seq_len: int,
|
def _check_if_can_support_max_seq_len(max_seq_len: int,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user