[Fix] Fix best_of behavior when n=1 (#3298)

This commit is contained in:
Nick Hill 2024-03-10 19:17:46 -07:00 committed by GitHub
parent 9e8744a545
commit 4b59f00e91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -87,12 +87,12 @@ class RequestOutput:
@classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# Get the top-n sequences.
n = seq_group.sampling_params.n
seqs = seq_group.get_seqs()
if n == 1:
if len(seqs) == 1:
top_n_seqs = seqs
else:
# Get the top-n sequences.
n = seq_group.sampling_params.n
if seq_group.sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score(
seq_group.sampling_params.length_penalty)