[Core] Reduce unnecessary compute when logprobs=None (#6532)
This commit is contained in:
parent
766435e660
commit
db9e5708a9
@ -14,7 +14,7 @@ MODELS = ["facebook/opt-125m"]
|
||||
@pytest.mark.parametrize("dtype",
|
||||
["float"]) # needed for comparing logprobs with HF
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
|
||||
@pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size
|
||||
@pytest.mark.parametrize("detokenize", [True, False])
|
||||
def test_get_prompt_logprobs(
|
||||
hf_runner,
|
||||
@ -63,7 +63,10 @@ def test_get_prompt_logprobs(
|
||||
assert result.outputs[0].logprobs is not None
|
||||
assert len(result.outputs[0].logprobs) == max_tokens
|
||||
for logprobs in result.outputs[0].logprobs:
|
||||
assert len(logprobs) == num_top_logprobs
|
||||
# If the output token is not included in the top X
|
||||
# logprob, it can return 1 more data
|
||||
assert (len(logprobs) == num_top_logprobs
|
||||
or len(logprobs) == num_top_logprobs + 1)
|
||||
output_text = result.outputs[0].text
|
||||
output_string_from_most_likely_tokens_lst: List[str] = []
|
||||
for top_logprobs in result.outputs[0].logprobs:
|
||||
@ -135,3 +138,35 @@ def test_max_logprobs():
|
||||
bad_sampling_params = SamplingParams(logprobs=2)
|
||||
with pytest.raises(ValueError):
|
||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("detokenize", [True, False])
|
||||
def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
|
||||
detokenize: bool, example_prompts):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
max_tokens = 5
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,
|
||||
) as vllm_model:
|
||||
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=None,
|
||||
temperature=0.0,
|
||||
detokenize=detokenize)
|
||||
results_logprobs_none = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_none)
|
||||
|
||||
for i in range(len(results_logprobs_none)):
|
||||
assert results_logprobs_none[i].outputs[0].logprobs is None
|
||||
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
import itertools
|
||||
from math import inf
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -774,8 +775,11 @@ def _get_logprobs(
|
||||
# The next token ids to get the logprob value from.
|
||||
next_token_ids: List[int] = []
|
||||
# The largest requested number of logprobs. We find logprobs as many as the
|
||||
# largest num logprobs in this API.
|
||||
largest_num_logprobs = 1
|
||||
# largest num logprobs in this API. If every logprobs is None, it will be
|
||||
# set to -1.
|
||||
largest_num_logprobs = -1
|
||||
# If beam search is enabled.
|
||||
use_beam_search = False
|
||||
|
||||
# Select indices to compute logprob from, ranks of token ids, and the top
|
||||
# k token ids from logprobs.
|
||||
@ -808,6 +812,8 @@ def _get_logprobs(
|
||||
largest_num_logprobs = max(largest_num_logprobs,
|
||||
sampling_params.logprobs)
|
||||
|
||||
use_beam_search = use_beam_search or sampling_params.use_beam_search
|
||||
|
||||
assert len(next_token_ids) == len(query_indices)
|
||||
|
||||
if len(query_indices) == 0:
|
||||
@ -815,35 +821,40 @@ def _get_logprobs(
|
||||
empty_prompt_logprob: Optional[PromptLogprobs] = None
|
||||
return [empty_prompt_logprob], [empty_sampled_logprob]
|
||||
|
||||
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
|
||||
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
|
||||
selected_logprobs, ranks = None, None
|
||||
top_logprobs, top_token_ids = None, None
|
||||
|
||||
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
|
||||
# contain duplicates if beam search is enabled.
|
||||
selected_logprobs = logprobs[[
|
||||
query_indices_gpu,
|
||||
next_token_ids_gpu,
|
||||
]]
|
||||
ranks = _get_ranks(
|
||||
logprobs[query_indices_gpu],
|
||||
next_token_ids_gpu,
|
||||
)
|
||||
assert selected_logprobs.shape[0] == ranks.shape[0]
|
||||
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
|
||||
# skip the whole logprob calculation.
|
||||
if largest_num_logprobs >= 0 or use_beam_search:
|
||||
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
|
||||
next_token_ids_gpu = torch.tensor(next_token_ids,
|
||||
device=logprobs.device)
|
||||
|
||||
# Logprobs of topk tokens for a batch of sequence groups.
|
||||
# (num_query_tokens_across_batch).
|
||||
if largest_num_logprobs > 0:
|
||||
top_logprobs, top_token_ids = torch.topk(logprobs,
|
||||
largest_num_logprobs,
|
||||
dim=-1)
|
||||
else:
|
||||
top_logprobs, top_token_ids = None, None
|
||||
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
|
||||
# contain duplicates if beam search is enabled.
|
||||
selected_logprobs = logprobs[[
|
||||
query_indices_gpu,
|
||||
next_token_ids_gpu,
|
||||
]]
|
||||
ranks = _get_ranks(
|
||||
logprobs[query_indices_gpu],
|
||||
next_token_ids_gpu,
|
||||
)
|
||||
assert selected_logprobs.shape[0] == ranks.shape[0]
|
||||
|
||||
selected_logprobs = selected_logprobs.to('cpu')
|
||||
ranks = ranks.to('cpu')
|
||||
if top_logprobs is not None and top_token_ids is not None:
|
||||
top_logprobs = top_logprobs.to('cpu')
|
||||
top_token_ids = top_token_ids.to('cpu')
|
||||
# We need to compute top k only if there exists logprobs > 0.
|
||||
if largest_num_logprobs > 0:
|
||||
# Logprobs of topk tokens for a batch of sequence groups.
|
||||
# (num_query_tokens_across_batch).
|
||||
top_logprobs, top_token_ids = torch.topk(logprobs,
|
||||
largest_num_logprobs,
|
||||
dim=-1)
|
||||
top_logprobs = top_logprobs.to('cpu')
|
||||
top_token_ids = top_token_ids.to('cpu')
|
||||
|
||||
selected_logprobs = selected_logprobs.to('cpu')
|
||||
ranks = ranks.to('cpu')
|
||||
|
||||
# Find prompt/sample logprobs.
|
||||
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
|
||||
@ -940,45 +951,52 @@ def _get_sampled_logprob_if_needed(
|
||||
):
|
||||
"""Compute the sample logprob if needed."""
|
||||
seq_ids = seq_group.seq_ids
|
||||
num_logprobs = seq_group.sampling_params.logprobs or 0
|
||||
num_logprobs = seq_group.sampling_params.logprobs
|
||||
use_beam_search = seq_group.sampling_params.use_beam_search
|
||||
sampled_logprobs: SampleLogprobs = []
|
||||
next_token_ids, parent_seq_ids = sample_result
|
||||
|
||||
if seq_group.do_sample:
|
||||
assert len(next_token_ids) > 0
|
||||
# Pre-select items from tensor. tolist() is faster than repetitive
|
||||
# `.item()` calls.
|
||||
selected_logprob_items = selected_logprobs[
|
||||
selected_logprobs_idx:selected_logprobs_idx +
|
||||
len(next_token_ids)].tolist()
|
||||
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
|
||||
len(next_token_ids)].tolist()
|
||||
for idx, (next_token_id,
|
||||
parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)):
|
||||
# Get the logprob of a sampled token.
|
||||
sampled_logprobs_dict = {
|
||||
next_token_id: (selected_logprob_items[idx], rank_items[idx])
|
||||
}
|
||||
# Get top K logprobs.
|
||||
if num_logprobs > 0:
|
||||
top_ids = top_token_ids[top_logprob_idx +
|
||||
parent_id, :num_logprobs].tolist()
|
||||
top_probs = top_logprobs[top_logprob_idx +
|
||||
parent_id, :num_logprobs].tolist()
|
||||
# Top K is already sorted by rank, so we can use 1 ~
|
||||
# num_logprobs + 1 for rank.
|
||||
top_ranks = range(1, num_logprobs + 1)
|
||||
sampled_logprobs_dict.update({
|
||||
top_id: (top_prob, rank)
|
||||
for top_id, top_prob, rank in zip(top_ids, top_probs,
|
||||
top_ranks)
|
||||
})
|
||||
if num_logprobs is None and not use_beam_search:
|
||||
for next_token_id in next_token_ids:
|
||||
# Use a dummy logprob
|
||||
sampled_logprobs.append({next_token_id: Logprob(inf)})
|
||||
else:
|
||||
# Pre-select items from tensor. tolist() is faster than repetitive
|
||||
# `.item()` calls.
|
||||
selected_logprob_items = selected_logprobs[
|
||||
selected_logprobs_idx:selected_logprobs_idx +
|
||||
len(next_token_ids)].tolist()
|
||||
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
|
||||
len(next_token_ids)].tolist()
|
||||
for idx, (next_token_id, parent_id) in enumerate(
|
||||
zip(next_token_ids, parent_seq_ids)):
|
||||
# Get the logprob of a sampled token.
|
||||
sampled_logprobs_dict = {
|
||||
next_token_id:
|
||||
(selected_logprob_items[idx], rank_items[idx])
|
||||
}
|
||||
if num_logprobs is not None and num_logprobs > 0:
|
||||
# Get top K logprobs.
|
||||
top_ids = top_token_ids[top_logprob_idx +
|
||||
parent_id, :num_logprobs].tolist()
|
||||
top_probs = top_logprobs[
|
||||
top_logprob_idx + parent_id, :num_logprobs].tolist()
|
||||
# Top K is already sorted by rank, so we can use 1 ~
|
||||
# num_logprobs + 1 for rank.
|
||||
top_ranks = range(1, num_logprobs + 1)
|
||||
sampled_logprobs_dict.update({
|
||||
top_id: (top_prob, rank)
|
||||
for top_id, top_prob, rank in zip(
|
||||
top_ids, top_probs, top_ranks)
|
||||
})
|
||||
|
||||
sampled_logprobs.append({
|
||||
token_id: Logprob(*logprob_and_rank)
|
||||
for token_id, logprob_and_rank in
|
||||
sampled_logprobs_dict.items()
|
||||
})
|
||||
sampled_logprobs.append({
|
||||
token_id: Logprob(*logprob_and_rank)
|
||||
for token_id, logprob_and_rank in
|
||||
sampled_logprobs_dict.items()
|
||||
})
|
||||
|
||||
# NOTE: This part of code is not intuitive. `selected_logprobs` include
|
||||
# logprobs for the current step, which has len(next_token_ids) tokens
|
||||
|
@ -29,7 +29,7 @@ class CompletionOutput:
|
||||
index: int
|
||||
text: str
|
||||
token_ids: Tuple[int, ...]
|
||||
cumulative_logprob: float
|
||||
cumulative_logprob: Optional[float]
|
||||
logprobs: Optional[SampleLogprobs]
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Union[int, str, None] = None
|
||||
@ -124,13 +124,14 @@ class RequestOutput:
|
||||
include_logprobs = seq_group.sampling_params.logprobs is not None
|
||||
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
|
||||
outputs = [
|
||||
CompletionOutput(seqs.index(seq),
|
||||
seq.get_output_text_to_return(text_buffer_length),
|
||||
seq.get_output_token_ids(),
|
||||
seq.get_cumulative_logprob(),
|
||||
seq.output_logprobs if include_logprobs else None,
|
||||
SequenceStatus.get_finished_reason(seq.status),
|
||||
seq.stop_reason) for seq in top_n_seqs
|
||||
CompletionOutput(
|
||||
seqs.index(seq),
|
||||
seq.get_output_text_to_return(text_buffer_length),
|
||||
seq.get_output_token_ids(),
|
||||
seq.get_cumulative_logprob() if include_logprobs else None,
|
||||
seq.output_logprobs if include_logprobs else None,
|
||||
SequenceStatus.get_finished_reason(seq.status),
|
||||
seq.stop_reason) for seq in top_n_seqs
|
||||
]
|
||||
|
||||
# Every sequence in the sequence group should have the same prompt.
|
||||
|
@ -92,11 +92,12 @@ class SamplingParams:
|
||||
min_tokens: Minimum number of tokens to generate per output sequence
|
||||
before EOS or stop_token_ids can be generated
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
Note that the implementation follows the OpenAI API: The return
|
||||
result includes the log probabilities on the `logprobs` most likely
|
||||
tokens, as well the chosen tokens. The API will always return the
|
||||
log probability of the sampled token, so there may be up to
|
||||
`logprobs+1` elements in the response.
|
||||
When set to None, no probability is returned. If set to a non-None
|
||||
value, the result includes the log probabilities of the specified
|
||||
number of most likely tokens, as well as the chosen tokens.
|
||||
Note that the implementation follows the OpenAI API: The API will
|
||||
always return the log probability of the sampled token, so there
|
||||
may be up to `logprobs+1` elements in the response.
|
||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||
detokenize: Whether to detokenize the output. Defaults to True.
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
@ -168,8 +169,8 @@ class SamplingParams:
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.min_tokens = min_tokens
|
||||
self.logprobs = logprobs
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.logprobs = 1 if logprobs is True else logprobs
|
||||
self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
|
||||
# NOTE: This parameter is only exposed at the engine level for now.
|
||||
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
||||
# not support returning only a list of token IDs.
|
||||
|
Loading…
x
Reference in New Issue
Block a user