# SPDX-License-Identifier: Apache-2.0 from typing import Optional import torch import torch.nn as nn import triton import triton.language as tl from vllm.logger import init_logger from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.utils import compiled_softmax from vllm.v1.spec_decode.metadata import SpecDecodeMetadata logger = init_logger(__name__) PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 GREEDY_TEMPERATURE: tl.constexpr = -1 # Maximum number of speculative draft tokens allowed per request in a single # step. This value is chosen to be large enough to handle typical use cases. MAX_SPEC_LEN = 32 class RejectionSampler(nn.Module): """ The implementation strictly follows the algorithm described in https://arxiv.org/abs/2211.17192. However, we want to clarify the terminology used in the implementation: accepted tokens: tokens that are accepted based on the relationship between the "raw" draft and target probabilities. recovered tokens: tokens that are sampled based on the adjusted probability distribution, which is derived from both the draft and target probabilities. bonus tokens: If all proposed tokens are accepted, the bonus token is added to the end of the sequence. The bonus token is only sampled from the target probabilities. We pass in the bonus tokens instead of sampling them in the rejection sampler to allow for more flexibility in the sampling process. For example, we can use top_p, top_k sampling for bonus tokens, while spec decode does not support these sampling strategies. output tokens: Tokens are finally generated with the rejection sampler. output tokens = accepted tokens + recovered tokens + bonus tokens """ def forward( self, metadata: SpecDecodeMetadata, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] target_logits: torch.Tensor, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: ''' Args: metadata: Metadata for spec decoding. draft_probs (Optional[torch.Tensor]): Probability distribution for the draft tokens. Shape is [num_tokens, vocab_size]. Can be None if probabilities are not provided, which is the case for ngram spec decode. target_logits (torch.Tensor): Target model's logits probability distribution. Shape is [num_tokens, vocab_size]. Here, probabilities from different requests are flattened into a single tensor because this is the shape of the output logits. bonus_token_ids_tensor (torch.Tensor): A tensor containing bonus tokens. Shape is [batch_size, 1]. Bonus tokens are added to the end of the sequence if all proposed tokens are accepted. We generate the bonus tokens outside of the rejection sampler with the default sampling strategy. It allows for more flexibility in the sampling process such as top_p, top_k sampling. sampling_metadata (SamplingMetadata): Additional metadata needed for sampling, such as temperature, top-k/top-p parameters, or other relevant information. Returns: output_token_ids (torch.Tensor): A tensor containing the final output token IDs. ''' assert metadata.max_spec_len <= MAX_SPEC_LEN # [num_tokens, vocab_size] target_probs = compute_probs( target_logits, metadata.cu_num_draft_tokens, sampling_metadata, ) output_token_ids = rejection_sample( metadata.draft_token_ids, metadata.num_draft_tokens, metadata.max_spec_len, metadata.cu_num_draft_tokens, draft_probs, target_probs, bonus_token_ids, sampling_metadata, ) return output_token_ids @staticmethod def parse_output( output_token_ids: torch.Tensor, vocab_size: int, ) -> list[list[int]]: output_token_ids_np = output_token_ids.cpu().numpy() # Create mask for valid tokens. valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (output_token_ids_np < vocab_size)) outputs = [ row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) ] return outputs def rejection_sample( # [num_tokens] draft_token_ids: torch.Tensor, # [batch_size] num_draft_tokens: list[int], max_spec_len: int, # [batch_size] cu_num_draft_tokens: torch.Tensor, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] target_probs: torch.Tensor, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: assert draft_token_ids.ndim == 1 assert draft_probs is None or draft_probs.ndim == 2 assert cu_num_draft_tokens.ndim == 1 assert target_probs.ndim == 2 batch_size = len(num_draft_tokens) num_tokens = draft_token_ids.shape[0] vocab_size = target_probs.shape[-1] device = target_probs.device assert draft_token_ids.is_contiguous() assert draft_probs is None or draft_probs.is_contiguous() assert target_probs.is_contiguous() assert bonus_token_ids.is_contiguous() assert target_probs.shape == (num_tokens, vocab_size) # Create output buffer. output_token_ids = torch.empty( (batch_size, max_spec_len + 1), dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. device=device, ) output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) if sampling_metadata.all_greedy: is_greedy = None else: is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) rejection_greedy_sample_kernel[(batch_size, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, target_argmax, bonus_token_ids, is_greedy, max_spec_len, num_warps=1, ) if sampling_metadata.all_greedy: return output_token_ids # Generate uniform probabilities for rejection sampling. # [num_tokens] uniform_probs = generate_uniform_probs( num_tokens, num_draft_tokens, sampling_metadata.generators, device, ) # Sample recovered tokens for each position. # [num_tokens] recovered_token_ids = sample_recovered_tokens( max_spec_len, num_draft_tokens, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, sampling_metadata, device, ) # Rejection sampling for random sampling requests. rejection_random_sample_kernel[(batch_size, )]( output_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, bonus_token_ids, recovered_token_ids, uniform_probs, is_greedy, max_spec_len, vocab_size, IS_NGRAM=draft_probs is None, num_warps=1, ) return output_token_ids def compute_probs( logits: torch.Tensor, # [num_tokens, vocab_size] cu_num_draft_tokens: torch.Tensor, # [batch_size] sampling_metadata: SamplingMetadata, ) -> torch.Tensor: """Compute probability distribution from logits based on sampling metadata. This function applies temperature scaling to the logits and converts them to probabilities using softmax. For greedy decoding, it returns the original logits. Args: logits: Input logits tensor to be converted to probabilities. cu_num_draft_tokens: Cumulative number of draft tokens. sampling_metadata: Metadata containing sampling parameters such as temperature and whether greedy sampling is used. Returns: torch.Tensor: Probability distribution (softmax of scaled logits) if non-greedy sampling is used, otherwise returns the original logits. """ assert logits.ndim == 2 assert cu_num_draft_tokens.ndim == 1 if sampling_metadata.all_greedy: return logits num_tokens = logits.shape[0] batch_size = cu_num_draft_tokens.shape[0] expanded_temperature = torch.empty( (num_tokens, 1), dtype=torch.float32, device=logits.device, ) expand_kernel[(batch_size, )]( expanded_temperature, sampling_metadata.temperature, cu_num_draft_tokens, GREEDY_TEMPERATURE, # replace_from 1, # replace_to MAX_NUM_TOKENS=MAX_SPEC_LEN, num_warps=1, ) output_prob = compiled_softmax(logits, expanded_temperature) return output_prob def generate_uniform_probs( num_tokens: int, num_draft_tokens: list[int], generators: dict[int, torch.Generator], device: torch.device, ) -> torch.Tensor: """ Generates a batch of uniform random samples, with optional seeding if available. This method creates a tensor of shape `(num_tokens, )` filled with uniform random values in the range [0, 1). If `generators` is provided, the requests with their own seeds will use the provided `torch.Generator` for reproducibility. The samples for the other requests will be generated without a seed. Args: num_tokens : int Total number of tokens. num_draft_tokens : List[List[int]] Number of draft tokens per request. generators : Optional[Dict[int, torch.Generator]] A dictionary mapping indices in the batch to `torch.Generator` objects. device : torch.device The device on which to allocate the tensor. Returns: uniform_rand : torch.Tensor A tensor of shape `(num_tokens, )` containing uniform random values in the range [0, 1). """ uniform_probs = torch.rand( (num_tokens, ), dtype=torch.float32, device=device, ) start_idx = 0 for req_idx, n in enumerate(num_draft_tokens): # Do not generate random numbers for requests with no draft tokens. # This can be important for reproducibility. if n == 0: continue end_idx = start_idx + n generator = generators.get(req_idx) if generator is not None: uniform_probs[start_idx:end_idx].uniform_(generator=generator) start_idx = end_idx return uniform_probs def sample_recovered_tokens( max_spec_len: int, num_draft_tokens: list[int], # [batch_size] cu_num_draft_tokens: torch.Tensor, # [num_tokens] draft_token_ids: torch.Tensor, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, device: torch.device, ) -> torch.Tensor: # NOTE(woosuk): Create only one distribution for each request. batch_size = len(num_draft_tokens) vocab_size = target_probs.shape[-1] q = torch.empty( (batch_size, vocab_size), dtype=torch.float32, device=device, ) q.exponential_() for i, generator in sampling_metadata.generators.items(): # Do not generate random numbers for requests with no draft tokens. # This can be important for reproducibility. if num_draft_tokens[i] > 0: q[i].exponential_(generator=generator) recovered_token_ids = torch.empty_like(draft_token_ids) sample_recovered_tokens_kernel[(batch_size, max_spec_len)]( recovered_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, q, vocab_size, triton.next_power_of_2(vocab_size), IS_NGRAM=draft_probs is None, ) return recovered_token_ids # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_greedy_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] target_argmax_ptr, # [num_tokens] bonus_token_ids_ptr, # [batch_size] is_greedy_ptr, # [batch_size] or None max_spec_len, ): req_idx = tl.program_id(0) # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run, # re-compilation may happen during runtime when is_greedy_ptr is None. if is_greedy_ptr is None: is_greedy = True else: is_greedy = tl.load(is_greedy_ptr + req_idx) if not is_greedy: # Early exit for non-greedy sampling requests. return if req_idx == 0: start_idx = 0 else: start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx rejected = False for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, target_argmax_id) if draft_token_id != target_argmax_id: # Reject. rejected = True if not rejected: # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, bonus_token_id) # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_random_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] draft_probs_ptr, # [num_tokens, vocab_size] or None target_probs_ptr, # [num_tokens, vocab_size] bonus_token_ids_ptr, # [batch_size] recovered_token_ids_ptr, # [num_tokens] uniform_probs_ptr, # [num_tokens] is_greedy_ptr, # [batch_size] max_spec_len, vocab_size, IS_NGRAM: tl.constexpr, ): req_idx = tl.program_id(0) is_greedy = tl.load(is_greedy_ptr + req_idx) if is_greedy: # Early exit for greedy sampling requests. return if req_idx == 0: start_idx = 0 else: start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx rejected = False for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) if IS_NGRAM: draft_prob = 1 else: draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) # NOTE(woosuk): While the draft probability should never be 0, # we check it to avoid NaNs. If it happens to be 0, we reject. if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: # Accept. token_id = draft_token_id else: # Reject. Use recovered token. rejected = True token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id) if not rejected: # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, bonus_token_id) # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @triton.jit(do_not_specialize=["replace_from", "replace_to"]) def expand_kernel( output_ptr, # [num_tokens] input_ptr, # [batch_size] cu_num_tokens_ptr, # [batch_size] replace_from, replace_to, MAX_NUM_TOKENS: tl.constexpr, ): req_idx = tl.program_id(0) if req_idx == 0: # noqa: SIM108 start_idx = 0 else: start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_tokens_ptr + req_idx) num_tokens = end_idx - start_idx src_val = tl.load(input_ptr + req_idx) src_val = tl.where(src_val == replace_from, replace_to, src_val) offset = tl.arange(0, MAX_NUM_TOKENS) tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens) @triton.jit def sample_recovered_tokens_kernel( output_token_ids_ptr, # [num_tokens] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] draft_probs_ptr, # [num_tokens, vocab_size] or None target_probs_ptr, # [num_tokens, vocab_size] q_ptr, # [batch_size, vocab_size] vocab_size, PADDED_VOCAB_SIZE: tl.constexpr, IS_NGRAM: tl.constexpr, ): req_idx = tl.program_id(0) if req_idx == 0: start_idx = 0 else: start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx # Early exit for out-of-range positions. pos = tl.program_id(1) if pos >= num_draft_tokens: return vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) if IS_NGRAM: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) # Temporarily zero out the probability of the draft token. # This is essentially the same as target_prob - draft_prob, except that # n-gram does not have draft_prob. We regard it as 1. tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0) prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=0) else: draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=0) target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=0) prob = tl.maximum(target_prob - draft_prob, 0) # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because # `tl.argmax` will select the maximum value. q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf")) recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) if IS_NGRAM: # Restore the original probability. tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob)