246 lines
9.3 KiB
Python
246 lines
9.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import dataclasses
|
|
|
|
import torch
|
|
|
|
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
from vllm.attention.backends.utils import CommonAttentionState
|
|
from vllm.model_executor import SamplingMetadata
|
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
|
from vllm.worker.multi_step_model_runner import StatefulModelInput
|
|
from vllm.worker.pooling_model_runner import (
|
|
ModelInputForGPUWithPoolingMetadata)
|
|
|
|
|
|
class MockAttentionBackend(AttentionBackend):
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def get_impl_cls():
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
return AttentionMetadata
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["AttentionMetadataBuilder"]:
|
|
return AttentionMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_state_cls() -> type["CommonAttentionState"]:
|
|
return CommonAttentionState
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> tuple[int, ...]:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def swap_blocks(
|
|
src_kv_cache: torch.Tensor,
|
|
dst_kv_cache: torch.Tensor,
|
|
src_to_dst: torch.Tensor,
|
|
) -> None:
|
|
pass
|
|
|
|
@staticmethod
|
|
def copy_blocks(
|
|
kv_caches: list[torch.Tensor],
|
|
src_to_dists: torch.Tensor,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
def test_model_runner_input():
|
|
sampling_metadata = SamplingMetadata(
|
|
["seq_group"],
|
|
"selected_token_indices",
|
|
"categorized_sample_indices",
|
|
"num_prompts",
|
|
)
|
|
attn_metadata = AttentionMetadata(
|
|
num_prefills=1,
|
|
num_prefill_tokens=2,
|
|
num_decode_tokens=3,
|
|
slot_mapping=torch.zeros(1),
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=True,
|
|
)
|
|
model_input = ModelInputForGPUWithSamplingMetadata(
|
|
input_tokens=torch.ones(10),
|
|
input_positions=torch.ones(10),
|
|
sampling_metadata=sampling_metadata,
|
|
attn_metadata=attn_metadata)
|
|
|
|
assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
|
|
|
|
# Test round trip serialization.
|
|
tensor_dict = model_input.as_broadcastable_tensor_dict()
|
|
attn_backend = MockAttentionBackend()
|
|
received_model_input = (
|
|
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
|
tensor_dict, attn_backend=attn_backend))
|
|
# Check that received copy has correct values.
|
|
assert isinstance(received_model_input,
|
|
ModelInputForGPUWithSamplingMetadata)
|
|
assert received_model_input.input_tokens is not None
|
|
assert (
|
|
received_model_input.input_tokens == model_input.input_tokens).all()
|
|
assert received_model_input.input_positions is not None
|
|
assert (received_model_input.input_positions == model_input.input_positions
|
|
).all()
|
|
assert received_model_input.multi_modal_kwargs is None
|
|
assert (received_model_input.multi_modal_kwargs ==
|
|
model_input.multi_modal_kwargs)
|
|
assert received_model_input.lora_requests is None
|
|
assert received_model_input.lora_requests == model_input.lora_requests
|
|
assert received_model_input.lora_mapping is None
|
|
assert received_model_input.lora_mapping == model_input.lora_mapping
|
|
for field in dataclasses.fields(AttentionMetadata):
|
|
assert getattr(received_model_input.attn_metadata, field.name,
|
|
None) == getattr(attn_metadata, field.name, None)
|
|
# For sampling metadata, only selected_token_indices is copied.
|
|
assert (received_model_input.sampling_metadata.selected_token_indices ==
|
|
sampling_metadata.selected_token_indices)
|
|
assert received_model_input.sampling_metadata.seq_groups is None
|
|
|
|
|
|
def test_embedding_model_runner_input():
|
|
pooling_metadata = PoolingMetadata(
|
|
seq_groups=[[0]],
|
|
seq_data={},
|
|
prompt_lens=[1],
|
|
)
|
|
attn_metadata = AttentionMetadata(
|
|
num_prefills=1,
|
|
num_prefill_tokens=2,
|
|
num_decode_tokens=3,
|
|
slot_mapping=torch.zeros(1),
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=True,
|
|
)
|
|
model_input = ModelInputForGPUWithPoolingMetadata(
|
|
input_tokens=torch.ones(10),
|
|
input_positions=torch.ones(10),
|
|
pooling_metadata=pooling_metadata,
|
|
attn_metadata=attn_metadata)
|
|
|
|
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
|
|
|
|
# Test round trip serialization.
|
|
tensor_dict = model_input.as_broadcastable_tensor_dict()
|
|
attn_backend = MockAttentionBackend()
|
|
received_model_input = (
|
|
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
|
tensor_dict, attn_backend=attn_backend))
|
|
# Check that received copy has correct values.
|
|
assert isinstance(received_model_input,
|
|
ModelInputForGPUWithPoolingMetadata)
|
|
assert received_model_input.input_tokens is not None
|
|
assert (
|
|
received_model_input.input_tokens == model_input.input_tokens).all()
|
|
assert received_model_input.input_positions is not None
|
|
assert (received_model_input.input_positions == model_input.input_positions
|
|
).all()
|
|
assert received_model_input.multi_modal_kwargs is None
|
|
assert (received_model_input.multi_modal_kwargs ==
|
|
model_input.multi_modal_kwargs)
|
|
assert received_model_input.lora_requests is None
|
|
assert received_model_input.lora_requests == model_input.lora_requests
|
|
assert received_model_input.lora_mapping is None
|
|
assert received_model_input.lora_mapping == model_input.lora_mapping
|
|
for field in dataclasses.fields(AttentionMetadata):
|
|
assert getattr(received_model_input.attn_metadata, field.name,
|
|
None) == getattr(attn_metadata, field.name, None)
|
|
# Pooling metadata is not broadcast.
|
|
assert received_model_input.pooling_metadata is None
|
|
|
|
|
|
def test_multi_step_model_runner_input():
|
|
sampling_metadata = SamplingMetadata(
|
|
["seq_group"],
|
|
"selected_token_indices",
|
|
"categorized_sample_indices",
|
|
"num_prompts",
|
|
)
|
|
attn_metadata = AttentionMetadata(
|
|
num_prefills=1,
|
|
num_prefill_tokens=2,
|
|
num_decode_tokens=3,
|
|
slot_mapping=torch.zeros(1),
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=True,
|
|
)
|
|
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
|
|
input_tokens=torch.ones(10),
|
|
input_positions=torch.ones(10),
|
|
sampling_metadata=sampling_metadata,
|
|
attn_metadata=attn_metadata)
|
|
|
|
model_input = StatefulModelInput(
|
|
frozen_model_input=frozen_model_input,
|
|
is_last_step=True,
|
|
is_first_multi_step=False,
|
|
current_step=4,
|
|
last_sampled_token_ids=torch.ones((10, 1)),
|
|
is_multi_step=True,
|
|
num_queries=8,
|
|
num_seqs=5,
|
|
cached_outputs=[],
|
|
)
|
|
|
|
assert isinstance(model_input, StatefulModelInput)
|
|
|
|
# Test round trip serialization.
|
|
tensor_dict = model_input.as_broadcastable_tensor_dict()
|
|
attn_backend = MockAttentionBackend()
|
|
received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
|
|
tensor_dict, attn_backend=attn_backend))
|
|
|
|
receieved_frozen_input = received_model_input.frozen_model_input
|
|
|
|
# Check that received copy has correct values.
|
|
assert isinstance(received_model_input, StatefulModelInput)
|
|
assert receieved_frozen_input.input_tokens is not None
|
|
assert (receieved_frozen_input.input_tokens ==
|
|
frozen_model_input.input_tokens).all()
|
|
assert receieved_frozen_input.input_positions is not None
|
|
assert (receieved_frozen_input.input_positions ==
|
|
frozen_model_input.input_positions).all()
|
|
assert receieved_frozen_input.multi_modal_kwargs is None
|
|
assert (frozen_model_input.multi_modal_kwargs ==
|
|
frozen_model_input.multi_modal_kwargs)
|
|
assert receieved_frozen_input.lora_requests is None
|
|
assert (receieved_frozen_input.lora_requests ==
|
|
frozen_model_input.lora_requests)
|
|
assert receieved_frozen_input.lora_mapping is None
|
|
assert (
|
|
receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
|
|
for field in dataclasses.fields(AttentionMetadata):
|
|
assert getattr(receieved_frozen_input.attn_metadata, field.name,
|
|
None) == getattr(attn_metadata, field.name, None)
|
|
# For sampling metadata, only selected_token_indices is copied.
|
|
assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
|
|
sampling_metadata.selected_token_indices)
|
|
assert receieved_frozen_input.sampling_metadata.seq_groups is None
|
|
|
|
# check non frozen fields
|
|
assert received_model_input.is_last_step == model_input.is_last_step
|
|
assert (received_model_input.is_first_multi_step ==
|
|
model_input.is_first_multi_step)
|
|
assert received_model_input.current_step == model_input.current_step
|
|
assert (received_model_input.last_sampled_token_ids ==
|
|
model_input.last_sampled_token_ids).all()
|
|
assert received_model_input.is_multi_step == model_input.is_multi_step
|