2024-06-30 11:44:25 +08:00
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
|
|
TokensText = Tuple[List[int], str]
|
|
|
|
|
|
|
|
|
|
|
|
def check_outputs_equal(outputs_0_lst: List[TokensText],
|
|
|
|
outputs_1_lst: List[TokensText], name_0: str,
|
|
|
|
name_1: str):
|
|
|
|
"""
|
|
|
|
Compare the two sequences generated by different models,
|
|
|
|
which should be equal.
|
|
|
|
"""
|
|
|
|
assert len(outputs_0_lst) == len(outputs_1_lst)
|
|
|
|
|
|
|
|
for prompt_idx, (outputs_0,
|
|
|
|
outputs_1) in enumerate(zip(outputs_0_lst,
|
|
|
|
outputs_1_lst)):
|
|
|
|
output_ids_0, output_str_0 = outputs_0
|
|
|
|
output_ids_1, output_str_1 = outputs_1
|
|
|
|
|
|
|
|
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:"
|
|
|
|
f"\n{name_0}:\t{output_str_0!r}"
|
|
|
|
f"\n{name_1}:\t{output_str_1!r}")
|
|
|
|
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
|
|
|
|
f"\n{name_0}:\t{output_str_0!r}"
|
|
|
|
f"\n{name_1}:\t{output_str_1!r}")
|
|
|
|
|
|
|
|
|
|
|
|
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
|
|
|
|
|
|
|
|
|
|
|
|
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
|
|
|
|
outputs_1_lst: List[TokensTextLogprobs], name_0: str,
|
|
|
|
name_1: str):
|
|
|
|
"""
|
|
|
|
Compare the logprobs of two sequences generated by different models,
|
2024-04-29 12:35:34 -04:00
|
|
|
which should be similar but not necessarily equal.
|
|
|
|
"""
|
2024-06-30 11:44:25 +08:00
|
|
|
assert len(outputs_0_lst) == len(outputs_1_lst)
|
|
|
|
|
2024-04-29 12:35:34 -04:00
|
|
|
# Loop through responses to each prompt.
|
|
|
|
for prompt_idx, (outputs_0,
|
|
|
|
outputs_1) in enumerate(zip(outputs_0_lst,
|
|
|
|
outputs_1_lst)):
|
|
|
|
output_ids_0, output_str_0, logprobs_0 = outputs_0
|
|
|
|
output_ids_1, output_str_1, logprobs_1 = outputs_1
|
|
|
|
|
|
|
|
# Loop through generated tokens.
|
|
|
|
for idx, (output_id_0,
|
|
|
|
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
|
|
|
|
|
|
|
# If generated tokens don't match, then
|
|
|
|
if output_id_0 != output_id_1:
|
|
|
|
# Each predicted token must be in top N logprobs of the other
|
|
|
|
assert output_id_0 in logprobs_1[idx], (
|
|
|
|
f"Test{prompt_idx}:"
|
|
|
|
f"\n{name_0}:\t{output_str_0!r}"
|
|
|
|
f"\n{name_1}:\t{output_str_1!r}")
|
|
|
|
assert output_id_1 in logprobs_0[idx], (
|
|
|
|
f"Test{prompt_idx}:"
|
|
|
|
f"\n{name_0}:\t{output_str_0!r}"
|
|
|
|
f"\n{name_1}:\t{output_str_1!r}")
|
|
|
|
|
|
|
|
# Break out since sequences will now diverge.
|
|
|
|
break
|