diff --git a/vllm/outputs.py b/vllm/outputs.py index 4f9eddee..b8173fd7 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -87,12 +87,12 @@ class RequestOutput: @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": - # Get the top-n sequences. - n = seq_group.sampling_params.n seqs = seq_group.get_seqs() - if n == 1: + if len(seqs) == 1: top_n_seqs = seqs else: + # Get the top-n sequences. + n = seq_group.sampling_params.n if seq_group.sampling_params.use_beam_search: sorting_key = lambda seq: seq.get_beam_search_score( seq_group.sampling_params.length_penalty)