Align vLLM's beam search implementation with HF generate (#857)

This commit is contained in:
Zhuohan Li 2023-09-04 17:29:42 -07:00 committed by GitHub
parent e15932bb60
commit 002800f081
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 596 additions and 260 deletions

View File

@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
+ kv_caches: List[KVCache], + kv_caches: List[KVCache],
+ input_metadata: InputMetadata, + input_metadata: InputMetadata,
+ cache_events: Optional[List[torch.cuda.Event]], + cache_events: Optional[List[torch.cuda.Event]],
+) -> Dict[int, SequenceOutputs]: +) -> SamplerOutput:
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors. 3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture. 4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.

View File

@ -67,8 +67,8 @@ class HfRunner:
output_ids, output_ids,
skip_special_tokens=True, skip_special_tokens=True,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
)[0] )
output_ids = output_ids[0].cpu().tolist() output_ids = output_ids.cpu().tolist()
outputs.append((output_ids, output_str)) outputs.append((output_ids, output_str))
return outputs return outputs
@ -77,8 +77,34 @@ class HfRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
return self.generate(prompts, do_sample=False, outputs = self.generate(prompts,
max_new_tokens=max_tokens) do_sample=False,
max_new_tokens=max_tokens)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
outputs[i] = (output_ids[0], output_str[0])
return outputs
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens,
num_beams=beam_width,
num_return_sequences=beam_width)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
for j in range(len(output_ids)):
output_ids[j] = [
x for x in output_ids[j]
if x != self.tokenizer.pad_token_id
]
outputs[i] = (output_ids, output_str)
return outputs
@pytest.fixture @pytest.fixture
@ -107,15 +133,20 @@ class VllmRunner:
prompts: List[str], prompts: List[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
req_outputs = self.model.generate( req_outputs = self.model.generate(prompts,
prompts, sampling_params=sampling_params) sampling_params=sampling_params)
outputs = [] outputs = []
for req_output in req_outputs: for req_output in req_outputs:
prompt_str = req_output.prompt prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids prompt_ids = req_output.prompt_token_ids
output_str = req_output.outputs[0].text req_sample_output_ids = []
output_ids = req_output.outputs[0].token_ids req_sample_output_strs = []
outputs.append((prompt_ids + output_ids, prompt_str + output_str)) for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs return outputs
def generate_greedy( def generate_greedy(
@ -124,7 +155,22 @@ class VllmRunner:
max_tokens: int, max_tokens: int,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
return self.generate(prompts, greedy_params) outputs = self.generate(prompts, greedy_params)
return [(output_ids[0], output_str[0]) for output_ids, output_str in
outputs]
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
beam_search_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens)
outputs = self.generate(prompts, beam_search_params)
return outputs
@pytest.fixture @pytest.fixture

View File

@ -0,0 +1,46 @@
"""Compare the outputs of HF and vLLM when using beam search.
Run `pytest tests/samplers/test_beam_search.py --forked`.
"""
import pytest
# FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256.
# 2. Increase beam_width to 8.
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS = [128]
BEAM_WIDTHS = [4]
MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
def test_beam_search_single_input(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del vllm_model
for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
vllm_output_ids, _ = vllm_outputs[i]
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")

View File

@ -172,9 +172,7 @@ class BlockSpaceManager:
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block. # CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
if seq.is_finished():
continue
new_block_table: BlockTable = [] new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
@ -203,9 +201,7 @@ class BlockSpaceManager:
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
# GPU block -> CPU block. # GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
if seq.is_finished():
continue
new_block_table: BlockTable = [] new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]

View File

@ -7,8 +7,7 @@ from vllm.core.block_manager import BlockSpaceManager
from vllm.core.policy import PolicyFactory from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs, SequenceGroupMetadata, SequenceStatus)
SequenceStatus)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -76,6 +75,7 @@ class Scheduler:
num_cpu_blocks=self.cache_config.num_cpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks,
) )
# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state. # Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = [] self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state. # Sequence groups in the RUNNING state.
@ -96,10 +96,11 @@ class Scheduler:
if seq_group.request_id in request_ids: if seq_group.request_id in request_ids:
# Remove the sequence group from the state queue. # Remove the sequence group from the state queue.
state_queue.remove(seq_group) state_queue.remove(seq_group)
for seq in seq_group.seqs: for seq in seq_group.get_seqs():
if seq.is_finished(): if seq.is_finished():
continue continue
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED) seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq)
request_ids.remove(seq_group.request_id) request_ids.remove(seq_group.request_id)
if not request_ids: if not request_ids:
return return
@ -123,6 +124,10 @@ class Scheduler:
if not self.swapped: if not self.swapped:
ignored_seq_groups: List[SequenceGroup] = [] ignored_seq_groups: List[SequenceGroup] = []
scheduled: List[SequenceGroup] = [] scheduled: List[SequenceGroup] = []
# The total number of sequences on the fly, including the
# requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running)
num_batched_tokens = 0 num_batched_tokens = 0
# Optimization: We do not sort the waiting queue since the preempted # Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups # sequence groups are added to the front and the new sequence groups
@ -130,6 +135,9 @@ class Scheduler:
while self.waiting: while self.waiting:
seq_group = self.waiting[0] seq_group = self.waiting[0]
assert seq_group.num_seqs() == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
num_prompt_tokens = seq_group.get_seqs()[0].get_len() num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens > self.prompt_limit: if num_prompt_tokens > self.prompt_limit:
logger.warning( logger.warning(
@ -152,11 +160,7 @@ class Scheduler:
# The total number of sequences in the RUNNING state should not # The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences. # exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs( num_new_seqs = seq_group.get_max_num_running_seqs()
status=SequenceStatus.WAITING)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs > if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs): self.scheduler_config.max_num_seqs):
break break
@ -165,6 +169,7 @@ class Scheduler:
self._allocate(seq_group) self._allocate(seq_group)
self.running.append(seq_group) self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens num_batched_tokens += num_prompt_tokens
num_curr_seqs += num_new_seqs
scheduled.append(seq_group) scheduled.append(seq_group)
if scheduled: if scheduled:
@ -210,30 +215,32 @@ class Scheduler:
# Swap in the sequence groups in the SWAPPED state if possible. # Swap in the sequence groups in the SWAPPED state if possible.
self.swapped = self.policy.sort_by_priority(now, self.swapped) self.swapped = self.policy.sort_by_priority(now, self.swapped)
while self.swapped and not blocks_to_swap_out: if not preempted:
seq_group = self.swapped[0] num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
# If the sequence group has been preempted in this step, stop. for seq_group in self.running)
if seq_group in preempted:
break
# If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group):
break
# The total number of sequences in the RUNNING state should not while self.swapped:
# exceed the maximum number of sequences. seq_group = self.swapped[0]
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) # If the sequence group cannot be swapped in, stop.
num_curr_seqs = sum( if not self.block_manager.can_swap_in(seq_group):
seq_group.num_seqs(status=SequenceStatus.RUNNING) break
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.swapped.pop(0) # The total number of sequences in the RUNNING state should not
self._swap_in(seq_group, blocks_to_swap_in) # exceed the maximum number of sequences.
self._append_slot(seq_group, blocks_to_copy) num_new_seqs = seq_group.get_max_num_running_seqs()
self.running.append(seq_group) if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs
self.running.append(seq_group)
# Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state.
num_batched_tokens = sum( num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING) seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running) for seq_group in self.running)
@ -275,40 +282,10 @@ class Scheduler:
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs return seq_group_metadata_list, scheduler_outputs
def update( def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
self, self.block_manager.fork(parent_seq, child_seq)
seq_outputs: Dict[int, SequenceOutputs],
) -> List[SequenceGroup]:
scheduled: List[SequenceGroup] = []
for seq_group in self.running:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
if seq.seq_id in seq_outputs:
scheduled.append(seq_group)
break
# Update the scheduled sequences and free blocks. def free_seq(self, seq: Sequence) -> None:
for seq_group in scheduled:
# Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam
# search). Free the current sequence.
self.block_manager.free(seq)
# Fork the parent sequence.
parent_seq = seq_group.find(output.parent_seq_id)
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)
# Process the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token_id(output.output_token, output.logprobs)
return scheduled
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
seq.status = finish_status
self.block_manager.free(seq) self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
@ -345,8 +322,8 @@ class Scheduler:
# If preemption mode is not specified, we determine the mode as follows: # If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than # We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences # swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not supported. In such a case, # (e.g., beam search), recomputation is not currently supported. In
# we use swapping instead. # such a case, we use swapping instead.
# FIXME(woosuk): This makes our scheduling policy a bit bizarre. # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
# As swapped sequences are prioritized over waiting sequences, # As swapped sequences are prioritized over waiting sequences,
# sequence groups with multiple sequences are implicitly prioritized # sequence groups with multiple sequences are implicitly prioritized
@ -354,8 +331,7 @@ class Scheduler:
# TODO(woosuk): Support recomputation for sequence groups with multiple # TODO(woosuk): Support recomputation for sequence groups with multiple
# sequences. This may require a more sophisticated CUDA kernel. # sequences. This may require a more sophisticated CUDA kernel.
if preemption_mode is None: if preemption_mode is None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) if seq_group.get_max_num_running_seqs() == 1:
if len(seqs) == 1:
preemption_mode = PreemptionMode.RECOMPUTE preemption_mode = PreemptionMode.RECOMPUTE
else: else:
preemption_mode = PreemptionMode.SWAP preemption_mode = PreemptionMode.SWAP

