vllm/cacheflow/models/sample.py

291 lines
10 KiB
Python
Raw Normal View History

2023-02-23 09:26:09 +00:00
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from cacheflow.models import InputMetadata
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceOutputs
2023-03-22 04:45:42 +08:00
from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_parallel_region
2023-02-23 09:26:09 +00:00
2023-03-26 08:00:39 +00:00
2023-02-23 09:26:09 +00:00
class Sampler(nn.Module):
2023-05-04 02:59:56 -07:00
def __init__(self, vocab_size: int) -> None:
super().__init__()
2023-05-04 02:59:56 -07:00
self.vocab_size = vocab_size
2023-02-23 09:26:09 +00:00
def forward(
self,
2023-02-23 20:30:12 +00:00
embedding: torch.Tensor,
2023-02-23 09:26:09 +00:00
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]:
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
2023-02-23 09:26:09 +00:00
# Get the logits for the next tokens.
2023-02-23 20:30:12 +00:00
logits = torch.matmul(hidden_states, embedding.t())
2023-03-22 04:45:42 +08:00
logits = gather_from_tensor_model_parallel_region(logits)
2023-05-04 02:59:56 -07:00
# Remove paddings in vocab.
logits = logits[:, :self.vocab_size]
2023-02-23 09:26:09 +00:00
# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures):
t = torch.tensor(
temperatures, dtype=logits.dtype, device=logits.device)
# Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1))
2023-03-31 23:33:43 -07:00
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p).
2023-03-31 23:33:43 -07:00
logprobs = torch.log(probs)
# Apply top-p truncation.
top_ps = _get_top_ps(input_metadata)
assert len(top_ps) == probs.shape[0]
if any(p < 1.0 for p in top_ps):
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
probs = _apply_top_p(probs, p)
2023-02-23 09:26:09 +00:00
# Sample the next tokens.
return _sample(probs, logprobs, input_metadata)
def _prune_hidden_states(
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
start_idx = 0
last_token_indicies: List[int] = []
for prompt_len in input_metadata.prompt_lens:
last_token_indicies.append(start_idx + prompt_len - 1)
start_idx += prompt_len
last_token_indicies.extend(
range(start_idx, start_idx + input_metadata.num_generation_tokens))
return hidden_states[last_token_indicies]
def _get_temperatures(
input_metadata: InputMetadata,
) -> List[float]:
# Collect the temperatures for the logits.
temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
if temperature == 0.0:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature = 1.0
if i < input_metadata.num_prompts:
# A prompt input.
temperatures.append(temperature)
else:
# A generation token.
temperatures += [temperature] * len(seq_ids)
return temperatures
def _get_top_ps(
input_metadata: InputMetadata,
) -> List[float]:
top_ps: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# A prompt input.
top_ps.append(sampling_params.top_p)
else:
# A generation token.
top_ps += [sampling_params.top_p] * len(seq_ids)
return top_ps
def _apply_top_p(
probs: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
# TODO(woosuk): Optimize.
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
probs = torch.gather(
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
return probs
def _get_topk_logprobs(
logprobs: torch.Tensor,
num_logprobs: int,
) -> Dict[int, float]:
if num_logprobs == 0:
return {}
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
if num_logprobs == 1:
topk_logprobs = [topk_logprobs.item()]
topk_ids = [topk_ids.item()]
else:
topk_logprobs = topk_logprobs.tolist()
topk_ids = topk_ids.tolist()
token_to_logprob: Dict[int, float] = {}
for token_id, logprob in zip(topk_ids, topk_logprobs):
token_to_logprob[token_id] = logprob
return token_to_logprob
def _sample_from_prompt(
prob: torch.Tensor,
sampling_params: SamplingParams,
) -> List[int]:
if sampling_params.use_beam_search:
# Beam search.
beam_width = sampling_params.n
_, next_token_ids = torch.topk(prob, beam_width)
next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature == 0.0:
# Greedy sampling.
assert sampling_params.n == 1
next_token_id = torch.argmax(prob)
next_token_ids = [next_token_id.item()]
else:
# Neucleus sampling.
# Sample n tokens for the prompt.
n = sampling_params.n
next_token_ids = torch.multinomial(
prob, num_samples=n, replacement=True)
2023-02-23 09:26:09 +00:00
next_token_ids = next_token_ids.tolist()
return next_token_ids
def _sample_from_generation_tokens(
seq_ids: List[int],
probs: torch.Tensor,
logprobs: torch.Tensor,
seq_logprobs: List[float],
sampling_params: SamplingParams,
) -> Tuple[List[int], List[int]]:
# NOTE(woosuk): sampling_params.n can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if sampling_params.use_beam_search:
# Beam search.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs = torch.tensor(
seq_logprobs, dtype=torch.float, device=logprobs.device)
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
vocab_size = logprobs.size(-1)
beam_width = len(seq_ids)
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
topk_ids = topk_ids.tolist()
seq_idx = [i // vocab_size for i in topk_ids]
beam_seq_ids = [seq_ids[i] for i in seq_idx]
token_ids = [i % vocab_size for i in topk_ids]
beam_outputs: Dict[int, Tuple[int, int]] = {}
outstanding_beams: List[Tuple[int, int]] = []
# If a beam survives, continue with it.
for seq_id, token_id in zip(beam_seq_ids, token_ids):
if seq_id not in beam_outputs:
beam_outputs[seq_id] = (seq_id, token_id)
else:
outstanding_beams.append((seq_id, token_id))
# If a beam is discarded, fork another beam.
for seq_id in seq_ids:
if seq_id not in beam_outputs:
beam_outputs[seq_id] = outstanding_beams.pop()
assert not outstanding_beams
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
elif sampling_params.temperature == 0.0:
# Greedy sampling.
assert len(seq_ids) == 1
next_token_id = torch.argmax(probs, dim=-1)
next_token_ids = [next_token_id.item()]
parent_seq_ids = seq_ids
else:
# Neucleus sampling.
# Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial(
probs, num_samples=1, replacement=True)
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
parent_seq_ids = seq_ids
return parent_seq_ids, next_token_ids
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]:
seq_outputs: Dict[int, SequenceOutputs] = {}
# TODO(woosuk): Optimize.
idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input.
assert len(seq_ids) == sampling_params.n
prob = probs[idx]
logprob = logprobs[idx]
idx += 1
# Sample the next tokens.
next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(
logprob, sampling_params.num_logprobs)
# Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id, seq_id, next_token_id, output_logprobs)
else:
# Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)]
logprob = logprobs[idx:idx + len(seq_ids)]
idx += len(seq_ids)
# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs: Dict[int, Dict[int, float]] = {}
for i, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs(
logprob[i], sampling_params.num_logprobs)
# Build the output.
for seq_id, parent_seq_id, next_token_id in zip(
seq_ids, parent_seq_ids, next_token_ids):
i = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id,
parent_seq_id,
next_token_id,
output_logprobs,
)
2023-02-23 09:26:09 +00:00
return seq_outputs