[core] cudagraph output with tensor weak reference (#9724)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
67a6882da4
commit
8549c82660
24
csrc/ops.h
24
csrc/ops.h
@ -5,6 +5,30 @@
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#include <vector>
|
||||
|
||||
torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
|
||||
// Ensure tensor is on CUDA
|
||||
if (!tensor.is_cuda()) {
|
||||
throw std::runtime_error("Tensor must be on CUDA device");
|
||||
}
|
||||
|
||||
// Get the raw data pointer
|
||||
void* data_ptr = tensor.data_ptr();
|
||||
|
||||
// Get tensor sizes and strides
|
||||
std::vector<int64_t> sizes = tensor.sizes().vec();
|
||||
std::vector<int64_t> strides = tensor.strides().vec();
|
||||
|
||||
// Get tensor options (dtype, device)
|
||||
auto options = tensor.options();
|
||||
|
||||
// Create a new tensor from the raw data pointer
|
||||
auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);
|
||||
|
||||
return new_tensor;
|
||||
}
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
|
@ -18,6 +18,9 @@
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||
|
||||
// Attention ops
|
||||
// Compute the attention between an input query and the cached
|
||||
// keys/values using PagedAttention.
|
||||
|
@ -1479,3 +1479,12 @@ class LazyDict(Mapping, Generic[T]):
|
||||
|
||||
def __len__(self):
|
||||
return len(self._factory)
|
||||
|
||||
|
||||
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Create a weak reference to a tensor.
|
||||
The new tensor will share the same data as the original tensor,
|
||||
but will not keep the original tensor alive.
|
||||
"""
|
||||
return torch.ops._C.weak_ref_tensor(tensor)
|
||||
|
@ -50,7 +50,7 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
|
||||
flatten_2d_lists, is_hip, is_pin_memory_available,
|
||||
supports_dynamo)
|
||||
supports_dynamo, weak_ref_tensor)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
@ -1426,12 +1426,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
# Prepare buffer for outputs. These will be reused for all batch sizes.
|
||||
# It will be filled after the first graph capture.
|
||||
hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
|
||||
None
|
||||
] * self.parallel_config.pipeline_parallel_size
|
||||
|
||||
graph_batch_size = self.max_batchsize_to_capture
|
||||
batch_size_capture_list = [
|
||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||
@ -1474,12 +1468,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
input_tokens[:batch_size],
|
||||
"positions":
|
||||
input_positions[..., :batch_size],
|
||||
"hidden_or_intermediate_states":
|
||||
hidden_or_intermediate_states[
|
||||
virtual_engine] # type: ignore
|
||||
[:batch_size]
|
||||
if hidden_or_intermediate_states[virtual_engine]
|
||||
is not None else None,
|
||||
"intermediate_inputs":
|
||||
intermediate_inputs[:batch_size]
|
||||
if intermediate_inputs is not None else None,
|
||||
@ -1762,15 +1750,13 @@ class CUDAGraphRunner(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
|
||||
torch.Tensor]],
|
||||
intermediate_inputs: Optional[IntermediateTensors],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
memory_pool: Optional[Tuple[int, int]],
|
||||
stream: torch.cuda.Stream,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
):
|
||||
assert self._graph is None
|
||||
# Run the model a few times without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
@ -1799,20 +1785,21 @@ class CUDAGraphRunner(nn.Module):
|
||||
intermediate_tensors=intermediate_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
if hidden_or_intermediate_states is not None:
|
||||
if get_pp_group().is_last_rank:
|
||||
hidden_or_intermediate_states.copy_(
|
||||
output_hidden_or_intermediate_states)
|
||||
else:
|
||||
for key in hidden_or_intermediate_states.tensors:
|
||||
hidden_or_intermediate_states[key].copy_(
|
||||
output_hidden_or_intermediate_states[key])
|
||||
else:
|
||||
hidden_or_intermediate_states = (
|
||||
|
||||
if isinstance(output_hidden_or_intermediate_states, torch.Tensor):
|
||||
hidden_or_intermediate_states = weak_ref_tensor(
|
||||
output_hidden_or_intermediate_states)
|
||||
elif isinstance(output_hidden_or_intermediate_states,
|
||||
IntermediateTensors):
|
||||
hidden_or_intermediate_states = IntermediateTensors(
|
||||
tensors={
|
||||
key: weak_ref_tensor(value)
|
||||
for key, value in
|
||||
output_hidden_or_intermediate_states.tensors.items()
|
||||
})
|
||||
|
||||
del output_hidden_or_intermediate_states
|
||||
# make sure `output_hidden_states` is deleted
|
||||
# make sure `output_hidden_or_intermediate_states` is deleted
|
||||
# in the graph's memory pool
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
@ -1837,7 +1824,6 @@ class CUDAGraphRunner(nn.Module):
|
||||
}
|
||||
else:
|
||||
self.output_buffers = hidden_or_intermediate_states
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user