100 lines
3.7 KiB
Python
100 lines
3.7 KiB
Python
import random
|
|
from typing import Tuple
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.utils import set_random_seed
|
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
|
from vllm.worker.model_runner import ModelRunner
|
|
|
|
|
|
class MockLogitsProcessor(LogitsProcessor):
|
|
|
|
def __init__(self, vocab_size: int, scale: float,
|
|
fake_logits: torch.Tensor):
|
|
super().__init__(vocab_size=vocab_size, scale=scale)
|
|
self.fake_logits = fake_logits.clone()
|
|
|
|
def forward(self, *args, **kwargs):
|
|
with patch(
|
|
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
|
|
lambda x, y: x
|
|
), patch(
|
|
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
|
|
lambda *args, **kwargs: self.fake_logits):
|
|
return super().forward(*args, **kwargs)
|
|
|
|
|
|
def _prepare_test(
|
|
batch_size: int
|
|
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]:
|
|
vocab_size = 32000
|
|
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
|
fake_logits = torch.full((batch_size, vocab_size),
|
|
1e-2,
|
|
dtype=input_tensor.dtype)
|
|
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
|
|
model_runner = ModelRunner(model_config=None,
|
|
parallel_config=None,
|
|
scheduler_config=None,
|
|
device_config=None,
|
|
load_config=None,
|
|
lora_config=None)
|
|
return input_tensor, fake_logits, logits_processor, model_runner
|
|
|
|
|
|
RANDOM_SEEDS = list(range(128))
|
|
CUDA_DEVICES = [
|
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
def test_logits_processors(seed: int, device: str):
|
|
set_random_seed(seed)
|
|
torch.set_default_device(device)
|
|
batch_size = random.randint(1, 256)
|
|
input_tensor, fake_logits, logits_processor, model_runner = _prepare_test(
|
|
batch_size)
|
|
|
|
# This sample logits processor gives infinite score to the i-th token,
|
|
# where i is the length of the input sequence.
|
|
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
|
def pick_ith(token_ids, logits):
|
|
logits[len(token_ids)] = float("inf")
|
|
return logits
|
|
|
|
seq_group_metadata_list = []
|
|
prompt_lens = []
|
|
for i in range(batch_size):
|
|
seq_group_metadata_list.append(
|
|
SequenceGroupMetadata(
|
|
request_id=f"test_{i}",
|
|
is_prompt=True,
|
|
seq_data={0: SequenceData([1, 2, 3])},
|
|
sampling_params=SamplingParams(temperature=0,
|
|
logits_processors=[pick_ith]),
|
|
block_tables={0: [1]},
|
|
))
|
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
|
|
|
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
|
prompt_lens,
|
|
subquery_lens=prompt_lens)
|
|
logits_processor_output = logits_processor(
|
|
embedding=None,
|
|
hidden_states=input_tensor,
|
|
sampling_metadata=sampling_metadata)
|
|
|
|
assert torch.isinf(logits_processor_output[:, 0]).all()
|
|
|
|
fake_logits *= logits_processor.scale
|
|
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
|
|
1e-4)
|
|
|
|
del model_runner
|