from typing import Dict, List, Tuple import torch import torch.nn as nn from cacheflow.model_executor.input_metadata import InputMetadata from cacheflow.model_executor.parallel_utils.tensor_parallel import ( gather_from_tensor_model_parallel_region) from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import SequenceOutputs class Sampler(nn.Module): def __init__(self, vocab_size: int) -> None: super().__init__() self.vocab_size = vocab_size def forward( self, embedding: torch.Tensor, hidden_states: torch.Tensor, input_metadata: InputMetadata, ) -> Dict[int, SequenceOutputs]: # Get the hidden states that we use for sampling. hidden_states = _prune_hidden_states(hidden_states, input_metadata) # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) logits = gather_from_tensor_model_parallel_region(logits) # Remove paddings in vocab (if any). logits = logits[:, :self.vocab_size] # Apply temperature scaling. temperatures = _get_temperatures(input_metadata) assert len(temperatures) == logits.shape[0] if any(t != 1.0 for t in temperatures): t = torch.tensor( temperatures, dtype=logits.dtype, device=logits.device) # Use in-place division to avoid creating a new tensor. logits.div_(t.unsqueeze(dim=1)) # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) # Compute the log probabilities (before applying top-p). logprobs = torch.log(probs) # Apply top-p and top-k truncation. top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) assert len(top_ps) == len(top_ks) == probs.shape[0] if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks): p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device) k = torch.tensor(top_ks, dtype=torch.int, device=probs.device) probs = _apply_top_p_top_k(probs, p, k) # Sample the next tokens. return _sample(probs, logprobs, input_metadata) def _prune_hidden_states( hidden_states: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: start_idx = 0 last_token_indicies: List[int] = [] for prompt_len in input_metadata.prompt_lens: last_token_indicies.append(start_idx + prompt_len - 1) start_idx += prompt_len last_token_indicies.extend( range(start_idx, start_idx + input_metadata.num_generation_tokens)) return hidden_states[last_token_indicies] def _get_temperatures( input_metadata: InputMetadata, ) -> List[float]: # Collect the temperatures for the logits. temperatures: List[float] = [] for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature if temperature == 0.0: # NOTE: Zero temperature means deterministic sampling # (i.e., greedy sampling or beam search). # Set the temperature to 1 to avoid division by zero. temperature = 1.0 if i < input_metadata.num_prompts: # A prompt input. temperatures.append(temperature) else: # A generation token. temperatures += [temperature] * len(seq_ids) return temperatures def _get_top_p_top_k( input_metadata: InputMetadata, vocab_size: int, ) -> Tuple[List[float], List[int]]: top_ps: List[float] = [] top_ks: List[int] = [] for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group top_p = sampling_params.top_p # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) # k=-1 means no truncation. top_k = vocab_size if top_k == -1 else top_k if i < input_metadata.num_prompts: # A prompt input. top_ps.append(top_p) top_ks.append(top_k) else: # A generation token. top_ps += [top_p] * len(seq_ids) top_ks += [top_k] * len(seq_ids) return top_ps, top_ks def _apply_top_p_top_k( probs: torch.Tensor, p: torch.Tensor, k: torch.Tensor, ) -> torch.Tensor: # TODO(woosuk): Optimize. probs_sort, probs_idx = probs.sort(dim=-1, descending=True) # Apply top-p. probs_sum = torch.cumsum(probs_sort, dim=-1) top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) probs_sort[top_p_mask] = 0.0 # Apply top-k. # Create a mask for the top-k elements. top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1) top_k_mask = top_k_mask >= k.unsqueeze(dim=1) probs_sort[top_k_mask] = 0.0 # Re-sort the probabilities. probs = torch.gather( probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1)) return probs def _get_topk_logprobs( logprobs: torch.Tensor, num_logprobs: int, ) -> Dict[int, float]: if num_logprobs == 0: return {} topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs) if num_logprobs == 1: topk_logprobs = [topk_logprobs.item()] topk_ids = [topk_ids.item()] else: topk_logprobs = topk_logprobs.tolist() topk_ids = topk_ids.tolist() token_to_logprob: Dict[int, float] = {} for token_id, logprob in zip(topk_ids, topk_logprobs): token_to_logprob[token_id] = logprob return token_to_logprob def _sample_from_prompt( prob: torch.Tensor, sampling_params: SamplingParams, ) -> List[int]: if sampling_params.use_beam_search: # Beam search. beam_width = sampling_params.n _, next_token_ids = torch.topk(prob, beam_width) next_token_ids = next_token_ids.tolist() elif sampling_params.temperature == 0.0: # Greedy sampling. assert sampling_params.n == 1 next_token_id = torch.argmax(prob) next_token_ids = [next_token_id.item()] else: # Random sampling. # Sample n tokens for the prompt. n = sampling_params.n next_token_ids = torch.multinomial( prob, num_samples=n, replacement=True) next_token_ids = next_token_ids.tolist() return next_token_ids def _sample_from_generation_tokens( seq_ids: List[int], probs: torch.Tensor, logprobs: torch.Tensor, seq_logprobs: List[float], sampling_params: SamplingParams, ) -> Tuple[List[int], List[int]]: # NOTE(woosuk): sampling_params.n can be greater than # len(seq_ids) because some sequences in the group might have # been already terminated. if sampling_params.use_beam_search: # Beam search. # Add cumulative logprobs for the sequences in the group. seq_logprobs = torch.tensor( seq_logprobs, dtype=torch.float, device=logprobs.device) logprobs = logprobs + seq_logprobs.unsqueeze(dim=1) vocab_size = logprobs.size(-1) beam_width = len(seq_ids) _, topk_ids = torch.topk(logprobs.flatten(), beam_width) topk_ids = topk_ids.tolist() seq_idx = [i // vocab_size for i in topk_ids] beam_seq_ids = [seq_ids[i] for i in seq_idx] token_ids = [i % vocab_size for i in topk_ids] beam_outputs: Dict[int, Tuple[int, int]] = {} outstanding_beams: List[Tuple[int, int]] = [] # If a beam survives, continue with it. for seq_id, token_id in zip(beam_seq_ids, token_ids): if seq_id not in beam_outputs: beam_outputs[seq_id] = (seq_id, token_id) else: outstanding_beams.append((seq_id, token_id)) # If a beam is discarded, fork another beam. for seq_id in seq_ids: if seq_id not in beam_outputs: beam_outputs[seq_id] = outstanding_beams.pop() assert not outstanding_beams parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids] next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids] elif sampling_params.temperature == 0.0: # Greedy sampling. assert len(seq_ids) == 1 next_token_id = torch.argmax(probs, dim=-1) next_token_ids = [next_token_id.item()] parent_seq_ids = seq_ids else: # Random sampling. # Sample 1 token for each sequence in the group. next_token_ids = torch.multinomial( probs, num_samples=1, replacement=True) next_token_ids = next_token_ids.squeeze(dim=-1).tolist() parent_seq_ids = seq_ids return parent_seq_ids, next_token_ids def _sample( probs: torch.Tensor, logprobs: torch.Tensor, input_metadata: InputMetadata, ) -> Dict[int, SequenceOutputs]: seq_outputs: Dict[int, SequenceOutputs] = {} # TODO(woosuk): Optimize. idx = 0 for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group if i < input_metadata.num_prompts: # Generate the next tokens for a prompt input. assert len(seq_ids) == sampling_params.n prob = probs[idx] logprob = logprobs[idx] idx += 1 # Sample the next tokens. next_token_ids = _sample_from_prompt(prob, sampling_params) # Get top-k log probabilities for the next tokens. next_logprobs = _get_topk_logprobs( logprob, sampling_params.num_logprobs) # Build the output. for seq_id, next_token_id in zip(seq_ids, next_token_ids): output_logprobs = next_logprobs.copy() output_logprobs[next_token_id] = logprob[next_token_id].item() seq_outputs[seq_id] = SequenceOutputs( seq_id, seq_id, next_token_id, output_logprobs) else: # Generate the next tokens for generation tokens. prob = probs[idx:idx + len(seq_ids)] logprob = logprobs[idx:idx + len(seq_ids)] idx += len(seq_ids) # Sample the next tokens. seq_logprobs = [ input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids] parent_seq_ids, next_token_ids = _sample_from_generation_tokens( seq_ids, prob, logprob, seq_logprobs, sampling_params) # Get top-k log probabilities for the next tokens. next_logprobs: Dict[int, Dict[int, float]] = {} for i, seq_id in enumerate(seq_ids): next_logprobs[seq_id] = _get_topk_logprobs( logprob[i], sampling_params.num_logprobs) # Build the output. for seq_id, parent_seq_id, next_token_id in zip( seq_ids, parent_seq_ids, next_token_ids): i = seq_ids.index(parent_seq_id) output_logprobs = next_logprobs[parent_seq_id].copy() output_logprobs[next_token_id] = logprob[i, next_token_id].item() seq_outputs[seq_id] = SequenceOutputs( seq_id, parent_seq_id, next_token_id, output_logprobs, ) return seq_outputs