vllm/tests/samplers/test_seeded_generate.py

86 lines
2.4 KiB
Python
Raw Permalink Normal View History

# SPDX-License-Identifier: Apache-2.0
2024-02-21 11:47:00 -08:00
"""Verify that seeded random sampling is deterministic.
Run `pytest tests/samplers/test_seeded_generate.py`.
2024-02-21 11:47:00 -08:00
"""
import copy
import random
from itertools import combinations
import pytest
from vllm import SamplingParams
2024-03-25 23:59:47 +09:00
from vllm.model_executor.utils import set_random_seed
2024-02-21 11:47:00 -08:00
MODEL = "facebook/opt-125m"
RANDOM_SEEDS = list(range(5))
@pytest.fixture
def vllm_model(vllm_runner, monkeypatch):
# This file relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
with vllm_runner(MODEL, dtype="half") as vllm_model:
yield vllm_model
2024-02-21 11:47:00 -08:00
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_random_sample_with_seed(
vllm_model,
example_prompts,
seed: int,
) -> None:
set_random_seed(seed)
sampling_params = SamplingParams(
# Parameters to ensure sufficient randomness
temperature=3.0,
2024-02-21 11:47:00 -08:00
top_p=min(random.random() + 0.3, 1),
top_k=random.randint(5, 20),
n=random.randint(1, 10),
presence_penalty=random.randint(0, 1),
max_tokens=8,
ignore_eos=True,
)
sampling_params_seed_1 = copy.deepcopy(sampling_params)
sampling_params_seed_1.seed = 100
sampling_params_seed_2 = copy.deepcopy(sampling_params)
sampling_params_seed_2.seed = 200
llm = vllm_model.model
for prompt in example_prompts:
for params in (
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
):
llm._add_request(prompt, params=params)
2024-02-21 11:47:00 -08:00
results = llm._run_engine(use_tqdm=False)
all_outputs = [[out.token_ids for out in output.outputs]
for output in results]
for i in range(0, len(example_prompts), 6):
outputs = all_outputs[i:i + 6]
# verify all non-seeded requests differ
for output_a, output_b in combinations(
(outputs[0], outputs[1], outputs[2], outputs[3]),
2,
):
assert output_a != output_b
# verify requests with the same seed match
assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5]
# verify generations within the same parallel sampling group differ
for output in outputs:
for sub_output_a, sub_output_b in combinations(output, 2):
assert sub_output_a != sub_output_b