[Misc] Keep only one implementation of the create_dummy_prompt function. (#4716)
This commit is contained in:
parent
208b71bcc1
commit
e965d46184
@ -1,36 +1,8 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceData,
|
||||
SequenceGroup, SequenceGroupOutput, SequenceOutput)
|
||||
|
||||
|
||||
def create_dummy_prompt(
|
||||
request_id: str,
|
||||
prompt_length: int,
|
||||
block_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
use_beam_search: bool = False,
|
||||
best_of: int = 1,
|
||||
) -> SequenceGroup:
|
||||
if not block_size:
|
||||
block_size = prompt_length
|
||||
|
||||
# Create dummy prompt sequence with tokens 0...block_size-1
|
||||
# and prompt "0 ... block_size".
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
|
||||
seq_group = SequenceGroup(
|
||||
request_id, [prompt],
|
||||
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
|
||||
time.time(), lora_request)
|
||||
|
||||
return seq_group
|
||||
from tests.core.utils import create_dummy_prompt
|
||||
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
|
||||
SequenceOutput)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -102,7 +74,7 @@ def test_sequence_data_prefill():
|
||||
|
||||
|
||||
def test_sequence_group_stage():
|
||||
seq_group = create_dummy_prompt("1", 12)
|
||||
_, seq_group = create_dummy_prompt("1", 12)
|
||||
assert seq_group.is_prefill() is True
|
||||
seq_group.update_num_computed_tokens(6)
|
||||
assert seq_group.is_prefill() is True
|
||||
|
Loading…
x
Reference in New Issue
Block a user