2023-05-14 22:32:38 -07:00
|
|
|
"""A layer that samples the next tokens from the model's outputs."""
|
2024-03-25 11:14:26 -06:00
|
|
|
import itertools
|
2024-07-30 00:47:31 +08:00
|
|
|
from math import inf
|
2023-09-22 17:48:04 -07:00
|
|
|
from typing import Dict, List, Optional, Tuple
|
2023-02-23 09:26:09 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
2024-07-29 23:51:27 +02:00
|
|
|
from vllm.triton_utils import HAS_TRITON
|
|
|
|
|
|
|
|
if HAS_TRITON:
|
|
|
|
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
|
|
|
|
2024-03-10 19:49:14 -07:00
|
|
|
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
2024-04-26 22:02:02 +09:00
|
|
|
SamplingTensors,
|
|
|
|
SequenceGroupToSample)
|
|
|
|
from vllm.sampling_params import SamplingType
|
2024-05-11 11:30:37 -07:00
|
|
|
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
|
|
|
PromptLogprobs, SampleLogprobs, SamplerOutput,
|
|
|
|
SequenceOutput)
|
2023-02-23 09:26:09 +00:00
|
|
|
|
2024-04-29 11:01:26 +09:00
|
|
|
# (num_token_ids, num_parent_ids) per sequence group.
|
|
|
|
SampleResultType = List[Tuple[List[int], List[int]]]
|
|
|
|
|
2023-07-03 11:31:55 -07:00
|
|
|
|
2023-02-23 09:26:09 +00:00
|
|
|
class Sampler(nn.Module):
|
2023-05-14 22:32:38 -07:00
|
|
|
"""Samples the next tokens from the model's outputs.
|
|
|
|
|
|
|
|
This layer does the following:
|
|
|
|
1. Discard the hidden states that are not used for sampling (i.e., all
|
|
|
|
tokens except the final one in each prompt).
|
|
|
|
2. Compute the logits for the next tokens.
|
2023-11-23 06:41:44 +08:00
|
|
|
3. Apply presence, frequency and repetition penalties.
|
2023-05-14 22:32:38 -07:00
|
|
|
4. Apply temperature scaling.
|
|
|
|
5. Apply top-p and top-k truncation.
|
|
|
|
6. Sample the next tokens.
|
|
|
|
Here, each sequence group within the batch can have different sampling
|
|
|
|
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
2024-04-10 02:39:56 -06:00
|
|
|
|
|
|
|
The structure of the logits tensor is coupled with the seq_groups in
|
|
|
|
sampling_metadata. Typically, each sequence in each seq_group has one row in
|
|
|
|
logits for the next token to be sampled; however, for a seq_group with a
|
|
|
|
prompt request with the prompt_logprobs sampling parameter, there are rows
|
|
|
|
in logits for each token in the input prompt.
|
2023-05-14 22:32:38 -07:00
|
|
|
"""
|
2023-02-23 09:26:09 +00:00
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
# Whether or not the SamplerOutput should have on-device tensors
|
|
|
|
# containing the sampled token ids and probabilities. This is used by
|
|
|
|
# speculative decoding.
|
|
|
|
self.include_gpu_probs_tensor = False
|
2024-08-08 22:42:45 -07:00
|
|
|
self.should_modify_greedy_probs_inplace = False
|
2024-04-23 01:02:36 -07:00
|
|
|
|
2024-07-17 17:30:28 -04:00
|
|
|
def _init_sampling_tensors(
|
|
|
|
self,
|
|
|
|
logits: torch.Tensor,
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
):
|
|
|
|
"""The goal here is to reuse sampling tensors between similar decode
|
|
|
|
runs. This is possible because sampling logic does not change between
|
|
|
|
decodes of the same sequences.
|
|
|
|
"""
|
|
|
|
_, vocab_size = logits.shape
|
|
|
|
|
|
|
|
# First free any existing stored sampling tensors.
|
|
|
|
# This is necessary because some sampling tensors may
|
|
|
|
# have pinned memory.
|
|
|
|
self._sampling_tensors = None
|
|
|
|
|
|
|
|
# Initialize new sampling tensors
|
|
|
|
(sampling_tensors, do_penalties, do_top_p_top_k,
|
|
|
|
do_min_p) = SamplingTensors.from_sampling_metadata(
|
|
|
|
sampling_metadata, vocab_size, logits.device, logits.dtype)
|
|
|
|
|
|
|
|
self._sampling_tensors = sampling_tensors
|
|
|
|
self._do_penalties = do_penalties
|
|
|
|
self._do_top_p_top_k = do_top_p_top_k
|
|
|
|
self._do_min_p = do_min_p
|
|
|
|
|
2023-02-23 09:26:09 +00:00
|
|
|
def forward(
|
|
|
|
self,
|
2024-03-21 07:25:01 +08:00
|
|
|
logits: torch.Tensor,
|
2023-11-29 22:16:37 -08:00
|
|
|
sampling_metadata: SamplingMetadata,
|
2024-01-04 03:30:22 +08:00
|
|
|
) -> Optional[SamplerOutput]:
|
2024-04-26 22:02:02 +09:00
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
logits: (num_tokens, vocab_size).
|
|
|
|
sampling_metadata: Metadata for sampling.
|
|
|
|
"""
|
2024-01-04 03:30:22 +08:00
|
|
|
assert logits is not None
|
2023-12-17 07:03:49 -08:00
|
|
|
_, vocab_size = logits.shape
|
|
|
|
|
2023-12-20 00:04:33 -08:00
|
|
|
# Prepare sampling tensors with pinned memory to avoid blocking.
|
2024-07-17 17:30:28 -04:00
|
|
|
if not sampling_metadata.reuse_sampling_tensors:
|
|
|
|
self._init_sampling_tensors(logits, sampling_metadata)
|
|
|
|
elif self._do_penalties:
|
|
|
|
# In this case, the sampling tensors logic depends on
|
|
|
|
# "output_tokens" of a sequence. As a result, we cannot
|
|
|
|
# reuse sampling tensors, since "output_tokens" changes
|
|
|
|
# between decode runs.
|
|
|
|
self._init_sampling_tensors(logits, sampling_metadata)
|
|
|
|
|
|
|
|
assert self._sampling_tensors is not None
|
|
|
|
sampling_tensors = self._sampling_tensors
|
|
|
|
do_penalties = self._do_penalties
|
|
|
|
do_top_p_top_k = self._do_top_p_top_k
|
|
|
|
do_min_p = self._do_min_p
|
|
|
|
|
|
|
|
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
2023-12-17 07:03:49 -08:00
|
|
|
|
2023-05-10 23:39:12 -07:00
|
|
|
# Apply presence and frequency penalties.
|
2023-12-17 07:03:49 -08:00
|
|
|
if do_penalties:
|
|
|
|
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
|
|
|
|
sampling_tensors.output_tokens,
|
|
|
|
sampling_tensors.presence_penalties,
|
|
|
|
sampling_tensors.frequency_penalties,
|
|
|
|
sampling_tensors.repetition_penalties)
|
2023-05-10 23:39:12 -07:00
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
# Apply temperature scaling.
|
2023-12-17 07:03:49 -08:00
|
|
|
# Use in-place division to avoid creating a new tensor.
|
2024-07-17 17:30:28 -04:00
|
|
|
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
|
2023-12-17 07:03:49 -08:00
|
|
|
|
|
|
|
if do_top_p_top_k:
|
2024-01-13 05:51:03 +08:00
|
|
|
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
2023-12-17 07:03:49 -08:00
|
|
|
sampling_tensors.top_ks)
|
|
|
|
|
2023-11-18 08:20:49 +08:00
|
|
|
if do_min_p:
|
2023-12-17 07:03:49 -08:00
|
|
|
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
2023-11-18 08:20:49 +08:00
|
|
|
|
2023-08-16 07:44:33 +08:00
|
|
|
# We use float32 for probabilities and log probabilities.
|
|
|
|
# Compute the probabilities.
|
|
|
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
2023-09-13 16:38:12 -07:00
|
|
|
# Compute the log probabilities.
|
|
|
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
2023-03-10 09:58:21 -08:00
|
|
|
|
2023-02-23 09:26:09 +00:00
|
|
|
# Sample the next tokens.
|
2024-04-23 01:02:36 -07:00
|
|
|
sample_results, maybe_sampled_tokens_tensor = _sample(
|
|
|
|
probs,
|
|
|
|
logprobs,
|
|
|
|
sampling_metadata,
|
|
|
|
sampling_tensors,
|
|
|
|
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
|
|
|
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.include_gpu_probs_tensor:
|
|
|
|
assert maybe_sampled_tokens_tensor is not None
|
2024-05-03 15:52:01 -07:00
|
|
|
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
|
2024-04-23 01:02:36 -07:00
|
|
|
else:
|
|
|
|
on_device_tensors = None
|
|
|
|
|
2023-10-16 10:56:50 -07:00
|
|
|
# Get the logprobs query results.
|
2024-07-17 17:30:28 -04:00
|
|
|
prompt_logprobs = None
|
|
|
|
sample_logprobs = None
|
|
|
|
if not sampling_metadata.skip_sampler_cpu_output:
|
|
|
|
prompt_logprobs, sample_logprobs = _get_logprobs(
|
|
|
|
logprobs, sampling_metadata, sample_results)
|
|
|
|
|
|
|
|
return _build_sampler_output(
|
|
|
|
sample_results,
|
|
|
|
sampling_metadata,
|
|
|
|
prompt_logprobs,
|
|
|
|
sample_logprobs,
|
|
|
|
on_device_tensors=on_device_tensors,
|
|
|
|
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
|
2024-04-23 01:02:36 -07:00
|
|
|
|
|
|
|
@property
|
|
|
|
def _should_modify_greedy_probs_inplace(self) -> bool:
|
|
|
|
"""Whether or not the sampler should modify the probability distribution
|
|
|
|
of greedily-sampled tokens such that multinomial sampling would sample
|
|
|
|
the greedily-sampled token.
|
|
|
|
|
|
|
|
In other words, if True then we set the probability of the greedily-
|
|
|
|
sampled token to 1.
|
|
|
|
|
|
|
|
This is used by speculative decoding, which requires that the sampling
|
|
|
|
method be encoded into the probability distribution.
|
|
|
|
"""
|
2024-08-08 22:42:45 -07:00
|
|
|
return self.should_modify_greedy_probs_inplace
|
2023-03-10 09:58:21 -08:00
|
|
|
|
|
|
|
|
2023-11-23 06:41:44 +08:00
|
|
|
def _get_bin_counts_and_mask(
|
2023-12-17 07:03:49 -08:00
|
|
|
tokens: torch.Tensor,
|
2023-11-23 06:41:44 +08:00
|
|
|
vocab_size: int,
|
|
|
|
num_seqs: int,
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
# Compute the bin counts for the tokens.
|
|
|
|
# vocab_size + 1 for padding.
|
|
|
|
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
|
|
|
dtype=torch.long,
|
2023-12-17 07:03:49 -08:00
|
|
|
device=tokens.device)
|
|
|
|
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
2023-11-23 06:41:44 +08:00
|
|
|
bin_counts = bin_counts[:, :vocab_size]
|
|
|
|
mask = bin_counts > 0
|
|
|
|
|
|
|
|
return bin_counts, mask
|
2023-05-10 23:39:12 -07:00
|
|
|
|
|
|
|
|
2024-03-25 11:14:26 -06:00
|
|
|
def _apply_min_tokens_penalty(
|
|
|
|
logits: torch.Tensor,
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
) -> torch.Tensor:
|
2024-04-26 22:02:02 +09:00
|
|
|
"""Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
|
|
|
|
have not been generated yet
|
|
|
|
"""
|
2024-03-25 11:14:26 -06:00
|
|
|
# list of indices in logits that will be set to -inf
|
2024-04-29 11:01:26 +09:00
|
|
|
logits_to_penalize: List[Tuple[int, int]] = []
|
2024-04-26 22:02:02 +09:00
|
|
|
logits_applied = 0
|
|
|
|
for seq_group in sampling_metadata.seq_groups:
|
|
|
|
seq_ids = seq_group.seq_ids
|
|
|
|
sampling_params = seq_group.sampling_params
|
|
|
|
|
|
|
|
sample_indices = seq_group.sample_indices
|
|
|
|
logits_applied += len(sample_indices) + len(
|
|
|
|
seq_group.prompt_logprob_indices)
|
|
|
|
if not seq_group.do_sample:
|
|
|
|
continue
|
2024-04-10 02:39:56 -06:00
|
|
|
|
2024-04-26 22:02:02 +09:00
|
|
|
start_idx = sample_indices[0]
|
2024-03-25 11:14:26 -06:00
|
|
|
min_tokens = sampling_params.min_tokens
|
2024-04-27 09:52:46 -07:00
|
|
|
token_ids_to_penalize = sampling_params.all_stop_token_ids
|
|
|
|
if min_tokens > 0 and token_ids_to_penalize:
|
2024-06-15 12:45:31 +08:00
|
|
|
seqs_to_penalize: List[int] = []
|
2024-04-27 09:52:46 -07:00
|
|
|
for j, seq_id in enumerate(seq_ids):
|
2024-04-26 22:02:02 +09:00
|
|
|
seq_data = seq_group.seq_data[seq_id]
|
2024-07-26 12:31:31 +08:00
|
|
|
if len(seq_data.output_token_ids_array) < min_tokens:
|
2024-04-27 09:52:46 -07:00
|
|
|
seqs_to_penalize.append(j)
|
2024-03-25 11:14:26 -06:00
|
|
|
|
|
|
|
if seqs_to_penalize:
|
|
|
|
# convert to the index into logits
|
2024-04-27 09:52:46 -07:00
|
|
|
seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
|
2024-03-25 11:14:26 -06:00
|
|
|
# itertools.product pairs each seq index with every token id
|
|
|
|
logits_to_penalize.extend(
|
|
|
|
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
|
|
|
|
|
|
|
if logits_to_penalize:
|
|
|
|
# use zip and * to group indices along each dimension
|
|
|
|
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
|
|
|
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
|
|
|
|
2024-04-10 02:39:56 -06:00
|
|
|
# verifies that no rows in logits were missed unexpectedly
|
2024-04-26 22:02:02 +09:00
|
|
|
assert logits_applied == logits.shape[0]
|
2024-03-25 11:14:26 -06:00
|
|
|
return logits
|
|
|
|
|
|
|
|
|
2023-12-17 07:03:49 -08:00
|
|
|
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
|
|
|
output_tokens_tensor: torch.Tensor,
|
|
|
|
presence_penalties: torch.Tensor,
|
|
|
|
frequency_penalties: torch.Tensor,
|
|
|
|
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
2023-09-22 17:48:04 -07:00
|
|
|
num_seqs, vocab_size = logits.shape
|
2023-12-17 07:03:49 -08:00
|
|
|
_, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
|
|
|
|
num_seqs)
|
2023-11-23 06:41:44 +08:00
|
|
|
output_bin_counts, output_mask = _get_bin_counts_and_mask(
|
2023-12-17 07:03:49 -08:00
|
|
|
output_tokens_tensor, vocab_size, num_seqs)
|
2023-05-10 23:39:12 -07:00
|
|
|
|
2023-10-30 01:02:41 +08:00
|
|
|
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
2023-11-23 06:41:44 +08:00
|
|
|
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
|
2023-10-30 01:02:41 +08:00
|
|
|
logits = torch.where(logits > 0, logits / repetition_penalties,
|
|
|
|
logits * repetition_penalties)
|
|
|
|
|
2023-05-10 23:39:12 -07:00
|
|
|
# We follow the definition in OpenAI API.
|
|
|
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
2023-12-17 07:03:49 -08:00
|
|
|
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
|
|
|
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
2023-05-10 23:39:12 -07:00
|
|
|
return logits
|
|
|
|
|
|
|
|
|
2024-01-13 05:51:03 +08:00
|
|
|
def _apply_top_k_top_p(
|
2023-08-16 07:44:33 +08:00
|
|
|
logits: torch.Tensor,
|
2023-12-17 07:03:49 -08:00
|
|
|
p: torch.Tensor,
|
|
|
|
k: torch.Tensor,
|
2023-03-10 09:58:21 -08:00
|
|
|
) -> torch.Tensor:
|
2024-01-13 05:51:03 +08:00
|
|
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
|
|
|
|
|
|
|
# Apply top-k.
|
|
|
|
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
|
|
|
# Get all the top_k values.
|
|
|
|
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
|
|
|
top_k_mask = logits_sort < top_k_mask
|
|
|
|
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
2023-05-10 12:51:36 -07:00
|
|
|
|
|
|
|
# Apply top-p.
|
2023-08-16 07:44:33 +08:00
|
|
|
probs_sort = logits_sort.softmax(dim=-1)
|
2024-01-13 05:51:03 +08:00
|
|
|
probs_sum = probs_sort.cumsum(dim=-1)
|
|
|
|
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
|
|
|
# at least one
|
|
|
|
top_p_mask[:, -1] = False
|
|
|
|
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
2023-05-10 12:51:36 -07:00
|
|
|
|
|
|
|
# Re-sort the probabilities.
|
2023-12-17 07:03:49 -08:00
|
|
|
src = torch.arange(logits_idx.shape[-1],
|
|
|
|
device=logits_idx.device).expand_as(logits_idx)
|
|
|
|
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
|
|
|
|
index=logits_idx,
|
|
|
|
src=src)
|
|
|
|
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
|
2023-08-16 07:44:33 +08:00
|
|
|
return logits
|
2023-03-10 09:58:21 -08:00
|
|
|
|
|
|
|
|
2023-11-18 08:20:49 +08:00
|
|
|
def _apply_min_p(
|
|
|
|
logits: torch.Tensor,
|
2023-12-17 07:03:49 -08:00
|
|
|
min_p: torch.Tensor,
|
2023-11-18 08:20:49 +08:00
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Adapted from
|
|
|
|
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
|
|
|
|
"""
|
|
|
|
probs = torch.softmax(logits, dim=-1)
|
|
|
|
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
2023-12-17 07:03:49 -08:00
|
|
|
scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
|
2023-11-18 08:20:49 +08:00
|
|
|
tokens_to_remove = probs < scaled_min_p
|
2023-12-17 07:03:49 -08:00
|
|
|
logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
|
2023-11-18 08:20:49 +08:00
|
|
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
2023-09-22 17:48:04 -07:00
|
|
|
def _greedy_sample(
|
2024-04-26 22:02:02 +09:00
|
|
|
selected_seq_groups: List[SequenceGroupToSample],
|
2023-12-17 07:03:49 -08:00
|
|
|
samples: torch.Tensor,
|
2024-04-29 11:01:26 +09:00
|
|
|
) -> SampleResultType:
|
2024-04-26 22:02:02 +09:00
|
|
|
"""Run greedy sampling on a given samples.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
selected_seq_groups: A list of sequence groups batched.
|
|
|
|
samples: (num_selected_samples,) A tensor of samples. The length of
|
|
|
|
samples could be smaller than selected_seq_groups if
|
|
|
|
seq_group.do_sample is False.
|
|
|
|
Returns:
|
|
|
|
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
|
|
|
same as the length of selected_seq_groups. If the corresponding
|
|
|
|
seq_group has do_sample=False, tuple contains ([], [])
|
|
|
|
"""
|
2024-06-15 12:45:31 +08:00
|
|
|
samples_lst = samples.tolist()
|
2023-09-22 17:48:04 -07:00
|
|
|
sample_idx = 0
|
2024-04-29 11:01:26 +09:00
|
|
|
results: SampleResultType = []
|
2023-09-22 17:48:04 -07:00
|
|
|
for seq_group in selected_seq_groups:
|
2024-04-26 22:02:02 +09:00
|
|
|
if not seq_group.do_sample:
|
|
|
|
results.append(([], []))
|
|
|
|
continue
|
|
|
|
|
|
|
|
seq_ids = seq_group.seq_ids
|
2023-09-22 17:48:04 -07:00
|
|
|
num_parent_seqs = len(seq_ids)
|
|
|
|
assert num_parent_seqs == 1, (
|
|
|
|
"Greedy sampling should have only one seq.")
|
|
|
|
parent_ids = list(range(num_parent_seqs))
|
2024-06-15 12:45:31 +08:00
|
|
|
next_token_ids = [samples_lst[sample_idx]]
|
2023-09-22 17:48:04 -07:00
|
|
|
results.append((next_token_ids, parent_ids))
|
|
|
|
sample_idx += num_parent_seqs
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def _random_sample(
|
2024-04-26 22:02:02 +09:00
|
|
|
selected_seq_groups: List[SequenceGroupToSample],
|
2023-12-17 07:03:49 -08:00
|
|
|
random_samples: torch.Tensor,
|
2024-04-29 11:01:26 +09:00
|
|
|
) -> SampleResultType:
|
2024-04-26 22:02:02 +09:00
|
|
|
"""Run random sampling on a given samples.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
selected_seq_groups: A list of sequence groups batched.
|
|
|
|
random_samples: (num_selected_samples,) A tensor of samples. The
|
|
|
|
length of samples could be smaller than selected_seq_groups if
|
|
|
|
seq_group.do_sample is False.
|
|
|
|
Returns:
|
|
|
|
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
|
|
|
same as the length of selected_seq_groups. If the corresponding
|
|
|
|
seq_group has do_sample=False, tuple contains ([], [])
|
|
|
|
"""
|
2023-09-22 17:48:04 -07:00
|
|
|
# Find the maximum best_of value of the prompt phase requests.
|
2023-12-17 07:03:49 -08:00
|
|
|
random_samples = random_samples.cpu()
|
2023-09-22 17:48:04 -07:00
|
|
|
sample_idx = 0
|
2024-04-29 11:01:26 +09:00
|
|
|
results: SampleResultType = []
|
2024-04-26 22:02:02 +09:00
|
|
|
for seq_group in selected_seq_groups:
|
|
|
|
if not seq_group.do_sample:
|
|
|
|
results.append(([], []))
|
|
|
|
continue
|
|
|
|
|
|
|
|
seq_ids = seq_group.seq_ids
|
|
|
|
sampling_params = seq_group.sampling_params
|
|
|
|
is_prompt = seq_group.is_prompt
|
2023-09-22 17:48:04 -07:00
|
|
|
num_parent_seqs = len(seq_ids)
|
|
|
|
if is_prompt:
|
|
|
|
# Prompt phase.
|
|
|
|
parent_ids = [0] * sampling_params.best_of
|
|
|
|
next_token_ids = random_samples[
|
|
|
|
sample_idx, :sampling_params.best_of].tolist()
|
|
|
|
else:
|
|
|
|
# Generation phase.
|
|
|
|
parent_ids = list(range(num_parent_seqs))
|
|
|
|
next_token_ids = random_samples[sample_idx:sample_idx +
|
|
|
|
num_parent_seqs, 0].tolist()
|
|
|
|
results.append((next_token_ids, parent_ids))
|
|
|
|
sample_idx += num_parent_seqs
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def _beam_search_sample(
|
2024-04-26 22:02:02 +09:00
|
|
|
selected_seq_groups: List[SequenceGroupToSample],
|
2023-03-10 09:58:21 -08:00
|
|
|
logprobs: torch.Tensor,
|
2024-04-29 11:01:26 +09:00
|
|
|
) -> SampleResultType:
|
2024-04-26 22:02:02 +09:00
|
|
|
"""Run beam sampling on a given samples.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
selected_seq_groups: A list of sequence groups batched.
|
|
|
|
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
|
|
|
|
on selected sample indices.
|
|
|
|
Returns:
|
|
|
|
Tuple of (next_token_ids, parent_ids). The length of returned list is
|
|
|
|
same as the length of selected_seq_groups. If the corresponding
|
|
|
|
seq_group has do_sample=False, tuple contains ([], [])
|
|
|
|
"""
|
2023-09-22 17:48:04 -07:00
|
|
|
# We sample 2 * beam_width candidates to make sure that with high
|
|
|
|
# probability we can get `beam_width` candidates in addition to
|
|
|
|
# the finished sequences for the next iteration. See
|
|
|
|
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
|
|
|
# for details. See also HF reference:
|
|
|
|
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
|
|
|
#
|
2023-10-16 10:56:50 -07:00
|
|
|
# NOTE: Beam search is not vectorized, so its speed can be slower than
|
2023-09-22 17:48:04 -07:00
|
|
|
# other sampling methods.
|
|
|
|
sample_idx = 0
|
2024-04-29 11:01:26 +09:00
|
|
|
results: SampleResultType = []
|
2024-04-26 22:02:02 +09:00
|
|
|
for seq_group in selected_seq_groups:
|
|
|
|
if not seq_group.do_sample:
|
|
|
|
results.append(([], []))
|
|
|
|
continue
|
|
|
|
|
|
|
|
is_prompt = seq_group.is_prompt
|
|
|
|
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
|
2023-09-22 17:48:04 -07:00
|
|
|
num_parent_seqs = len(seq_ids)
|
|
|
|
beam_width = sampling_params.best_of
|
|
|
|
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
|
|
|
if is_prompt:
|
|
|
|
# Prompt phase.
|
|
|
|
assert num_parent_seqs == 1, (
|
|
|
|
"Prompt input should have only one seq.")
|
|
|
|
parent_ids = [0] * (2 * beam_width)
|
|
|
|
_, next_token_ids = torch.topk(seq_group_logprobs[0],
|
|
|
|
2 * beam_width)
|
|
|
|
next_token_ids = next_token_ids.tolist()
|
|
|
|
else:
|
|
|
|
# Generation phase.
|
2024-06-15 12:45:31 +08:00
|
|
|
cumulative_logprobs: List[float] = [
|
2024-04-26 22:02:02 +09:00
|
|
|
seq_group.seq_data[seq_id].cumulative_logprob
|
|
|
|
for seq_id in seq_ids
|
2023-09-22 17:48:04 -07:00
|
|
|
]
|
2024-04-29 11:01:26 +09:00
|
|
|
cumulative_logprobs_tensor = torch.tensor(
|
2023-09-22 17:48:04 -07:00
|
|
|
cumulative_logprobs,
|
|
|
|
dtype=torch.float,
|
|
|
|
device=seq_group_logprobs.device)
|
|
|
|
seq_group_logprobs = (seq_group_logprobs +
|
2024-04-29 11:01:26 +09:00
|
|
|
cumulative_logprobs_tensor.unsqueeze(dim=1))
|
2023-09-22 17:48:04 -07:00
|
|
|
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
|
|
|
2 * beam_width)
|
|
|
|
topk_ids = topk_ids.tolist()
|
|
|
|
vocab_size = seq_group_logprobs.size(-1)
|
|
|
|
parent_ids = [i // vocab_size for i in topk_ids]
|
|
|
|
next_token_ids = [i % vocab_size for i in topk_ids]
|
|
|
|
results.append((next_token_ids, parent_ids))
|
|
|
|
sample_idx += num_parent_seqs
|
|
|
|
assert sample_idx == logprobs.size(0)
|
|
|
|
return results
|
2023-03-10 09:58:21 -08:00
|
|
|
|
|
|
|
|
2023-12-17 07:03:49 -08:00
|
|
|
# torch.multinomial forces a GPU<->CPU sync.
|
|
|
|
# Therefore, we use an optimized implementation instead.
|
|
|
|
# Note that we always sample with replacement.
|
|
|
|
# probs will be modified in place, but this is fine, as we pass
|
|
|
|
# in a copy already.
|
|
|
|
def _multinomial(
|
|
|
|
probs: torch.Tensor,
|
|
|
|
num_samples: int,
|
2024-04-26 22:02:02 +09:00
|
|
|
seq_groups: Optional[List[SequenceGroupToSample]] = None,
|
2024-02-21 11:47:00 -08:00
|
|
|
) -> torch.Tensor:
|
2023-12-17 07:03:49 -08:00
|
|
|
if num_samples > 1:
|
|
|
|
# This is equivalent to torch.repeat_interleaved (which also
|
|
|
|
# forces a GPU<->CPU sync).
|
|
|
|
# This allows us to do sampling with replacement by creating
|
|
|
|
# num_samples copies of each row in the tensor, and then
|
|
|
|
# batch sampling the resulting tensor.
|
|
|
|
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
|
|
|
probs.shape[1]).contiguous().view(
|
|
|
|
-1, probs.shape[1])
|
2024-02-21 11:47:00 -08:00
|
|
|
q = torch.empty_like(probs)
|
|
|
|
if seq_groups is None:
|
|
|
|
q.exponential_()
|
|
|
|
else:
|
|
|
|
sample_idx = 0
|
2024-04-26 22:02:02 +09:00
|
|
|
for seq_group in seq_groups:
|
|
|
|
seq_ids = seq_group.seq_ids
|
2024-02-21 11:47:00 -08:00
|
|
|
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
2024-04-26 22:02:02 +09:00
|
|
|
q[sample_idx:next_sample_idx].exponential_(
|
|
|
|
generator=seq_group.generator)
|
2024-02-21 11:47:00 -08:00
|
|
|
sample_idx = next_sample_idx
|
2023-12-17 07:03:49 -08:00
|
|
|
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
|
|
|
|
|
|
|
|
2024-03-20 14:45:08 -07:00
|
|
|
def _sample_with_torch(
|
2023-03-10 09:58:21 -08:00
|
|
|
probs: torch.Tensor,
|
|
|
|
logprobs: torch.Tensor,
|
2023-11-29 22:16:37 -08:00
|
|
|
sampling_metadata: SamplingMetadata,
|
2024-04-23 01:02:36 -07:00
|
|
|
include_gpu_probs_tensor: bool,
|
|
|
|
modify_greedy_probs: bool,
|
2024-04-29 11:01:26 +09:00
|
|
|
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
|
|
|
categorized_seq_group_ids: Dict[SamplingType,
|
|
|
|
List[int]] = {t: []
|
|
|
|
for t in SamplingType}
|
2023-11-29 22:16:37 -08:00
|
|
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
|
|
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
2024-04-26 22:02:02 +09:00
|
|
|
sampling_params = seq_group.sampling_params
|
2023-09-22 17:48:04 -07:00
|
|
|
sampling_type = sampling_params.sampling_type
|
|
|
|
categorized_seq_group_ids[sampling_type].append(i)
|
2023-10-16 10:56:50 -07:00
|
|
|
|
|
|
|
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
2024-06-15 12:45:31 +08:00
|
|
|
sample_metadata: Dict[SamplingType,
|
|
|
|
Tuple[List[int], List[SequenceGroupToSample]]] = {}
|
|
|
|
multinomial_samples: Dict[SamplingType, torch.Tensor] = {}
|
2023-12-17 07:03:49 -08:00
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
# Create output tensor for sampled token ids.
|
|
|
|
if include_gpu_probs_tensor:
|
|
|
|
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
|
|
|
|
1,
|
|
|
|
dtype=torch.long,
|
|
|
|
device=logprobs.device)
|
|
|
|
else:
|
|
|
|
sampled_token_ids_tensor = None
|
|
|
|
|
2023-12-17 07:03:49 -08:00
|
|
|
# Counterintiutively, having two loops here is actually faster.
|
|
|
|
# The first loop can run without waiting on GPU<->CPU sync.
|
2023-09-22 17:48:04 -07:00
|
|
|
for sampling_type in SamplingType:
|
2024-03-20 14:45:08 -07:00
|
|
|
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
2023-10-16 10:56:50 -07:00
|
|
|
num_tokens = len(sample_indices)
|
2023-09-22 17:48:04 -07:00
|
|
|
if num_tokens == 0:
|
|
|
|
continue
|
2024-04-23 01:02:36 -07:00
|
|
|
|
2024-04-26 22:02:02 +09:00
|
|
|
seq_group_id = categorized_seq_group_ids[sampling_type]
|
|
|
|
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
|
|
|
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
|
|
|
|
long_sample_indices = sample_indices.long()
|
2023-09-22 17:48:04 -07:00
|
|
|
if sampling_type == SamplingType.GREEDY:
|
2024-04-23 01:02:36 -07:00
|
|
|
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
2024-02-28 09:34:34 -08:00
|
|
|
dim=-1)
|
2024-04-23 01:02:36 -07:00
|
|
|
|
2024-06-15 12:45:31 +08:00
|
|
|
if sampled_token_ids_tensor is not None:
|
2024-04-23 01:02:36 -07:00
|
|
|
# Store sampled tokens in output tensor.
|
|
|
|
sampled_token_ids_tensor[
|
|
|
|
long_sample_indices] = greedy_samples.unsqueeze(-1)
|
|
|
|
|
|
|
|
if modify_greedy_probs:
|
|
|
|
# If required, modify the probabilities such that sampling from
|
|
|
|
# the modified distribution would always sample the argmax
|
|
|
|
# token id.
|
|
|
|
_modify_greedy_probs_inplace(logprobs, probs,
|
|
|
|
long_sample_indices,
|
|
|
|
greedy_samples)
|
|
|
|
|
2024-02-21 11:47:00 -08:00
|
|
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
2024-03-20 14:45:08 -07:00
|
|
|
max_best_of_in_batch = 1
|
2024-04-26 22:02:02 +09:00
|
|
|
for seq_group in seq_groups:
|
|
|
|
if seq_group.is_prompt:
|
|
|
|
sampling_params = seq_group.sampling_params
|
2024-03-20 14:45:08 -07:00
|
|
|
max_best_of_in_batch = max(max_best_of_in_batch,
|
|
|
|
sampling_params.best_of)
|
2024-02-21 11:47:00 -08:00
|
|
|
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
|
|
|
"seq_groups": seq_groups,
|
|
|
|
}
|
2024-04-23 01:02:36 -07:00
|
|
|
|
2024-02-21 11:47:00 -08:00
|
|
|
multinomial_samples[sampling_type] = _multinomial(
|
2024-04-23 01:02:36 -07:00
|
|
|
probs[long_sample_indices], max_best_of_in_batch,
|
2024-03-20 14:45:08 -07:00
|
|
|
**seeded_args)
|
2024-04-23 01:02:36 -07:00
|
|
|
|
2024-06-15 12:45:31 +08:00
|
|
|
if sampled_token_ids_tensor is not None:
|
2024-04-23 01:02:36 -07:00
|
|
|
# Store sampled tokens in output tensor.
|
|
|
|
sampled_token_ids_tensor[
|
|
|
|
long_sample_indices] = multinomial_samples[sampling_type]
|
|
|
|
|
2023-12-17 07:03:49 -08:00
|
|
|
elif sampling_type == SamplingType.BEAM:
|
|
|
|
beam_search_logprobs = logprobs[sample_indices]
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
|
|
|
|
|
|
|
# GPU<->CPU sync happens in the loop below.
|
2024-04-23 01:02:36 -07:00
|
|
|
# This also converts the sample output to Python objects.
|
2024-07-17 17:30:28 -04:00
|
|
|
if not sampling_metadata.skip_sampler_cpu_output:
|
|
|
|
for sampling_type in SamplingType:
|
|
|
|
if sampling_type not in sample_metadata:
|
|
|
|
continue
|
|
|
|
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
|
|
|
|
if sampling_type == SamplingType.GREEDY:
|
|
|
|
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
|
|
|
elif sampling_type in (SamplingType.RANDOM,
|
|
|
|
SamplingType.RANDOM_SEED):
|
|
|
|
sample_results = _random_sample(
|
|
|
|
seq_groups, multinomial_samples[sampling_type])
|
|
|
|
elif sampling_type == SamplingType.BEAM:
|
|
|
|
sample_results = _beam_search_sample(seq_groups,
|
|
|
|
beam_search_logprobs)
|
|
|
|
sample_results_dict.update(zip(seq_group_id, sample_results))
|
|
|
|
|
|
|
|
sample_results = [
|
|
|
|
sample_results_dict.get(i, ([], []))
|
|
|
|
for i in range(len(sampling_metadata.seq_groups))
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
sample_results = []
|
2023-09-22 17:48:04 -07:00
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
return sample_results, sampled_token_ids_tensor
|
2023-10-16 10:56:50 -07:00
|
|
|
|
|
|
|
|
2024-03-20 14:45:08 -07:00
|
|
|
def _sample_with_triton_kernel(
|
|
|
|
probs: torch.Tensor,
|
|
|
|
logprobs: torch.Tensor,
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
sampling_tensors: SamplingTensors,
|
2024-04-29 11:01:26 +09:00
|
|
|
) -> SampleResultType:
|
|
|
|
categorized_seq_group_ids: Dict[SamplingType,
|
|
|
|
List[int]] = {t: []
|
|
|
|
for t in SamplingType}
|
2024-03-20 14:45:08 -07:00
|
|
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
|
|
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
2024-04-26 22:02:02 +09:00
|
|
|
sampling_params = seq_group.sampling_params
|
2024-03-20 14:45:08 -07:00
|
|
|
sampling_type = sampling_params.sampling_type
|
|
|
|
categorized_seq_group_ids[sampling_type].append(i)
|
|
|
|
|
|
|
|
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
2024-06-15 12:45:31 +08:00
|
|
|
sample_metadata: Dict[SamplingType,
|
|
|
|
Tuple[List[int], List[SequenceGroupToSample],
|
|
|
|
torch.Tensor, torch.Tensor]] = {}
|
2024-03-20 14:45:08 -07:00
|
|
|
max_best_of_in_batch = 1
|
|
|
|
|
|
|
|
# Counterintiutively, having two loops here is actually faster.
|
|
|
|
# The first loop can run without waiting on GPU<->CPU sync.
|
|
|
|
for sampling_type in SamplingType:
|
|
|
|
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
|
|
|
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
|
|
|
|
num_tokens = len(sample_indices)
|
|
|
|
if num_tokens == 0:
|
|
|
|
continue
|
2024-04-26 22:02:02 +09:00
|
|
|
seq_group_id = categorized_seq_group_ids[sampling_type]
|
|
|
|
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
|
|
|
|
sample_metadata[sampling_type] = (seq_group_id, seq_groups,
|
|
|
|
sample_indices,
|
2024-03-20 14:45:08 -07:00
|
|
|
sampled_token_indices)
|
|
|
|
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
|
|
|
|
SamplingType.RANDOM_SEED):
|
2024-04-26 22:02:02 +09:00
|
|
|
for seq_group in seq_groups:
|
|
|
|
if seq_group.is_prompt:
|
|
|
|
sampling_params = seq_group.sampling_params
|
2024-03-20 14:45:08 -07:00
|
|
|
max_best_of_in_batch = max(max_best_of_in_batch,
|
|
|
|
sampling_params.best_of)
|
|
|
|
elif sampling_type == SamplingType.BEAM:
|
|
|
|
beam_search_logprobs = logprobs[sample_indices]
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
|
|
|
|
|
|
|
sampled_tokens, _, _ = sample_triton(
|
|
|
|
probs=probs,
|
|
|
|
seeds=sampling_tensors.sampling_seeds,
|
|
|
|
max_best_of=max_best_of_in_batch,
|
|
|
|
sample_indices=sampling_tensors.sample_indices,
|
|
|
|
logprobs=logprobs,
|
|
|
|
# don't save logprobs because we have logic for that below
|
|
|
|
# TODO: use this instead of the CPU-based logic below
|
|
|
|
save_logprobs=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
# GPU<->CPU sync happens in the loop below.
|
|
|
|
|
|
|
|
for sampling_type in SamplingType:
|
|
|
|
if sampling_type not in sample_metadata:
|
|
|
|
continue
|
2024-04-26 22:02:02 +09:00
|
|
|
(seq_group_id, seq_groups, sample_indices,
|
2024-03-20 14:45:08 -07:00
|
|
|
sampled_token_indices) = sample_metadata[sampling_type]
|
|
|
|
if sampling_type == SamplingType.GREEDY:
|
|
|
|
sample_results = _greedy_sample(
|
|
|
|
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
|
|
|
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
|
|
|
sample_results = _random_sample(
|
2024-04-26 22:02:02 +09:00
|
|
|
seq_groups, sampled_tokens[sampled_token_indices])
|
2024-03-20 14:45:08 -07:00
|
|
|
elif sampling_type == SamplingType.BEAM:
|
2024-04-26 22:02:02 +09:00
|
|
|
sample_results = _beam_search_sample(seq_groups,
|
2024-03-20 14:45:08 -07:00
|
|
|
beam_search_logprobs)
|
2024-04-26 22:02:02 +09:00
|
|
|
sample_results_dict.update(zip(seq_group_id, sample_results))
|
2024-03-20 14:45:08 -07:00
|
|
|
|
|
|
|
sample_results = [
|
2024-04-26 22:02:02 +09:00
|
|
|
sample_results_dict.get(i, ([], []))
|
2024-03-20 14:45:08 -07:00
|
|
|
for i in range(len(sampling_metadata.seq_groups))
|
|
|
|
]
|
|
|
|
return sample_results
|
|
|
|
|
|
|
|
|
|
|
|
def _sample(
|
2024-04-23 01:02:36 -07:00
|
|
|
probs: torch.Tensor, logprobs: torch.Tensor,
|
|
|
|
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
|
|
|
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
2024-04-29 11:01:26 +09:00
|
|
|
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
|
2024-04-26 22:02:02 +09:00
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
probs: (num_query_tokens_in_batch, num_vocab)
|
|
|
|
logprobs: (num_query_tokens_in_batch, num_vocab)
|
|
|
|
sampling_metadata: The metadata for a batch for sampling.
|
|
|
|
sampling_tensors: Tensors that include sampling related metadata.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
(next_token_ids, parent_seq_ids) for each seq group in a batch.
|
|
|
|
If sampling is skipped, it returns ([], [])
|
2024-04-27 09:52:46 -07:00
|
|
|
sampled_token_ids_tensor: A tensor of sampled token ids.
|
2024-04-26 22:02:02 +09:00
|
|
|
"""
|
2024-04-23 01:02:36 -07:00
|
|
|
return _sample_with_torch(
|
|
|
|
probs,
|
|
|
|
logprobs,
|
|
|
|
sampling_metadata,
|
|
|
|
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
|
|
|
modify_greedy_probs=modify_greedy_probs,
|
|
|
|
)
|
2024-03-20 14:45:08 -07:00
|
|
|
|
|
|
|
# TODO: Enable once Triton kernel & associated code is faster.
|
|
|
|
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
|
|
|
|
# sampling_tensors)
|
|
|
|
|
|
|
|
|
2024-03-25 16:03:02 -07:00
|
|
|
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
2024-03-25 13:13:10 -04:00
|
|
|
"""
|
|
|
|
This function calculates the ranks of the chosen tokens in a logprob tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (torch.Tensor): 2D logprob tensor of shape (N, M)
|
|
|
|
where N is the no. of tokens and M is the vocab dim.
|
2024-03-25 16:03:02 -07:00
|
|
|
indices (torch.Tensor): List of chosen token indices.
|
2024-03-25 13:13:10 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
|
2024-06-28 17:50:16 +04:00
|
|
|
Each element in the returned tensor represents the rank
|
2024-03-25 13:13:10 -04:00
|
|
|
of the chosen token in the input logprob tensor.
|
|
|
|
"""
|
2024-03-25 16:03:02 -07:00
|
|
|
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
|
|
|
|
indices]
|
2024-05-12 20:47:47 -04:00
|
|
|
result = (x > vals[:, None])
|
|
|
|
del vals
|
|
|
|
return result.sum(1).add_(1)
|
2024-03-25 13:13:10 -04:00
|
|
|
|
|
|
|
|
2023-10-16 10:56:50 -07:00
|
|
|
def _get_logprobs(
|
|
|
|
logprobs: torch.Tensor,
|
2023-11-29 22:16:37 -08:00
|
|
|
sampling_metadata: SamplingMetadata,
|
2024-04-29 11:01:26 +09:00
|
|
|
sample_results: SampleResultType,
|
2024-04-26 22:02:02 +09:00
|
|
|
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
|
|
|
|
"""Return sample lobprobs and prompt logprobs.
|
|
|
|
|
|
|
|
The logic consists of 3 parts.
|
|
|
|
- Select indices to compute logprob from, ranks of token ids, and
|
|
|
|
the top k token ids from logprobs.
|
|
|
|
- Compute prompt logprobs if required.
|
|
|
|
- Compute sample logprobs if required.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
|
|
|
|
logprob per vocab. Sequence groups' query tokens are batched in a
|
|
|
|
single flattened tensor. For example, assuming there are N
|
|
|
|
seq groups, it is sorted by prefill tokens for seq_group_1 (if
|
|
|
|
prompt logprob is enabled), decode tokens for seq_group_1 (if
|
|
|
|
sampling is required), prefill tokens for seq_group_2, ...
|
|
|
|
sampling_metadata: The sampling metadata.
|
|
|
|
sample_results: (num_seq_groups) The tuple of (next_token_ids,
|
|
|
|
parent_ids) for each sequence group. When beam search is enabled,
|
|
|
|
sample_results can contain different number of seq_ids from
|
|
|
|
sampling_metadata.seq_groups. It is because beam search creates
|
|
|
|
2 * BEAM_WIDTH number of samples (whereas there are only up to
|
|
|
|
BEAM_WIDTH number of seq_ids).
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A tuple of prompt and sample logprobs per sequence group in a batch.
|
|
|
|
"""
|
|
|
|
# The index of query token to calculate logprobs. It includes both
|
|
|
|
# prompt and sample logprob indices.
|
|
|
|
query_indices: List[int] = []
|
|
|
|
# The next token ids to get the logprob value from.
|
|
|
|
next_token_ids: List[int] = []
|
|
|
|
# The largest requested number of logprobs. We find logprobs as many as the
|
2024-07-30 00:47:31 +08:00
|
|
|
# largest num logprobs in this API. If every logprobs is None, it will be
|
|
|
|
# set to -1.
|
|
|
|
largest_num_logprobs = -1
|
|
|
|
# If beam search is enabled.
|
|
|
|
use_beam_search = False
|
2024-04-26 22:02:02 +09:00
|
|
|
|
|
|
|
# Select indices to compute logprob from, ranks of token ids, and the top
|
|
|
|
# k token ids from logprobs.
|
|
|
|
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
|
|
|
|
sample_results):
|
|
|
|
sampling_params = seq_group.sampling_params
|
|
|
|
|
|
|
|
# Update indices and tokens for prompt logprobs.
|
|
|
|
if (seq_group.is_prompt
|
2023-10-16 10:56:50 -07:00
|
|
|
and sampling_params.prompt_logprobs is not None):
|
|
|
|
largest_num_logprobs = max(largest_num_logprobs,
|
|
|
|
sampling_params.prompt_logprobs)
|
2024-04-26 22:02:02 +09:00
|
|
|
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
|
|
|
|
query_indices.extend(seq_group.prompt_logprob_indices)
|
|
|
|
next_token_ids.extend(next_prompt_tokens)
|
|
|
|
|
|
|
|
# Update indices and next tokenes for sample logprob.
|
|
|
|
if seq_group.do_sample:
|
|
|
|
token_ids, parent_seq_ids = sample_result
|
|
|
|
# NOTE: We cannot directly use sample_indices because
|
|
|
|
# sample_indices only contain parent seq_ids of a previous step.
|
|
|
|
# The current step may have different number of seq_ids, and
|
|
|
|
# we can obtain it from `sample_result[1]`.
|
|
|
|
query_idx = seq_group.sample_indices[0]
|
|
|
|
query_indices.extend(
|
|
|
|
[query_idx + parent_id for parent_id in parent_seq_ids])
|
|
|
|
next_token_ids.extend(token_ids)
|
|
|
|
|
|
|
|
if sampling_params.logprobs is not None:
|
|
|
|
largest_num_logprobs = max(largest_num_logprobs,
|
|
|
|
sampling_params.logprobs)
|
|
|
|
|
2024-07-30 00:47:31 +08:00
|
|
|
use_beam_search = use_beam_search or sampling_params.use_beam_search
|
|
|
|
|
2024-04-26 22:02:02 +09:00
|
|
|
assert len(next_token_ids) == len(query_indices)
|
|
|
|
|
|
|
|
if len(query_indices) == 0:
|
2024-04-29 11:01:26 +09:00
|
|
|
empty_sampled_logprob: SampleLogprobs = []
|
|
|
|
empty_prompt_logprob: Optional[PromptLogprobs] = None
|
2024-04-26 22:02:02 +09:00
|
|
|
return [empty_prompt_logprob], [empty_sampled_logprob]
|
|
|
|
|
2024-07-30 00:47:31 +08:00
|
|
|
selected_logprobs, ranks = None, None
|
|
|
|
top_logprobs, top_token_ids = None, None
|
|
|
|
|
|
|
|
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
|
|
|
|
# skip the whole logprob calculation.
|
|
|
|
if largest_num_logprobs >= 0 or use_beam_search:
|
|
|
|
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
|
|
|
|
next_token_ids_gpu = torch.tensor(next_token_ids,
|
|
|
|
device=logprobs.device)
|
|
|
|
|
|
|
|
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
|
|
|
|
# contain duplicates if beam search is enabled.
|
|
|
|
selected_logprobs = logprobs[[
|
|
|
|
query_indices_gpu,
|
|
|
|
next_token_ids_gpu,
|
|
|
|
]]
|
|
|
|
ranks = _get_ranks(
|
|
|
|
logprobs[query_indices_gpu],
|
|
|
|
next_token_ids_gpu,
|
|
|
|
)
|
|
|
|
assert selected_logprobs.shape[0] == ranks.shape[0]
|
|
|
|
|
|
|
|
# We need to compute top k only if there exists logprobs > 0.
|
|
|
|
if largest_num_logprobs > 0:
|
|
|
|
# Logprobs of topk tokens for a batch of sequence groups.
|
|
|
|
# (num_query_tokens_across_batch).
|
|
|
|
top_logprobs, top_token_ids = torch.topk(logprobs,
|
|
|
|
largest_num_logprobs,
|
|
|
|
dim=-1)
|
|
|
|
top_logprobs = top_logprobs.to('cpu')
|
|
|
|
top_token_ids = top_token_ids.to('cpu')
|
2023-10-16 10:56:50 -07:00
|
|
|
|
2024-07-30 00:47:31 +08:00
|
|
|
selected_logprobs = selected_logprobs.to('cpu')
|
|
|
|
ranks = ranks.to('cpu')
|
2024-04-26 22:02:02 +09:00
|
|
|
|
|
|
|
# Find prompt/sample logprobs.
|
|
|
|
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
|
|
|
|
sample_logprobs_per_seq_group: List[SampleLogprobs] = []
|
|
|
|
top_logprob_idx = 0
|
|
|
|
selected_logprobs_idx = 0
|
|
|
|
|
|
|
|
for seq_group, sample_result in zip(sampling_metadata.seq_groups,
|
|
|
|
sample_results):
|
|
|
|
(prompt_logprobs, top_logprob_idx,
|
|
|
|
selected_logprobs_idx) = _get_prompt_logprob_if_needed(
|
|
|
|
seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
|
|
|
|
selected_logprobs_idx, top_logprob_idx)
|
|
|
|
prompt_logprobs_per_seq_group.append(prompt_logprobs)
|
|
|
|
|
|
|
|
(sampled_logprobs, top_logprob_idx,
|
|
|
|
selected_logprobs_idx) = _get_sampled_logprob_if_needed(
|
|
|
|
seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
|
|
|
|
top_logprobs, selected_logprobs_idx, top_logprob_idx)
|
|
|
|
sample_logprobs_per_seq_group.append(sampled_logprobs)
|
|
|
|
|
|
|
|
return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
|
|
|
|
|
|
|
|
|
|
|
|
def _get_prompt_logprob_if_needed(
|
|
|
|
seq_group: SequenceGroupToSample,
|
|
|
|
selected_logprobs: torch.Tensor,
|
|
|
|
ranks: torch.Tensor,
|
|
|
|
top_token_ids: torch.Tensor,
|
|
|
|
top_logprobs: torch.Tensor,
|
|
|
|
selected_logprobs_idx: int,
|
|
|
|
top_logprob_idx: int,
|
|
|
|
):
|
|
|
|
"""Compute the prompt logprob from a sequence group if needed."""
|
|
|
|
sampling_params = seq_group.sampling_params
|
|
|
|
is_prompt = seq_group.is_prompt
|
|
|
|
|
|
|
|
# Find prompt logprobs
|
|
|
|
prompt_logprobs: Optional[PromptLogprobs] = None
|
2024-05-09 00:42:28 +09:00
|
|
|
if is_prompt and sampling_params.prompt_logprobs is not None:
|
2024-04-26 22:02:02 +09:00
|
|
|
prompt_logprobs = []
|
|
|
|
num_logprobs = sampling_params.prompt_logprobs
|
|
|
|
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
|
2024-05-09 00:42:28 +09:00
|
|
|
# Pre-select indexes and create a list. It is faster than calling .item
|
|
|
|
# repetitively.
|
|
|
|
selected_logprob_items = selected_logprobs[
|
|
|
|
selected_logprobs_idx:selected_logprobs_idx +
|
|
|
|
len(next_prompt_tokens)].tolist()
|
|
|
|
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
|
|
|
|
len(next_prompt_tokens)].tolist()
|
|
|
|
|
|
|
|
for idx, token_id in enumerate(next_prompt_tokens):
|
2024-04-26 22:02:02 +09:00
|
|
|
# Calculate the prompt logprob of the real prompt tokens.
|
|
|
|
# {token_id: (logprob, rank_from_vocab)}
|
|
|
|
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
|
2024-05-09 00:42:28 +09:00
|
|
|
token_id: (selected_logprob_items[idx], rank_items[idx])
|
2024-04-26 22:02:02 +09:00
|
|
|
}
|
2023-10-16 10:56:50 -07:00
|
|
|
|
2024-04-26 22:02:02 +09:00
|
|
|
# Add top K prompt logprobs along with its rank.
|
|
|
|
if num_logprobs > 0:
|
2024-05-09 00:42:28 +09:00
|
|
|
top_ids = top_token_ids[
|
|
|
|
top_logprob_idx, :num_logprobs].tolist()
|
|
|
|
top_probs = top_logprobs[
|
|
|
|
top_logprob_idx, :num_logprobs].tolist()
|
|
|
|
# Top K is already sorted by rank, so we can use 1 ~
|
|
|
|
# num_logprobs + 1 for rank.
|
|
|
|
top_ranks = range(1, num_logprobs + 1)
|
|
|
|
prompt_logprobs_dict.update({
|
|
|
|
top_id: (top_prob, rank)
|
|
|
|
for top_id, top_prob, rank in zip(top_ids, top_probs,
|
|
|
|
top_ranks)
|
|
|
|
})
|
2024-04-26 22:02:02 +09:00
|
|
|
prompt_logprobs.append({
|
|
|
|
token_id: Logprob(*logprob_and_rank)
|
|
|
|
for token_id, logprob_and_rank in prompt_logprobs_dict.items()
|
|
|
|
})
|
|
|
|
# + 1 to go to the next prompt token.
|
|
|
|
top_logprob_idx += 1
|
2024-05-09 00:42:28 +09:00
|
|
|
|
|
|
|
# + len(next_prompt_tokens) to go to the next prompt.
|
|
|
|
selected_logprobs_idx += len(next_prompt_tokens)
|
2024-04-26 22:02:02 +09:00
|
|
|
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
|
|
|
|
|
|
|
|
|
|
|
|
def _get_sampled_logprob_if_needed(
|
|
|
|
seq_group: SequenceGroupToSample,
|
|
|
|
sample_result: Tuple[List[int], List[int]],
|
|
|
|
selected_logprobs: torch.Tensor,
|
|
|
|
ranks: torch.Tensor,
|
|
|
|
top_token_ids: torch.Tensor,
|
|
|
|
top_logprobs: torch.Tensor,
|
|
|
|
selected_logprobs_idx: int,
|
|
|
|
top_logprob_idx: int,
|
|
|
|
):
|
|
|
|
"""Compute the sample logprob if needed."""
|
|
|
|
seq_ids = seq_group.seq_ids
|
2024-07-30 00:47:31 +08:00
|
|
|
num_logprobs = seq_group.sampling_params.logprobs
|
|
|
|
use_beam_search = seq_group.sampling_params.use_beam_search
|
2024-04-26 22:02:02 +09:00
|
|
|
sampled_logprobs: SampleLogprobs = []
|
|
|
|
next_token_ids, parent_seq_ids = sample_result
|
|
|
|
|
|
|
|
if seq_group.do_sample:
|
|
|
|
assert len(next_token_ids) > 0
|
2024-07-30 00:47:31 +08:00
|
|
|
if num_logprobs is None and not use_beam_search:
|
|
|
|
for next_token_id in next_token_ids:
|
|
|
|
# Use a dummy logprob
|
|
|
|
sampled_logprobs.append({next_token_id: Logprob(inf)})
|
|
|
|
else:
|
|
|
|
# Pre-select items from tensor. tolist() is faster than repetitive
|
|
|
|
# `.item()` calls.
|
|
|
|
selected_logprob_items = selected_logprobs[
|
|
|
|
selected_logprobs_idx:selected_logprobs_idx +
|
|
|
|
len(next_token_ids)].tolist()
|
|
|
|
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
|
|
|
|
len(next_token_ids)].tolist()
|
|
|
|
for idx, (next_token_id, parent_id) in enumerate(
|
|
|
|
zip(next_token_ids, parent_seq_ids)):
|
|
|
|
# Get the logprob of a sampled token.
|
|
|
|
sampled_logprobs_dict = {
|
|
|
|
next_token_id:
|
|
|
|
(selected_logprob_items[idx], rank_items[idx])
|
|
|
|
}
|
|
|
|
if num_logprobs is not None and num_logprobs > 0:
|
|
|
|
# Get top K logprobs.
|
|
|
|
top_ids = top_token_ids[top_logprob_idx +
|
|
|
|
parent_id, :num_logprobs].tolist()
|
|
|
|
top_probs = top_logprobs[
|
|
|
|
top_logprob_idx + parent_id, :num_logprobs].tolist()
|
|
|
|
# Top K is already sorted by rank, so we can use 1 ~
|
|
|
|
# num_logprobs + 1 for rank.
|
|
|
|
top_ranks = range(1, num_logprobs + 1)
|
|
|
|
sampled_logprobs_dict.update({
|
|
|
|
top_id: (top_prob, rank)
|
|
|
|
for top_id, top_prob, rank in zip(
|
|
|
|
top_ids, top_probs, top_ranks)
|
|
|
|
})
|
|
|
|
|
|
|
|
sampled_logprobs.append({
|
|
|
|
token_id: Logprob(*logprob_and_rank)
|
|
|
|
for token_id, logprob_and_rank in
|
|
|
|
sampled_logprobs_dict.items()
|
2024-05-09 00:42:28 +09:00
|
|
|
})
|
|
|
|
|
|
|
|
# NOTE: This part of code is not intuitive. `selected_logprobs` include
|
|
|
|
# logprobs for the current step, which has len(next_token_ids) tokens
|
|
|
|
# per sequence group. `logprobs` includes logprobs from the previous
|
|
|
|
# steps, which has len(seq_ids) tokens per sequence group.
|
|
|
|
|
|
|
|
# Iterate to the next sequence group in a batch.
|
|
|
|
selected_logprobs_idx += len(next_token_ids)
|
|
|
|
# Iterate to the next sequence group in a batch.
|
2024-04-26 22:02:02 +09:00
|
|
|
top_logprob_idx += len(seq_ids)
|
|
|
|
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
|
2023-10-16 10:56:50 -07:00
|
|
|
|
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
|
|
|
sample_indices: torch.Tensor,
|
|
|
|
greedy_samples: torch.Tensor) -> None:
|
|
|
|
"""Modify the probability distributions of the greedily-sampled tokens such
|
|
|
|
that each sampled token has a "probability" of 1.0. This is required by
|
|
|
|
speculative decoding, which depends on the sampling method being encoded
|
|
|
|
within the probability distribution for correctness.
|
|
|
|
|
|
|
|
# Why do we only need to do this for greedy sampling?
|
|
|
|
|
|
|
|
vLLM's sampler performs the following steps for greedy or multinomial
|
|
|
|
(random) sampling:
|
|
|
|
1. Get logits from model.
|
|
|
|
2. Modify logits according to per-sequence sampling parameters.
|
|
|
|
- Multiply by temperature, top-k and top-p masking, penalize tokens
|
|
|
|
according to their frequency, etc.
|
|
|
|
3. Sample a token.
|
|
|
|
- Random sampling simply samples from the modified probability
|
|
|
|
distribution.
|
|
|
|
- Greedy sampling performs `argmax` to obtain the token with the
|
|
|
|
highest likelihood.
|
2024-06-28 17:50:16 +04:00
|
|
|
|
2024-04-23 01:02:36 -07:00
|
|
|
Ignoring greedy sampling for a moment, we find that the computed probability
|
|
|
|
distribution has the following property: we can sample from it independently
|
|
|
|
and find that the token sampled by the Sampler has a frequency corresponding
|
|
|
|
to how often we see it in our sampling. In other words, for tokens sampled
|
|
|
|
with vLLM's random SamplingType, the computed probability distribution
|
|
|
|
encodes the sampling methodology completely.
|
|
|
|
|
|
|
|
Greedy sampling does not normally have this property. vLLM modifies logits
|
|
|
|
according to sampling params, then performs `argmax`, then returns the
|
|
|
|
sampled token and the computed probability distribution. If we sample from
|
|
|
|
the distribution, we'll find the likelihood of the greedily-sampled token
|
|
|
|
is not always 1.0.
|
|
|
|
|
|
|
|
Since lossless speculative decoding requires that the sampling methodology
|
|
|
|
be encoded within the probability distribution, we are motivated to modify
|
|
|
|
the probability distribution such that the sampled token has probability 1
|
|
|
|
when speculative decoding is used.
|
|
|
|
|
|
|
|
NOTE: Alternatively, we could use an extremely low temperature to achieve
|
|
|
|
greedy sampling using multinomial computation and unite the codepaths. This
|
|
|
|
has implications on the overall design of the sampler, e.g. how to record
|
|
|
|
accurate logprobs for the user, so this improvement is deferred to later.
|
|
|
|
"""
|
2024-05-03 15:52:01 -07:00
|
|
|
# NOTE: logprobs are not modified so they can be returned to the user.
|
2024-04-23 01:02:36 -07:00
|
|
|
probs[sample_indices, :] = 0
|
|
|
|
probs[sample_indices, greedy_samples] = 1.0
|
|
|
|
|
|
|
|
|
2023-10-16 10:56:50 -07:00
|
|
|
def _build_sampler_output(
|
2024-04-29 11:01:26 +09:00
|
|
|
sample_results: SampleResultType,
|
2023-11-29 22:16:37 -08:00
|
|
|
sampling_metadata: SamplingMetadata,
|
2024-07-17 17:30:28 -04:00
|
|
|
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
|
|
|
|
sample_logprobs: Optional[List[SampleLogprobs]],
|
2024-05-03 15:52:01 -07:00
|
|
|
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
|
|
|
|
torch.Tensor]],
|
2024-07-17 17:30:28 -04:00
|
|
|
skip_sampler_cpu_output: bool = False,
|
2023-10-16 10:56:50 -07:00
|
|
|
) -> SamplerOutput:
|
2024-04-23 01:02:36 -07:00
|
|
|
"""Construct Python objects with the output of sampling.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
on_device_tensors: Tuple containing on-device tensors with the
|
|
|
|
probabilities used in sampling and the sampled token ids. This
|
|
|
|
allows post-processing without copies to CPU/serialization, e.g. in
|
|
|
|
speculative decoding rejection sampling.
|
|
|
|
"""
|
2024-06-15 12:45:31 +08:00
|
|
|
sampler_output: List[CompletionSequenceGroupOutput] = []
|
2024-07-17 17:30:28 -04:00
|
|
|
if not skip_sampler_cpu_output:
|
|
|
|
assert prompt_logprobs is not None
|
|
|
|
assert sample_logprobs is not None
|
|
|
|
|
|
|
|
for (seq_group, sample_result, group_prompt_logprobs,
|
|
|
|
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
|
|
|
sample_results, prompt_logprobs,
|
|
|
|
sample_logprobs):
|
|
|
|
seq_ids = seq_group.seq_ids
|
|
|
|
next_token_ids, parent_ids = sample_result
|
|
|
|
seq_outputs: List[SequenceOutput] = []
|
|
|
|
for parent_id, next_token_id, logprobs in zip(
|
|
|
|
parent_ids, next_token_ids, group_sample_logprobs):
|
|
|
|
seq_outputs.append(
|
|
|
|
SequenceOutput(seq_ids[parent_id], next_token_id,
|
|
|
|
logprobs))
|
|
|
|
sampler_output.append(
|
|
|
|
CompletionSequenceGroupOutput(seq_outputs,
|
|
|
|
group_prompt_logprobs))
|
2024-04-23 01:02:36 -07:00
|
|
|
|
|
|
|
# If not specified, store None values in SamplerOutput.
|
|
|
|
if on_device_tensors is not None:
|
2024-05-03 15:52:01 -07:00
|
|
|
(sampled_token_probs, logprobs_tensor,
|
|
|
|
sampled_token_ids) = on_device_tensors
|
2024-04-23 01:02:36 -07:00
|
|
|
else:
|
2024-05-03 15:52:01 -07:00
|
|
|
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
|
|
|
|
None)
|
2024-04-23 01:02:36 -07:00
|
|
|
|
|
|
|
return SamplerOutput(
|
|
|
|
outputs=sampler_output,
|
|
|
|
sampled_token_probs=sampled_token_probs,
|
|
|
|
sampled_token_ids=sampled_token_ids,
|
2024-05-03 15:52:01 -07:00
|
|
|
logprobs=logprobs_tensor,
|
2024-04-23 01:02:36 -07:00
|
|
|
)
|
2024-04-26 22:02:02 +09:00
|
|
|
|
|
|
|
|
2024-04-29 11:01:26 +09:00
|
|
|
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
|
2024-04-26 22:02:02 +09:00
|
|
|
"""Get a list of next prompt tokens to compute logprob from a
|
|
|
|
given sequence group.
|
|
|
|
|
|
|
|
It is used to compute prompt logprob. Imagine you have logprob for each
|
|
|
|
query token. Query token needs to know the next prompt token id to compute
|
|
|
|
prompt logprob. This is a helper to obtain next prompt token ids.
|
|
|
|
|
|
|
|
This API has to be used only when the caller knows seq_group is in prefill
|
|
|
|
stage.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A list of next prompt tokens to compute logprob.
|
|
|
|
"""
|
|
|
|
assert seq_group.is_prompt, (
|
|
|
|
"Caller should ensure the sequence group is in a prefill stage.")
|
|
|
|
seq_ids = seq_group.seq_ids
|
2024-05-04 02:20:12 +09:00
|
|
|
query_len = seq_group.query_len
|
|
|
|
assert query_len is not None
|
2024-04-26 22:02:02 +09:00
|
|
|
# prompt has only 1 seq id.
|
|
|
|
assert len(seq_ids) == 1
|
|
|
|
seq_data = seq_group.seq_data[seq_ids[0]]
|
|
|
|
computed_len = seq_data.get_num_computed_tokens()
|
|
|
|
prompt_tokens = seq_data.prompt_token_ids
|
|
|
|
# +1 because we are looking for a next prompt token.
|
|
|
|
next_token_index_start = computed_len + 1
|
2024-05-04 02:20:12 +09:00
|
|
|
next_token_index_end = min(computed_len + query_len + 1,
|
2024-04-26 22:02:02 +09:00
|
|
|
len(prompt_tokens))
|
|
|
|
next_prompt_tokens = prompt_tokens[
|
|
|
|
next_token_index_start:next_token_index_end]
|
|
|
|
return next_prompt_tokens
|