vllm/tests/models/utils.py
Russell Bryant e489ad7a21
[Misc] Add SPDX-License-Identifier headers to python source files (#12628)
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**

commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:18:24 2025 -0500

    Add SPDX license headers to python source files
    
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
    also be easily used by tools to help manage license compliance.
    
The Linux Foundation runs license scans against the codebase to help
ensure
    we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
    
    More information can be found on the SPDX site:
    
    - https://spdx.dev/learn/handling-license-info/
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:36:32 2025 -0500

    Check for SPDX headers using pre-commit
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-02-02 11:58:18 -08:00

288 lines
12 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from vllm.config import ModelConfig, TaskOption
from vllm.inputs import InputContext
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
TokensText = Tuple[List[int], str]
def check_outputs_equal(
*,
outputs_0_lst: Sequence[TokensText],
outputs_1_lst: Sequence[TokensText],
name_0: str,
name_1: str,
):
"""
Compare the two sequences generated by different models,
which should be equal.
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1
# The text and token outputs should exactly match
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_str_0 == output_str_1, fail_msg
assert output_ids_0 == output_ids_1, fail_msg
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * List of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
float]],
SampleLogprobs]]]
# Allow for tokens to be represented as str's rather than IDs;
# tuple of
# * Token string representations list
# * String
# * Optional list of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
List[Dict[str,
Logprob]]]]]
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * Optional list of top sample logprobs for each sampled token
# * Optional list of top prompt logprobs for each prompt token
#
# Allows prompt logprobs to be requested.
TokensTextLogprobsPromptLogprobs = Tuple[
List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]
def check_logprobs_close(
*,
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
name_0: str,
name_1: str,
num_outputs_0_skip_tokens: int = 0,
warn_on_mismatch: bool = True,
always_check_logprobs: bool = False,
) -> None:
"""Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
How sample logprobs are compared:
* `always_check_logprobs == True`: set of highest-logprob token ids
must match between seq0 and seq1 at all sampled token offsets
* `always_check_logprobs == False`: highest-logprob token ids are
only compared at sampled token offsets for which generated token
ids don't match
Prompt logprobs must be provided either for both input sequences, or
for neither. If prompt logprobs are provided, then highest-logprob
prompt token ids must match between seq0 and seq1 at all prompt token
offsets.
Args:
outputs_0_lst: First sequence to compare
outputs_0_lst: Second sequence to compare
name_0: sequence #0 name
name_1: sequence #1 name
num_outputs_0_skip_tokens: If > 0, specifies the number of initial
sequence #0 tokens & logprobs to discard
before comparison, i.e. all
of sequence #1 will be compared to
sequence #0 beginning at index
num_outputs_0_skip_tokens
warn_on_mismatch: Issue a warning if there is token-wise or text-wise
mismatch between the two sequences
always_check_logprobs: If true, check logprobs even when tokens match
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
# Loop through responses to each prompt.
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
assert len(outputs_0) == len(outputs_1)
if len(outputs_0) == 3:
assert len(outputs_1) == 3
# Break out tokens, text & sample logprobs
# (prompt logprobs were not provided)
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
elif len(outputs_0) == 4:
assert len(outputs_1) == 4
# Break out tokens, text, sample logprobs & prompt logprobs
(
output_ids_0,
output_str_0,
logprobs_0,
prompt_logprobs_0,
) = outputs_0
(
output_ids_1,
output_str_1,
logprobs_1,
prompt_logprobs_1,
) = outputs_1
# Test prompt logprobs closeness
if (prompt_logprobs_0 is not None
and prompt_logprobs_1 is not None):
# Both sequences' prompt logprobs lists are not `None``
# (although individual list elements may be `None`);
# for each token's logprobs:
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
zip(prompt_logprobs_0, prompt_logprobs_1)):
fail_msg = (
f"Prompt logprobs test:"
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
if logprobs_elem_0 is None:
# If the seq 0 token's logprobs are `None`,
# the seq 1 token's logprobs must be `None`
assert logprobs_elem_1 is None, fail_msg
else:
# If the seq 0 token's logprobs are not `None`,
# the seq 1 token's logprobs must not be `None`
assert logprobs_elem_1 is not None, fail_msg
# Logprobs check: top-k token choices must be the same
assert (set(logprobs_elem_0.keys()) == set(
logprobs_elem_1.keys())), fail_msg
else:
# Both sequence logprobs lists must be `None`
fail_msg = (f"Prompt logprobs test:"
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
assert (prompt_logprobs_0 is None
and prompt_logprobs_1 is None), fail_msg
else:
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
f"{len(outputs_0)} elements were provided: "
f"{outputs_0}")
if logprobs_0 is None:
logprobs_0 = [None] * len(output_ids_0)
if logprobs_1 is None:
logprobs_1 = [None] * len(output_ids_1)
# Skip specified number of initial sequence #0 tokens
# & logprobs, leaving output text as-is for simplicity
# (text mismatches may generate warnings but do not
# cause the test to fail.)
if num_outputs_0_skip_tokens < 0:
raise ValueError("num_outputs_0_skip_tokens must be non-negative")
output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
# Loop through generated tokens.
for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
is_tok_mismatch = output_id_0 != output_id_1
# If generated tokens don't match
# or it is desired to always check logprobs,
# then
if is_tok_mismatch or always_check_logprobs:
logprobs_elem_0 = logprobs_0[idx]
logprobs_elem_1 = logprobs_1[idx]
# Each predicted token must be in top N logprobs of the other
fail_msg = (
f"Test{prompt_idx}:"
f"\nMatched tokens:\t{output_ids_0[:idx]}"
f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
assert logprobs_elem_0 is not None, fail_msg
assert logprobs_elem_1 is not None, fail_msg
assert output_id_0 in logprobs_elem_1, fail_msg
assert output_id_1 in logprobs_elem_0, fail_msg
if warn_on_mismatch and is_tok_mismatch:
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
# Break out since sequences will now diverge.
break
else:
if output_str_0 != output_str_1 and warn_on_mismatch:
# The token outputs exactly match,
# so the text outputs should exactly match as well
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
def build_model_context(model_name: str,
task: TaskOption = "auto",
tokenizer_name: Optional[str] = None,
trust_remote_code: bool = False,
dtype: Optional[Union[str, torch.dtype]] = None,
mm_processor_kwargs: Optional[Dict] = None,
limit_mm_per_prompt: Optional[Dict] = None):
"""Creates an InputContext for a given model.
Args:
model_name: Name of the model being considered.
tokenizer_name: Name of the tokenizer being considered.
trust_remote_code: Whether or not to allow loading remote code.
mm_processor_kwargs: optional processor kwargs for to be leveraged
in the input processor, mapper, dummy data creation, etc.
limit_mm_per_prompt: Multimodal limits.
Returns:
InputContext for the model being considered.
"""
if tokenizer_name is None:
tokenizer_name = model_name
if dtype is None:
dtype = "half"
model_config = ModelConfig(
model_name,
task=task,
tokenizer=tokenizer_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
dtype=dtype,
seed=0,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt=limit_mm_per_prompt,
)
return InputContext(model_config)