vllm/tests/samplers/test_sampler.py

348 lines
13 KiB
Python
Raw Normal View History

import random
2024-02-21 11:47:00 -08:00
from typing import Tuple, List
from unittest.mock import patch
2023-10-16 12:57:26 -07:00
import pytest
import torch
from transformers import GenerationConfig, GenerationMixin
2024-02-21 11:47:00 -08:00
from typing import Optional
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
class MockLogitsSampler(Sampler):
def __init__(self, fake_logits: torch.Tensor):
super().__init__()
self.fake_logits = fake_logits
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None)
return input_tensor, fake_logits, sampler, model_runner
RANDOM_SEEDS = list(range(128))
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
2024-02-21 11:47:00 -08:00
def _do_sample(
batch_size: int,
input_tensor: torch.Tensor,
sampler: MockLogitsSampler,
model_runner: ModelRunner,
sampling_params: SamplingParams,
):
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
2024-02-21 11:47:00 -08:00
sampling_params=sampling_params,
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
2024-02-21 11:47:00 -08:00
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
2023-10-16 12:57:26 -07:00
for nth_output in sequence_output.samples:
assert nth_output.output_token == expected[i].item()
2024-01-14 12:37:58 -08:00
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
for i in range(batch_size):
fake_logits[i, i] = 1e2
2024-02-21 11:47:00 -08:00
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
2024-02-21 11:47:00 -08:00
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
2024-02-21 11:47:00 -08:00
for i in range(batch_size):
2024-02-21 11:47:00 -08:00
fake_logits[i, i] = 1e2
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampling_params)
for i, sequence_output in enumerate(sampler_output):
2023-10-16 12:57:26 -07:00
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
2024-01-14 12:37:58 -08:00
del model_runner
2024-02-21 11:47:00 -08:00
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
2024-02-21 11:47:00 -08:00
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
2024-02-21 11:47:00 -08:00
model_runner, sampling_params)
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
2024-02-21 11:47:00 -08:00
model_runner, sampling_params)
assert first_sampler_output == second_sampler_output
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
2024-02-21 11:47:00 -08:00
sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
2024-01-14 12:37:58 -08:00
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
seq_group_metadata_list = []
2024-02-21 11:47:00 -08:00
expected_tokens: List[Optional[List[int]]] = []
prompt_lens = []
for i in range(batch_size):
2024-02-21 11:47:00 -08:00
expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
2024-02-21 11:47:00 -08:00
expected = [torch.argmax(fake_logits[i], dim=-1).item()]
elif sampling_type in (1, 2):
n = random.randint(1, 10)
sampling_params = SamplingParams(
temperature=random.random() + 0.1,
top_p=min(random.random() + 0.1, 1),
top_k=random.randint(0, 10) or -1,
n=n,
presence_penalty=random.randint(0, 1),
)
2024-02-21 11:47:00 -08:00
if sampling_type == 2:
sampling_params.seed = random.randint(0, 10000)
else:
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected = list(range(i, i + n))
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
2024-02-21 11:47:00 -08:00
expected_tokens.append(expected)
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
2024-02-21 11:47:00 -08:00
def test_sampling(model_runner: ModelRunner):
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
sampler_output = sampler(logits=fake_logits,
2024-02-21 11:47:00 -08:00
sampling_metadata=sampling_metadata)
for i, (sequence_output, metadata) in enumerate(
zip(sampler_output, seq_group_metadata_list)):
if metadata.sampling_params.use_beam_search:
continue
if (metadata.sampling_params.seed is not None
and expected_tokens[i] is None):
# Record seeded random result to compare with results of
# second invocation
2024-02-21 11:47:00 -08:00
expected_tokens[i] = [
nth_output.output_token
for nth_output in sequence_output.samples
]
continue
for n, nth_output in enumerate(sequence_output.samples):
if (metadata.sampling_params.temperature == 0
or metadata.sampling_params.seed is not None):
2024-02-21 11:47:00 -08:00
# Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens[i][n]
else:
# For non-seeded random check that one of the high-logit
# tokens were chosen
2024-02-21 11:47:00 -08:00
assert nth_output.output_token in expected_tokens[i]
# Test batch
test_sampling(model_runner)
# Shuffle the batch and resample
target_index = list(range(batch_size))
for list_to_shuffle in (target_index, seq_group_metadata_list,
expected_tokens, prompt_lens):
random.Random(seed).shuffle(list_to_shuffle)
target_index = torch.tensor(target_index)
input_tensor.data = input_tensor.index_select(0, target_index)
fake_logits.data = fake_logits.index_select(0, target_index)
# This time, results of seeded random samples will be compared with
# the corresponding sample in the pre-shuffled batch
2024-02-21 11:47:00 -08:00
test_sampling(model_runner)
2024-01-14 12:37:58 -08:00
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
set_random_seed(seed)
batch_size = random.randint(1, 256)
top_k = random.randint(100, 500)
top_p = random.random() * 0.1
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024),
device=device,
dtype=torch.float16)
fake_logits = torch.normal(0,
5,
size=(batch_size, vocab_size),
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None)
generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k,
top_p=top_p,
do_sample=True)
warpers = generation_model._get_logits_warper(generation_config)
assert len(warpers) == 2 # top_p and top_k
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1,
top_k=top_k,
top_p=top_p,
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sample_probs = None
def mock_sample(probs, *args, **kwargs):
nonlocal sample_probs
sample_probs = probs
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
2024-01-14 12:37:58 -08:00
del model_runner