2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2024-02-21 11:47:00 -08:00
|
|
|
"""Verify that seeded random sampling is deterministic.
|
|
|
|
|
2024-03-29 13:06:40 +09:00
|
|
|
Run `pytest tests/samplers/test_seeded_generate.py`.
|
2024-02-21 11:47:00 -08:00
|
|
|
"""
|
|
|
|
import copy
|
|
|
|
import random
|
|
|
|
from itertools import combinations
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from vllm import SamplingParams
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.model_executor.utils import set_random_seed
|
2024-02-21 11:47:00 -08:00
|
|
|
|
|
|
|
MODEL = "facebook/opt-125m"
|
|
|
|
RANDOM_SEEDS = list(range(5))
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
2025-03-15 01:02:20 -04:00
|
|
|
def vllm_model(vllm_runner, monkeypatch):
|
|
|
|
# This file relies on V0 internals.
|
|
|
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
2024-06-08 01:59:20 -07:00
|
|
|
with vllm_runner(MODEL, dtype="half") as vllm_model:
|
|
|
|
yield vllm_model
|
2024-02-21 11:47:00 -08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
|
|
|
def test_random_sample_with_seed(
|
|
|
|
vllm_model,
|
|
|
|
example_prompts,
|
|
|
|
seed: int,
|
|
|
|
) -> None:
|
|
|
|
set_random_seed(seed)
|
|
|
|
|
|
|
|
sampling_params = SamplingParams(
|
|
|
|
# Parameters to ensure sufficient randomness
|
2025-01-21 19:51:35 +00:00
|
|
|
temperature=3.0,
|
2024-02-21 11:47:00 -08:00
|
|
|
top_p=min(random.random() + 0.3, 1),
|
|
|
|
top_k=random.randint(5, 20),
|
|
|
|
n=random.randint(1, 10),
|
|
|
|
presence_penalty=random.randint(0, 1),
|
|
|
|
max_tokens=8,
|
|
|
|
ignore_eos=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
sampling_params_seed_1 = copy.deepcopy(sampling_params)
|
|
|
|
sampling_params_seed_1.seed = 100
|
|
|
|
sampling_params_seed_2 = copy.deepcopy(sampling_params)
|
|
|
|
sampling_params_seed_2.seed = 200
|
|
|
|
|
|
|
|
llm = vllm_model.model
|
|
|
|
|
|
|
|
for prompt in example_prompts:
|
|
|
|
for params in (
|
|
|
|
sampling_params,
|
|
|
|
sampling_params_seed_1,
|
|
|
|
sampling_params_seed_2,
|
|
|
|
sampling_params,
|
|
|
|
sampling_params_seed_1,
|
|
|
|
sampling_params_seed_2,
|
|
|
|
):
|
2024-05-29 04:29:31 +08:00
|
|
|
llm._add_request(prompt, params=params)
|
2024-02-21 11:47:00 -08:00
|
|
|
|
|
|
|
results = llm._run_engine(use_tqdm=False)
|
|
|
|
all_outputs = [[out.token_ids for out in output.outputs]
|
|
|
|
for output in results]
|
|
|
|
|
|
|
|
for i in range(0, len(example_prompts), 6):
|
|
|
|
outputs = all_outputs[i:i + 6]
|
|
|
|
|
|
|
|
# verify all non-seeded requests differ
|
|
|
|
for output_a, output_b in combinations(
|
|
|
|
(outputs[0], outputs[1], outputs[2], outputs[3]),
|
|
|
|
2,
|
|
|
|
):
|
|
|
|
assert output_a != output_b
|
|
|
|
|
|
|
|
# verify requests with the same seed match
|
|
|
|
assert outputs[1] == outputs[4]
|
|
|
|
assert outputs[2] == outputs[5]
|
2025-01-21 19:51:35 +00:00
|
|
|
|
|
|
|
# verify generations within the same parallel sampling group differ
|
|
|
|
for output in outputs:
|
|
|
|
for sub_output_a, sub_output_b in combinations(output, 2):
|
|
|
|
assert sub_output_a != sub_output_b
|