648 lines
24 KiB
Python
648 lines
24 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import itertools
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.engine.arg_utils import EngineArgs
|
|
from vllm.platforms import current_platform
|
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
|
from vllm.utils import make_tensor_with_pad
|
|
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
|
|
|
BATCH_SIZES = [1, 4, 16, 64, 256]
|
|
|
|
|
|
def _create_model_runner(model: str, *args,
|
|
**kwargs) -> EncoderDecoderModelRunner:
|
|
engine_args = EngineArgs(model, *args, **kwargs)
|
|
engine_config = engine_args.create_engine_config()
|
|
model_runner = EncoderDecoderModelRunner(
|
|
vllm_config=engine_config,
|
|
is_driver_worker=True,
|
|
)
|
|
return model_runner
|
|
|
|
|
|
@pytest.mark.skipif(condition=current_platform.is_cpu(),
|
|
reason="CPU backend is currently "
|
|
"unsupported for encoder/ "
|
|
"decoder models")
|
|
def test_empty_seq_group():
|
|
"""Verify prepare prompt and decode returns empty output
|
|
for empty seq group list"""
|
|
|
|
model_runner = _create_model_runner(
|
|
"facebook/bart-base",
|
|
seed=0,
|
|
dtype="float16",
|
|
max_num_batched_tokens=100000,
|
|
max_num_seqs=100000,
|
|
enable_chunked_prefill=False,
|
|
enforce_eager=True,
|
|
)
|
|
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
|
model_input = model_runner._prepare_model_input_tensors(
|
|
seq_group_metadata_list)
|
|
(
|
|
input_tokens,
|
|
input_positions,
|
|
encoder_input_tokens,
|
|
encoder_input_positions,
|
|
attn_metadata,
|
|
return_seq_lens,
|
|
) = (
|
|
model_input.input_tokens,
|
|
model_input.input_positions,
|
|
model_input.encoder_input_tokens,
|
|
model_input.encoder_input_positions,
|
|
model_input.attn_metadata,
|
|
model_input.seq_lens,
|
|
)
|
|
assert input_tokens is None
|
|
assert input_positions is None
|
|
assert encoder_input_tokens is None
|
|
assert encoder_input_positions is None
|
|
assert attn_metadata is None
|
|
assert return_seq_lens is None
|
|
|
|
|
|
@pytest.mark.skipif(condition=current_platform.is_cpu(),
|
|
reason="CPU backend is currently "
|
|
"unsupported for encoder/ "
|
|
"decoder models")
|
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
|
def test_prepare_prompt(batch_size):
|
|
'''
|
|
Test the ability of the encoder/decoder model runner subclass to
|
|
produce prefill-phase model inputs & attention metadata.
|
|
|
|
Test behavior:
|
|
|
|
* Instantiate BART base model & enc/dec model runner
|
|
* Construct sequence-group metadata for dummy prompts
|
|
* Test that encoder attention, decoder self-attention,
|
|
and encoder/decoder cross-attention inputs are correct
|
|
|
|
Arguments:
|
|
|
|
* batch_size
|
|
* backend_name: The attention backend under test
|
|
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
|
|
'''
|
|
|
|
model_runner = _create_model_runner(
|
|
"facebook/bart-base",
|
|
seed=0,
|
|
dtype="float16",
|
|
max_num_batched_tokens=100000,
|
|
max_num_seqs=100000,
|
|
enable_chunked_prefill=False,
|
|
enforce_eager=True,
|
|
)
|
|
|
|
seq_lens: list[int] = []
|
|
encoder_seq_lens: list[int] = []
|
|
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
|
block_tables = {0: [1]}
|
|
cross_block_table = [2]
|
|
for i in range(batch_size):
|
|
# make sure all tokens fit into one block
|
|
seq_len = i % (model_runner.block_size - 1) + 1
|
|
seq_lens.append(seq_len)
|
|
seq_data = SequenceData.from_seqs(range(seq_len))
|
|
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
|
encoder_seq_lens.append(encoder_seq_len)
|
|
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
|
|
seq_group_metadata = SequenceGroupMetadata(
|
|
request_id=f"test_{i}",
|
|
is_prompt=True,
|
|
seq_data={0: seq_data},
|
|
sampling_params=SamplingParams(temperature=0),
|
|
block_tables=block_tables,
|
|
encoder_seq_data=encoder_seq_data,
|
|
cross_block_table=cross_block_table,
|
|
)
|
|
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
|
|
seq_group_metadata_list.append(seq_group_metadata)
|
|
|
|
# Build
|
|
# * Decoder model inputs
|
|
# * Decoder self-attention KV caching data structures
|
|
# * Encoder model inputs
|
|
# * Encoder/decoder cross-attention KV caching data structures
|
|
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
|
|
|
input_tokens = model_input.input_tokens
|
|
input_positions = model_input.input_positions
|
|
attn_metadata = model_input.attn_metadata
|
|
return_seq_lens = model_input.seq_lens
|
|
slot_mapping = attn_metadata.slot_mapping
|
|
encoder_input_tokens = model_input.encoder_input_tokens
|
|
encoder_input_positions = model_input.encoder_input_positions
|
|
cross_slot_mapping = attn_metadata.cross_slot_mapping
|
|
assert return_seq_lens == seq_lens
|
|
assert len(slot_mapping) == len(input_tokens)
|
|
assert len(cross_slot_mapping) == len(encoder_input_tokens)
|
|
|
|
# Verify input metadata is correct for prompts.
|
|
# - Decoder attention metadata
|
|
device = model_runner.device
|
|
assert attn_metadata.num_prefills > 0
|
|
assert attn_metadata.num_decode_tokens == 0
|
|
assert torch.equal(attn_metadata.seq_lens_tensor,
|
|
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
|
assert attn_metadata.seq_lens == seq_lens
|
|
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
|
|
assert attn_metadata.max_decode_seq_len == 0
|
|
# - Encoder attention metadata
|
|
assert attn_metadata.encoder_seq_lens == encoder_seq_lens
|
|
assert torch.equal(
|
|
attn_metadata.encoder_seq_lens_tensor,
|
|
torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
|
|
assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
|
|
assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
|
|
|
|
# Test decoder subquery start locs.
|
|
start_idx = 0
|
|
start_loc = [start_idx]
|
|
for seq_len in seq_lens:
|
|
start_idx += seq_len
|
|
start_loc.append(start_idx)
|
|
assert torch.equal(
|
|
attn_metadata.query_start_loc,
|
|
torch.tensor(start_loc, dtype=torch.int32, device=device),
|
|
)
|
|
|
|
# Test decoder seq start locs & context lengths
|
|
|
|
assert torch.equal(
|
|
attn_metadata.seq_start_loc,
|
|
torch.tensor(start_loc, dtype=torch.int32, device=device),
|
|
)
|
|
assert torch.equal(
|
|
attn_metadata.context_lens_tensor,
|
|
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
|
|
dtype=torch.int,
|
|
device=device),
|
|
)
|
|
|
|
# Verify block tables are correct for prompts
|
|
# - Decoder self-attention
|
|
expected = torch.tensor(
|
|
[[] for _ in range(len(seq_group_metadata_list))],
|
|
dtype=torch.int32,
|
|
device=model_runner.device,
|
|
)
|
|
assert torch.equal(
|
|
attn_metadata.block_tables,
|
|
expected,
|
|
)
|
|
# - Encoder/decoder cross-attention
|
|
assert torch.equal(
|
|
attn_metadata.cross_block_tables,
|
|
expected,
|
|
)
|
|
|
|
# Cuda graph should not be used for prefill.
|
|
assert attn_metadata.use_cuda_graph is False
|
|
|
|
# Verify the lengths of input tokens & positions
|
|
# - Decoder
|
|
assert len(input_tokens) == sum(seq_lens)
|
|
assert len(input_positions) == sum(seq_lens)
|
|
# -- An indirect check that model_input.input_tokens
|
|
# and model_input.input_positions are correct -
|
|
# by design of the test, the input tokens are
|
|
# equal to the input position values, so if
|
|
# the model_input data structure has the correct
|
|
# values then these two should be equal
|
|
assert torch.equal(
|
|
input_tokens,
|
|
input_positions,
|
|
)
|
|
# - Encoder
|
|
assert len(encoder_input_tokens) == sum(encoder_seq_lens)
|
|
# -- An indirect check that model_input.encoder_input_tokens
|
|
# and model_input.encoder_input_positions are correct -
|
|
# by design of the test, the input tokens are
|
|
# equal to the input position values, so if
|
|
# the model_input data structure has the correct
|
|
# values then these two should be equal
|
|
assert torch.equal(
|
|
encoder_input_tokens,
|
|
encoder_input_positions,
|
|
)
|
|
|
|
# Test that vLLM sampling infrastructure chooses the correct
|
|
# sequence positions at which to sample (i.e. the end of
|
|
# each sequence) in the prefill phase
|
|
|
|
expected_selected_token_indices = []
|
|
selected_token_start_idx = 0
|
|
for seq_len in seq_lens:
|
|
# Compute the index offset of the final token in each
|
|
# prompt (recall that the prompts are concatenated)
|
|
expected_selected_token_indices.append(selected_token_start_idx +
|
|
seq_len - 1)
|
|
selected_token_start_idx += seq_len
|
|
|
|
sampling_metadata = model_input.sampling_metadata
|
|
actual = sampling_metadata.selected_token_indices
|
|
expected = torch.tensor(
|
|
expected_selected_token_indices,
|
|
device=actual.device,
|
|
dtype=actual.dtype,
|
|
)
|
|
assert torch.equal(actual, expected)
|
|
|
|
|
|
@pytest.mark.skipif(condition=current_platform.is_cpu(),
|
|
reason="CPU backend is currently "
|
|
"unsupported for encoder/ "
|
|
"decoder models")
|
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
|
@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
|
|
def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
|
|
'''
|
|
Test the ability of the encoder/decoder model runner subclass to
|
|
produce decode-phase model inputs & attention metadata.
|
|
|
|
Test behavior:
|
|
|
|
* Instantiate BART base model & enc/dec model runner
|
|
* Construct sequence-group metadata for dummy prompts
|
|
* Test that encoder attention, decoder self-attention,
|
|
and encoder/decoder cross-attention inputs are correct
|
|
|
|
Arguments:
|
|
|
|
* batch_size
|
|
* multiple_seqs_per_seq_group
|
|
* backend_name: The attention backend under test
|
|
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
|
|
'''
|
|
|
|
model_runner = _create_model_runner(
|
|
"facebook/bart-base",
|
|
seed=0,
|
|
dtype="float16",
|
|
max_num_batched_tokens=100000,
|
|
max_num_seqs=100000,
|
|
enable_chunked_prefill=False,
|
|
enforce_eager=True,
|
|
)
|
|
|
|
seq_lens: list[int] = []
|
|
encoder_seq_lens: list[int] = []
|
|
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
|
block_tables = {
|
|
0: [1],
|
|
1: [3]
|
|
} if multiple_seqs_per_seq_group else {
|
|
0: [1]
|
|
}
|
|
cross_block_table = [2]
|
|
for i in range(batch_size):
|
|
# make sure all tokens fit into one block
|
|
seq_len = i % (model_runner.block_size - 1) + 1
|
|
seq_data = SequenceData.from_seqs(range(seq_len))
|
|
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
|
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
|
|
|
|
seq_group_metadata = SequenceGroupMetadata(
|
|
request_id=f"test_{i}",
|
|
is_prompt=False,
|
|
seq_data={
|
|
0: seq_data,
|
|
1: seq_data
|
|
} if multiple_seqs_per_seq_group else {0: seq_data},
|
|
sampling_params=SamplingParams(temperature=0),
|
|
block_tables=block_tables,
|
|
encoder_seq_data=encoder_seq_data,
|
|
cross_block_table=cross_block_table,
|
|
)
|
|
assert seq_group_metadata.token_chunk_size == 1
|
|
seq_group_metadata_list.append(seq_group_metadata)
|
|
seq_lens.extend(
|
|
[seq_len for _ in range(len(seq_group_metadata.seq_data))])
|
|
encoder_seq_lens.extend(
|
|
[encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
|
|
|
|
# Build
|
|
# * Decoder model inputs
|
|
# * Decoder self-attention KV caching data structures
|
|
# * Encoder model inputs
|
|
# * Encoder/decoder cross-attention KV caching data structures
|
|
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
|
input_tokens = model_input.input_tokens
|
|
input_positions = model_input.input_positions
|
|
attn_metadata = model_input.attn_metadata
|
|
return_seq_lens = model_input.seq_lens
|
|
slot_mapping = attn_metadata.slot_mapping
|
|
encoder_input_tokens = model_input.encoder_input_tokens
|
|
encoder_input_positions = model_input.encoder_input_positions
|
|
cross_slot_mapping = attn_metadata.cross_slot_mapping
|
|
assert return_seq_lens == seq_lens
|
|
assert len(slot_mapping) == len(input_tokens)
|
|
assert len(cross_slot_mapping) == len(encoder_input_tokens)
|
|
|
|
# Verify input metadata is correct for decode phase.
|
|
# - Decoder attention metadata
|
|
device = model_runner.device
|
|
assert attn_metadata.num_prefills == 0
|
|
assert attn_metadata.num_decode_tokens > 0
|
|
assert torch.equal(attn_metadata.seq_lens_tensor,
|
|
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
|
assert attn_metadata.seq_lens == seq_lens
|
|
assert attn_metadata.max_prefill_seq_len == 0
|
|
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
|
# - Encoder attention metadata
|
|
assert attn_metadata.encoder_seq_lens == encoder_seq_lens
|
|
assert torch.equal(
|
|
attn_metadata.encoder_seq_lens_tensor,
|
|
torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
|
|
assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
|
|
assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
|
|
|
|
# Test decoder subquery start locs.
|
|
start_idx = 0
|
|
start_loc = [start_idx]
|
|
for seq_len in seq_lens:
|
|
start_idx += 1
|
|
start_loc.append(start_idx)
|
|
assert torch.equal(
|
|
attn_metadata.query_start_loc,
|
|
torch.tensor(start_loc, dtype=torch.int32, device=device),
|
|
)
|
|
|
|
# Test decoder seq start locs. Note that for normal prefill it is
|
|
# equivalent to query_start_loc.
|
|
start_idx = 0
|
|
seq_start_loc = [start_idx]
|
|
for seq_len in seq_lens:
|
|
start_idx += seq_len
|
|
seq_start_loc.append(start_idx)
|
|
|
|
# Test seq_start_loc and context lengths
|
|
|
|
assert torch.equal(
|
|
attn_metadata.seq_start_loc,
|
|
torch.tensor(seq_start_loc, dtype=torch.int32, device=device),
|
|
)
|
|
assert torch.equal(
|
|
attn_metadata.context_lens_tensor,
|
|
torch.tensor([seq_len - 1 for seq_len in seq_lens],
|
|
dtype=torch.int,
|
|
device=device))
|
|
|
|
# Verify block tables are correct for prompts
|
|
# - Decoder self-attention
|
|
flattened_block_tables = [
|
|
block_table for block_table in block_tables.values()
|
|
]
|
|
expected = torch.tensor(flattened_block_tables *
|
|
len(seq_group_metadata_list),
|
|
dtype=torch.int32,
|
|
device=model_runner.device)
|
|
assert torch.equal(
|
|
attn_metadata.block_tables,
|
|
expected,
|
|
)
|
|
# - Encoder/decoder cross-attention
|
|
expected = torch.tensor([
|
|
cross_block_table for seq_group_metadata in seq_group_metadata_list
|
|
for _ in range(len(seq_group_metadata.seq_data))
|
|
],
|
|
dtype=torch.int32,
|
|
device=model_runner.device)
|
|
assert torch.equal(
|
|
attn_metadata.cross_block_tables,
|
|
expected,
|
|
)
|
|
|
|
# Model runner's CUDAGraph setting should be propagated to attention
|
|
# metadata.
|
|
assert attn_metadata.use_cuda_graph is False
|
|
|
|
# Verify the lengths of input tokens & positions
|
|
# - Decoder
|
|
assert len(input_tokens) == len(seq_lens)
|
|
assert len(input_positions) == len(seq_lens)
|
|
# -- An indirect check that model_input.input_tokens
|
|
# and model_input.input_positions are correct -
|
|
# by design of the test, the input tokens are
|
|
# equal to the input position values, so if
|
|
# the model_input data structure has the correct
|
|
# values then these two should be equal
|
|
assert torch.equal(
|
|
input_tokens,
|
|
input_positions,
|
|
)
|
|
# - Encoder
|
|
assert len(encoder_input_tokens) == 0
|
|
assert len(encoder_input_tokens) == 0
|
|
# -- An indirect check that model_input.encoder_input_tokens
|
|
# and model_input.encoder_input_positions are correct -
|
|
# by design of the test, the input tokens are
|
|
# equal to the input position values, so if
|
|
# the model_input data structure has the correct
|
|
# values then these two should be equal
|
|
assert torch.equal(
|
|
encoder_input_tokens,
|
|
encoder_input_positions,
|
|
)
|
|
|
|
# Test that vLLM sampling infrastructure chooses the correct
|
|
# sequence positions at which to sample (i.e. the end of
|
|
# each sequence) in the decode phase
|
|
|
|
expected_selected_token_indices = []
|
|
for selected_token_start_idx, seq_len in enumerate(seq_lens):
|
|
# Compute the index offset of the final token in each
|
|
# sequence's decoded outputs; since a single token is
|
|
# decoded per iteration per sequence, then the length
|
|
# of the decoded tokens for a given sequence is 1 and
|
|
# the final index offset into a given sequence's
|
|
# generated tokens is 0 (i.e. the expected sampling index
|
|
# for a given sequence is just `selected_token_start_idx`)
|
|
expected_selected_token_indices.append(selected_token_start_idx)
|
|
|
|
sampling_metadata = model_input.sampling_metadata
|
|
actual = sampling_metadata.selected_token_indices
|
|
expected = torch.tensor(
|
|
expected_selected_token_indices,
|
|
device=actual.device,
|
|
dtype=actual.dtype,
|
|
)
|
|
assert torch.equal(actual, expected)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
|
@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
|
|
def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
|
|
"""
|
|
Tests that for encoder-decoder models with CUDA Graph capture and replay
|
|
enabled, the tensors used during the decode phase are correctly padded
|
|
for varying input batch sizes.
|
|
"""
|
|
model_runner = _create_model_runner(
|
|
"facebook/bart-base",
|
|
seed=0,
|
|
dtype="float16",
|
|
max_num_batched_tokens=100000,
|
|
max_num_seqs=100000,
|
|
enable_chunked_prefill=False,
|
|
enforce_eager=False,
|
|
)
|
|
block_tables = {
|
|
0: [1],
|
|
1: [3]
|
|
} if multiple_seqs_per_seq_group else {
|
|
0: [1]
|
|
}
|
|
seq_lens: list[int] = []
|
|
encoder_seq_lens: list[int] = []
|
|
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
|
|
|
cross_block_table = [2]
|
|
expanded_batch_size = 0
|
|
for i in range(batch_size):
|
|
# make sure all tokens fit into one block
|
|
seq_len = i % (model_runner.block_size - 1) + 1
|
|
seq_data = SequenceData.from_seqs(range(seq_len))
|
|
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
|
|
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
|
|
seq_group_metadata = SequenceGroupMetadata(
|
|
request_id=f"test_{i}",
|
|
is_prompt=False,
|
|
seq_data={
|
|
0: seq_data,
|
|
1: seq_data
|
|
} if multiple_seqs_per_seq_group else {0: seq_data},
|
|
sampling_params=SamplingParams(temperature=0),
|
|
block_tables=block_tables,
|
|
encoder_seq_data=encoder_seq_data,
|
|
cross_block_table=cross_block_table,
|
|
)
|
|
assert seq_group_metadata.token_chunk_size == 1
|
|
seq_lens.extend(
|
|
[seq_len for _ in range(len(seq_group_metadata.seq_data))])
|
|
encoder_seq_lens.extend(
|
|
[encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
|
|
expanded_batch_size = expanded_batch_size + len(
|
|
seq_group_metadata.seq_data)
|
|
seq_group_metadata_list.append(seq_group_metadata)
|
|
|
|
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
|
input_tokens = model_input.input_tokens
|
|
input_positions = model_input.input_positions
|
|
attn_metadata = model_input.attn_metadata
|
|
return_seq_lens = model_input.seq_lens
|
|
slot_mapping = attn_metadata.slot_mapping
|
|
encoder_input_tokens = model_input.encoder_input_tokens
|
|
encoder_input_positions = model_input.encoder_input_positions
|
|
cross_slot_mapping = attn_metadata.cross_slot_mapping
|
|
|
|
# With CUDA Graph capture and replay enabled, the decoder and encoder
|
|
# input sequences will be padded. Create the expected padded tensors
|
|
# accordingly.
|
|
graph_batch_size = model_runner.vllm_config.pad_for_cudagraph(
|
|
expanded_batch_size)
|
|
cuda_graph_pad_size = graph_batch_size - expanded_batch_size
|
|
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
|
|
padded_encoder_seq_lens = encoder_seq_lens + list(
|
|
itertools.repeat(1, cuda_graph_pad_size))
|
|
|
|
assert return_seq_lens == padded_seq_lens
|
|
assert len(slot_mapping) == len(input_tokens)
|
|
assert len(cross_slot_mapping) == len(encoder_input_tokens)
|
|
|
|
# Verify attention metadata
|
|
device = model_runner.device
|
|
assert attn_metadata.num_prefills == 0
|
|
assert attn_metadata.num_decode_tokens > 0
|
|
assert torch.equal(
|
|
attn_metadata.seq_lens_tensor,
|
|
torch.tensor(padded_seq_lens, device=device, dtype=torch.int))
|
|
assert attn_metadata.seq_lens == padded_seq_lens
|
|
assert attn_metadata.max_prefill_seq_len == 0
|
|
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
|
# - Encoder attention metadata
|
|
assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens
|
|
assert torch.equal(
|
|
attn_metadata.encoder_seq_lens_tensor,
|
|
torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int))
|
|
assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens)
|
|
assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens)
|
|
|
|
# Verify block tables are correct for prompts
|
|
# - Decoder self-attention. Pad the block tables as expected.
|
|
flattened_block_tables = [
|
|
block_table for _ in range(len(seq_group_metadata_list))
|
|
for block_table in block_tables.values()
|
|
]
|
|
flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)])
|
|
expected = make_tensor_with_pad(
|
|
flattened_block_tables,
|
|
max_len=64,
|
|
pad=0,
|
|
dtype=torch.int32,
|
|
device=model_runner.device,
|
|
)
|
|
assert torch.equal(
|
|
attn_metadata.block_tables,
|
|
expected,
|
|
)
|
|
# - Encoder/decoder cross-attention. Pad the cross-attention block tables
|
|
# as expected.
|
|
expected = [
|
|
cross_block_table for seq_group_metadata in seq_group_metadata_list
|
|
for _ in range(len(seq_group_metadata.seq_data))
|
|
]
|
|
expected.extend([[] for _ in range(cuda_graph_pad_size)])
|
|
expected = make_tensor_with_pad(
|
|
expected,
|
|
max_len=64,
|
|
pad=0,
|
|
dtype=torch.int32,
|
|
device=model_runner.device,
|
|
)
|
|
assert torch.equal(
|
|
attn_metadata.cross_block_tables,
|
|
expected,
|
|
)
|
|
|
|
# Model runner's CUDAGraph setting should be propagated to attention
|
|
# metadata.
|
|
assert attn_metadata.use_cuda_graph is True
|
|
|
|
# Verify the lengths of input tokens & positions
|
|
# - Decoder
|
|
assert len(input_tokens) == len(padded_seq_lens)
|
|
assert len(input_positions) == len(padded_seq_lens)
|
|
# -- An indirect check that model_input.input_tokens
|
|
# and model_input.input_positions are correct -
|
|
# by design of the test, the input tokens are
|
|
# equal to the input position values, so if
|
|
# the model_input data structure has the correct
|
|
# values then these two should be equal
|
|
assert torch.equal(
|
|
input_tokens,
|
|
input_positions,
|
|
)
|
|
# - Encoder
|
|
assert len(encoder_input_tokens) == 0
|
|
assert len(encoder_input_tokens) == 0
|
|
# -- An indirect check that model_input.encoder_input_tokens
|
|
# and model_input.encoder_input_positions are correct -
|
|
# by design of the test, the input tokens are
|
|
# equal to the input position values, so if
|
|
# the model_input data structure has the correct
|
|
# values then these two should be equal
|
|
assert torch.equal(
|
|
encoder_input_tokens,
|
|
encoder_input_positions,
|
|
)
|