From 0373e1837e1a85c595fa9fc67c775bc6cbe105a2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 8 Jun 2024 19:14:43 -0700 Subject: [PATCH] [Core][CUDA Graph] add output buffer for cudagraph (#5074) [Core][CUDA Graph] add output buffer for cudagraph to reduce memory footprint (#5074) --- vllm/worker/model_runner.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c59288b4..7879a5de 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,3 +1,4 @@ +import gc import time import warnings from collections import defaultdict @@ -894,6 +895,10 @@ class ModelRunner: seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() + # Prepare buffer for outputs. These will be reused for all batch sizes. + # It will be filled after the first graph capture. + hidden_states: Optional[torch.Tensor] = None + graph_batch_size = _get_graph_batch_size( self.scheduler_config.max_num_seqs) batch_size_capture_list = [ @@ -930,9 +935,11 @@ class ModelRunner: self.set_active_loras(set(), lora_mapping) graph_runner = CUDAGraphRunner(self.model) - graph_runner.capture( + hidden_states = graph_runner.capture( input_tokens[:batch_size], input_positions[:batch_size], + hidden_states[:batch_size] + if hidden_states is not None else None, kv_caches, attn_metadata, memory_pool=self.graph_memory_pool, @@ -969,12 +976,13 @@ class CUDAGraphRunner: self, input_ids: torch.Tensor, positions: torch.Tensor, + hidden_states: Optional[torch.Tensor], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, memory_pool: Optional[Tuple[int, int]], stream: torch.cuda.Stream, **kwargs, - ) -> None: + ) -> torch.Tensor: assert self._graph is None # Run the model a few times without capturing the graph. # This is to make sure that the captured graph does not include the @@ -993,13 +1001,21 @@ class CUDAGraphRunner: # Capture the graph. self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): - hidden_states = self.model( + output_hidden_states = self.model( input_ids, positions, kv_caches, attn_metadata, **kwargs, ) + if hidden_states is not None: + hidden_states.copy_(output_hidden_states) + else: + hidden_states = output_hidden_states + del output_hidden_states + # make sure `output_hidden_states` is deleted + # in the graph's memory pool + gc.collect() torch.cuda.synchronize() # Save the input and output buffers. @@ -1012,7 +1028,7 @@ class CUDAGraphRunner: "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} - return + return hidden_states def forward( self,