[Core][CUDA Graph] add output buffer for cudagraph (#5074)
[Core][CUDA Graph] add output buffer for cudagraph to reduce memory footprint (#5074)
This commit is contained in:
parent
c09dade2a2
commit
0373e1837e
@ -1,3 +1,4 @@
|
|||||||
|
import gc
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -894,6 +895,10 @@ class ModelRunner:
|
|||||||
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
||||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||||
|
|
||||||
|
# Prepare buffer for outputs. These will be reused for all batch sizes.
|
||||||
|
# It will be filled after the first graph capture.
|
||||||
|
hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
graph_batch_size = _get_graph_batch_size(
|
graph_batch_size = _get_graph_batch_size(
|
||||||
self.scheduler_config.max_num_seqs)
|
self.scheduler_config.max_num_seqs)
|
||||||
batch_size_capture_list = [
|
batch_size_capture_list = [
|
||||||
@ -930,9 +935,11 @@ class ModelRunner:
|
|||||||
self.set_active_loras(set(), lora_mapping)
|
self.set_active_loras(set(), lora_mapping)
|
||||||
|
|
||||||
graph_runner = CUDAGraphRunner(self.model)
|
graph_runner = CUDAGraphRunner(self.model)
|
||||||
graph_runner.capture(
|
hidden_states = graph_runner.capture(
|
||||||
input_tokens[:batch_size],
|
input_tokens[:batch_size],
|
||||||
input_positions[:batch_size],
|
input_positions[:batch_size],
|
||||||
|
hidden_states[:batch_size]
|
||||||
|
if hidden_states is not None else None,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
memory_pool=self.graph_memory_pool,
|
memory_pool=self.graph_memory_pool,
|
||||||
@ -969,12 +976,13 @@ class CUDAGraphRunner:
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
hidden_states: Optional[torch.Tensor],
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
memory_pool: Optional[Tuple[int, int]],
|
memory_pool: Optional[Tuple[int, int]],
|
||||||
stream: torch.cuda.Stream,
|
stream: torch.cuda.Stream,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> torch.Tensor:
|
||||||
assert self._graph is None
|
assert self._graph is None
|
||||||
# Run the model a few times without capturing the graph.
|
# Run the model a few times without capturing the graph.
|
||||||
# This is to make sure that the captured graph does not include the
|
# This is to make sure that the captured graph does not include the
|
||||||
@ -993,13 +1001,21 @@ class CUDAGraphRunner:
|
|||||||
# Capture the graph.
|
# Capture the graph.
|
||||||
self._graph = torch.cuda.CUDAGraph()
|
self._graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
||||||
hidden_states = self.model(
|
output_hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
if hidden_states is not None:
|
||||||
|
hidden_states.copy_(output_hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = output_hidden_states
|
||||||
|
del output_hidden_states
|
||||||
|
# make sure `output_hidden_states` is deleted
|
||||||
|
# in the graph's memory pool
|
||||||
|
gc.collect()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# Save the input and output buffers.
|
# Save the input and output buffers.
|
||||||
@ -1012,7 +1028,7 @@ class CUDAGraphRunner:
|
|||||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||||
}
|
}
|
||||||
self.output_buffers = {"hidden_states": hidden_states}
|
self.output_buffers = {"hidden_states": hidden_states}
|
||||||
return
|
return hidden_states
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user