[Encoder decoder] Add cuda graph support during decoding for encoder-decoder models (#7631)
This commit is contained in:
parent
1b6de8352b
commit
1009e93c5d
@ -252,6 +252,13 @@ steps:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- bash ./run-tests.sh -c configs/models-small.txt -t 1
|
||||
|
||||
- label: Encoder Decoder tests # 5min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/encoder_decoder
|
||||
commands:
|
||||
- pytest -v -s encoder_decoder
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 20 min
|
||||
fast_check: false
|
||||
mirror_hardwares: [ amd ]
|
||||
|
0
tests/encoder_decoder/__init__.py
Normal file
0
tests/encoder_decoder/__init__.py
Normal file
98
tests/encoder_decoder/test_e2e_correctness.py
Normal file
98
tests/encoder_decoder/test_e2e_correctness.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""E2E tests to verify the correctness of the encoder-decoder framework
|
||||
|
||||
Run `pytest tests/encoder_decoder/test_e2e_correctness.py`.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from ..conftest import DecoderPromptType
|
||||
from ..models.utils import check_logprobs_close
|
||||
|
||||
|
||||
def vllm_to_hf_output(
|
||||
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
|
||||
decoder_prompt_type: DecoderPromptType,
|
||||
):
|
||||
"""Sanitize vllm output to be comparable with hf output."""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
|
||||
hf_output_str = output_str + "</s>"
|
||||
if decoder_prompt_type == DecoderPromptType.NONE:
|
||||
hf_output_str = "<s>" + hf_output_str
|
||||
|
||||
return output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
is_cpu(),
|
||||
reason="CPU backend is not currently supported with encoder/decoder models"
|
||||
)
|
||||
def test_encoder_decoder_e2e(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_encoder_decoder_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
decoder_prompt_type: DecoderPromptType,
|
||||
enforce_eager: bool,
|
||||
) -> None:
|
||||
'''
|
||||
End-to-End (E2E) test for the encoder-decoder framework.
|
||||
This test evaluates the encoder-decoder functionality using the BART
|
||||
model. We compare the outputs of the Hugging Face and vLLM
|
||||
implementations to ensure that both implementations produce consistent
|
||||
and correct results.
|
||||
'''
|
||||
test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type]
|
||||
|
||||
# Configuration settings for HF baseline
|
||||
hf_kwargs = {
|
||||
"top_k": None,
|
||||
"num_beams": 1,
|
||||
"repetition_penalty": 1.0,
|
||||
"top_p": 1.0,
|
||||
"length_penalty": 1.0,
|
||||
"early_stopping": False,
|
||||
"no_repeat_ngram_size": None,
|
||||
"min_length": 0
|
||||
}
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||
test_case_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
**hf_kwargs,
|
||||
))
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
enforce_eager=enforce_eager) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||
test_case_prompts, max_tokens, num_logprobs)
|
||||
|
||||
hf_skip_tokens = (1
|
||||
if decoder_prompt_type == DecoderPromptType.NONE else 0)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, decoder_prompt_type)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
num_outputs_0_skip_tokens=hf_skip_tokens,
|
||||
)
|
@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
from array import array
|
||||
from typing import List
|
||||
|
||||
@ -7,13 +8,9 @@ import torch
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
|
||||
SequenceData, SequenceGroupMetadata)
|
||||
from vllm.utils import is_cpu
|
||||
from vllm.utils import is_cpu, make_tensor_with_pad
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
|
||||
# CUDA graph scenarios to test
|
||||
#
|
||||
# Currently CUDA graph is not supported
|
||||
ENFORCE_EAGER = [True]
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
BATCH_SIZES = [1, 4, 16, 64, 256]
|
||||
|
||||
@ -40,8 +37,7 @@ def _create_model_runner(model: str, *args,
|
||||
reason="CPU backend is currently "
|
||||
"unsupported for encoder/ "
|
||||
"decoder models")
|
||||
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
|
||||
def test_empty_seq_group(enforce_eager, ):
|
||||
def test_empty_seq_group():
|
||||
"""Verify prepare prompt and decode returns empty output
|
||||
for empty seq group list"""
|
||||
|
||||
@ -52,7 +48,7 @@ def test_empty_seq_group(enforce_eager, ):
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enforce_eager=enforce_eager,
|
||||
enforce_eager=True,
|
||||
)
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
@ -85,11 +81,7 @@ def test_empty_seq_group(enforce_eager, ):
|
||||
"unsupported for encoder/ "
|
||||
"decoder models")
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
|
||||
def test_prepare_prompt(
|
||||
batch_size,
|
||||
enforce_eager,
|
||||
):
|
||||
def test_prepare_prompt(batch_size):
|
||||
'''
|
||||
Test the ability of the encoder/decoder model runner subclass to
|
||||
produce prefill-phase model inputs & attention metadata.
|
||||
@ -115,7 +107,7 @@ def test_prepare_prompt(
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enforce_eager=enforce_eager,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
seq_lens: List[int] = []
|
||||
@ -281,11 +273,7 @@ def test_prepare_prompt(
|
||||
"unsupported for encoder/ "
|
||||
"decoder models")
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
|
||||
def test_prepare_decode(
|
||||
batch_size,
|
||||
enforce_eager,
|
||||
):
|
||||
def test_prepare_decode(batch_size):
|
||||
'''
|
||||
Test the ability of the encoder/decoder model runner subclass to
|
||||
produce decode-phase model inputs & attention metadata.
|
||||
@ -311,7 +299,7 @@ def test_prepare_decode(
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enforce_eager=enforce_eager,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
seq_lens: List[int] = []
|
||||
@ -428,7 +416,8 @@ def test_prepare_decode(
|
||||
expected,
|
||||
)
|
||||
|
||||
# Cuda graph should is currently not supported for encoder/decoer.
|
||||
# Model runner's CUDAGraph setting should be propagated to attention
|
||||
# metadata.
|
||||
assert attn_metadata.use_cuda_graph is False
|
||||
|
||||
# Verify the lengths of input tokens & positions
|
||||
@ -484,3 +473,152 @@ def test_prepare_decode(
|
||||
dtype=actual.dtype,
|
||||
)
|
||||
assert torch.equal(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_decode_cuda_graph(batch_size):
|
||||
"""
|
||||
Tests that for encoder-decoder models with CUDA Graph capture and replay
|
||||
enabled, the tensors used during the decode phase are correctly padded
|
||||
for varying input batch sizes.
|
||||
"""
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/bart-base",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enforce_eager=False,
|
||||
)
|
||||
|
||||
seq_lens: List[int] = []
|
||||
encoder_seq_lens: List[int] = []
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
block_tables = {0: [1]}
|
||||
cross_block_table = [2]
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
|
||||
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
||||
encoder_seq_lens.append(encoder_seq_len)
|
||||
encoder_seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables=block_tables,
|
||||
encoder_seq_data=encoder_seq_data,
|
||||
cross_block_table=cross_block_table,
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
encoder_input_tokens = model_input.encoder_input_tokens
|
||||
encoder_input_positions = model_input.encoder_input_positions
|
||||
cross_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
|
||||
# With CUDA Graph capture and replay enabled, the decoder and encoder
|
||||
# input sequences will be padded. Create the expected padded tensors
|
||||
# accordingly.
|
||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||
cuda_graph_pad_size = graph_batch_size - batch_size
|
||||
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
|
||||
padded_encoder_seq_lens = encoder_seq_lens + list(
|
||||
itertools.repeat(1, cuda_graph_pad_size))
|
||||
|
||||
assert return_seq_lens == padded_seq_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
assert len(cross_slot_mapping) == len(encoder_input_tokens)
|
||||
|
||||
# Verify attention metadata
|
||||
device = model_runner.device
|
||||
assert attn_metadata.num_prefills == 0
|
||||
assert attn_metadata.num_decode_tokens > 0
|
||||
assert torch.equal(
|
||||
attn_metadata.seq_lens_tensor,
|
||||
torch.tensor(padded_seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.seq_lens == padded_seq_lens
|
||||
assert attn_metadata.max_prefill_seq_len == 0
|
||||
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
||||
# - Encoder attention metadata
|
||||
assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens
|
||||
assert torch.equal(
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens)
|
||||
assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens)
|
||||
|
||||
# Verify block tables are correct for prompts
|
||||
# - Decoder self-attention. Pad the block tables as expected.
|
||||
expected = [block_tables[0] for _ in range(batch_size)]
|
||||
expected.extend([[] for _ in range(cuda_graph_pad_size)])
|
||||
expected = make_tensor_with_pad(
|
||||
expected,
|
||||
max_len=64,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device,
|
||||
)
|
||||
assert torch.equal(
|
||||
attn_metadata.block_tables,
|
||||
expected,
|
||||
)
|
||||
# - Encoder/decoder cross-attention. Pad the cross-attention block tables
|
||||
# as expected.
|
||||
expected = [cross_block_table for _ in range(len(seq_group_metadata_list))]
|
||||
expected.extend([[] for _ in range(cuda_graph_pad_size)])
|
||||
expected = make_tensor_with_pad(
|
||||
expected,
|
||||
max_len=64,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device,
|
||||
)
|
||||
assert torch.equal(
|
||||
attn_metadata.cross_block_tables,
|
||||
expected,
|
||||
)
|
||||
|
||||
# Model runner's CUDAGraph setting should be propagated to attention
|
||||
# metadata.
|
||||
assert attn_metadata.use_cuda_graph is True
|
||||
|
||||
# Verify the lengths of input tokens & positions
|
||||
# - Decoder
|
||||
assert len(input_tokens) == len(padded_seq_lens)
|
||||
assert len(input_positions) == len(padded_seq_lens)
|
||||
# -- An indirect check that model_input.input_tokens
|
||||
# and model_input.input_positions are correct -
|
||||
# by design of the test, the input tokens are
|
||||
# equal to the input position values, so if
|
||||
# the model_input data structure has the correct
|
||||
# values then these two should be equal
|
||||
assert torch.equal(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
)
|
||||
# - Encoder
|
||||
assert len(encoder_input_tokens) == 0
|
||||
assert len(encoder_input_tokens) == 0
|
||||
# -- An indirect check that model_input.encoder_input_tokens
|
||||
# and model_input.encoder_input_positions are correct -
|
||||
# by design of the test, the input tokens are
|
||||
# equal to the input position values, so if
|
||||
# the model_input data structure has the correct
|
||||
# values then these two should be equal
|
||||
assert torch.equal(
|
||||
encoder_input_tokens,
|
||||
encoder_input_positions,
|
||||
)
|
||||
|
@ -156,18 +156,27 @@ class AttentionState(ABC, Generic[T]):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self,
|
||||
batch_size: int,
|
||||
is_encoder_decoder_model: bool = False) -> T:
|
||||
"""Get attention metadata for CUDA graph capture of batch_size."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
|
||||
def get_graph_input_buffers(
|
||||
self,
|
||||
attn_metadata: T,
|
||||
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
||||
"""Get attention-specific input buffers for CUDA graph capture."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
|
||||
attn_metadata: T) -> None:
|
||||
def prepare_graph_input_buffers(
|
||||
self,
|
||||
input_buffers: Dict[str, Any],
|
||||
attn_metadata: T,
|
||||
is_encoder_decoder_model: bool = False) -> None:
|
||||
"""In-place modify input buffers dict for CUDA graph replay."""
|
||||
...
|
||||
|
||||
|
@ -172,7 +172,8 @@ class FlashInferState(AttentionState):
|
||||
state._prefill_wrapper = self._get_prefill_wrapper()
|
||||
return state
|
||||
|
||||
def graph_capture_get_metadata_for_batch(self, batch_size: int):
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
assert self._is_graph_capturing
|
||||
_indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
|
||||
_last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
|
||||
@ -232,12 +233,17 @@ class FlashInferState(AttentionState):
|
||||
attn_metadata.begin_forward()
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(self, attn_metadata):
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
return {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
}
|
||||
|
||||
def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
return
|
||||
|
||||
def begin_forward(self, model_input):
|
||||
|
@ -304,7 +304,8 @@ class CommonAttentionState(AttentionState):
|
||||
assert self._is_graph_capturing
|
||||
return self.__class__(self.runner)
|
||||
|
||||
def graph_capture_get_metadata_for_batch(self, batch_size: int):
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
assert self._is_graph_capturing
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
@ -322,21 +323,121 @@ class CommonAttentionState(AttentionState):
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers backend.
|
||||
# Assert the same.
|
||||
assert self.runner.attn_backend.get_name() == "xformers", \
|
||||
f"Expected attn_backend name to be 'xformers', but "\
|
||||
f" got '{self.runner.attn_backend.get_name()}'"
|
||||
self._update_captured_metadata_for_enc_dec_model(
|
||||
batch_size=batch_size, attn_metadata=attn_metadata)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]:
|
||||
return {
|
||||
def get_graph_input_buffers(
|
||||
self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
|
||||
input_buffers = {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers backend.
|
||||
# Assert the same.
|
||||
assert self.runner.attn_backend.get_name() == "xformers", \
|
||||
f"Expected attn_backend name to be 'xformers', but "\
|
||||
f" got '{self.runner.attn_backend.get_name()}'"
|
||||
self._add_additonal_input_buffers_for_enc_dec_model(
|
||||
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self, input_buffers,
|
||||
attn_metadata) -> None:
|
||||
def prepare_graph_input_buffers(
|
||||
self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False) -> None:
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
# The encoder decoder model works only with XFormers backend.
|
||||
# Assert the same.
|
||||
assert self.runner.attn_backend.get_name() == "xformers", \
|
||||
f"Expected attn_backend name to be 'xformers', but "\
|
||||
f" got '{self.runner.attn_backend.get_name()}'"
|
||||
self._prepare_input_buffers_for_enc_dec_model(
|
||||
attn_metadata, input_buffers)
|
||||
|
||||
def begin_forward(self, model_input) -> None:
|
||||
return
|
||||
|
||||
def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
|
||||
attn_metadata):
|
||||
"""
|
||||
Updates the attention metadata parameters for CUDA graph capture in an
|
||||
encoder-decoder model.
|
||||
|
||||
This method modifies attention-related tensors and metadata required
|
||||
for CUDA graph capture in encoder-decoder models. Specifically, it
|
||||
updates the cross-attention and encoder sequence tensors in the
|
||||
AttentionMetadata object.
|
||||
"""
|
||||
# During decode phase the cross_slot_mapping will be empty. Hence set
|
||||
# an empty tensor for CUDA Graph capture.
|
||||
attn_metadata.cross_slot_mapping = torch.tensor(
|
||||
[], dtype=torch.int).cuda()
|
||||
attn_metadata.cross_block_tables = torch.full(
|
||||
(batch_size, self.runner.get_max_block_per_batch()),
|
||||
1,
|
||||
dtype=torch.int).cuda()
|
||||
attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
|
||||
1,
|
||||
dtype=torch.int).cuda()
|
||||
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
||||
(batch_size, ), 1, dtype=torch.int).cuda()
|
||||
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
|
||||
|
||||
def _add_additonal_input_buffers_for_enc_dec_model(
|
||||
self, attn_metadata, input_buffers: Dict[str, Any]):
|
||||
"""
|
||||
Saves additional input buffers specific to the encoder-decoder model
|
||||
from the attention metadata.
|
||||
|
||||
This method extracts and stores encoder-decoder related input buffers
|
||||
from the `attn_metadata` into the `input_buffers` dictionary. The
|
||||
buffers include encoder sequence lengths, cross-slot mappings, and
|
||||
cross-block tables, which are essential for the encoder-decoder model
|
||||
during CUDA graph replay.
|
||||
"""
|
||||
input_buffers["encoder_seq_lens_tensor"] = (
|
||||
attn_metadata.decode_metadata.encoder_seq_lens_tensor)
|
||||
input_buffers["cross_slot_mapping"] = (
|
||||
attn_metadata.decode_metadata.cross_slot_mapping)
|
||||
input_buffers["cross_block_tables"] = (
|
||||
attn_metadata.decode_metadata.cross_block_tables)
|
||||
|
||||
def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
|
||||
input_buffers: Dict[str,
|
||||
Any]):
|
||||
"""
|
||||
Populates input buffers with data from the encoder-decoder model's
|
||||
attention metadata.
|
||||
|
||||
This method fills the input buffers with encoder-decoder specific
|
||||
tensors. It copies data from the `attn_metadata` and keyword arguments
|
||||
(`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
|
||||
The copied data includes attention-related metadata as well as input
|
||||
IDs and positional information for the encoder.
|
||||
"""
|
||||
input_buffers["encoder_seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.encoder_seq_lens_tensor,
|
||||
non_blocking=True)
|
||||
input_buffers["cross_slot_mapping"].copy_(
|
||||
attn_metadata.decode_metadata.cross_slot_mapping,
|
||||
non_blocking=True)
|
||||
input_buffers["cross_block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.cross_block_tables,
|
||||
non_blocking=True)
|
||||
|
@ -16,9 +16,8 @@ from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
from vllm.transformers_utils.config import (ConfigFormat, get_config,
|
||||
get_hf_image_processor_config,
|
||||
get_hf_text_config)
|
||||
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
|
||||
cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
||||
is_hip, is_neuron, is_openvino, is_xpu,
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
is_cpu, is_hip, is_neuron, is_openvino, is_xpu,
|
||||
print_warning_once)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -96,15 +95,15 @@ class ModelConfig:
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
If None, the user did not specify, so default to False -
|
||||
except for encoder/decoder models, which currently require
|
||||
eager mode.
|
||||
If None, the user did not specify, so default to False.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
|
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode
|
||||
to eager mode. Additionally for encoder-decoder models, if the
|
||||
sequence length of the encoder input is larger than this, we fall
|
||||
back to the eager mode.
|
||||
disable_sliding_window: Whether to disable sliding window. If True,
|
||||
we will disable the sliding window functionality of the model.
|
||||
If the model does not support sliding window, this argument is
|
||||
@ -186,32 +185,8 @@ class ModelConfig:
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
self.use_async_output_proc = use_async_output_proc
|
||||
|
||||
# Choose a default enforce_eager value if the user did not specify
|
||||
# a value (enforce_eager is None)
|
||||
if getattr(self.hf_config, 'is_encoder_decoder', False):
|
||||
if self.enforce_eager is None:
|
||||
# *Only for encoder/decoder models* and
|
||||
# *only if enforce_eager is unset*, override
|
||||
# to enforce_eager=True
|
||||
#
|
||||
# Add a logger message since it is *somewhat* non-intuitive that
|
||||
# enforce_eager is True when the user has not specified its
|
||||
# value.
|
||||
logger.info("Forcing enforce_eager == True because "
|
||||
"enforce_eager setting was unspecified and "
|
||||
"CUDAGraph is not supported with encoder/ "
|
||||
"decoder models.")
|
||||
self.enforce_eager = True
|
||||
|
||||
if not self.enforce_eager:
|
||||
# Eager mode explicitly disabled by user for an encoder/
|
||||
# decoder model; however CUDAGRAPH + encoder/decoder is
|
||||
# not currently supported
|
||||
raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH)
|
||||
elif self.enforce_eager is None:
|
||||
# *Only for decoder-only models*, enforce_eager
|
||||
# defaults to False if unset. This is intuitive
|
||||
# so no logging message needed.
|
||||
# Set enforce_eager to False if the value is unset.
|
||||
if self.enforce_eager is None:
|
||||
self.enforce_eager = False
|
||||
|
||||
if (not self.disable_sliding_window
|
||||
|
@ -472,7 +472,10 @@ class EngineArgs:
|
||||
default=EngineArgs.max_seq_len_to_capture,
|
||||
help='Maximum sequence length covered by CUDA '
|
||||
'graphs. When a sequence has context length '
|
||||
'larger than this, we fall back to eager mode.')
|
||||
'larger than this, we fall back to eager mode. '
|
||||
'Additionally for encoder-decoder models, if the '
|
||||
'sequence length of the encoder input is larger '
|
||||
'than this, we fall back to the eager mode.')
|
||||
parser.add_argument('--disable-custom-all-reduce',
|
||||
action='store_true',
|
||||
default=EngineArgs.disable_custom_all_reduce,
|
||||
|
@ -88,7 +88,9 @@ class LLM:
|
||||
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
|
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode.
|
||||
to eager mode. Additionally for encoder-decoder models, if the
|
||||
sequence length of the encoder input is larger than this, we fall
|
||||
back to the eager mode.
|
||||
disable_custom_all_reduce: See ParallelConfig
|
||||
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
|
||||
:ref:`engine_args`)
|
||||
@ -137,9 +139,7 @@ class LLM:
|
||||
LLM constructor.
|
||||
|
||||
Note: if enforce_eager is unset (enforce_eager is None)
|
||||
it defaults to False for decoder-only models and True
|
||||
for encoder/decoder models, since encoder/decoder models
|
||||
do not currently support CUDAGraph.
|
||||
it defaults to False.
|
||||
'''
|
||||
|
||||
if "disable_log_stats" not in kwargs:
|
||||
|
@ -848,11 +848,13 @@ class BartForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
*,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
|
@ -71,10 +71,6 @@ STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
@ -98,7 +94,6 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
|
||||
"STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
|
||||
"STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
|
||||
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
|
||||
"STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
|
||||
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
|
||||
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
|
||||
"STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU
|
||||
|
@ -1,4 +1,5 @@
|
||||
import dataclasses
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, cast
|
||||
|
||||
import torch
|
||||
@ -24,7 +25,8 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
ModelInputForGPUWithSamplingMetadata,
|
||||
_get_graph_batch_size)
|
||||
from vllm.worker.model_runner_base import (
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict)
|
||||
@ -178,7 +180,15 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
raise ValueError("num_steps > 1 is not supported in "
|
||||
"EncoderDecoderModelRunner")
|
||||
|
||||
model_executable = self.model
|
||||
if (model_input.attn_metadata is not None
|
||||
and model_input.attn_metadata.prefill_metadata is None
|
||||
and model_input.attn_metadata.decode_metadata.use_cuda_graph):
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[
|
||||
model_input.virtual_engine][graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
@ -200,6 +210,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.model.sample(
|
||||
logits=logits,
|
||||
@ -231,14 +244,12 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
"""
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
|
||||
(
|
||||
attn_metadata,
|
||||
encoder_input_tokens_tensor,
|
||||
encoder_input_positions_tensor,
|
||||
) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
|
||||
model_input))
|
||||
|
||||
# Inject attn_metadata encoder/cross-attention fields &
|
||||
# encoder input tokens/positions into model_input.
|
||||
# Frozen dataclass fields cannot be modified, so use
|
||||
@ -437,11 +448,29 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
cross_block_tables.append([] if (
|
||||
cross_block_table is None) else cross_block_table)
|
||||
|
||||
# Convert cross-attention block tables to encoder input tensor
|
||||
if (model_input.attn_metadata is not None
|
||||
and model_input.attn_metadata.use_cuda_graph):
|
||||
# We will be using CUDA graph replay for this decode.
|
||||
max_len_of_block_table = self.get_max_block_per_batch()
|
||||
batch_size = len(encoder_seq_lens)
|
||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||
assert graph_batch_size >= batch_size
|
||||
cuda_graph_pad_size = graph_batch_size - batch_size
|
||||
# extend the cross_block_tables and encoder_seq_lens to match
|
||||
# the graph_batch_size.
|
||||
cross_block_tables.extend([[]
|
||||
for _ in range(cuda_graph_pad_size)
|
||||
])
|
||||
encoder_seq_lens.extend(
|
||||
itertools.repeat(1, cuda_graph_pad_size))
|
||||
|
||||
else:
|
||||
max_len_of_block_table = max(
|
||||
len(block_table) for block_table in cross_block_tables)
|
||||
|
||||
cross_block_tables = make_tensor_with_pad(
|
||||
cross_block_tables,
|
||||
max_len=max(
|
||||
len(block_table) for block_table in cross_block_tables),
|
||||
max_len=max_len_of_block_table,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
|
@ -243,6 +243,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
prefix_cache_hit: bool = False,
|
||||
reinit: bool = False,
|
||||
reinit_use_defaults: bool = False,
|
||||
encoder_seq_len: int = 0,
|
||||
):
|
||||
if reinit:
|
||||
assert len(self.seq_ids) == len(seq_ids) # type: ignore
|
||||
@ -256,6 +257,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.block_tables = block_tables
|
||||
self.computed_block_nums = computed_block_nums
|
||||
self.n_seqs = n_seqs
|
||||
self.encoder_seq_len = encoder_seq_len
|
||||
|
||||
if reinit:
|
||||
if len(self.seq_ids) == 1 and reinit_use_defaults:
|
||||
@ -702,6 +704,11 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
assert n_seqs == 1
|
||||
self.decode_only = False
|
||||
|
||||
encoder_seq_len = 0
|
||||
|
||||
if self.runner.model_config.is_encoder_decoder_model:
|
||||
encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len()
|
||||
|
||||
inter_data = self.init_cached_inter_data(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
seq_ids=seq_ids,
|
||||
@ -709,7 +716,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
block_tables=seq_group_metadata.block_tables,
|
||||
computed_block_nums=seq_group_metadata.computed_block_nums,
|
||||
reinit=True,
|
||||
reinit_use_defaults=True)
|
||||
reinit_use_defaults=True,
|
||||
encoder_seq_len=encoder_seq_len)
|
||||
|
||||
self.inter_data_list.append(inter_data)
|
||||
|
||||
@ -719,11 +727,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
for per_seq_group_fn in self.per_seq_group_compute_fns:
|
||||
per_seq_group_fn(inter_data, seq_group_metadata)
|
||||
|
||||
def _use_captured_graph(self, batch_size: int,
|
||||
max_decode_seq_len: int) -> bool:
|
||||
def _use_captured_graph(self,
|
||||
batch_size: int,
|
||||
max_decode_seq_len: int,
|
||||
max_encoder_seq_len: int = 0) -> bool:
|
||||
return (self.decode_only and not self.runner.model_config.enforce_eager
|
||||
and batch_size <= self.runner.max_batchsize_to_capture
|
||||
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
|
||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
and max_decode_seq_len <= self.runner.max_seq_len_to_capture
|
||||
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
|
||||
and batch_size <= self.runner.max_batchsize_to_capture)
|
||||
|
||||
def build(self) -> ModelInputForGPU:
|
||||
"""Finalize the builder intermediate data and
|
||||
@ -763,15 +775,18 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
input_positions.extend(cur_input_positions)
|
||||
|
||||
seq_lens = []
|
||||
query_lens = []
|
||||
max_decode_seq_len = 0
|
||||
max_encoder_seq_len = 0
|
||||
for inter_data in self.inter_data_list:
|
||||
seq_lens.extend(inter_data.seq_lens)
|
||||
query_lens.extend(inter_data.query_lens)
|
||||
if not inter_data.is_prompt:
|
||||
max_decode_seq_len = max(max_decode_seq_len,
|
||||
max(inter_data.seq_lens))
|
||||
query_lens = []
|
||||
for inter_data in self.inter_data_list:
|
||||
query_lens.extend(inter_data.query_lens)
|
||||
if self.runner.model_config.is_encoder_decoder_model:
|
||||
max_encoder_seq_len = max(max_encoder_seq_len,
|
||||
inter_data.encoder_seq_len)
|
||||
|
||||
# Mapping from request IDs to sequence IDs. Used for Jamba models
|
||||
# that manages the cache by itself.
|
||||
@ -781,8 +796,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
}
|
||||
|
||||
batch_size = len(input_tokens)
|
||||
use_captured_graph = self._use_captured_graph(batch_size,
|
||||
max_decode_seq_len)
|
||||
use_captured_graph = self._use_captured_graph(
|
||||
batch_size,
|
||||
max_decode_seq_len,
|
||||
max_encoder_seq_len=max_encoder_seq_len)
|
||||
|
||||
# If cuda graph can be used, pad tensors accordingly.
|
||||
# See `capture_model` API for more details.
|
||||
@ -1364,7 +1381,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
attn_metadata = (
|
||||
self.attn_state.graph_capture_get_metadata_for_batch(
|
||||
batch_size))
|
||||
batch_size,
|
||||
is_encoder_decoder_model=self.model_config.
|
||||
is_encoder_decoder_model))
|
||||
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
@ -1380,10 +1399,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
)
|
||||
self.set_active_prompt_adapters(
|
||||
set(), prompt_adapter_mapping)
|
||||
|
||||
graph_runner = CUDAGraphRunner(
|
||||
self.model, self.attn_backend.get_name(),
|
||||
self.attn_state.graph_clone(batch_size))
|
||||
self.attn_state.graph_clone(batch_size),
|
||||
self.model_config.is_encoder_decoder_model)
|
||||
|
||||
capture_inputs = {
|
||||
"input_ids":
|
||||
@ -1420,6 +1439,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.model.get_seqlen_agnostic_capture_inputs(
|
||||
batch_size)
|
||||
})
|
||||
if self.model_config.is_encoder_decoder_model:
|
||||
# add the additional inputs to capture for
|
||||
# encoder-decoder models.
|
||||
self._update_inputs_to_capture_for_enc_dec_model(
|
||||
capture_inputs)
|
||||
|
||||
graph_runner.capture(**capture_inputs)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[virtual_engine][batch_size] = (
|
||||
@ -1430,6 +1455,24 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
# This usually takes < 10 seconds.
|
||||
logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
|
||||
|
||||
def _update_inputs_to_capture_for_enc_dec_model(self,
|
||||
capture_inputs: Dict[str,
|
||||
Any]):
|
||||
"""
|
||||
Updates the set of input tensors needed for CUDA graph capture in an
|
||||
encoder-decoder model.
|
||||
|
||||
This method modifies the provided `capture_inputs` dictionary by
|
||||
adding tensors specific to encoder-decoder specific models that
|
||||
need to be captured for CUDA Graph replay.
|
||||
"""
|
||||
# During the decode phase encoder_input_ids and encoder_positions are
|
||||
# unset. Do the same thing for graph capture.
|
||||
capture_inputs["encoder_input_ids"] = torch.tensor(
|
||||
[], dtype=torch.long).cuda()
|
||||
capture_inputs["encoder_positions"] = torch.tensor(
|
||||
[], dtype=torch.long).cuda()
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_config.get_vocab_size()
|
||||
@ -1629,7 +1672,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
class CUDAGraphRunner:
|
||||
|
||||
def __init__(self, model: nn.Module, backend_name: str,
|
||||
attn_state: AttentionState):
|
||||
attn_state: AttentionState, is_encoder_decoder_model: bool):
|
||||
self.model = model
|
||||
self.backend_name = backend_name
|
||||
self.attn_state = attn_state
|
||||
@ -1638,6 +1681,7 @@ class CUDAGraphRunner:
|
||||
self.output_buffers: Dict[str, torch.Tensor] = {}
|
||||
|
||||
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
||||
self._is_encoder_decoder_model = is_encoder_decoder_model
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
@ -1671,8 +1715,9 @@ class CUDAGraphRunner:
|
||||
intermediate_tensors=intermediate_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
# Wait for the warm up operations to finish before proceeding with
|
||||
# Graph Capture.
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture the graph.
|
||||
self._graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
||||
@ -1704,10 +1749,14 @@ class CUDAGraphRunner:
|
||||
|
||||
# Save the input and output buffers.
|
||||
self.input_buffers = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
**self.attn_state.get_graph_input_buffers(attn_metadata),
|
||||
"input_ids":
|
||||
input_ids,
|
||||
"positions":
|
||||
positions,
|
||||
"kv_caches":
|
||||
kv_caches,
|
||||
**self.attn_state.get_graph_input_buffers(
|
||||
attn_metadata, self._is_encoder_decoder_model),
|
||||
**kwargs,
|
||||
}
|
||||
if intermediate_inputs is not None:
|
||||
@ -1737,8 +1786,8 @@ class CUDAGraphRunner:
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
|
||||
non_blocking=True)
|
||||
self.attn_state.prepare_graph_input_buffers(self.input_buffers,
|
||||
attn_metadata)
|
||||
self.attn_state.prepare_graph_input_buffers(
|
||||
self.input_buffers, attn_metadata, self._is_encoder_decoder_model)
|
||||
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
||||
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
|
||||
**kwargs)
|
||||
@ -1752,6 +1801,12 @@ class CUDAGraphRunner:
|
||||
if key != "model_execute_time" and key != "model_forward_time":
|
||||
self.input_buffers[key].copy_(intermediate_tensors[key],
|
||||
non_blocking=True)
|
||||
if self._is_encoder_decoder_model:
|
||||
self.input_buffers["encoder_input_ids"].copy_(
|
||||
kwargs['encoder_input_ids'], non_blocking=True)
|
||||
self.input_buffers["encoder_positions"].copy_(
|
||||
kwargs['encoder_positions'], non_blocking=True)
|
||||
|
||||
# Run the graph.
|
||||
self.graph.replay()
|
||||
# Return the output tensor.
|
||||
|
@ -47,10 +47,6 @@ def assert_enc_dec_mr_supported_scenario(
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
|
||||
|
||||
if not enc_dec_mr.model_config.enforce_eager:
|
||||
raise NotImplementedError(
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH'])
|
||||
|
||||
if enc_dec_mr.prompt_adapter_config is not None:
|
||||
raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
|
||||
'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'])
|
||||
|
Loading…
x
Reference in New Issue
Block a user