[Core] CUDA Graphs for Multi-Step + Chunked-Prefill (#8645)

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2024-10-02 15:44:39 -04:00 committed by GitHub
parent 7f60520deb
commit afb050b29d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 97 additions and 34 deletions

View File

@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
long const* sampled_token_ids_ptr, long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
int64_t const block_tables_stride) {
int const n_pad = num_seqs - num_queries;
if (n_pad && blockIdx.x == 0) {
// Handle cuda graph padding
int const offset = num_queries;
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
input_tokens_ptr[offset + i] = 0;
input_positions_ptr[offset + i] = 0;
slot_mapping_ptr[offset + i] = -1;
}
}
int num_query_blocks = div_ceil(num_queries, num_threads);
if (blockIdx.x >= num_query_blocks) {

View File

@ -500,6 +500,30 @@ class FlashAttentionMetadataBuilder(
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def _get_graph_runner_block_tables(
self, num_seqs: int,
block_tables: List[List[int]]) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]
return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
@ -533,29 +557,13 @@ class FlashAttentionMetadataBuilder(
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]
block_tables = torch.from_numpy(input_block_tables).to(
device=device, non_blocking=True)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,

View File

@ -712,14 +712,62 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def _use_captured_graph(self,
batch_size: int,
decode_only: bool,
max_decode_seq_len: int,
max_encoder_seq_len: int = 0) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager
return (decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
and batch_size <= self.runner.max_batchsize_to_capture)
def _get_cuda_graph_pad_size(self,
num_seqs: int,
max_decode_seq_len: int,
max_encoder_seq_len: int = 0) -> int:
"""
Determine the number of padding sequences required for running in
CUDA graph mode. Returns -1 if CUDA graphs cannot be used.
In the multi-step + chunked-prefill case, only the first step
has Prefills (if any). The rest of the steps are guaranteed to be all
decodes. In this case, we set up the padding as if all the sequences
are decodes so we may run all steps except the first step in CUDA graph
mode. The padding is accounted for in the multi-step `advance_step`
family of functions.
Args:
num_seqs (int): Number of sequences scheduled to run.
max_decode_seq_len (int): Greatest of all the decode sequence
lengths. Used only in checking the viablility of using
CUDA graphs.
max_encoder_seq_len (int, optional): Greatest of all the encode
sequence lengths. Defaults to 0. Used only in checking the
viability of using CUDA graphs.
Returns:
int: Returns the determined number of padding sequences. If
CUDA graphs is not viable, returns -1.
"""
is_mscp: bool = self.runner.scheduler_config.is_multi_step and \
self.runner.scheduler_config.chunked_prefill_enabled
decode_only = self.decode_only or is_mscp
if not decode_only:
# Early exit so we can treat num_seqs as the batch_size below.
return -1
# batch_size out of this function refers to the number of input
# tokens being scheduled. This conflation of num_seqs as batch_size
# is valid as this is a decode-only case.
batch_size = num_seqs
if not self._use_captured_graph(batch_size, decode_only,
max_decode_seq_len,
max_encoder_seq_len):
return -1
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
return graph_batch_size - batch_size
def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and
create on-device tensors.
@ -778,21 +826,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for data in self.inter_data_list
}
batch_size = len(input_tokens)
use_captured_graph = self._use_captured_graph(
batch_size,
max_decode_seq_len,
cuda_graph_pad_size = self._get_cuda_graph_pad_size(
num_seqs=len(seq_lens),
max_decode_seq_len=max_encoder_seq_len,
max_encoder_seq_len=max_encoder_seq_len)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
cuda_graph_pad_size = -1
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size
batch_size = graph_batch_size
batch_size = len(input_tokens)
if cuda_graph_pad_size != -1:
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
batch_size += cuda_graph_pad_size
# Tokens and positions.
if cuda_graph_pad_size: