
- **Add SPDX license headers to python source files** - **Check for SPDX headers using pre-commit** commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745 Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:18:24 2025 -0500 Add SPDX license headers to python source files This commit adds SPDX license headers to python source files as recommended to the project by the Linux Foundation. These headers provide a concise way that is both human and machine readable for communicating license information for each source file. It helps avoid any ambiguity about the license of the code and can also be easily used by tools to help manage license compliance. The Linux Foundation runs license scans against the codebase to help ensure we are in compliance with the licenses of the code we use, including dependencies. Having these headers in place helps that tool do its job. More information can be found on the SPDX site: - https://spdx.dev/learn/handling-license-info/ Signed-off-by: Russell Bryant <rbryant@redhat.com> commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:36:32 2025 -0500 Check for SPDX headers using pre-commit Signed-off-by: Russell Bryant <rbryant@redhat.com> --------- Signed-off-by: Russell Bryant <rbryant@redhat.com>
150 lines
4.8 KiB
Python
150 lines
4.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
|
from vllm.model_executor.layers.sampler import _get_ranks
|
|
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
|
TypicalAcceptanceSampler)
|
|
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
|
|
from vllm.spec_decode.util import (get_sampled_token_logprobs,
|
|
split_batch_by_proposal_len)
|
|
|
|
|
|
def test_get_all_seq_ids():
|
|
"""Verify get_all_seq_ids extracts all seq ids.
|
|
"""
|
|
expected_seq_ids = list(range(10)) + list(range(100, 110))
|
|
|
|
seq_group_metadata_list = [
|
|
SequenceGroupMetadata(
|
|
request_id=str(seq_id),
|
|
is_prompt=True,
|
|
seq_data={
|
|
seq_id: MagicMock(),
|
|
},
|
|
sampling_params=MagicMock(),
|
|
block_tables={
|
|
seq_id: MagicMock(),
|
|
},
|
|
lora_request=None,
|
|
) for seq_id in expected_seq_ids
|
|
]
|
|
|
|
actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
|
assert actual_seq_ids == expected_seq_ids
|
|
|
|
|
|
@pytest.fixture
|
|
def fake_sequence_group_metadata():
|
|
seq_ids = list(range(3))
|
|
return [
|
|
SequenceGroupMetadata(
|
|
request_id=str(i),
|
|
is_prompt=True,
|
|
seq_data={
|
|
i: MagicMock(),
|
|
},
|
|
sampling_params=MagicMock(),
|
|
block_tables={
|
|
i: MagicMock(),
|
|
},
|
|
lora_request=None,
|
|
) for i in seq_ids
|
|
]
|
|
|
|
|
|
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
|
proposal_lens = [0, 1, 0]
|
|
_, (filtered_groups,
|
|
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
|
proposal_lens)
|
|
|
|
expected_groups = [
|
|
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
|
|
]
|
|
expected_indices = [0, 2]
|
|
|
|
assert filtered_groups == expected_groups
|
|
assert indices == expected_indices
|
|
|
|
|
|
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
|
proposal_lens = [0, 1, 2]
|
|
(filtered_groups,
|
|
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
|
proposal_lens)
|
|
|
|
expected_groups = [
|
|
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
|
|
]
|
|
expected_indices = [1, 2]
|
|
|
|
assert filtered_groups == expected_groups
|
|
assert indices == expected_indices
|
|
|
|
|
|
def test_empty_inputs():
|
|
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
|
|
|
|
assert filtered_groups == []
|
|
assert indices == []
|
|
|
|
|
|
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
|
proposal_lens = [0, 0, 0]
|
|
(filtered_groups,
|
|
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
|
proposal_lens)
|
|
|
|
assert filtered_groups == []
|
|
assert indices == []
|
|
|
|
|
|
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
|
proposal_lens = [1, 1, 1]
|
|
_, (filtered_groups,
|
|
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
|
proposal_lens)
|
|
|
|
assert filtered_groups == []
|
|
assert indices == []
|
|
|
|
|
|
def mock_spec_decode_sampler(acceptance_sampler_method):
|
|
"""
|
|
Returns either a RejectionSampler or TypicalAcceptanceSampler
|
|
object depending on whether acceptance_sampler_method is
|
|
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
|
|
"""
|
|
if acceptance_sampler_method == "rejection_sampler":
|
|
sampler = MagicMock(spec=RejectionSampler)
|
|
sampler.token_id_dtype = torch.int64
|
|
return sampler
|
|
elif acceptance_sampler_method == "typical_acceptance_sampler":
|
|
sampler = MagicMock(spec=TypicalAcceptanceSampler)
|
|
sampler.token_id_dtype = torch.int64
|
|
return sampler
|
|
else:
|
|
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
|
|
|
|
|
|
def test_get_sampled_token_logprobs():
|
|
"""Verify get_sampled_token_logprobs returns consistent rankings
|
|
with regular get_ranks when probabilities match exactly.
|
|
"""
|
|
logprob_tensor = torch.tensor(
|
|
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
|
|
sampled_token_tensor = torch.tensor([[1,
|
|
0]]) # shape (num_steps, batch_size)
|
|
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
|
|
sampled_token_tensor)
|
|
|
|
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
|
|
sampled_token_tensor.reshape(-1))
|
|
|
|
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)
|