[Core][CUDA Graph] add output buffer for cudagraph (#5074)
[Core][CUDA Graph] add output buffer for cudagraph to reduce memory footprint (#5074)
This commit is contained in:
parent
c09dade2a2
commit
0373e1837e
@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
@ -894,6 +895,10 @@ class ModelRunner:
|
||||
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
|
||||
# Prepare buffer for outputs. These will be reused for all batch sizes.
|
||||
# It will be filled after the first graph capture.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
graph_batch_size = _get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
batch_size_capture_list = [
|
||||
@ -930,9 +935,11 @@ class ModelRunner:
|
||||
self.set_active_loras(set(), lora_mapping)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
graph_runner.capture(
|
||||
hidden_states = graph_runner.capture(
|
||||
input_tokens[:batch_size],
|
||||
input_positions[:batch_size],
|
||||
hidden_states[:batch_size]
|
||||
if hidden_states is not None else None,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
@ -969,12 +976,13 @@ class CUDAGraphRunner:
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
memory_pool: Optional[Tuple[int, int]],
|
||||
stream: torch.cuda.Stream,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
) -> torch.Tensor:
|
||||
assert self._graph is None
|
||||
# Run the model a few times without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
@ -993,13 +1001,21 @@ class CUDAGraphRunner:
|
||||
# Capture the graph.
|
||||
self._graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
||||
hidden_states = self.model(
|
||||
output_hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
if hidden_states is not None:
|
||||
hidden_states.copy_(output_hidden_states)
|
||||
else:
|
||||
hidden_states = output_hidden_states
|
||||
del output_hidden_states
|
||||
# make sure `output_hidden_states` is deleted
|
||||
# in the graph's memory pool
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Save the input and output buffers.
|
||||
@ -1012,7 +1028,7 @@ class CUDAGraphRunner:
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
self.output_buffers = {"hidden_states": hidden_states}
|
||||
return
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user