[Core] CUDA Graphs for Multi-Step + Chunked-Prefill (#8645)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
7f60520deb
commit
afb050b29d
@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
|
||||
long const* sampled_token_ids_ptr, long* input_positions_ptr,
|
||||
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
|
||||
int64_t const block_tables_stride) {
|
||||
int const n_pad = num_seqs - num_queries;
|
||||
if (n_pad && blockIdx.x == 0) {
|
||||
// Handle cuda graph padding
|
||||
int const offset = num_queries;
|
||||
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
|
||||
input_tokens_ptr[offset + i] = 0;
|
||||
input_positions_ptr[offset + i] = 0;
|
||||
slot_mapping_ptr[offset + i] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x >= num_query_blocks) {
|
||||
|
@ -500,6 +500,30 @@ class FlashAttentionMetadataBuilder(
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int,
|
||||
block_tables: List[List[int]]) -> torch.Tensor:
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
|
||||
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
|
||||
for i, block_table in enumerate(block_tables):
|
||||
if block_table:
|
||||
num_blocks = len(block_table)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[i, :num_blocks] = block_table
|
||||
else:
|
||||
# It may be possible to have more blocks allocated due
|
||||
# to lookahead slots of multi-step, however, they are
|
||||
# not used anyway, so can be safely ignored.
|
||||
graph_block_tables[
|
||||
i, :max_blocks] = block_table[:max_blocks]
|
||||
|
||||
return torch.from_numpy(graph_block_tables).to(
|
||||
device=self.runner.device, non_blocking=True)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
@ -533,29 +557,13 @@ class FlashAttentionMetadataBuilder(
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
||||
max_blocks = input_block_tables.shape[1]
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
if block_table:
|
||||
num_blocks = len(block_table)
|
||||
if num_blocks <= max_blocks:
|
||||
input_block_tables[i, :num_blocks] = block_table
|
||||
else:
|
||||
# It may be possible to have more blocks allocated due
|
||||
# to lookahead slots of multi-step, however, they are
|
||||
# not used anyway, so can be safely ignored.
|
||||
input_block_tables[
|
||||
i, :max_blocks] = block_table[:max_blocks]
|
||||
|
||||
block_tables = torch.from_numpy(input_block_tables).to(
|
||||
device=device, non_blocking=True)
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
|
@ -712,14 +712,62 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
def _use_captured_graph(self,
|
||||
batch_size: int,
|
||||
decode_only: bool,
|
||||
max_decode_seq_len: int,
|
||||
max_encoder_seq_len: int = 0) -> bool:
|
||||
return (self.decode_only and not self.runner.model_config.enforce_eager
|
||||
return (decode_only and not self.runner.model_config.enforce_eager
|
||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
and max_decode_seq_len <= self.runner.max_seq_len_to_capture
|
||||
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
|
||||
and batch_size <= self.runner.max_batchsize_to_capture)
|
||||
|
||||
def _get_cuda_graph_pad_size(self,
|
||||
num_seqs: int,
|
||||
max_decode_seq_len: int,
|
||||
max_encoder_seq_len: int = 0) -> int:
|
||||
"""
|
||||
Determine the number of padding sequences required for running in
|
||||
CUDA graph mode. Returns -1 if CUDA graphs cannot be used.
|
||||
|
||||
In the multi-step + chunked-prefill case, only the first step
|
||||
has Prefills (if any). The rest of the steps are guaranteed to be all
|
||||
decodes. In this case, we set up the padding as if all the sequences
|
||||
are decodes so we may run all steps except the first step in CUDA graph
|
||||
mode. The padding is accounted for in the multi-step `advance_step`
|
||||
family of functions.
|
||||
|
||||
Args:
|
||||
num_seqs (int): Number of sequences scheduled to run.
|
||||
max_decode_seq_len (int): Greatest of all the decode sequence
|
||||
lengths. Used only in checking the viablility of using
|
||||
CUDA graphs.
|
||||
max_encoder_seq_len (int, optional): Greatest of all the encode
|
||||
sequence lengths. Defaults to 0. Used only in checking the
|
||||
viability of using CUDA graphs.
|
||||
Returns:
|
||||
int: Returns the determined number of padding sequences. If
|
||||
CUDA graphs is not viable, returns -1.
|
||||
"""
|
||||
is_mscp: bool = self.runner.scheduler_config.is_multi_step and \
|
||||
self.runner.scheduler_config.chunked_prefill_enabled
|
||||
decode_only = self.decode_only or is_mscp
|
||||
if not decode_only:
|
||||
# Early exit so we can treat num_seqs as the batch_size below.
|
||||
return -1
|
||||
|
||||
# batch_size out of this function refers to the number of input
|
||||
# tokens being scheduled. This conflation of num_seqs as batch_size
|
||||
# is valid as this is a decode-only case.
|
||||
batch_size = num_seqs
|
||||
if not self._use_captured_graph(batch_size, decode_only,
|
||||
max_decode_seq_len,
|
||||
max_encoder_seq_len):
|
||||
return -1
|
||||
|
||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||
assert graph_batch_size >= batch_size
|
||||
return graph_batch_size - batch_size
|
||||
|
||||
def build(self) -> ModelInputForGPU:
|
||||
"""Finalize the builder intermediate data and
|
||||
create on-device tensors.
|
||||
@ -778,21 +826,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
for data in self.inter_data_list
|
||||
}
|
||||
|
||||
batch_size = len(input_tokens)
|
||||
use_captured_graph = self._use_captured_graph(
|
||||
batch_size,
|
||||
max_decode_seq_len,
|
||||
cuda_graph_pad_size = self._get_cuda_graph_pad_size(
|
||||
num_seqs=len(seq_lens),
|
||||
max_decode_seq_len=max_encoder_seq_len,
|
||||
max_encoder_seq_len=max_encoder_seq_len)
|
||||
|
||||
# If cuda graph can be used, pad tensors accordingly.
|
||||
# See `capture_model` API for more details.
|
||||
# vLLM uses cuda graph only for decoding requests.
|
||||
cuda_graph_pad_size = -1
|
||||
if use_captured_graph:
|
||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||
assert graph_batch_size >= batch_size
|
||||
cuda_graph_pad_size = graph_batch_size - batch_size
|
||||
batch_size = graph_batch_size
|
||||
batch_size = len(input_tokens)
|
||||
if cuda_graph_pad_size != -1:
|
||||
# If cuda graph can be used, pad tensors accordingly.
|
||||
# See `capture_model` API for more details.
|
||||
# vLLM uses cuda graph only for decoding requests.
|
||||
batch_size += cuda_graph_pad_size
|
||||
|
||||
# Tokens and positions.
|
||||
if cuda_graph_pad_size:
|
||||
|
Loading…
x
Reference in New Issue
Block a user