vllm/tests/worker/test_model_runner.py

49 lines
1.9 KiB
Python
Raw Normal View History

import random
import torch
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
2023-12-02 22:17:33 -08:00
from vllm.worker.model_runner import ModelRunner
2023-12-02 22:17:33 -08:00
def test_prepare_prompt():
model_runner = ModelRunner(None, None, None)
model_runner.set_block_size(16)
batch_size = random.randint(1, 256)
prompt_lens = []
seq_group_metadata_list = []
for i in range(batch_size):
# make sure all tokens fit into one block
2023-12-02 22:17:33 -08:00
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len))
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
))
2023-12-02 22:17:33 -08:00
expected_selected_token_indices = []
selected_token_start_idx = 0
max_seq_len = max(prompt_lens)
for prompt_len in prompt_lens:
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
2023-12-02 22:17:33 -08:00
input_tokens, input_positions, _ = model_runner._prepare_prompt(
seq_group_metadata_list)
2023-12-02 22:17:33 -08:00
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len)
assert input_positions.shape == (batch_size, max_seq_len)
torch.testing.assert_close(input_tokens, input_positions)
2023-12-02 22:17:33 -08:00
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)