[Core][CUDA Graph] add output buffer for cudagraph (#5074)

[Core][CUDA Graph] add output buffer for cudagraph to reduce memory footprint (#5074)
This commit is contained in:
youkaichao 2024-06-08 19:14:43 -07:00 committed by GitHub
parent c09dade2a2
commit 0373e1837e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,4 @@
import gc
import time
import warnings
from collections import defaultdict
@ -894,6 +895,10 @@ class ModelRunner:
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
# Prepare buffer for outputs. These will be reused for all batch sizes.
# It will be filled after the first graph capture.
hidden_states: Optional[torch.Tensor] = None
graph_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [
@ -930,9 +935,11 @@ class ModelRunner:
self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
hidden_states = graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
hidden_states[:batch_size]
if hidden_states is not None else None,
kv_caches,
attn_metadata,
memory_pool=self.graph_memory_pool,
@ -969,12 +976,13 @@ class CUDAGraphRunner:
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream,
**kwargs,
) -> None:
) -> torch.Tensor:
assert self._graph is None
# Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the
@ -993,13 +1001,21 @@ class CUDAGraphRunner:
# Capture the graph.
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
hidden_states = self.model(
output_hidden_states = self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
**kwargs,
)
if hidden_states is not None:
hidden_states.copy_(output_hidden_states)
else:
hidden_states = output_hidden_states
del output_hidden_states
# make sure `output_hidden_states` is deleted
# in the graph's memory pool
gc.collect()
torch.cuda.synchronize()
# Save the input and output buffers.
@ -1012,7 +1028,7 @@ class CUDAGraphRunner:
"block_tables": attn_metadata.decode_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
return
return hidden_states
def forward(
self,