2023-05-14 22:32:38 -07:00
|
|
|
"""A layer that samples the next tokens from the model's outputs."""
|
2023-02-23 09:26:09 +00:00
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
2023-05-10 23:39:12 -07:00
|
|
|
import numpy as np
|
2023-02-23 09:26:09 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
2023-05-09 15:30:12 -07:00
|
|
|
from cacheflow.model_executor.input_metadata import InputMetadata
|
|
|
|
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
|
|
|
gather_from_tensor_model_parallel_region)
|
2023-03-10 09:58:21 -08:00
|
|
|
from cacheflow.sampling_params import SamplingParams
|
|
|
|
from cacheflow.sequence import SequenceOutputs
|
2023-02-23 09:26:09 +00:00
|
|
|
|
2023-03-26 08:00:39 +00:00
|
|
|
|
2023-02-23 09:26:09 +00:00
|
|
|
class Sampler(nn.Module):
|
2023-05-14 22:32:38 -07:00
|
|
|
"""Samples the next tokens from the model's outputs.
|
|
|
|
|
|
|
|
This layer does the following:
|
|
|
|
1. Discard the hidden states that are not used for sampling (i.e., all
|
|
|
|
tokens except the final one in each prompt).
|
|
|
|
2. Compute the logits for the next tokens.
|
|
|
|
3. Apply presence and frequency penalties.
|
|
|
|
4. Apply temperature scaling.
|
|
|
|
5. Apply top-p and top-k truncation.
|
|
|
|
6. Sample the next tokens.
|
|
|
|
Here, each sequence group within the batch can have different sampling
|
|
|
|
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
|
|
|
"""
|
2023-02-23 09:26:09 +00:00
|
|
|
|
2023-05-04 02:59:56 -07:00
|
|
|
def __init__(self, vocab_size: int) -> None:
|
2023-03-30 11:04:21 -07:00
|
|
|
super().__init__()
|
2023-05-04 02:59:56 -07:00
|
|
|
self.vocab_size = vocab_size
|
2023-02-23 09:26:09 +00:00
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
2023-02-23 20:30:12 +00:00
|
|
|
embedding: torch.Tensor,
|
2023-02-23 09:26:09 +00:00
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
input_metadata: InputMetadata,
|
2023-03-10 09:58:21 -08:00
|
|
|
) -> Dict[int, SequenceOutputs]:
|
|
|
|
# Get the hidden states that we use for sampling.
|
|
|
|
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
2023-02-23 09:26:09 +00:00
|
|
|
|
|
|
|
# Get the logits for the next tokens.
|
2023-02-23 20:30:12 +00:00
|
|
|
logits = torch.matmul(hidden_states, embedding.t())
|
2023-03-22 04:45:42 +08:00
|
|
|
logits = gather_from_tensor_model_parallel_region(logits)
|
2023-05-09 15:30:12 -07:00
|
|
|
# Remove paddings in vocab (if any).
|
2023-05-04 02:59:56 -07:00
|
|
|
logits = logits[:, :self.vocab_size]
|
2023-02-23 09:26:09 +00:00
|
|
|
|
2023-05-10 23:39:12 -07:00
|
|
|
# Apply presence and frequency penalties.
|
|
|
|
output_tokens = _get_output_tokens(input_metadata)
|
|
|
|
assert len(output_tokens) == logits.shape[0]
|
|
|
|
presence_penalties, frequency_penalties = _get_penalties(input_metadata)
|
|
|
|
assert len(presence_penalties) == logits.shape[0]
|
|
|
|
assert len(frequency_penalties) == logits.shape[0]
|
|
|
|
logits = _apply_penalties(
|
|
|
|
logits, output_tokens, presence_penalties, frequency_penalties,
|
|
|
|
self.vocab_size)
|
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
# Apply temperature scaling.
|
|
|
|
temperatures = _get_temperatures(input_metadata)
|
|
|
|
assert len(temperatures) == logits.shape[0]
|
|
|
|
if any(t != 1.0 for t in temperatures):
|
|
|
|
t = torch.tensor(
|
|
|
|
temperatures, dtype=logits.dtype, device=logits.device)
|
|
|
|
# Use in-place division to avoid creating a new tensor.
|
|
|
|
logits.div_(t.unsqueeze(dim=1))
|
|
|
|
|
2023-03-31 23:33:43 -07:00
|
|
|
# We use float32 for probabilities and log probabilities.
|
2023-03-10 09:58:21 -08:00
|
|
|
# Compute the probabilities.
|
|
|
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
2023-05-10 23:39:12 -07:00
|
|
|
# Compute the log probabilities (before applying top-p and top-k).
|
2023-03-31 23:33:43 -07:00
|
|
|
logprobs = torch.log(probs)
|
2023-03-10 09:58:21 -08:00
|
|
|
|
2023-05-10 12:51:36 -07:00
|
|
|
# Apply top-p and top-k truncation.
|
|
|
|
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
|
|
|
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
2023-05-19 11:35:44 -06:00
|
|
|
if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks):
|
2023-05-10 23:39:12 -07:00
|
|
|
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
2023-03-10 09:58:21 -08:00
|
|
|
|
2023-02-23 09:26:09 +00:00
|
|
|
# Sample the next tokens.
|
2023-03-10 09:58:21 -08:00
|
|
|
return _sample(probs, logprobs, input_metadata)
|
|
|
|
|
|
|
|
|
|
|
|
def _prune_hidden_states(
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
input_metadata: InputMetadata,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
start_idx = 0
|
|
|
|
last_token_indicies: List[int] = []
|
|
|
|
for prompt_len in input_metadata.prompt_lens:
|
|
|
|
last_token_indicies.append(start_idx + prompt_len - 1)
|
|
|
|
start_idx += prompt_len
|
|
|
|
last_token_indicies.extend(
|
|
|
|
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
|
|
|
return hidden_states[last_token_indicies]
|
|
|
|
|
|
|
|
|
2023-05-10 23:39:12 -07:00
|
|
|
def _get_penalties(
|
|
|
|
input_metadata: InputMetadata,
|
|
|
|
) -> Tuple[List[float], List[float]]:
|
|
|
|
# Collect the presence and frequency penalties.
|
|
|
|
presence_penalties: List[float] = []
|
|
|
|
frequency_penalties: List[float] = []
|
|
|
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
|
|
seq_ids, sampling_params = seq_group
|
|
|
|
p = sampling_params.presence_penalty
|
|
|
|
f = sampling_params.frequency_penalty
|
|
|
|
if i < input_metadata.num_prompts:
|
|
|
|
# A prompt input.
|
|
|
|
presence_penalties.append(p)
|
|
|
|
frequency_penalties.append(f)
|
|
|
|
else:
|
|
|
|
# A generation token.
|
|
|
|
presence_penalties += [p] * len(seq_ids)
|
|
|
|
frequency_penalties += [f] * len(seq_ids)
|
|
|
|
return presence_penalties, frequency_penalties
|
|
|
|
|
|
|
|
|
|
|
|
def _get_output_tokens(
|
|
|
|
input_metadata: InputMetadata,
|
|
|
|
) -> List[List[int]]:
|
|
|
|
output_tokens: List[List[int]] = []
|
|
|
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
|
|
seq_ids, _ = seq_group
|
|
|
|
if i < input_metadata.num_prompts:
|
|
|
|
# A prompt input.
|
|
|
|
# NOTE: While the prompt input usually has no output tokens,
|
|
|
|
# it may have output tokens in the case of recomputation.
|
|
|
|
seq_id = seq_ids[0]
|
|
|
|
seq_data = input_metadata.seq_data[seq_id]
|
|
|
|
output_tokens.append(seq_data.output_token_ids)
|
|
|
|
else:
|
|
|
|
# A generation token.
|
|
|
|
for seq_id in seq_ids:
|
|
|
|
seq_data = input_metadata.seq_data[seq_id]
|
|
|
|
output_tokens.append(seq_data.output_token_ids)
|
|
|
|
return output_tokens
|
|
|
|
|
|
|
|
|
|
|
|
def _apply_penalties(
|
|
|
|
logits: torch.Tensor,
|
|
|
|
output_tokens: List[List[int]],
|
|
|
|
presence_penalties: List[float],
|
|
|
|
frequency_penalties: List[float],
|
|
|
|
vocab_size: int,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
num_seqs = logits.shape[0]
|
|
|
|
# Collect the indices of sequences that have non-zero penalties.
|
|
|
|
indices = []
|
|
|
|
for i in range(num_seqs):
|
|
|
|
if not output_tokens[i]:
|
|
|
|
continue
|
|
|
|
p = presence_penalties[i]
|
|
|
|
f = frequency_penalties[i]
|
|
|
|
if p == 0.0 and f == 0.0:
|
|
|
|
continue
|
|
|
|
indices.append(i)
|
|
|
|
|
|
|
|
# Return early if all sequences have zero penalties.
|
|
|
|
if not indices:
|
|
|
|
return logits
|
|
|
|
|
|
|
|
bin_counts = []
|
|
|
|
for i in indices:
|
|
|
|
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
|
|
|
|
bin_counts = np.stack(bin_counts, axis=0)
|
|
|
|
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
|
|
|
|
device=logits.device)
|
|
|
|
|
|
|
|
frequency_penalties = [frequency_penalties[i] for i in indices]
|
|
|
|
frequency_penalties = torch.tensor(
|
|
|
|
frequency_penalties, dtype=logits.dtype, device=logits.device)
|
|
|
|
presence_penalties = [presence_penalties[i] for i in indices]
|
|
|
|
presence_penalties = torch.tensor(
|
|
|
|
presence_penalties, dtype=logits.dtype, device=logits.device)
|
|
|
|
|
|
|
|
# We follow the definition in OpenAI API.
|
|
|
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
|
|
|
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
|
|
|
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
|
|
|
|
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
def _get_temperatures(
|
|
|
|
input_metadata: InputMetadata,
|
|
|
|
) -> List[float]:
|
|
|
|
# Collect the temperatures for the logits.
|
|
|
|
temperatures: List[float] = []
|
|
|
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
|
|
seq_ids, sampling_params = seq_group
|
|
|
|
temperature = sampling_params.temperature
|
|
|
|
if temperature == 0.0:
|
|
|
|
# NOTE: Zero temperature means deterministic sampling
|
|
|
|
# (i.e., greedy sampling or beam search).
|
|
|
|
# Set the temperature to 1 to avoid division by zero.
|
|
|
|
temperature = 1.0
|
|
|
|
|
|
|
|
if i < input_metadata.num_prompts:
|
|
|
|
# A prompt input.
|
|
|
|
temperatures.append(temperature)
|
|
|
|
else:
|
|
|
|
# A generation token.
|
|
|
|
temperatures += [temperature] * len(seq_ids)
|
|
|
|
return temperatures
|
|
|
|
|
|
|
|
|
2023-05-10 12:51:36 -07:00
|
|
|
def _get_top_p_top_k(
|
2023-03-10 09:58:21 -08:00
|
|
|
input_metadata: InputMetadata,
|
2023-05-10 12:51:36 -07:00
|
|
|
vocab_size: int,
|
|
|
|
) -> Tuple[List[float], List[int]]:
|
2023-03-10 09:58:21 -08:00
|
|
|
top_ps: List[float] = []
|
2023-05-10 12:51:36 -07:00
|
|
|
top_ks: List[int] = []
|
2023-03-10 09:58:21 -08:00
|
|
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
|
|
seq_ids, sampling_params = seq_group
|
2023-05-10 12:51:36 -07:00
|
|
|
top_p = sampling_params.top_p
|
|
|
|
# k should not be greater than the vocab size.
|
|
|
|
top_k = min(sampling_params.top_k, vocab_size)
|
|
|
|
# k=-1 means no truncation.
|
|
|
|
top_k = vocab_size if top_k == -1 else top_k
|
2023-03-10 09:58:21 -08:00
|
|
|
if i < input_metadata.num_prompts:
|
|
|
|
# A prompt input.
|
2023-05-10 12:51:36 -07:00
|
|
|
top_ps.append(top_p)
|
|
|
|
top_ks.append(top_k)
|
2023-03-10 09:58:21 -08:00
|
|
|
else:
|
|
|
|
# A generation token.
|
2023-05-10 12:51:36 -07:00
|
|
|
top_ps += [top_p] * len(seq_ids)
|
|
|
|
top_ks += [top_k] * len(seq_ids)
|
|
|
|
return top_ps, top_ks
|
2023-03-10 09:58:21 -08:00
|
|
|
|
|
|
|
|
2023-05-10 12:51:36 -07:00
|
|
|
def _apply_top_p_top_k(
|
2023-03-10 09:58:21 -08:00
|
|
|
probs: torch.Tensor,
|
2023-05-10 23:39:12 -07:00
|
|
|
top_ps: List[float],
|
|
|
|
top_ks: List[int],
|
2023-03-10 09:58:21 -08:00
|
|
|
) -> torch.Tensor:
|
2023-05-10 23:39:12 -07:00
|
|
|
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
|
|
|
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
2023-03-10 09:58:21 -08:00
|
|
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
2023-05-10 12:51:36 -07:00
|
|
|
|
|
|
|
# Apply top-p.
|
2023-03-10 09:58:21 -08:00
|
|
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
2023-05-10 12:51:36 -07:00
|
|
|
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
|
|
|
probs_sort[top_p_mask] = 0.0
|
|
|
|
|
|
|
|
# Apply top-k.
|
|
|
|
# Create a mask for the top-k elements.
|
|
|
|
top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
|
|
|
|
top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1)
|
|
|
|
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
|
|
|
|
probs_sort[top_k_mask] = 0.0
|
|
|
|
|
|
|
|
# Re-sort the probabilities.
|
2023-03-10 09:58:21 -08:00
|
|
|
probs = torch.gather(
|
|
|
|
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
|
|
|
|
return probs
|
|
|
|
|
|
|
|
|
|
|
|
def _get_topk_logprobs(
|
|
|
|
logprobs: torch.Tensor,
|
|
|
|
num_logprobs: int,
|
|
|
|
) -> Dict[int, float]:
|
|
|
|
if num_logprobs == 0:
|
|
|
|
return {}
|
|
|
|
|
|
|
|
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
|
|
|
|
if num_logprobs == 1:
|
|
|
|
topk_logprobs = [topk_logprobs.item()]
|
|
|
|
topk_ids = [topk_ids.item()]
|
|
|
|
else:
|
|
|
|
topk_logprobs = topk_logprobs.tolist()
|
|
|
|
topk_ids = topk_ids.tolist()
|
|
|
|
|
|
|
|
token_to_logprob: Dict[int, float] = {}
|
|
|
|
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
|
|
|
token_to_logprob[token_id] = logprob
|
|
|
|
return token_to_logprob
|
|
|
|
|
|
|
|
|
|
|
|
def _sample_from_prompt(
|
|
|
|
prob: torch.Tensor,
|
|
|
|
sampling_params: SamplingParams,
|
|
|
|
) -> List[int]:
|
|
|
|
if sampling_params.use_beam_search:
|
|
|
|
# Beam search.
|
2023-05-21 11:18:00 -07:00
|
|
|
beam_width = sampling_params.best_of
|
2023-03-10 09:58:21 -08:00
|
|
|
_, next_token_ids = torch.topk(prob, beam_width)
|
|
|
|
next_token_ids = next_token_ids.tolist()
|
|
|
|
elif sampling_params.temperature == 0.0:
|
|
|
|
# Greedy sampling.
|
2023-05-21 11:18:00 -07:00
|
|
|
assert sampling_params.best_of == 1
|
2023-03-10 09:58:21 -08:00
|
|
|
next_token_id = torch.argmax(prob)
|
|
|
|
next_token_ids = [next_token_id.item()]
|
|
|
|
else:
|
2023-05-10 12:51:36 -07:00
|
|
|
# Random sampling.
|
2023-05-21 11:18:00 -07:00
|
|
|
# Sample `best_of` tokens for the prompt.
|
|
|
|
num_seqs = sampling_params.best_of
|
2023-03-10 09:58:21 -08:00
|
|
|
next_token_ids = torch.multinomial(
|
2023-05-21 11:18:00 -07:00
|
|
|
prob, num_samples=num_seqs, replacement=True)
|
2023-02-23 09:26:09 +00:00
|
|
|
next_token_ids = next_token_ids.tolist()
|
2023-03-10 09:58:21 -08:00
|
|
|
return next_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
def _sample_from_generation_tokens(
|
|
|
|
seq_ids: List[int],
|
|
|
|
probs: torch.Tensor,
|
|
|
|
logprobs: torch.Tensor,
|
|
|
|
seq_logprobs: List[float],
|
|
|
|
sampling_params: SamplingParams,
|
|
|
|
) -> Tuple[List[int], List[int]]:
|
2023-05-21 11:18:00 -07:00
|
|
|
# NOTE(woosuk): sampling_params.best_of can be greater than
|
2023-03-10 09:58:21 -08:00
|
|
|
# len(seq_ids) because some sequences in the group might have
|
|
|
|
# been already terminated.
|
|
|
|
if sampling_params.use_beam_search:
|
|
|
|
# Beam search.
|
|
|
|
# Add cumulative logprobs for the sequences in the group.
|
|
|
|
seq_logprobs = torch.tensor(
|
|
|
|
seq_logprobs, dtype=torch.float, device=logprobs.device)
|
|
|
|
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
|
|
|
|
|
|
|
|
vocab_size = logprobs.size(-1)
|
|
|
|
beam_width = len(seq_ids)
|
|
|
|
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
|
2023-04-07 17:45:07 -07:00
|
|
|
topk_ids = topk_ids.tolist()
|
|
|
|
seq_idx = [i // vocab_size for i in topk_ids]
|
2023-03-10 09:58:21 -08:00
|
|
|
beam_seq_ids = [seq_ids[i] for i in seq_idx]
|
2023-04-07 17:45:07 -07:00
|
|
|
token_ids = [i % vocab_size for i in topk_ids]
|
2023-03-10 09:58:21 -08:00
|
|
|
|
|
|
|
beam_outputs: Dict[int, Tuple[int, int]] = {}
|
|
|
|
outstanding_beams: List[Tuple[int, int]] = []
|
|
|
|
# If a beam survives, continue with it.
|
|
|
|
for seq_id, token_id in zip(beam_seq_ids, token_ids):
|
|
|
|
if seq_id not in beam_outputs:
|
|
|
|
beam_outputs[seq_id] = (seq_id, token_id)
|
|
|
|
else:
|
|
|
|
outstanding_beams.append((seq_id, token_id))
|
|
|
|
|
|
|
|
# If a beam is discarded, fork another beam.
|
|
|
|
for seq_id in seq_ids:
|
|
|
|
if seq_id not in beam_outputs:
|
|
|
|
beam_outputs[seq_id] = outstanding_beams.pop()
|
|
|
|
assert not outstanding_beams
|
|
|
|
|
|
|
|
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
|
|
|
|
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
|
|
|
|
elif sampling_params.temperature == 0.0:
|
|
|
|
# Greedy sampling.
|
|
|
|
assert len(seq_ids) == 1
|
|
|
|
next_token_id = torch.argmax(probs, dim=-1)
|
2023-05-23 17:58:51 -07:00
|
|
|
next_token_ids = [int(next_token_id.item())]
|
2023-03-10 09:58:21 -08:00
|
|
|
parent_seq_ids = seq_ids
|
|
|
|
else:
|
2023-05-10 12:51:36 -07:00
|
|
|
# Random sampling.
|
2023-03-10 09:58:21 -08:00
|
|
|
# Sample 1 token for each sequence in the group.
|
|
|
|
next_token_ids = torch.multinomial(
|
|
|
|
probs, num_samples=1, replacement=True)
|
|
|
|
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
|
|
|
|
parent_seq_ids = seq_ids
|
|
|
|
return parent_seq_ids, next_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
def _sample(
|
|
|
|
probs: torch.Tensor,
|
|
|
|
logprobs: torch.Tensor,
|
|
|
|
input_metadata: InputMetadata,
|
|
|
|
) -> Dict[int, SequenceOutputs]:
|
|
|
|
seq_outputs: Dict[int, SequenceOutputs] = {}
|
|
|
|
|
|
|
|
# TODO(woosuk): Optimize.
|
|
|
|
idx = 0
|
|
|
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
|
|
seq_ids, sampling_params = seq_group
|
|
|
|
if i < input_metadata.num_prompts:
|
|
|
|
# Generate the next tokens for a prompt input.
|
2023-05-21 11:18:00 -07:00
|
|
|
assert len(seq_ids) == sampling_params.best_of
|
2023-03-10 09:58:21 -08:00
|
|
|
prob = probs[idx]
|
|
|
|
logprob = logprobs[idx]
|
|
|
|
idx += 1
|
|
|
|
|
|
|
|
# Sample the next tokens.
|
|
|
|
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
|
|
|
# Get top-k log probabilities for the next tokens.
|
|
|
|
next_logprobs = _get_topk_logprobs(
|
2023-05-11 15:45:30 -07:00
|
|
|
logprob, sampling_params.logprobs)
|
2023-03-10 09:58:21 -08:00
|
|
|
|
|
|
|
# Build the output.
|
|
|
|
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
|
|
|
|
output_logprobs = next_logprobs.copy()
|
|
|
|
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
|
|
|
seq_outputs[seq_id] = SequenceOutputs(
|
|
|
|
seq_id, seq_id, next_token_id, output_logprobs)
|
|
|
|
else:
|
|
|
|
# Generate the next tokens for generation tokens.
|
|
|
|
prob = probs[idx:idx + len(seq_ids)]
|
|
|
|
logprob = logprobs[idx:idx + len(seq_ids)]
|
|
|
|
idx += len(seq_ids)
|
|
|
|
|
|
|
|
# Sample the next tokens.
|
|
|
|
seq_logprobs = [
|
2023-05-21 11:18:00 -07:00
|
|
|
input_metadata.seq_data[seq_id].cumulative_logprob
|
2023-05-10 23:39:12 -07:00
|
|
|
for seq_id in seq_ids]
|
2023-03-10 09:58:21 -08:00
|
|
|
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
|
|
|
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
|
|
|
|
|
|
|
# Get top-k log probabilities for the next tokens.
|
|
|
|
next_logprobs: Dict[int, Dict[int, float]] = {}
|
|
|
|
for i, seq_id in enumerate(seq_ids):
|
|
|
|
next_logprobs[seq_id] = _get_topk_logprobs(
|
2023-05-11 15:45:30 -07:00
|
|
|
logprob[i], sampling_params.logprobs)
|
2023-03-10 09:58:21 -08:00
|
|
|
|
|
|
|
# Build the output.
|
|
|
|
for seq_id, parent_seq_id, next_token_id in zip(
|
|
|
|
seq_ids, parent_seq_ids, next_token_ids):
|
|
|
|
i = seq_ids.index(parent_seq_id)
|
|
|
|
output_logprobs = next_logprobs[parent_seq_id].copy()
|
|
|
|
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
|
|
|
|
seq_outputs[seq_id] = SequenceOutputs(
|
|
|
|
seq_id,
|
|
|
|
parent_seq_id,
|
|
|
|
next_token_id,
|
|
|
|
output_logprobs,
|
|
|
|
)
|
2023-02-23 09:26:09 +00:00
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
return seq_outputs
|