2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
import pytest
|
|
|
|
|
2024-10-08 08:12:56 -06:00
|
|
|
from vllm.inputs import zip_enc_dec_prompts
|
2024-08-09 10:39:41 +08:00
|
|
|
from vllm.inputs.parse import parse_and_batch_prompt
|
2024-05-29 04:29:31 +08:00
|
|
|
|
|
|
|
STRING_INPUTS = [
|
|
|
|
'',
|
|
|
|
'foo',
|
|
|
|
'foo bar',
|
|
|
|
'foo baz bar',
|
|
|
|
'foo bar qux baz',
|
|
|
|
]
|
|
|
|
|
|
|
|
TOKEN_INPUTS = [
|
|
|
|
[-1],
|
|
|
|
[1],
|
|
|
|
[1, 2],
|
|
|
|
[1, 3, 4],
|
|
|
|
[1, 2, 4, 3],
|
|
|
|
]
|
|
|
|
|
|
|
|
INPUTS_SLICES = [
|
|
|
|
slice(None, None, -1),
|
|
|
|
slice(None, None, 2),
|
|
|
|
slice(None, None, -2),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
def test_parse_single_batch_empty():
|
|
|
|
with pytest.raises(ValueError, match="at least one prompt"):
|
|
|
|
parse_and_batch_prompt([])
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="at least one prompt"):
|
|
|
|
parse_and_batch_prompt([[]])
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('string_input', STRING_INPUTS)
|
|
|
|
def test_parse_single_batch_string_consistent(string_input: str):
|
|
|
|
assert parse_and_batch_prompt(string_input) \
|
|
|
|
== parse_and_batch_prompt([string_input])
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
|
2025-03-03 01:34:51 +00:00
|
|
|
def test_parse_single_batch_token_consistent(token_input: list[int]):
|
2024-05-29 04:29:31 +08:00
|
|
|
assert parse_and_batch_prompt(token_input) \
|
|
|
|
== parse_and_batch_prompt([token_input])
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
|
|
|
|
def test_parse_single_batch_string_slice(inputs_slice: slice):
|
|
|
|
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
|
|
|
|
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
|
2024-10-08 08:12:56 -06:00
|
|
|
|
|
|
|
|
|
|
|
# yapf: disable
|
|
|
|
@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [
|
|
|
|
(None, [{}, {}]),
|
|
|
|
({}, [{}, {}]),
|
|
|
|
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
|
|
|
|
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
|
|
|
|
])
|
|
|
|
# yapf: enable
|
|
|
|
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
|
|
|
|
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
|
|
|
|
encoder_prompts = ['An encoder prompt', 'Another encoder prompt']
|
|
|
|
decoder_prompts = ['A decoder prompt', 'Another decoder prompt']
|
|
|
|
zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts,
|
|
|
|
mm_processor_kwargs)
|
|
|
|
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
|
|
|
|
for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts,
|
|
|
|
expected_mm_kwargs,
|
|
|
|
zipped_prompts):
|
|
|
|
assert isinstance(zipped, dict)
|
|
|
|
assert len(zipped.keys()) == 3
|
|
|
|
assert zipped['encoder_prompt'] == enc
|
|
|
|
assert zipped['decoder_prompt'] == dec
|
|
|
|
assert zipped['mm_processor_kwargs'] == exp_kwargs
|