30 lines
1.3 KiB
Python
30 lines
1.3 KiB
Python
![]() |
def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1):
|
||
|
"""Compare the logprobs of two sequences generated by different models,
|
||
|
which should be similar but not necessarily equal.
|
||
|
"""
|
||
|
# Loop through responses to each prompt.
|
||
|
for prompt_idx, (outputs_0,
|
||
|
outputs_1) in enumerate(zip(outputs_0_lst,
|
||
|
outputs_1_lst)):
|
||
|
output_ids_0, output_str_0, logprobs_0 = outputs_0
|
||
|
output_ids_1, output_str_1, logprobs_1 = outputs_1
|
||
|
|
||
|
# Loop through generated tokens.
|
||
|
for idx, (output_id_0,
|
||
|
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||
|
|
||
|
# If generated tokens don't match, then
|
||
|
if output_id_0 != output_id_1:
|
||
|
# Each predicted token must be in top N logprobs of the other
|
||
|
assert output_id_0 in logprobs_1[idx], (
|
||
|
f"Test{prompt_idx}:"
|
||
|
f"\n{name_0}:\t{output_str_0!r}"
|
||
|
f"\n{name_1}:\t{output_str_1!r}")
|
||
|
assert output_id_1 in logprobs_0[idx], (
|
||
|
f"Test{prompt_idx}:"
|
||
|
f"\n{name_0}:\t{output_str_0!r}"
|
||
|
f"\n{name_1}:\t{output_str_1!r}")
|
||
|
|
||
|
# Break out since sequences will now diverge.
|
||
|
break
|