[Fix] Fix best_of behavior when n=1 (#3298)
This commit is contained in:
parent
9e8744a545
commit
4b59f00e91
@ -87,12 +87,12 @@ class RequestOutput:
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||
# Get the top-n sequences.
|
||||
n = seq_group.sampling_params.n
|
||||
seqs = seq_group.get_seqs()
|
||||
if n == 1:
|
||||
if len(seqs) == 1:
|
||||
top_n_seqs = seqs
|
||||
else:
|
||||
# Get the top-n sequences.
|
||||
n = seq_group.sampling_params.n
|
||||
if seq_group.sampling_params.use_beam_search:
|
||||
sorting_key = lambda seq: seq.get_beam_search_score(
|
||||
seq_group.sampling_params.length_penalty)
|
||||
|
Loading…
x
Reference in New Issue
Block a user