View File

@ -11,7 +11,8 @@ from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata, from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs,
SequenceStatus) SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer) get_tokenizer)
@ -258,14 +259,11 @@ class LLMEngine:
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
seqs: List[Sequence] = [] seq_id = next(self.seq_counter)
for _ in range(sampling_params.best_of): seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seqs.append(seq)
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, seqs, sampling_params, seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time) arrival_time)
# Add the sequence group to the scheduler. # Add the sequence group to the scheduler.
@ -303,22 +301,230 @@ class LLMEngine:
] ]
return seq_group_metadata_list, scheduler_outputs, None return seq_group_metadata_list, scheduler_outputs, None
def _process_worker_outputs( def _check_beam_search_early_stopping(
self, output, self,
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: early_stopping: Union[bool, str],
# Update the scheduler with the model outputs. sampling_params: SamplingParams,
seq_groups = self.scheduler.update(output) best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = (current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
if early_stopping is False:
highest_attainable_score = (best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id))
return current_worst_score >= highest_attainable_score
def _process_sequence_group_samples(
self, seq_group: SequenceGroup,
samples: List[SequenceOutputs]) -> None:
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutputs] = parent_child_dict[
parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
self._decode_sequence(seq)
self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs = []
unselected_child_seqs = []
beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
seq_group.sampling_params.early_stopping,
seq_group.sampling_params, best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
self.scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq)
def _process_model_outputs(
self, output: SamplerOutput,
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, samples in zip(scheduled_seq_groups, output):
self._process_sequence_group_samples(seq_group, samples)
# Decode the sequences.
self._decode_sequences(seq_groups)
# Stop the sequences that meet the stopping criteria.
self._stop_sequences(seq_groups)
# Free the finished sequence groups. # Free the finished sequence groups.
self.scheduler.free_finished_seq_groups() self.scheduler.free_finished_seq_groups()
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups: for seq_group in (scheduled_seq_groups +
scheduler_outputs.ignored_seq_groups):
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
@ -351,7 +557,7 @@ class LLMEngine:
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
) )
return self._process_worker_outputs(output, scheduler_outputs) return self._process_model_outputs(output, scheduler_outputs)
def _log_system_stats( def _log_system_stats(
self, self,
@ -416,55 +622,44 @@ class LLMEngine:
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now self.last_logging_time = now
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: def _decode_sequence(self, seq: Sequence) -> None:
"""Decodes the sequence outputs.""" """Decodes the new token for a sequence."""
for seq_group in seq_groups: new_token, new_output_text = detokenize_incrementally(
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): self.tokenizer,
new_token, new_output_text = detokenize_incrementally( seq.output_tokens,
self.tokenizer, seq.get_last_token_id(),
seq.output_tokens, skip_special_tokens=True,
seq.get_last_token_id(), )
skip_special_tokens=True, if new_token is not None:
) seq.output_tokens.append(new_token)
if new_token is not None: seq.output_text = new_output_text
seq.output_tokens.append(new_token)
seq.output_text = new_output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences.""" """Stop the finished sequences."""
for seq_group in seq_groups: for stop_str in sampling_params.stop:
sampling_params = seq_group.sampling_params if seq.output_text.endswith(stop_str):
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): # Truncate the output text so that the stop string is
# Check if the sequence has generated a stop string. # not included in the output.
stopped = False seq.output_text = seq.output_text[:-len(stop_str)]
for stop_str in sampling_params.stop: seq.status = SequenceStatus.FINISHED_STOPPED
if seq.output_text.endswith(stop_str): return
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
stopped = True
break
if stopped:
continue
# Check if the sequence has reached max_model_len. # Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len: if seq.get_len() > self.scheduler_config.max_model_len:
self.scheduler.free_seq( seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
seq, SequenceStatus.FINISHED_LENGTH_CAPPED) return
continue
# Check if the sequence has reached max_tokens. # Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens: if seq.get_output_len() == sampling_params.max_tokens:
self.scheduler.free_seq( seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
seq, SequenceStatus.FINISHED_LENGTH_CAPPED) return
continue
# Check if the sequence has generated the EOS token. # Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos: if ((not sampling_params.ignore_eos)
if seq.get_last_token_id() == self.tokenizer.eos_token_id: and seq.get_last_token_id() == self.tokenizer.eos_token_id):
self.scheduler.free_seq( seq.status = SequenceStatus.FINISHED_STOPPED
seq, SequenceStatus.FINISHED_STOPPED) return
continue
def _run_workers( def _run_workers(
self, self,

View File

@ -9,7 +9,7 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
gather_from_tensor_model_parallel_region) gather_from_tensor_model_parallel_region)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput, SequenceOutputs
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
@ -39,7 +39,7 @@ class Sampler(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
# Get the hidden states that we use for sampling. # Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata) hidden_states = _prune_hidden_states(hidden_states, input_metadata)
@ -292,7 +292,13 @@ def _sample_from_prompt(
if sampling_params.use_beam_search: if sampling_params.use_beam_search:
# Beam search. # Beam search.
beam_width = sampling_params.best_of beam_width = sampling_params.best_of
_, next_token_ids = torch.topk(prob, beam_width) # Sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
_, next_token_ids = torch.topk(prob, 2 * beam_width)
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature < _SAMPLING_EPS: elif sampling_params.temperature < _SAMPLING_EPS:
# Greedy sampling. # Greedy sampling.
@ -330,29 +336,11 @@ def _sample_from_generation_tokens(
vocab_size = logprobs.size(-1) vocab_size = logprobs.size(-1)
beam_width = len(seq_ids) beam_width = len(seq_ids)
_, topk_ids = torch.topk(logprobs.flatten(), beam_width) _, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
topk_ids = topk_ids.tolist() topk_ids = topk_ids.tolist()
seq_idx = [i // vocab_size for i in topk_ids] seq_idx = [i // vocab_size for i in topk_ids]
beam_seq_ids = [seq_ids[i] for i in seq_idx] parent_seq_ids = [seq_ids[i] for i in seq_idx]
token_ids = [i % vocab_size for i in topk_ids] next_token_ids = [i % vocab_size for i in topk_ids]
beam_outputs: Dict[int, Tuple[int, int]] = {}
outstanding_beams: List[Tuple[int, int]] = []
# If a beam survives, continue with it.
for seq_id, token_id in zip(beam_seq_ids, token_ids):
if seq_id not in beam_outputs:
beam_outputs[seq_id] = (seq_id, token_id)
else:
outstanding_beams.append((seq_id, token_id))
# If a beam is discarded, fork another beam.
for seq_id in seq_ids:
if seq_id not in beam_outputs:
beam_outputs[seq_id] = outstanding_beams.pop()
assert not outstanding_beams
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
elif sampling_params.temperature < _SAMPLING_EPS: elif sampling_params.temperature < _SAMPLING_EPS:
# Greedy sampling. # Greedy sampling.
assert len(seq_ids) == 1 assert len(seq_ids) == 1
@ -374,16 +362,18 @@ def _sample(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
seq_outputs: Dict[int, SequenceOutputs] = {} seq_outputs: SamplerOutput = []
# TODO(woosuk): Optimize. # TODO(woosuk): Optimize.
idx = 0 idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_group_outputs: List[SequenceOutputs] = []
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts: if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input. # Generate the next tokens for a prompt input.
assert len(seq_ids) == sampling_params.best_of assert len(seq_ids) == 1, "Prompt input should have only one seq."
parent_seq_id = seq_ids[0]
prob = probs[idx] prob = probs[idx]
logprob = logprobs[idx] logprob = logprobs[idx]
idx += 1 idx += 1
@ -395,17 +385,18 @@ def _sample(
sampling_params.logprobs) sampling_params.logprobs)
# Build the output. # Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids): for next_token_id in next_token_ids:
output_logprobs = next_logprobs.copy() output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item() output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id, seq_group_outputs.append(
next_token_id, SequenceOutputs(parent_seq_id, next_token_id,
output_logprobs) output_logprobs))
else: else:
# Generate the next tokens for generation tokens. # Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)] num_parent_seqs = len(seq_ids)
logprob = logprobs[idx:idx + len(seq_ids)] prob = probs[idx:idx + num_parent_seqs]
idx += len(seq_ids) logprob = logprobs[idx:idx + num_parent_seqs]
idx += num_parent_seqs
# Sample the next tokens. # Sample the next tokens.
seq_logprobs = [ seq_logprobs = [
@ -422,17 +413,15 @@ def _sample(
logprob[j], sampling_params.logprobs) logprob[j], sampling_params.logprobs)
# Build the output. # Build the output.
for seq_id, parent_seq_id, next_token_id in zip( for parent_seq_id, next_token_id in zip(parent_seq_ids,
seq_ids, parent_seq_ids, next_token_ids): next_token_ids):
j = seq_ids.index(parent_seq_id) j = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy() output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[j, output_logprobs[next_token_id] = logprob[j,
next_token_id].item() next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs( seq_group_outputs.append(
seq_id, SequenceOutputs(parent_seq_id, next_token_id,
parent_seq_id, output_logprobs))
next_token_id, seq_outputs.append(seq_group_outputs)
output_logprobs,
)
return seq_outputs return seq_outputs

View File

@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -41,7 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.aquila import AquilaConfig from vllm.transformers_utils.configs.aquila import AquilaConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -273,7 +273,7 @@ class AquilaForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,

View File

@ -23,12 +23,11 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
import math import math
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -42,6 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -290,7 +290,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,

View File

@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
import math import math
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -264,7 +264,7 @@ class BloomForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,

View File

@ -19,7 +19,7 @@
"""PyTorch Falcon model.""" """PyTorch Falcon model."""
import math import math
from typing import Dict, List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
reduce_from_tensor_model_parallel_region) reduce_from_tensor_model_parallel_region)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -397,7 +397,7 @@ class FalconForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
positions, positions,

View File

@ -21,7 +21,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -218,7 +218,7 @@ class GPT2LMHeadModel(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,

View File

@ -22,7 +22,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -39,7 +39,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -246,7 +246,7 @@ class GPTBigCodeForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,

View File

@ -20,7 +20,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -203,7 +203,7 @@ class GPTJForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,

View File

@ -20,7 +20,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -215,7 +215,7 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches, hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.embed_out.weight, hidden_states, next_tokens = self.sampler(self.embed_out.weight, hidden_states,

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -17,7 +17,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
from vllm.model_executor.weight_utils import ( from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab, hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights) load_tensor_parallel_weights)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -218,7 +218,7 @@ class InternLMForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,

View File

@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -43,7 +43,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -256,7 +256,7 @@ class LlamaForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,

View File

@ -1,7 +1,7 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math import math
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -16,7 +16,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -230,7 +230,7 @@ class MPTForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,

View File

@ -21,7 +21,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -282,7 +282,7 @@ class OPTForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,

View File

@ -8,7 +8,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -32,7 +32,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.sequence import SequenceOutputs from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.qwen import QWenConfig from vllm.transformers_utils.configs.qwen import QWenConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -235,7 +235,7 @@ class QWenLMHeadModel(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,

View File

@ -75,10 +75,12 @@ class RequestOutput:
# Get the top-n sequences. # Get the top-n sequences.
n = seq_group.sampling_params.n n = seq_group.sampling_params.n
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
assert n <= len(seqs) if seq_group.sampling_params.use_beam_search:
sorted_seqs = sorted(seqs, sorting_key = lambda seq: seq.get_beam_search_score(
key=lambda seq: seq.get_cumulative_logprob(), seq_group.sampling_params.length_penalty)
reverse=True) else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n] top_n_seqs = sorted_seqs[:n]
# Create the outputs. # Create the outputs.

View File

@ -34,6 +34,15 @@ class SamplingParams:
top_k: Integer that controls the number of top tokens to consider. Set top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens. to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling. use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
early_stopping: Controls the stopping condition for beam search. It
accepts the following values: `True`, where the generation stops as
soon as there are `best_of` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very
unlikely to find better candidates; `"never"`, where the beam search
procedure only stops when there cannot be better candidates
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated. stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings. The returned output will not contain the stop strings.
ignore_eos: Whether to ignore the EOS token and continue generating ignore_eos: Whether to ignore the EOS token and continue generating
@ -52,6 +61,8 @@ class SamplingParams:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
use_beam_search: bool = False, use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Union[None, str, List[str]] = None, stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: int = 16,
@ -65,6 +76,8 @@ class SamplingParams:
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
if stop is None: if stop is None:
self.stop = [] self.stop = []
elif isinstance(stop, str): elif isinstance(stop, str):
@ -78,9 +91,11 @@ class SamplingParams:
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
self._verify_beam_search() self._verify_beam_search()
elif self.temperature < _SAMPLING_EPS: else:
# Zero temperature means greedy sampling. self._verify_non_beam_search()
self._verify_greedy_sampling() if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self._verify_greedy_sampling()
def _verify_args(self) -> None: def _verify_args(self) -> None:
if self.n < 1: if self.n < 1:
@ -119,6 +134,20 @@ class SamplingParams:
raise ValueError("top_p must be 1 when using beam search.") raise ValueError("top_p must be 1 when using beam search.")
if self.top_k != -1: if self.top_k != -1:
raise ValueError("top_k must be -1 when using beam search.") raise ValueError("top_k must be -1 when using beam search.")
if self.early_stopping not in [True, False, "never"]:
raise ValueError(
f"early_stopping must be True, False, or 'never', "
f"got {self.early_stopping}.")
def _verify_non_beam_search(self) -> None:
if self.early_stopping is not False:
raise ValueError("early_stopping is not effective and must be "
"False when not using beam search.")
if (self.length_penalty < 1.0 - _SAMPLING_EPS
or self.length_penalty > 1.0 + _SAMPLING_EPS):
raise ValueError(
"length_penalty is not effective and must be the "
"default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None: def _verify_greedy_sampling(self) -> None:
if self.best_of > 1: if self.best_of > 1:
@ -138,6 +167,8 @@ class SamplingParams:
f"top_p={self.top_p}, " f"top_p={self.top_p}, "
f"top_k={self.top_k}, " f"top_k={self.top_k}, "
f"use_beam_search={self.use_beam_search}, " f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, " f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, " f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, " f"max_tokens={self.max_tokens}, "

View File

@ -69,6 +69,9 @@ class SequenceData:
def get_len(self) -> int: def get_len(self) -> int:
return len(self.output_token_ids) + len(self.prompt_token_ids) return len(self.output_token_ids) + len(self.prompt_token_ids)
def get_prompt_len(self) -> int:
return len(self.prompt_token_ids)
def get_output_len(self) -> int: def get_output_len(self) -> int:
return len(self.output_token_ids) return len(self.output_token_ids)
@ -155,6 +158,9 @@ class Sequence:
def get_len(self) -> int: def get_len(self) -> int:
return self.data.get_len() return self.data.get_len()
def get_prompt_len(self) -> int:
return self.data.get_prompt_len()
def get_output_len(self) -> int: def get_output_len(self) -> int:
return self.data.get_output_len() return self.data.get_output_len()
@ -170,14 +176,32 @@ class Sequence:
def get_cumulative_logprob(self) -> float: def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob return self.data.cumulative_logprob
def get_beam_search_score(self,
length_penalty: float = 0.0,
seq_len: Optional[int] = None,
eos_token_id: Optional[int] = None) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
if seq_len is None:
seq_len = self.get_len()
# Note: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if (eos_token_id is not None
and self.get_last_token_id() == eos_token_id):
seq_len -= 1
return self.get_cumulative_logprob() / (seq_len**length_penalty)
def is_finished(self) -> bool: def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status) return SequenceStatus.is_finished(self.status)
def fork(self, child_seq: "Sequence") -> None: def fork(self, new_seq_id: int) -> "Sequence":
child_seq.logical_token_blocks = copy.deepcopy( new_seq = copy.deepcopy(self)
self.logical_token_blocks) new_seq.seq_id = new_seq_id
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs) return new_seq
child_seq.data = copy.deepcopy(self.data)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"Sequence(seq_id={self.seq_id}, " return (f"Sequence(seq_id={self.seq_id}, "
@ -203,35 +227,66 @@ class SequenceGroup:
arrival_time: float, arrival_time: float,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs = seqs self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.arrival_time = arrival_time self.arrival_time = arrival_time
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if self.sampling_params.use_beam_search:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
return self.sampling_params.best_of
else:
if self.sampling_params.best_of > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences running.
return self.sampling_params.best_of
# At sampling stages, return the number of actual sequences
# running.
return self.num_seqs(status=SequenceStatus.RUNNING)
def get_seqs( def get_seqs(
self, self,
status: Optional[SequenceStatus] = None, status: Optional[SequenceStatus] = None,
) -> List[Sequence]: ) -> List[Sequence]:
if status is None: if status is None:
return self.seqs return list(self.seqs_dict.values())
else: else:
return [seq for seq in self.seqs if seq.status == status] return [
seq for seq in self.seqs_dict.values() if seq.status == status
]
def get_finished_seqs(self) -> List[Sequence]:
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status)) return len(self.get_seqs(status))
def find(self, seq_id: int) -> Sequence: def find(self, seq_id: int) -> Sequence:
for seq in self.seqs: if seq_id not in self.seqs_dict:
if seq.seq_id == seq_id: raise ValueError(f"Sequence {seq_id} not found.")
return seq return self.seqs_dict[seq_id]
raise ValueError(f"Sequence {seq_id} not found.")
def add(self, seq: Sequence) -> None:
if seq.seq_id in self.seqs_dict:
raise ValueError(f"Sequence {seq.seq_id} already exists.")
self.seqs_dict[seq.seq_id] = seq
def remove(self, seq_id: int) -> None:
if seq_id not in self.seqs_dict:
raise ValueError(f"Sequence {seq_id} not found.")
del self.seqs_dict[seq_id]
def is_finished(self) -> bool: def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.seqs) return all(seq.is_finished() for seq in self.get_seqs())
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, " return (f"SequenceGroup(request_id={self.request_id}, "
f"sampling_params={self.sampling_params}, " f"sampling_params={self.sampling_params}, "
f"num_seqs={len(self.seqs)})") f"num_seqs={len(self.seqs_dict)})")
class SequenceGroupMetadata: class SequenceGroupMetadata:
@ -266,7 +321,6 @@ class SequenceOutputs:
"""The model output associated with a sequence. """The model output associated with a sequence.
Args: Args:
seq_id: The ID of the sequence.
parent_seq_id: The ID of the parent sequence (for forking in beam parent_seq_id: The ID of the parent sequence (for forking in beam
search). search).
output_token: The output token ID. output_token: The output token ID.
@ -276,26 +330,27 @@ class SequenceOutputs:
def __init__( def __init__(
self, self,
seq_id: int,
parent_seq_id: int, parent_seq_id: int,
output_token: int, output_token: int,
logprobs: Dict[int, float], logprobs: Dict[int, float],
) -> None: ) -> None:
self.seq_id = seq_id
self.parent_seq_id = parent_seq_id self.parent_seq_id = parent_seq_id
self.output_token = output_token self.output_token = output_token
self.logprobs = logprobs self.logprobs = logprobs
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceOutputs(seq_id={self.seq_id}, " return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
f"parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}), " f"output_token={self.output_token}), "
f"logprobs={self.logprobs}") f"logprobs={self.logprobs}")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs): if not isinstance(other, SequenceOutputs):
return NotImplemented return NotImplementedError()
return (self.seq_id == other.seq_id return (self.parent_seq_id == other.parent_seq_id
and self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token and self.output_token == other.output_token
and self.logprobs == other.logprobs) and self.logprobs == other.logprobs)
# For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token.
SamplerOutput = List[List[SequenceOutputs]]

View File

@ -11,7 +11,7 @@ from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel) initialize_model_parallel)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.utils import get_gpu_memory from vllm.utils import get_gpu_memory
@ -260,7 +260,7 @@ class Worker:
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
) -> Dict[int, SequenceOutputs]: ) -> SamplerOutput:
# Issue cache operations. # Issue cache operations.
issued_cache_op = False issued_cache_op = False
if blocks_to_swap_in: if blocks_to_swap_in: