2023-02-23 09:26:09 +00:00
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
from cacheflow.models import InputMetadata
|
2023-03-10 09:58:21 -08:00
|
|
|
from cacheflow.sampling_params import SamplingParams
|
|
|
|
from cacheflow.sequence import SequenceOutputs
|
2023-03-22 04:45:42 +08:00
|
|
|
from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_parallel_region
|
2023-02-23 09:26:09 +00:00
|
|
|
|
2023-03-26 08:00:39 +00:00
|
|
|
|
2023-02-23 09:26:09 +00:00
|
|
|
class Sampler(nn.Module):
|
|
|
|
|
2023-05-04 02:59:56 -07:00
|
|
|
def __init__(self, vocab_size: int) -> None:
|
2023-03-30 11:04:21 -07:00
|
|
|
super().__init__()
|
2023-05-04 02:59:56 -07:00
|
|
|
self.vocab_size = vocab_size
|
2023-02-23 09:26:09 +00:00
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
2023-02-23 20:30:12 +00:00
|
|
|
embedding: torch.Tensor,
|
2023-02-23 09:26:09 +00:00
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
input_metadata: InputMetadata,
|
2023-03-10 09:58:21 -08:00
|
|
|
) -> Dict[int, SequenceOutputs]:
|
|
|
|
# Get the hidden states that we use for sampling.
|
|
|
|
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
2023-02-23 09:26:09 +00:00
|
|
|
|
|
|
|
# Get the logits for the next tokens.
|
2023-02-23 20:30:12 +00:00
|
|
|
logits = torch.matmul(hidden_states, embedding.t())
|
2023-03-22 04:45:42 +08:00
|
|
|
logits = gather_from_tensor_model_parallel_region(logits)
|
2023-05-04 02:59:56 -07:00
|
|
|
# Remove paddings in vocab.
|
|
|
|
logits = logits[:, :self.vocab_size]
|
2023-02-23 09:26:09 +00:00
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
# Apply temperature scaling.
|
|
|
|
temperatures = _get_temperatures(input_metadata)
|
|
|
|
assert len(temperatures) == logits.shape[0]
|
|
|
|
if any(t != 1.0 for t in temperatures):
|
|
|
|
t = torch.tensor(
|
|
|
|
temperatures, dtype=logits.dtype, device=logits.device)
|
|
|
|
# Use in-place division to avoid creating a new tensor.
|
|
|
|
logits.div_(t.unsqueeze(dim=1))
|
|
|
|
|
2023-03-31 23:33:43 -07:00
|
|
|
# We use float32 for probabilities and log probabilities.
|
2023-03-10 09:58:21 -08:00
|
|
|
# Compute the probabilities.
|
|
|
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
|
|
# Compute the log probabilities (before applying top-p).
|
2023-03-31 23:33:43 -07:00
|
|
|
logprobs = torch.log(probs)
|
2023-03-10 09:58:21 -08:00
|
|
|
|
|
|
|
# Apply top-p truncation.
|
|
|
|
top_ps = _get_top_ps(input_metadata)
|
|
|
|
assert len(top_ps) == probs.shape[0]
|
|
|
|
if any(p < 1.0 for p in top_ps):
|
|
|
|
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
|
|
|
probs = _apply_top_p(probs, p)
|
|
|
|
|
2023-02-23 09:26:09 +00:00
|
|
|
# Sample the next tokens.
|
2023-03-10 09:58:21 -08:00
|
|
|
return _sample(probs, logprobs, input_metadata)
|
|
|
|
|
|
|
|
|
|
|
|
def _prune_hidden_states(
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
input_metadata: InputMetadata,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
start_idx = 0
|
|
|
|
last_token_indicies: List[int] = []
|
|
|
|
for prompt_len in input_metadata.prompt_lens:
|
|
|
|
last_token_indicies.append(start_idx + prompt_len - 1)
|
|
|
|
start_idx += prompt_len
|
|
|
|
last_token_indicies.extend(
|
|
|
|
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
|
|
|
return hidden_states[last_token_indicies]
|
|
|
|
|
|
|
|
|
|
|
|
def _get_temperatures(
|
|
|
|
input_metadata: InputMetadata,
|
|
|
|
) -> List[float]:
|
|
|
|
# Collect the temperatures for the logits.
|
|
|
|
temperatures: List[float] = []
|
|
|
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
|
|
seq_ids, sampling_params = seq_group
|
|
|
|
temperature = sampling_params.temperature
|
|
|
|
if temperature == 0.0:
|
|
|
|
# NOTE: Zero temperature means deterministic sampling
|
|
|
|
# (i.e., greedy sampling or beam search).
|
|
|
|
# Set the temperature to 1 to avoid division by zero.
|
|
|
|
temperature = 1.0
|
|
|
|
|
|
|
|
if i < input_metadata.num_prompts:
|
|
|
|
# A prompt input.
|
|
|
|
temperatures.append(temperature)
|
|
|
|
else:
|
|
|
|
# A generation token.
|
|
|
|
temperatures += [temperature] * len(seq_ids)
|
|
|
|
return temperatures
|
|
|
|
|
|
|
|
|
|
|
|
def _get_top_ps(
|
|
|
|
input_metadata: InputMetadata,
|
|
|
|
) -> List[float]:
|
|
|
|
top_ps: List[float] = []
|
|
|
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
|
|
seq_ids, sampling_params = seq_group
|
|
|
|
if i < input_metadata.num_prompts:
|
|
|
|
# A prompt input.
|
|
|
|
top_ps.append(sampling_params.top_p)
|
|
|
|
else:
|
|
|
|
# A generation token.
|
|
|
|
top_ps += [sampling_params.top_p] * len(seq_ids)
|
|
|
|
return top_ps
|
|
|
|
|
|
|
|
|
|
|
|
def _apply_top_p(
|
|
|
|
probs: torch.Tensor,
|
|
|
|
p: torch.Tensor,
|
|
|
|
) -> torch.Tensor:
|
|
|
|
# TODO(woosuk): Optimize.
|
|
|
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
|
|
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
|
|
mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
|
|
|
probs_sort[mask] = 0.0
|
|
|
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
|
|
|
probs = torch.gather(
|
|
|
|
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
|
|
|
|
return probs
|
|
|
|
|
|
|
|
|
|
|
|
def _get_topk_logprobs(
|
|
|
|
logprobs: torch.Tensor,
|
|
|
|
num_logprobs: int,
|
|
|
|
) -> Dict[int, float]:
|
|
|
|
if num_logprobs == 0:
|
|
|
|
return {}
|
|
|
|
|
|
|
|
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
|
|
|
|
if num_logprobs == 1:
|
|
|
|
topk_logprobs = [topk_logprobs.item()]
|
|
|
|
topk_ids = [topk_ids.item()]
|
|
|
|
else:
|
|
|
|
topk_logprobs = topk_logprobs.tolist()
|
|
|
|
topk_ids = topk_ids.tolist()
|
|
|
|
|
|
|
|
token_to_logprob: Dict[int, float] = {}
|
|
|
|
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
|
|
|
token_to_logprob[token_id] = logprob
|
|
|
|
return token_to_logprob
|
|
|
|
|
|
|
|
|
|
|
|
def _sample_from_prompt(
|
|
|
|
prob: torch.Tensor,
|
|
|
|
sampling_params: SamplingParams,
|
|
|
|
) -> List[int]:
|
|
|
|
if sampling_params.use_beam_search:
|
|
|
|
# Beam search.
|
|
|
|
beam_width = sampling_params.n
|
|
|
|
_, next_token_ids = torch.topk(prob, beam_width)
|
|
|
|
next_token_ids = next_token_ids.tolist()
|
|
|
|
elif sampling_params.temperature == 0.0:
|
|
|
|
# Greedy sampling.
|
|
|
|
assert sampling_params.n == 1
|
|
|
|
next_token_id = torch.argmax(prob)
|
|
|
|
next_token_ids = [next_token_id.item()]
|
|
|
|
else:
|
|
|
|
# Neucleus sampling.
|
|
|
|
# Sample n tokens for the prompt.
|
|
|
|
n = sampling_params.n
|
|
|
|
next_token_ids = torch.multinomial(
|
|
|
|
prob, num_samples=n, replacement=True)
|
2023-02-23 09:26:09 +00:00
|
|
|
next_token_ids = next_token_ids.tolist()
|
2023-03-10 09:58:21 -08:00
|
|
|
return next_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
def _sample_from_generation_tokens(
|
|
|
|
seq_ids: List[int],
|
|
|
|
probs: torch.Tensor,
|
|
|
|
logprobs: torch.Tensor,
|
|
|
|
seq_logprobs: List[float],
|
|
|
|
sampling_params: SamplingParams,
|
|
|
|
) -> Tuple[List[int], List[int]]:
|
|
|
|
# NOTE(woosuk): sampling_params.n can be greater than
|
|
|
|
# len(seq_ids) because some sequences in the group might have
|
|
|
|
# been already terminated.
|
|
|
|
if sampling_params.use_beam_search:
|
|
|
|
# Beam search.
|
|
|
|
# Add cumulative logprobs for the sequences in the group.
|
|
|
|
seq_logprobs = torch.tensor(
|
|
|
|
seq_logprobs, dtype=torch.float, device=logprobs.device)
|
|
|
|
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
|
|
|
|
|
|
|
|
vocab_size = logprobs.size(-1)
|
|
|
|
beam_width = len(seq_ids)
|
|
|
|
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
|
2023-04-07 17:45:07 -07:00
|
|
|
topk_ids = topk_ids.tolist()
|
|
|
|
seq_idx = [i // vocab_size for i in topk_ids]
|
2023-03-10 09:58:21 -08:00
|
|
|
beam_seq_ids = [seq_ids[i] for i in seq_idx]
|
2023-04-07 17:45:07 -07:00
|
|
|
token_ids = [i % vocab_size for i in topk_ids]
|
2023-03-10 09:58:21 -08:00
|
|
|
|
|
|
|
beam_outputs: Dict[int, Tuple[int, int]] = {}
|
|
|
|
outstanding_beams: List[Tuple[int, int]] = []
|
|
|
|
# If a beam survives, continue with it.
|
|
|
|
for seq_id, token_id in zip(beam_seq_ids, token_ids):
|
|
|
|
if seq_id not in beam_outputs:
|
|
|
|
beam_outputs[seq_id] = (seq_id, token_id)
|
|
|
|
else:
|
|
|
|
outstanding_beams.append((seq_id, token_id))
|
|
|
|
|
|
|
|
# If a beam is discarded, fork another beam.
|
|
|
|
for seq_id in seq_ids:
|
|
|
|
if seq_id not in beam_outputs:
|
|
|
|
beam_outputs[seq_id] = outstanding_beams.pop()
|
|
|
|
assert not outstanding_beams
|
|
|
|
|
|
|
|
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
|
|
|
|
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
|
|
|
|
elif sampling_params.temperature == 0.0:
|
|
|
|
# Greedy sampling.
|
|
|
|
assert len(seq_ids) == 1
|
|
|
|
next_token_id = torch.argmax(probs, dim=-1)
|
|
|
|
next_token_ids = [next_token_id.item()]
|
|
|
|
parent_seq_ids = seq_ids
|
|
|
|
else:
|
|
|
|
# Neucleus sampling.
|
|
|
|
# Sample 1 token for each sequence in the group.
|
|
|
|
next_token_ids = torch.multinomial(
|
|
|
|
probs, num_samples=1, replacement=True)
|
|
|
|
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
|
|
|
|
parent_seq_ids = seq_ids
|
|
|
|
return parent_seq_ids, next_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
def _sample(
|
|
|
|
probs: torch.Tensor,
|
|
|
|
logprobs: torch.Tensor,
|
|
|
|
input_metadata: InputMetadata,
|
|
|
|
) -> Dict[int, SequenceOutputs]:
|
|
|
|
seq_outputs: Dict[int, SequenceOutputs] = {}
|
|
|
|
|
|
|
|
# TODO(woosuk): Optimize.
|
|
|
|
idx = 0
|
|
|
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
|
|
seq_ids, sampling_params = seq_group
|
|
|
|
if i < input_metadata.num_prompts:
|
|
|
|
# Generate the next tokens for a prompt input.
|
|
|
|
assert len(seq_ids) == sampling_params.n
|
|
|
|
prob = probs[idx]
|
|
|
|
logprob = logprobs[idx]
|
|
|
|
idx += 1
|
|
|
|
|
|
|
|
# Sample the next tokens.
|
|
|
|
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
|
|
|
# Get top-k log probabilities for the next tokens.
|
|
|
|
next_logprobs = _get_topk_logprobs(
|
|
|
|
logprob, sampling_params.num_logprobs)
|
|
|
|
|
|
|
|
# Build the output.
|
|
|
|
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
|
|
|
|
output_logprobs = next_logprobs.copy()
|
|
|
|
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
|
|
|
seq_outputs[seq_id] = SequenceOutputs(
|
|
|
|
seq_id, seq_id, next_token_id, output_logprobs)
|
|
|
|
else:
|
|
|
|
# Generate the next tokens for generation tokens.
|
|
|
|
prob = probs[idx:idx + len(seq_ids)]
|
|
|
|
logprob = logprobs[idx:idx + len(seq_ids)]
|
|
|
|
idx += len(seq_ids)
|
|
|
|
|
|
|
|
# Sample the next tokens.
|
|
|
|
seq_logprobs = [
|
|
|
|
input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
|
|
|
|
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
|
|
|
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
|
|
|
|
|
|
|
# Get top-k log probabilities for the next tokens.
|
|
|
|
next_logprobs: Dict[int, Dict[int, float]] = {}
|
|
|
|
for i, seq_id in enumerate(seq_ids):
|
|
|
|
next_logprobs[seq_id] = _get_topk_logprobs(
|
|
|
|
logprob[i], sampling_params.num_logprobs)
|
|
|
|
|
|
|
|
# Build the output.
|
|
|
|
for seq_id, parent_seq_id, next_token_id in zip(
|
|
|
|
seq_ids, parent_seq_ids, next_token_ids):
|
|
|
|
i = seq_ids.index(parent_seq_id)
|
|
|
|
output_logprobs = next_logprobs[parent_seq_id].copy()
|
|
|
|
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
|
|
|
|
seq_outputs[seq_id] = SequenceOutputs(
|
|
|
|
seq_id,
|
|
|
|
parent_seq_id,
|
|
|
|
next_token_id,
|
|
|
|
output_logprobs,
|
|
|
|
)
|
2023-02-23 09:26:09 +00:00
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
return seq_outputs
|