[Core][CUDA Graph] add output buffer for cudagraph (#5074)

[Core][CUDA Graph] add output buffer for cudagraph to reduce memory footprint (#5074)
This commit is contained in:
youkaichao 2024-06-08 19:14:43 -07:00 committed by GitHub
parent c09dade2a2
commit 0373e1837e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,4 @@
import gc
import time import time
import warnings import warnings
from collections import defaultdict from collections import defaultdict
@ -894,6 +895,10 @@ class ModelRunner:
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda()
# Prepare buffer for outputs. These will be reused for all batch sizes.
# It will be filled after the first graph capture.
hidden_states: Optional[torch.Tensor] = None
graph_batch_size = _get_graph_batch_size( graph_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs) self.scheduler_config.max_num_seqs)
batch_size_capture_list = [ batch_size_capture_list = [
@ -930,9 +935,11 @@ class ModelRunner:
self.set_active_loras(set(), lora_mapping) self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model) graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture( hidden_states = graph_runner.capture(
input_tokens[:batch_size], input_tokens[:batch_size],
input_positions[:batch_size], input_positions[:batch_size],
hidden_states[:batch_size]
if hidden_states is not None else None,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
memory_pool=self.graph_memory_pool, memory_pool=self.graph_memory_pool,
@ -969,12 +976,13 @@ class CUDAGraphRunner:
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]], memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream, stream: torch.cuda.Stream,
**kwargs, **kwargs,
) -> None: ) -> torch.Tensor:
assert self._graph is None assert self._graph is None
# Run the model a few times without capturing the graph. # Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the # This is to make sure that the captured graph does not include the
@ -993,13 +1001,21 @@ class CUDAGraphRunner:
# Capture the graph. # Capture the graph.
self._graph = torch.cuda.CUDAGraph() self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
hidden_states = self.model( output_hidden_states = self.model(
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
**kwargs, **kwargs,
) )
if hidden_states is not None:
hidden_states.copy_(output_hidden_states)
else:
hidden_states = output_hidden_states
del output_hidden_states
# make sure `output_hidden_states` is deleted
# in the graph's memory pool
gc.collect()
torch.cuda.synchronize() torch.cuda.synchronize()
# Save the input and output buffers. # Save the input and output buffers.
@ -1012,7 +1028,7 @@ class CUDAGraphRunner:
"block_tables": attn_metadata.decode_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
} }
self.output_buffers = {"hidden_states": hidden_states} self.output_buffers = {"hidden_states": hidden_states}
return return hidden_states
def forward( def forward(
self, self,