# SPDX-License-Identifier: Apache-2.0 from typing import Any, Optional import pytest import torch import torch.nn.functional as F from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, RejectionSampler) from vllm.v1.spec_decode.metadata import SpecDecodeMetadata DEVICE = "cuda" @pytest.fixture def rejection_sampler(): return RejectionSampler() def create_logits_tensor(output_token_ids: list[list[int]], vocab_size: int = 100) -> torch.Tensor: """Helper function to create logits tensor that will produce desired token ids on argmax""" token_ids = [tokens[:-1] for tokens in output_token_ids] num_total_tokens = sum(len(tokens) for tokens in token_ids) logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE) start_loc = 0 for tokens in token_ids: for j, token_id in enumerate(tokens): logits[start_loc + j, token_id] = 100.0 start_loc += len(tokens) return logits def create_sampling_metadata( all_greedy: bool, temperature: Optional[torch.Tensor] = None, generators: Optional[dict[int, Any]] = None, ) -> SamplingMetadata: """Create a v1 sampling metadata object with all_greedy set to the given value. Either all greedy or all random sampling is used. """ generators = generators or {} if all_greedy: temperature = None else: assert temperature is not None return SamplingMetadata( temperature=temperature, all_greedy=all_greedy, all_random=not all_greedy, top_p=None, top_k=None, min_p=torch.empty(1, ), generators=generators, max_num_logprobs=0, no_penalties=False, prompt_token_ids=None, frequency_penalties=torch.tensor([]), presence_penalties=torch.tensor([]), repetition_penalties=torch.tensor([]), output_token_ids=[], min_tokens={}, logit_bias=[None], allowed_token_ids_mask=None, bad_words_token_ids={}, ) ########################### Tests for Greedy Sampling ################### def test_perfect_match(rejection_sampler): """Test when output tokens perfectly match speculated tokens""" spec_tokens = [[1, 2, 3]] output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) output = rejection_sampler( spec_decode_metadata, draft_probs=None, target_logits=logits, bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) def test_early_mismatch(rejection_sampler): """Test when there's an early mismatch in tokens""" spec_tokens = [[1, 2, 3]] output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1 metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) output = rejection_sampler( spec_decode_metadata, draft_probs=None, target_logits=logits, bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) expected = torch.tensor( [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device, ) assert torch.equal(output, expected) def test_multiple_sequences(rejection_sampler): """Test handling multiple sequences of speculated tokens""" spec_tokens = [[1, 2], [3]] output_tokens = [[1, 2, 5], [3, 4]] # Two sequences with bonus tokens 5 and 4 metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) output = rejection_sampler( spec_decode_metadata, draft_probs=None, target_logits=logits, bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) def test_single_token_sequence(rejection_sampler): """Test handling sequences with single token""" spec_tokens = [[1]] output_tokens = [[1, 2]] # Single token with bonus token 2 metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) output = rejection_sampler( spec_decode_metadata, draft_probs=None, target_logits=logits, bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) def test_empty_sequence(rejection_sampler): """Test handling empty sequence of speculated tokens""" spec_tokens: list[list[int]] = [[]] output_tokens = [[5]] # Just the bonus token metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) output = rejection_sampler( spec_decode_metadata, draft_probs=None, target_logits=logits, bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) def test_multiple_mismatches(rejection_sampler): """Test handling multiple sequences with mismatches""" spec_tokens = [[1, 2, 3], [4, 5, 6]] output_tokens = [[1, 2, 7, 6], [4, 8, 6, 9]] # Mismatches in both sequences metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) output = rejection_sampler( spec_decode_metadata, draft_probs=None, target_logits=logits, bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) expected = torch.tensor( [[1, 2, 7, PLACEHOLDER_TOKEN_ID], [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device, ) assert torch.equal(output, expected) @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", [ ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]], [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches ]) def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expected): """Parametrized test for various matching scenarios""" metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], device=logits.device) spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) output = rejection_sampler( spec_decode_metadata, draft_probs=None, target_logits=logits, bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) assert torch.equal(output, expected_tensor) ########################### Tests for Random Sampling ################### @pytest.mark.parametrize("k", [1, 3, 5]) @pytest.mark.parametrize("vocab_size", [1000]) @pytest.mark.parametrize("batch_size", [1, 4, 8]) @pytest.mark.parametrize("frac_seeded", [0.0, 0.5]) @pytest.mark.parametrize("n_rep", [20]) def test_deterministic_when_seeded( rejection_sampler, k: int, vocab_size: int, batch_size: int, frac_seeded: float, n_rep: int, ): num_tokens = batch_size * k draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE) draft_probs = F.softmax(draft_probs, dim=-1) target_logits = torch.rand_like(draft_probs) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64, device=DEVICE) draft_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device=DEVICE) seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded results = [] for _ in range(n_rep): seeded_seqs = { i: torch.Generator(device=DEVICE).manual_seed(i) for i in range(batch_size) if seeded_mask[i] } temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) sampling_metadata = create_sampling_metadata(all_greedy=False, temperature=temperature, generators=seeded_seqs) spec_decode_metadata = SpecDecodeMetadata.make_dummy( draft_token_ids.tolist(), device=DEVICE) rep_result = rejection_sampler( spec_decode_metadata, draft_probs=draft_probs, target_logits=target_logits, bonus_token_ids=bonus_token_ids, sampling_metadata=sampling_metadata, ) results.append(rep_result) for i in range(batch_size): if seeded_mask[i]: for j in range(1, n_rep): assert torch.equal(results[j][i], results[0][i]) def test_rejection_sampling_approximates_target_distribution(): """Verify rejection sampling approximates target distribution, despite sampling from a potentially distinct draft distribution. This is done by first creating a random target probability distribution and a random draft probability distribution. We then sample token ids from the rejection sampler using these draft and target distributions. The samples are used to estimate the output probability distribution, which we expect to approximate the target distribution. A basic distance metric is used to determine similarity between distributions. We expect that as we increase the number of samples, the distance between the observed distribution and the target distribution decreases. To measure this, we compare the distance of the observed distribution against both the target distribution and a uniform random distribution. We expect the distance between the observed distribution and the target distribution to improve much more than the distance improvement between the observed distribution and the random distribution. """ torch.set_default_device(DEVICE) vocab_size = 10 k = 2 num_reference_probs = 100 # Prepare draft, target, and reference probability distributions draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), dim=-1) target_logits = torch.rand(vocab_size, dtype=torch.float32) target_probs = F.softmax(target_logits, dim=-1) reference_probs = F.softmax( torch.rand(num_reference_probs, vocab_size, dtype=torch.float32), dim=-1, ) sample_sizes = [10, 100, 1_000, 10_000, 100_000] distance_wrt_reference: list[float] = [] distance_wrt_target: list[float] = [] for num_samples in sample_sizes: # Sample using rejection sampling. rej_sample_probs = estimate_rejection_sampling_pdf( draft_probs, target_logits, k, vocab_size, num_samples) rej_sample_probs = rej_sample_probs.to(DEVICE) # Average distance from reference probs. reference_vs_rejsample_dist = torch.dist( reference_probs, rej_sample_probs).item() / reference_probs.shape[0] target_vs_rejsample_dist = torch.dist(target_probs, rej_sample_probs).item() distance_wrt_reference.append(reference_vs_rejsample_dist) distance_wrt_target.append(target_vs_rejsample_dist) relative_change_in_distance_wrt_target = get_ratio_first_to_last( distance_wrt_target) relative_change_in_distance_wrt_reference = get_ratio_first_to_last( distance_wrt_reference) print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} " f"{reference_vs_rejsample_dist=:.05f}") print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} " f"{relative_change_in_distance_wrt_reference=:.02f}") relative_change_in_distance_wrt_target = get_ratio_first_to_last( distance_wrt_target) relative_change_in_distance_wrt_reference = get_ratio_first_to_last( distance_wrt_reference) expected_improvement_multiplier = 20 assert (relative_change_in_distance_wrt_target > relative_change_in_distance_wrt_reference * expected_improvement_multiplier) def get_ratio_first_to_last(elements: list[float]) -> float: return elements[0] / elements[-1] def estimate_rejection_sampling_pdf( draft_probs: torch.Tensor, target_logits: torch.Tensor, k: int, vocab_size: int, num_samples: int, ) -> torch.Tensor: """Estimate the probability distribution of the output tokens using rejection sampling. Args: draft_probs: Draft probability distribution. target_logits: Target logits. num_samples: Number of samples to draw. Returns: Estimated probability distribution of the output tokens. """ rejection_sampler = RejectionSampler() num_tokens = num_samples * k # Repeat draft probs num_samples * k times. draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1) # Repeat target probs num_tokens times. target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1) # Randomly sample draft token ids from draft probs. draft_token_ids = torch.multinomial(draft_probs[:, 0, :], num_samples=k, replacement=True).reshape( num_samples, k) draft_probs = draft_probs.view(num_tokens, vocab_size) # Bonus tokens not used but required. bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat(num_samples, 1) temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE) sampling_metadata = create_sampling_metadata(all_greedy=False, temperature=temperature) spec_decode_metadata = SpecDecodeMetadata.make_dummy( draft_token_ids.tolist(), device=bonus_token_ids.device) output_token_ids = rejection_sampler( spec_decode_metadata, draft_probs=draft_probs, target_logits=target_logits, bonus_token_ids=bonus_token_ids, sampling_metadata=sampling_metadata, ) output_token_ids = output_token_ids[:, :-1].flatten() hist = torch.histogram(output_token_ids.to(dtype=torch.float, device="cpu"), bins=vocab_size, range=(0, vocab_size), density=True) return hist.hist