[Core] *Prompt* logprobs support in Multi-step (#8199)

This commit is contained in:
afeldman-nm 2024-09-18 11:38:43 -04:00 committed by GitHub
parent 7c7714d856
commit a8c1d161a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 299 additions and 58 deletions

View File

@ -20,6 +20,8 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
BatchFeature) BatchFeature)
from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.models.auto.auto_factory import _BaseAutoModelClass
from tests.models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs)
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
@ -33,7 +35,6 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts) to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity, is_cpu) identity, is_cpu)
@ -469,7 +470,7 @@ class HfRunner:
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
videos: Optional[List[np.ndarray]] = None, videos: Optional[List[np.ndarray]] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: ) -> List[TokensTextLogprobs]:
all_logprobs: List[List[Dict[int, float]]] = [] all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = [] all_output_ids: List[List[int]] = []
all_output_strs: List[str] = [] all_output_strs: List[str] = []
@ -525,7 +526,7 @@ class HfRunner:
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: ) -> List[TokensTextLogprobs]:
''' '''
Greedy logprobs generation for vLLM encoder/decoder models Greedy logprobs generation for vLLM encoder/decoder models
''' '''
@ -653,14 +654,16 @@ class VllmRunner:
@staticmethod @staticmethod
def _final_steps_generate_w_logprobs( def _final_steps_generate_w_logprobs(
req_outputs: List[RequestOutput], req_outputs: List[RequestOutput],
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ) -> List[TokensTextLogprobsPromptLogprobs]:
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] outputs: List[TokensTextLogprobsPromptLogprobs] = []
for req_output in req_outputs: for req_output in req_outputs:
assert len(req_output.outputs) > 0
for sample in req_output.outputs: for sample in req_output.outputs:
output_str = sample.text output_str = sample.text
output_ids = list(sample.token_ids) output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs)) outputs.append((output_ids, output_str, output_logprobs,
req_output.prompt_logprobs))
return outputs return outputs
def generate_w_logprobs( def generate_w_logprobs(
@ -670,7 +673,8 @@ class VllmRunner:
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
assert sampling_params.logprobs is not None assert sampling_params.logprobs is not None
if images is not None: if images is not None:
@ -695,13 +699,20 @@ class VllmRunner:
req_outputs = self.model.generate(inputs, req_outputs = self.model.generate(inputs,
sampling_params=sampling_params) sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)
toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)
def generate_encoder_decoder_w_logprobs( def generate_encoder_decoder_w_logprobs(
self, self,
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
sampling_params: SamplingParams, sampling_params: SamplingParams,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
''' '''
Logprobs generation for vLLM encoder/decoder models Logprobs generation for vLLM encoder/decoder models
''' '''
@ -709,7 +720,12 @@ class VllmRunner:
assert sampling_params.logprobs is not None assert sampling_params.logprobs is not None
req_outputs = self.model.generate(encoder_decoder_prompts, req_outputs = self.model.generate(encoder_decoder_prompts,
sampling_params=sampling_params) sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs) toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)
def generate_greedy( def generate_greedy(
self, self,
@ -727,44 +743,48 @@ class VllmRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
num_prompt_logprobs: Optional[int] = None,
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ) -> Union[List[TokensTextLogprobs],
greedy_logprobs_params = SamplingParams(temperature=0.0, List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=num_logprobs, logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs),
stop_token_ids=stop_token_ids) stop_token_ids=stop_token_ids)
outputs = self.generate_w_logprobs(prompts,
return self.generate_w_logprobs(prompts,
greedy_logprobs_params, greedy_logprobs_params,
images=images, images=images,
audios=audios, audios=audios,
videos=videos) videos=videos)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
def generate_encoder_decoder_greedy_logprobs( def generate_encoder_decoder_greedy_logprobs(
self, self,
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: num_prompt_logprobs: Optional[int] = None,
greedy_logprobs_params = SamplingParams(temperature=0.0, ) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
use_beam_search=False, use_beam_search=False,
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=num_logprobs) logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs),
)
''' '''
Greedy logprobs generation for vLLM encoder/decoder models Greedy logprobs generation for vLLM encoder/decoder models
''' '''
outputs = self.generate_encoder_decoder_w_logprobs( return self.generate_encoder_decoder_w_logprobs(
encoder_decoder_prompts, greedy_logprobs_params) encoder_decoder_prompts, greedy_logprobs_params)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
def generate_beam_search( def generate_beam_search(
self, self,
prompts: List[str], prompts: List[str],

View File

@ -1,7 +1,7 @@
import warnings import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import Logprob, SampleLogprobs from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
TokensText = Tuple[List[int], str] TokensText = Tuple[List[int], str]
@ -34,20 +34,47 @@ def check_outputs_equal(
assert output_ids_0 == output_ids_1, fail_msg assert output_ids_0 == output_ids_1, fail_msg
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * List of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int, TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
float]], float]],
SampleLogprobs]]] SampleLogprobs]]]
# Allow for tokens to be represented as str's rather than IDs # Allow for tokens to be represented as str's rather than IDs;
# tuple of
# * Token string representations list
# * String
# * Optional list of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]], TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
List[Dict[str, List[Dict[str,
Logprob]]]]] Logprob]]]]]
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * Optional list of top sample logprobs for each sampled token
# * Optional list of top prompt logprobs for each prompt token
#
# Allows prompt logprobs to be requested.
TokensTextLogprobsPromptLogprobs = Tuple[
List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]
def check_logprobs_close( def check_logprobs_close(
*, *,
outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], outputs_0_lst: Sequence[Union[TokensTextLogprobs,
outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
name_0: str, name_0: str,
name_1: str, name_1: str,
num_outputs_0_skip_tokens: int = 0, num_outputs_0_skip_tokens: int = 0,
@ -57,6 +84,18 @@ def check_logprobs_close(
"""Compare the logprobs of two sequences generated by different models, """Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal. which should be similar but not necessarily equal.
How sample logprobs are compared:
* `always_check_logprobs == True`: set of highest-logprob token ids
must match between seq0 and seq1 at all sampled token offsets
* `always_check_logprobs == False`: highest-logprob token ids are
only compared at sampled token offsets for which generated token
ids don't match
Prompt logprobs must be provided either for both input sequences, or
for neither. If prompt logprobs are provided, then highest-logprob
prompt token ids must match between seq0 and seq1 at all prompt token
offsets.
Args: Args:
outputs_0_lst: First sequence to compare outputs_0_lst: First sequence to compare
outputs_0_lst: Second sequence to compare outputs_0_lst: Second sequence to compare
@ -78,8 +117,65 @@ def check_logprobs_close(
for prompt_idx, (outputs_0, for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst, outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)): outputs_1_lst)):
assert len(outputs_0) == len(outputs_1)
if len(outputs_0) == 3:
assert len(outputs_1) == 3
# Break out tokens, text & sample logprobs
# (prompt logprobs were not provided)
output_ids_0, output_str_0, logprobs_0 = outputs_0 output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1 output_ids_1, output_str_1, logprobs_1 = outputs_1
elif len(outputs_0) == 4:
assert len(outputs_1) == 4
# Break out tokens, text, sample logprobs & prompt logprobs
(
output_ids_0,
output_str_0,
logprobs_0,
prompt_logprobs_0,
) = outputs_0
(
output_ids_1,
output_str_1,
logprobs_1,
prompt_logprobs_1,
) = outputs_1
# Test prompt logprobs closeness
if (prompt_logprobs_0 is not None
and prompt_logprobs_1 is not None):
# Both sequences' prompt logprobs lists are not `None``
# (although individual list elements may be `None`);
# for each token's logprobs:
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
zip(prompt_logprobs_0, prompt_logprobs_1)):
fail_msg = (
f"Prompt logprobs test:"
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
if logprobs_elem_0 is None:
# If the seq 0 token's logprobs are `None`,
# the seq 1 token's logprobs must be `None`
assert logprobs_elem_1 is None, fail_msg
else:
# If the seq 0 token's logprobs are not `None`,
# the seq 1 token's logprobs must not be `None`
assert logprobs_elem_1 is not None, fail_msg
# Logprobs check: top-k token choices must be the same
assert (set(logprobs_elem_0.keys()) == set(
logprobs_elem_1.keys())), fail_msg
else:
# Both sequence logprobs lists must be `None`
fail_msg = (f"Prompt logprobs test:"
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
assert (prompt_logprobs_0 is None
and prompt_logprobs_1 is None), fail_msg
else:
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
f"{len(outputs_0)} elements were provided: "
f"{outputs_0}")
if logprobs_0 is None: if logprobs_0 is None:
logprobs_0 = [None] * len(output_ids_0) logprobs_0 = [None] * len(output_ids_0)

View File

@ -100,3 +100,95 @@ def test_multi_step_llm(
name_0="hf", name_0="hf",
name_1="vllm", name_1="vllm",
) )
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)])
def test_multi_step_llm_w_prompt_logprobs(
vllm_runner,
example_prompts,
model: str,
dtype: str,
tp_size: int,
max_tokens: int,
enforce_eager: int,
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
) -> None:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
Set up a vLLM engine instance w/ single-step scheduling as a ground-truth
reference.
Prompt them with the same example prompts.
Validate:
* All generated logprobs are all very close
Args:
hf_runner: HF transformers model runner fixture
vllm_runner: vLLM model runner fixture
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
dtype: tensor datatype for engine to utilize
tp_size: degree of tensor-parallelism
max_tokens: the maximum number of tokens to generate
enforce_eager
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
num_prompts: number of example prompts under test
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
num_prompt_logprobs: number of logprobs to return for each prompt token;
note that this argument is not supported by the
OpenAI completions endpoint.
"""
prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts
with vllm_runner(
model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
prompts,
max_tokens,
num_logprobs,
num_prompt_logprobs=num_prompt_logprobs)
with vllm_runner(
model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
) as vllm_model:
single_step_vllm_outputs = vllm_model.generate_greedy_logprobs(
prompts,
max_tokens,
num_logprobs,
num_prompt_logprobs=num_prompt_logprobs)
check_logprobs_close(
outputs_0_lst=single_step_vllm_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)

View File

@ -493,6 +493,7 @@ async def completions_with_server_args(
''' '''
outputs = None outputs = None
max_wait_seconds = 240 * 3 # 240 is default
with RemoteOpenAIServer(model_name, with RemoteOpenAIServer(model_name,
server_cli_args, server_cli_args,
max_wait_seconds=max_wait_seconds) as server: max_wait_seconds=max_wait_seconds) as server:
@ -503,7 +504,7 @@ async def completions_with_server_args(
stream=False, stream=False,
max_tokens=5, max_tokens=5,
logprobs=num_logprobs) logprobs=num_logprobs)
assert outputs is not None assert outputs is not None, "Completion API call failed."
return outputs return outputs

View File

@ -614,34 +614,66 @@ def _pythonize_sampler_output(
frozen_model_input = model_input.frozen_model_input frozen_model_input = model_input.frozen_model_input
assert frozen_model_input.sampling_metadata is not None assert frozen_model_input.sampling_metadata is not None
sampling_metadata = frozen_model_input.sampling_metadata
# samples generation should have been skipped # samples generation should have been skipped
assert not output.outputs assert not output.outputs
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
# We guarantee output tensors are ready, so it is safe to
# pythonize the sampler output & obtain CPU-side logprobs.
#
# However we should check whether logprobs pythonization may
# be skipped entirely, i.e. because no logprobs were requested
# or pythonization was not deferred. To that end,
#
# * `prompt_logprobs_are_requested_for_prefill` signals that
# there are *any* prefill-phase requests which specify that
# prompt logprobs should be returned.
#
# * `any_logprobs_are_requested` signals that there are any
# requests which (1) specify that sample logprobs should be
# returned, or (2) are in the prefill phase AND specify that
# prompt logprobs should be returned.
#
# Later on, these flags cause adjustments to the pythonization
# process to accommodate logprobs.
seq_groups = sampling_metadata.seq_groups
prompt_logprobs_are_requested_for_prefill = any([
sg.sampling_params.prompt_logprobs is not None and sg.is_prompt
for sg in seq_groups
])
any_logprobs_are_requested = (
prompt_logprobs_are_requested_for_prefill
or any([sg.sampling_params.logprobs is not None for sg in seq_groups]))
if prompt_logprobs_are_requested_for_prefill:
# CPU GPU sync, after gathering *only* sampled tokens (since
# requesting prompt logprobs leads `sampled_token_ids` to
# include prompt token ids in addition to sampled token ids.)
sample_idx_tensor = torch.tensor(
[sdx for sg in seq_groups for sdx in sg.sample_indices])
pinned_buffer = pinned_buffer.copy_(
sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
else:
# CPU GPU sync # CPU GPU sync
pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False) pinned_buffer = pinned_buffer.copy_(sampled_token_ids,
non_blocking=False)
# this will not block as the tensors are already on CPU # this will not block as the tensors are already on CPU
samples_list = pinned_buffer.tolist() samples_list = pinned_buffer.tolist()
sampling_metadata = frozen_model_input.sampling_metadata
skip_sampler_cpu_output = ( skip_sampler_cpu_output = (
frozen_model_input.sampling_metadata.skip_sampler_cpu_output) frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
# We are guaranteed output tensors are ready, so it is safe to # *Don't* skip logprobs pythonization *if*:
# pythonize the sampler output & obtain CPU-side logprobs. # * Any requests require logprobs to be returned in this
# # iteration AND
# However this computation may be skipped entirely # * These requests are being scheduled in a fashion which
# if no pythonization was deferred. # defers pythonization (i.e. multi-step scheduling.)
seq_groups = sampling_metadata.seq_groups
logprobs_are_requested = any([
sg.sampling_params.logprobs is not None
or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups
])
do_pythonize_logprobs = (skip_sampler_cpu_output do_pythonize_logprobs = (skip_sampler_cpu_output
and logprobs_are_requested) and any_logprobs_are_requested)
( (
prompt_logprobs, prompt_logprobs,
sample_logprobs, sample_logprobs,
@ -666,7 +698,7 @@ def _pythonize_sampler_output(
prompt_logprobs[sgdx], prompt_logprobs[sgdx],
sample_logprobs[sgdx], sample_logprobs[sgdx],
) )
elif logprobs_are_requested: elif any_logprobs_are_requested:
( (
group_prompt_logprobs, group_prompt_logprobs,
group_sample_logprobs, group_sample_logprobs,
@ -696,7 +728,7 @@ def _pythonize_sampler_output(
seq_output.parent_seq_id = seq_ids[parent_id] seq_output.parent_seq_id = seq_ids[parent_id]
seq_output.output_token = next_token_id seq_output.output_token = next_token_id
if logprobs_are_requested: if any_logprobs_are_requested:
seq_output.logprobs = group_sample_logprobs[tdx] seq_output.logprobs = group_sample_logprobs[tdx]
else: else:
logprobs = next(iter(seq_output.logprobs.values())) logprobs = next(iter(seq_output.logprobs.values()))
@ -714,7 +746,7 @@ def _pythonize_sampler_output(
seq_outputs.append( seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id, SequenceOutput(seq_ids[parent_id], next_token_id,
(group_sample_logprobs[tdx] (group_sample_logprobs[tdx]
if logprobs_are_requested else { if any_logprobs_are_requested else {
next_token_id: next_token_id:
Logprob(logprob=float('inf'), Logprob(logprob=float('inf'),
rank=None, rank=None,
@ -722,12 +754,12 @@ def _pythonize_sampler_output(
}))) })))
if cache is not None: if cache is not None:
completion_seq_group_output.prompt_logprobs = \ completion_seq_group_output.prompt_logprobs = \
group_prompt_logprobs if logprobs_are_requested else None group_prompt_logprobs if any_logprobs_are_requested else None
output.outputs.append(completion_seq_group_output) output.outputs.append(completion_seq_group_output)
else: else:
output.outputs.append( output.outputs.append(
CompletionSequenceGroupOutput( CompletionSequenceGroupOutput(
seq_outputs, (group_prompt_logprobs seq_outputs, (group_prompt_logprobs
if logprobs_are_requested else None))) if any_logprobs_are_requested else None)))
assert len(output.outputs) > 0 assert len(output.outputs) > 0