2024-03-06 11:23:34 +09:00
|
|
|
import time
|
2024-06-15 12:45:31 +08:00
|
|
|
from typing import List, Optional
|
|
|
|
from typing import Sequence as GenericSequence
|
|
|
|
from typing import Tuple
|
2024-03-06 11:23:34 +09:00
|
|
|
|
|
|
|
from vllm import SamplingParams
|
2024-11-05 10:07:31 +08:00
|
|
|
from vllm.inputs import EncoderDecoderInputs, token_inputs
|
2024-04-04 06:13:49 +09:00
|
|
|
from vllm.lora.request import LoRARequest
|
2024-03-27 23:59:28 -07:00
|
|
|
from vllm.sequence import Logprob, Sequence, SequenceGroup
|
2024-03-06 11:23:34 +09:00
|
|
|
|
|
|
|
|
|
|
|
def create_dummy_prompt(
|
2024-04-04 06:13:49 +09:00
|
|
|
request_id: str,
|
|
|
|
prompt_length: int,
|
|
|
|
block_size: Optional[int] = None,
|
|
|
|
lora_request: Optional[LoRARequest] = None,
|
|
|
|
best_of: int = 1,
|
2024-08-08 10:43:30 -07:00
|
|
|
prompt_tokens: Optional[List[int]] = None,
|
2024-10-06 15:48:11 -04:00
|
|
|
min_tokens: int = 0,
|
|
|
|
max_tokens: int = 16,
|
2024-04-04 06:13:49 +09:00
|
|
|
) -> Tuple[Sequence, SequenceGroup]:
|
2024-03-06 11:23:34 +09:00
|
|
|
if not block_size:
|
|
|
|
block_size = prompt_length
|
|
|
|
|
2024-08-08 10:43:30 -07:00
|
|
|
if prompt_tokens is None:
|
|
|
|
# Create dummy prompt sequence with tokens 0...block_size-1
|
|
|
|
# and prompt "0 ... block_size".
|
|
|
|
prompt_tokens = list(range(prompt_length))
|
2024-03-06 11:23:34 +09:00
|
|
|
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
2024-05-29 04:29:31 +08:00
|
|
|
prompt = Sequence(int(request_id),
|
2024-11-05 10:07:31 +08:00
|
|
|
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
2024-05-29 04:29:31 +08:00
|
|
|
block_size=block_size)
|
2024-05-10 16:01:01 -06:00
|
|
|
seq_group = SequenceGroup(request_id=request_id,
|
|
|
|
seqs=[prompt],
|
|
|
|
arrival_time=time.time(),
|
|
|
|
sampling_params=SamplingParams(
|
2024-10-06 15:48:11 -04:00
|
|
|
best_of=best_of,
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
min_tokens=min_tokens),
|
2024-05-10 16:01:01 -06:00
|
|
|
lora_request=lora_request)
|
2024-03-06 11:23:34 +09:00
|
|
|
|
|
|
|
return prompt, seq_group
|
|
|
|
|
|
|
|
|
2024-05-29 12:09:13 -04:00
|
|
|
def create_dummy_prompt_encoder_decoder(
|
|
|
|
request_id: str,
|
|
|
|
decoder_prompt_length: int,
|
|
|
|
encoder_prompt_length: int,
|
|
|
|
block_size: Optional[int] = None,
|
|
|
|
lora_request: Optional[LoRARequest] = None,
|
|
|
|
best_of: int = 1,
|
2024-06-15 12:45:31 +08:00
|
|
|
) -> Tuple[Sequence, Sequence, SequenceGroup]:
|
2024-05-29 12:09:13 -04:00
|
|
|
if not block_size:
|
|
|
|
block_size = decoder_prompt_length
|
|
|
|
|
|
|
|
# Create dummy prompt sequence with tokens 0...block_size-1
|
2024-08-06 16:51:47 -04:00
|
|
|
# and prompt "0 ... block_size". Note that the prompt string
|
|
|
|
# doesn't actually match the tokens
|
2024-05-29 12:09:13 -04:00
|
|
|
decoder_prompt_tokens = list(range(decoder_prompt_length))
|
|
|
|
decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
|
2024-08-06 16:51:47 -04:00
|
|
|
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
|
|
|
|
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
|
|
|
|
|
2024-11-05 10:07:31 +08:00
|
|
|
inputs: EncoderDecoderInputs = {
|
|
|
|
"decoder": token_inputs(decoder_prompt_tokens,
|
|
|
|
prompt=decoder_prompt_str),
|
|
|
|
"encoder": token_inputs(encoder_prompt_tokens,
|
|
|
|
prompt=encoder_prompt_str),
|
2024-08-06 16:51:47 -04:00
|
|
|
}
|
2024-05-29 12:09:13 -04:00
|
|
|
|
|
|
|
decoder_prompt = Sequence(int(request_id),
|
2024-11-05 10:07:31 +08:00
|
|
|
inputs=inputs["decoder"],
|
|
|
|
block_size=block_size)
|
2024-05-29 12:09:13 -04:00
|
|
|
|
|
|
|
encoder_prompt = Sequence(int(request_id),
|
2024-11-05 10:07:31 +08:00
|
|
|
inputs=inputs["encoder"],
|
|
|
|
block_size=block_size)
|
|
|
|
|
2024-05-29 12:09:13 -04:00
|
|
|
seq_group = SequenceGroup(request_id=request_id,
|
|
|
|
seqs=[decoder_prompt],
|
2024-10-06 22:47:04 -07:00
|
|
|
sampling_params=SamplingParams(best_of=best_of),
|
2024-05-29 12:09:13 -04:00
|
|
|
arrival_time=time.time(),
|
|
|
|
lora_request=lora_request,
|
|
|
|
encoder_seq=encoder_prompt)
|
|
|
|
|
|
|
|
return decoder_prompt, encoder_prompt, seq_group
|
|
|
|
|
|
|
|
|
2024-03-27 23:59:28 -07:00
|
|
|
def create_seq_group(
|
2024-04-16 13:09:21 -07:00
|
|
|
seq_prompt_len: int = 1024,
|
2024-06-15 12:45:31 +08:00
|
|
|
seq_output_lens: GenericSequence[int] = (128, ),
|
2024-04-16 13:09:21 -07:00
|
|
|
request_id: str = '0',
|
|
|
|
seq_id_start: int = 0,
|
|
|
|
sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
|
2024-03-27 23:59:28 -07:00
|
|
|
|
|
|
|
assert len(seq_output_lens) > 0
|
|
|
|
|
2024-04-16 13:09:21 -07:00
|
|
|
if sampling_params is None:
|
|
|
|
sampling_params = SamplingParams()
|
|
|
|
|
2024-04-01 15:55:24 -07:00
|
|
|
prompt_token_ids = [0] * seq_prompt_len
|
2024-03-27 23:59:28 -07:00
|
|
|
|
2024-06-15 12:45:31 +08:00
|
|
|
seqs: List[Sequence] = []
|
2024-03-27 23:59:28 -07:00
|
|
|
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
|
|
|
seq = Sequence(
|
|
|
|
seq_id=seq_id_start + seq_id_offset,
|
2024-11-05 10:07:31 +08:00
|
|
|
inputs=token_inputs(prompt_token_ids),
|
2024-03-27 23:59:28 -07:00
|
|
|
block_size=16,
|
|
|
|
)
|
|
|
|
|
|
|
|
for i in range(output_len):
|
|
|
|
seq.append_token_id(
|
|
|
|
token_id=i,
|
|
|
|
logprobs={i: Logprob(0.0)},
|
|
|
|
)
|
|
|
|
seqs.append(seq)
|
|
|
|
|
|
|
|
seq_group = SequenceGroup(
|
|
|
|
request_id=request_id,
|
|
|
|
seqs=seqs,
|
2024-04-16 13:09:21 -07:00
|
|
|
sampling_params=sampling_params,
|
2024-03-27 23:59:28 -07:00
|
|
|
arrival_time=time.time(),
|
|
|
|
)
|
|
|
|
|
|
|
|
return seq_group
|
|
|
|
|
|
|
|
|
2024-05-29 12:09:13 -04:00
|
|
|
def create_seq_group_encoder_decoder(
|
|
|
|
seq_prompt_len: int = 1024,
|
2024-06-15 12:45:31 +08:00
|
|
|
seq_output_lens: GenericSequence[int] = (128, ),
|
2024-05-29 12:09:13 -04:00
|
|
|
request_id: str = '0',
|
|
|
|
seq_id_start: int = 0,
|
|
|
|
sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
|
|
|
|
|
|
|
|
assert len(seq_output_lens) > 0
|
|
|
|
|
|
|
|
if sampling_params is None:
|
|
|
|
sampling_params = SamplingParams()
|
|
|
|
|
|
|
|
prompt_token_ids = [0] * seq_prompt_len
|
|
|
|
|
2024-11-05 10:07:31 +08:00
|
|
|
inputs: EncoderDecoderInputs = {
|
|
|
|
"decoder": token_inputs(prompt_token_ids),
|
|
|
|
"encoder": token_inputs(prompt_token_ids),
|
2024-08-06 16:51:47 -04:00
|
|
|
}
|
|
|
|
|
2024-05-29 12:09:13 -04:00
|
|
|
seqs = []
|
|
|
|
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
2024-08-06 16:51:47 -04:00
|
|
|
# Construct decoder input sequences
|
2024-11-05 10:07:31 +08:00
|
|
|
seq = Sequence(
|
|
|
|
seq_id=seq_id_start + seq_id_offset,
|
|
|
|
inputs=inputs["decoder"],
|
|
|
|
block_size=16,
|
|
|
|
)
|
2024-05-29 12:09:13 -04:00
|
|
|
|
|
|
|
for i in range(output_len):
|
|
|
|
seq.append_token_id(
|
|
|
|
token_id=i,
|
|
|
|
logprobs={i: Logprob(0.0)},
|
|
|
|
)
|
|
|
|
seqs.append(seq)
|
|
|
|
|
2024-08-06 16:51:47 -04:00
|
|
|
# Encoder input sequence
|
2024-11-05 10:07:31 +08:00
|
|
|
encoder_seq = Sequence(
|
|
|
|
seq_id=seq_id_start + len(seq_output_lens),
|
|
|
|
inputs=inputs["encoder"],
|
|
|
|
block_size=16,
|
|
|
|
)
|
2024-05-29 12:09:13 -04:00
|
|
|
|
|
|
|
return SequenceGroup(request_id=request_id,
|
|
|
|
seqs=seqs,
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
arrival_time=time.time(),
|
|
|
|
encoder_seq=encoder_seq)
|
|
|
|
|
|
|
|
|
2024-03-06 11:23:34 +09:00
|
|
|
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
2024-08-06 16:51:47 -04:00
|
|
|
return (seq_len + block_size - 1) // block_size
|
|
|
|
|
|
|
|
|
|
|
|
# Helper functions for scheduler tests
|
|
|
|
|
|
|
|
|
|
|
|
def get_sequence_groups(scheduler_output):
|
|
|
|
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
|
|
|
|
|
|
|
|
|
|
|
def append_new_token(out, token_id: int):
|
|
|
|
seq_groups = get_sequence_groups(out)
|
|
|
|
for seq_group in seq_groups:
|
|
|
|
for seq in seq_group.get_seqs():
|
|
|
|
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
|
|
|
|
|
|
|
|
|
|
|
def schedule_and_update_computed_tokens(scheduler):
|
2024-08-26 20:53:20 -07:00
|
|
|
metas, out, _ = scheduler.schedule()
|
2024-08-06 16:51:47 -04:00
|
|
|
for s, meta in zip(out.scheduled_seq_groups, metas):
|
|
|
|
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
|
|
|
return metas, out
|
|
|
|
|
|
|
|
|
|
|
|
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
|
|
|
|
seq_group.update_num_computed_tokens(token_chunk_size)
|
|
|
|
for seq in seq_group.get_seqs():
|
|
|
|
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|