[Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894)
This commit is contained in:
parent
69e1d2fb69
commit
e95cd87959
@ -230,6 +230,76 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 2,
|
||||
"max_num_seqs": 2,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [
|
||||
{
|
||||
"use_v2_block_manager": False,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"use_v2_block_manager": True,
|
||||
"num_lookahead_slots": 0,
|
||||
},
|
||||
{
|
||||
"use_v2_block_manager": True,
|
||||
"num_lookahead_slots": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
|
||||
test_llm_generator, batch_size):
|
||||
"""Verify that chunked prefill works with BlockManagerV2, with and without
|
||||
lookahead scheduling.
|
||||
"""
|
||||
output_len = 32
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
print('Getting token ids with BlockManagerV1')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids with BlockManagerV2')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
|
||||
test_token_ids):
|
||||
assert expected_token_ids == actual_token_ids
|
||||
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -31,14 +31,17 @@ def create_dummy_prompt(
|
||||
|
||||
|
||||
def create_seq_group(
|
||||
seq_prompt_len=1024,
|
||||
seq_output_lens=(128, ),
|
||||
request_id='0',
|
||||
seq_id_start=0,
|
||||
) -> SequenceGroup:
|
||||
seq_prompt_len: int = 1024,
|
||||
seq_output_lens: Iterable[int] = (128, ),
|
||||
request_id: str = '0',
|
||||
seq_id_start: int = 0,
|
||||
sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
|
||||
|
||||
assert len(seq_output_lens) > 0
|
||||
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
prompt_token_ids = [0] * seq_prompt_len
|
||||
|
||||
seqs = []
|
||||
@ -60,7 +63,7 @@ def create_seq_group(
|
||||
seq_group = SequenceGroup(
|
||||
request_id=request_id,
|
||||
seqs=seqs,
|
||||
sampling_params=SamplingParams(),
|
||||
sampling_params=sampling_params,
|
||||
arrival_time=time.time(),
|
||||
)
|
||||
|
||||
|
270
tests/engine/output_processor/test_multi_step.py
Normal file
270
tests/engine/output_processor/test_multi_step.py
Normal file
@ -0,0 +1,270 @@
|
||||
import random
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from tests.core.utils import create_seq_group
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [1, 12])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
|
||||
"""Verify multi-step decoding appends token ids correctly.
|
||||
|
||||
We append token ids and verify all the token ids were appended correctly.
|
||||
Note that ignore_eos=True.
|
||||
"""
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=1024,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(max_tokens=seq_output_len +
|
||||
num_new_tokens,
|
||||
ignore_eos=True),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_prompt_len", [1024])
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
|
||||
@pytest.mark.parametrize("max_tokens", [128 + 3])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
|
||||
seq_output_len: int, max_tokens: int):
|
||||
"""Verify tokens after max_tokens are dropped and not appended to the
|
||||
sequence.
|
||||
"""
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=seq_prompt_len,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(max_tokens=max_tokens, ),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Expect the processed sequence to not go over max tokens in len.
|
||||
assert seq.get_len() == seq_prompt_len + max_tokens
|
||||
|
||||
# Expect the correct tokens were appended.
|
||||
expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
|
||||
assert seq.get_token_ids(
|
||||
)[-len(expected_appended_tokens):] == expected_appended_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_prompt_len", [1024])
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [12])
|
||||
@pytest.mark.parametrize("seed", list(range(6)))
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
||||
seq_output_len: int, seed: int):
|
||||
"""Verify the eos token id is included in the sequence, but subsequent
|
||||
tokens are dropped (not appended to sequence).
|
||||
"""
|
||||
random.seed(seed)
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
eos_token_id = 100
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=seq_prompt_len,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(
|
||||
# Ensure enough space.
|
||||
max_tokens=seq_output_len + num_new_tokens, ),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
assert eos_token_id not in new_token_ids
|
||||
eos_index = random.randint(0, len(new_token_ids) - 1)
|
||||
new_token_ids[eos_index] = eos_token_id
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Expect the processed sequence to not go beyond provided eos.
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)
|
||||
|
||||
# Expect the correct tokens were appended.
|
||||
expected_appended_tokens = new_token_ids[:eos_index + 1]
|
||||
assert seq.get_token_ids(
|
||||
)[-len(expected_appended_tokens):] == expected_appended_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_prompt_len", [1024])
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [12])
|
||||
@pytest.mark.parametrize("seed", list(range(6)))
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
||||
seq_output_len: int, seed: int):
|
||||
"""When sampling parameters dictate that we should ignore the eos token id,
|
||||
ensure all token ids are appended even if the eos token id is emitted.
|
||||
"""
|
||||
random.seed(seed)
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
eos_token_id = 100
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=seq_prompt_len,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(
|
||||
# Ensure enough space.
|
||||
max_tokens=seq_output_len + num_new_tokens,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
assert eos_token_id not in new_token_ids
|
||||
eos_index = random.randint(0, len(new_token_ids) - 1)
|
||||
new_token_ids[eos_index] = eos_token_id
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Expect the processed sequence to go beyond eos.
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens
|
||||
|
||||
# Expect the correct tokens were appended.
|
||||
expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
|
||||
seq_output_len]
|
||||
assert seq.get_token_ids(
|
||||
)[-len(expected_appended_tokens):] == expected_appended_tokens
|
||||
|
||||
|
||||
def mock_tokenizer(eos_token_id=1000):
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizer)
|
||||
tokenizer.eos_token_id = eos_token_id
|
||||
return tokenizer
|
@ -1,4 +1,8 @@
|
||||
from itertools import cycle
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
@ -7,18 +11,47 @@ from vllm import SamplingParams
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
"speculative_model": "facebook/opt-125m",
|
||||
"num_speculative_tokens": 5,
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip real loading for fast test.
|
||||
"load_format": "dummy",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 1,
|
||||
},
|
||||
{
|
||||
# No spec decode.
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
# NOTE: We should run more permutations of this test (more BS, more seeds). But
|
||||
# because our spec decode generates gibberish token ids, the likelihood of
|
||||
# emitting an invalid token combination is nontrivial. This causes divergence in
|
||||
# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf-
|
||||
# start" bytes are emitted.
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_config(test_llm_generator):
|
||||
output_len = 1024
|
||||
def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
|
||||
"""Run generation with speculative decoding on a batch. Verify the engine
|
||||
generates the correct number of tokens (via ignore_eos=True), and that the
|
||||
detokenization matches HF transformers.
|
||||
"""
|
||||
output_len = 32
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
@ -28,23 +61,91 @@ def test_spec_decode_config(test_llm_generator):
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=False,
|
||||
)
|
||||
|
||||
batch_tokens, batch_token_ids = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
# Expect a generation for each prompt in the batch.
|
||||
assert len(batch_token_ids) == len(prompts)
|
||||
|
||||
# Expect each generation to have expected number of tokens (note
|
||||
# ignore_eos=True).
|
||||
assert all(len(token_ids) == output_len for token_ids in batch_token_ids)
|
||||
|
||||
# Expect detokenized string to match.
|
||||
tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
|
||||
for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids):
|
||||
expected_tokens = tok.decode(actual_token_ids)
|
||||
print(f"{actual_token_ids=}")
|
||||
assert actual_tokens.strip() == expected_tokens.strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "JackFram/llama-68m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Skip real loading for fast test.
|
||||
"load_format": "dummy",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Expect failure as spec decode not supported by
|
||||
# Ray backend.
|
||||
"worker_use_ray": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail(test_llm_generator):
|
||||
"""Verify that speculative decoding with Ray fails.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Speculative decoding not yet supported for GPU backend"):
|
||||
get_token_ids_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
with pytest.raises(AssertionError,
|
||||
match="Speculative decoding not yet supported for "):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
|
||||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
||||
def get_output_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
tokens = [output.outputs[0].text for output in outputs]
|
||||
del llm
|
||||
|
||||
return token_ids
|
||||
return tokens, token_ids
|
||||
|
@ -125,7 +125,7 @@ def test_same_output_for_single_step():
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
expected_output = worker.execute_model(
|
||||
**single_step_execute_model_data.to_dict(), )
|
||||
**single_step_execute_model_data.to_dict(), )[0]
|
||||
|
||||
actual_token_ids = [
|
||||
output.samples[0].output_token for output in actual_output
|
||||
@ -219,7 +219,7 @@ def test_same_output_for_multi_step():
|
||||
continuations=continuations,
|
||||
final_seq_lens=final_seq_lens))
|
||||
|
||||
single_step_output.append(
|
||||
single_step_output.extend(
|
||||
worker.execute_model(**execute_model_data.to_dict(), ))
|
||||
|
||||
# Append output tokens to new sequence data.
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
@ -37,7 +38,8 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
|
||||
worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
|
||||
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
@ -102,7 +104,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
|
||||
worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
|
||||
seen_contexts = []
|
||||
|
||||
@ -189,13 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
|
||||
target_worker.execute_model.return_value = target_output[0]
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
exception_secret = 'artifical stop'
|
||||
rejection_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
|
||||
worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
|
||||
assert len(rejection_sampler.call_args_list) == 1
|
||||
args, _ = rejection_sampler.call_args_list[0]
|
||||
@ -268,7 +272,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
|
||||
target_worker.execute_model.return_value = target_output[0]
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
rejection_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
@ -283,7 +287,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
rejection_sampler.return_value = rejection_sampler_output
|
||||
|
||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
num_lookahead_slots=k)
|
||||
|
||||
expected_output = create_sampler_output_list(
|
||||
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
|
||||
@ -380,7 +384,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
|
||||
target_worker.execute_model.return_value = target_output[0]
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
rejection_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
@ -400,7 +404,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
mock_rejsample_metrics)
|
||||
|
||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
num_lookahead_slots=k)
|
||||
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
||||
|
||||
call_args_list = (
|
||||
@ -423,6 +427,8 @@ def test_k_equals_zero(k: int, batch_size: int):
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
@ -435,7 +441,7 @@ def test_k_equals_zero(k: int, batch_size: int):
|
||||
batch_size, k, prev_output_token_len=0)
|
||||
|
||||
out = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
num_lookahead_slots=k)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
@ -443,7 +449,7 @@ def test_k_equals_zero(k: int, batch_size: int):
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict(), return_python_output=False)
|
||||
**execute_model_data.to_dict())
|
||||
target_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
|
||||
@ -462,6 +468,8 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
@ -474,7 +482,7 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
batch_size, k, prev_output_token_len=0)
|
||||
|
||||
out = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
num_lookahead_slots=k)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
@ -482,7 +490,7 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict(), return_python_output=False)
|
||||
**execute_model_data.to_dict())
|
||||
target_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
|
||||
|
@ -212,7 +212,7 @@ def create_sampler_output_list(
|
||||
SequenceOutput(
|
||||
output_token=token_id,
|
||||
parent_seq_id=seq_ids[seq_index],
|
||||
logprobs={token_id: 0},
|
||||
logprobs={token_id: Logprob(0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
|
@ -104,7 +104,6 @@ class BlockTable:
|
||||
token_ids (List[int]): The sequence of token IDs to be appended.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
assert token_ids, "can't append empty token ids"
|
||||
|
||||
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
|
||||
num_lookahead_slots)
|
||||
|
@ -762,9 +762,7 @@ class Scheduler:
|
||||
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
|
||||
swapped_in.blocks_to_copy),
|
||||
ignored_seq_groups=prefills.ignored_seq_groups,
|
||||
num_lookahead_slots=(prefills.num_lookahead_slots +
|
||||
running_scheduled.num_lookahead_slots +
|
||||
swapped_in.num_lookahead_slots),
|
||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||
)
|
||||
|
||||
def _schedule_chunked_prefill(self):
|
||||
@ -850,9 +848,7 @@ class Scheduler:
|
||||
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
|
||||
swapped_in.blocks_to_copy),
|
||||
ignored_seq_groups=prefills.ignored_seq_groups,
|
||||
num_lookahead_slots=(prefills.num_lookahead_slots +
|
||||
running_scheduled.num_lookahead_slots +
|
||||
swapped_in.num_lookahead_slots),
|
||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||
)
|
||||
|
||||
def _schedule(self) -> SchedulerOutputs:
|
||||
|
@ -217,7 +217,9 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
else:
|
||||
output = []
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
return self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups)
|
||||
|
||||
async def encode_request_async(
|
||||
self,
|
||||
|
@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Iterable, List, Optional, Tuple, Type, Union
|
||||
from typing import Iterable, List, Optional, Type, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@ -11,6 +11,10 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics import StatLogger, Stats
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.engine.ray_utils import initialize_ray_cluster
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
@ -18,8 +22,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
||||
SequenceGroup, SequenceGroupOutput, SequenceOutput,
|
||||
SequenceStatus)
|
||||
SequenceGroup)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
@ -187,6 +190,21 @@ class LLMEngine:
|
||||
labels=dict(model_name=model_config.model))
|
||||
self.stat_logger.info("cache_config", self.cache_config)
|
||||
|
||||
# Create sequence output processor, e.g. for beam search or
|
||||
# speculative decoding.
|
||||
self.output_processor = (
|
||||
SequenceGroupOutputProcessor.create_output_processor(
|
||||
self.scheduler_config,
|
||||
self.detokenizer,
|
||||
self.scheduler,
|
||||
self.seq_counter,
|
||||
self.get_tokenizer_for_seq,
|
||||
stop_checker=StopChecker(
|
||||
self.scheduler_config.max_model_len,
|
||||
self.get_tokenizer_for_seq,
|
||||
),
|
||||
))
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
"""Initialize the KV cache in the worker(s).
|
||||
|
||||
@ -412,240 +430,32 @@ class LLMEngine:
|
||||
"""Returns True if there are unfinished requests."""
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
def _check_beam_search_early_stopping(
|
||||
self,
|
||||
early_stopping: Union[bool, str],
|
||||
sampling_params: SamplingParams,
|
||||
best_running_seq: Sequence,
|
||||
current_worst_seq: Sequence,
|
||||
) -> bool:
|
||||
assert sampling_params.use_beam_search
|
||||
length_penalty = sampling_params.length_penalty
|
||||
if early_stopping is True:
|
||||
return True
|
||||
|
||||
current_worst_score = current_worst_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=current_worst_seq.eos_token_id)
|
||||
if early_stopping is False:
|
||||
highest_attainable_score = best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=best_running_seq.eos_token_id)
|
||||
else:
|
||||
assert early_stopping == "never"
|
||||
if length_penalty > 0.0:
|
||||
# If length_penalty > 0.0, beam search will prefer longer
|
||||
# sequences. The highest attainable score calculation is
|
||||
# based on the longest possible sequence length in this case.
|
||||
max_possible_length = max(
|
||||
best_running_seq.get_prompt_len() +
|
||||
sampling_params.max_tokens,
|
||||
self.scheduler_config.max_model_len)
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=best_running_seq.eos_token_id,
|
||||
seq_len=max_possible_length))
|
||||
else:
|
||||
# Otherwise, beam search will prefer shorter sequences. The
|
||||
# highest attainable score calculation is based on the current
|
||||
# sequence length.
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=best_running_seq.eos_token_id))
|
||||
return current_worst_score >= highest_attainable_score
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
outputs: SequenceGroupOutput) -> None:
|
||||
|
||||
# Process prompt logprobs
|
||||
prompt_logprobs = outputs.prompt_logprobs
|
||||
if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
|
||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||
seq_group, prompt_logprobs)
|
||||
seq_group.prompt_logprobs = prompt_logprobs
|
||||
|
||||
# Process samples
|
||||
samples = outputs.samples
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||
parent_child_dict = {
|
||||
parent_seq.seq_id: []
|
||||
for parent_seq in parent_seqs
|
||||
}
|
||||
for sample in samples:
|
||||
parent_child_dict[sample.parent_seq_id].append(sample)
|
||||
# List of (child, parent)
|
||||
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
||||
|
||||
# Process the child samples for each parent sequence
|
||||
for parent in parent_seqs:
|
||||
child_samples: List[SequenceOutput] = parent_child_dict[
|
||||
parent.seq_id]
|
||||
if len(child_samples) == 0:
|
||||
# This parent sequence has no children samples. Remove
|
||||
# the parent sequence from the sequence group since it will
|
||||
# not be used in the future iterations.
|
||||
parent.status = SequenceStatus.FINISHED_ABORTED
|
||||
seq_group.remove(parent.seq_id)
|
||||
self.scheduler.free_seq(parent)
|
||||
continue
|
||||
# Fork the parent sequence if there are multiple child samples.
|
||||
for child_sample in child_samples[:-1]:
|
||||
new_child_seq_id = next(self.seq_counter)
|
||||
child = parent.fork(new_child_seq_id)
|
||||
child.append_token_id(child_sample.output_token,
|
||||
child_sample.logprobs)
|
||||
child_seqs.append((child, parent))
|
||||
# Continue the parent sequence for the last child sample.
|
||||
# We reuse the parent sequence here to reduce redundant memory
|
||||
# copies, especially when using non-beam search sampling methods.
|
||||
last_child_sample = child_samples[-1]
|
||||
parent.append_token_id(last_child_sample.output_token,
|
||||
last_child_sample.logprobs)
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
if seq_group.sampling_params.detokenize:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, seq_group.sampling_params)
|
||||
else:
|
||||
new_char_count = 0
|
||||
self._check_stop(seq, new_char_count, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
if not seq_group.sampling_params.use_beam_search:
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
# NOTE: we need to fork the new sequences before freeing the
|
||||
# old sequences.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
return
|
||||
|
||||
# Beam search case
|
||||
# Select the child sequences to keep in the sequence group.
|
||||
selected_child_seqs = []
|
||||
unselected_child_seqs = []
|
||||
beam_width = seq_group.sampling_params.best_of
|
||||
length_penalty = seq_group.sampling_params.length_penalty
|
||||
|
||||
# Select the newly finished sequences with the highest scores
|
||||
# to replace existing finished sequences.
|
||||
# Tuple of (seq, parent, is_new)
|
||||
existing_finished_seqs = [(seq, None, False)
|
||||
for seq in existing_finished_seqs]
|
||||
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
|
||||
if seq.is_finished()]
|
||||
all_finished_seqs = existing_finished_seqs + new_finished_seqs
|
||||
# Sort the finished sequences by their scores.
|
||||
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
|
||||
reverse=True)
|
||||
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
||||
if is_new:
|
||||
# A newly generated child sequence finishes and has a high
|
||||
# score, so we will add it into the sequence group.
|
||||
selected_child_seqs.append((seq, parent))
|
||||
for seq, parent, is_new in all_finished_seqs[beam_width:]:
|
||||
if is_new:
|
||||
# A newly generated child sequence finishes but has a low
|
||||
# score, so we will not add it into the sequence group.
|
||||
# Additionally, if this sequence is a continuation of a
|
||||
# parent sequence, we will need remove the parent sequence
|
||||
# from the sequence group.
|
||||
unselected_child_seqs.append((seq, parent))
|
||||
else:
|
||||
# An existing finished sequence has a low score, so we will
|
||||
# remove it from the sequence group.
|
||||
seq_group.remove(seq.seq_id)
|
||||
|
||||
# select the top beam_width sequences from the running
|
||||
# sequences for the next iteration to continue the beam
|
||||
# search.
|
||||
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
|
||||
if not seq.is_finished()]
|
||||
# Sort the running sequences by their scores.
|
||||
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
|
||||
reverse=True)
|
||||
|
||||
# Check if we can stop the beam search.
|
||||
if len(running_child_seqs) == 0:
|
||||
# No running sequences, stop the beam search.
|
||||
stop_beam_search = True
|
||||
elif len(all_finished_seqs) < beam_width:
|
||||
# Not enough finished sequences, continue the beam search.
|
||||
stop_beam_search = False
|
||||
else:
|
||||
# Check the early stopping criteria
|
||||
best_running_seq = running_child_seqs[0][0]
|
||||
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
||||
stop_beam_search = self._check_beam_search_early_stopping(
|
||||
seq_group.sampling_params.early_stopping,
|
||||
seq_group.sampling_params, best_running_seq, current_worst_seq)
|
||||
|
||||
if stop_beam_search:
|
||||
# Stop the beam search and remove all the running sequences from
|
||||
# the sequence group.
|
||||
unselected_child_seqs.extend(running_child_seqs)
|
||||
else:
|
||||
# Continue the beam search and select the top beam_width sequences
|
||||
# to continue the beam search.
|
||||
selected_child_seqs.extend(running_child_seqs[:beam_width])
|
||||
# The remaining running sequences will not be used in the next
|
||||
# iteration. Again, if these sequences are continuations of
|
||||
# parent sequences, we will need to remove the parent sequences
|
||||
# from the sequence group.
|
||||
unselected_child_seqs.extend(running_child_seqs[beam_width:])
|
||||
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
|
||||
# Remove the unselected parent sequences from the sequence group and
|
||||
# free their memory in block manager.
|
||||
for seq, parent in unselected_child_seqs:
|
||||
if seq is parent:
|
||||
# Remove the parent sequence if it is not selected for next
|
||||
# iteration
|
||||
seq_group.remove(seq.seq_id)
|
||||
self.scheduler.free_seq(seq)
|
||||
|
||||
def _process_model_outputs(
|
||||
self, output: SamplerOutput,
|
||||
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
|
||||
self, output: List[SamplerOutput],
|
||||
scheduled_seq_groups: List[SequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
|
||||
Returns RequestOutputs that can be returned to the client.
|
||||
"""
|
||||
|
||||
now = time.time()
|
||||
|
||||
# Organize outputs by [sequence group][step] instead of
|
||||
# [step][sequence group].
|
||||
output_by_sequence_group = create_output_by_sequence_group(
|
||||
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
|
||||
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||
for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
|
||||
for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
|
||||
output_by_sequence_group):
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
# If uncomputed tokens > 0, it means prefill is chunked.
|
||||
# We don't need to process outputs in that case.
|
||||
if seq_group.get_num_uncomputed_tokens() == 0:
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
self.output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
@ -657,13 +467,9 @@ class LLMEngine:
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups:
|
||||
for seq_group in ignored_seq_groups:
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
|
||||
# Log stats.
|
||||
if self.log_stats:
|
||||
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
||||
return request_outputs
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
@ -721,13 +527,23 @@ class LLMEngine:
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
output = self.model_executor.execute_model(
|
||||
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
|
||||
scheduler_outputs.blocks_to_swap_out,
|
||||
scheduler_outputs.blocks_to_copy)
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
|
||||
else:
|
||||
output = []
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups)
|
||||
|
||||
# Log stats.
|
||||
if self.log_stats:
|
||||
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
||||
|
||||
return request_outputs
|
||||
|
||||
def do_log_stats(self) -> None:
|
||||
"""Forced log when no requests active."""
|
||||
@ -807,87 +623,6 @@ class LLMEngine:
|
||||
time_e2e_requests=time_e2e_requests,
|
||||
)
|
||||
|
||||
def _check_stop(self, seq: Sequence, new_char_count: int,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Stop the finished sequences.
|
||||
|
||||
new_char_count is the number of chars added to the
|
||||
sequence's output text for the newly generated token
|
||||
"""
|
||||
|
||||
# Check if the minimum number of tokens has been generated yet;
|
||||
# skip the stop string/token checks if not
|
||||
if seq.get_output_len() < sampling_params.min_tokens:
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == seq.eos_token_id):
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if a stop token was encountered.
|
||||
# This assumes a single token produced per step.
|
||||
last_token_id = seq.get_last_token_id()
|
||||
if last_token_id in sampling_params.stop_token_ids:
|
||||
if new_char_count and (
|
||||
not sampling_params.include_stop_str_in_output):
|
||||
# Remove last token
|
||||
seq.output_text = seq.output_text[:-new_char_count]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = last_token_id
|
||||
return
|
||||
|
||||
# Check if any stop strings are matched.
|
||||
stop_str = self._check_stop_strings(seq, new_char_count,
|
||||
sampling_params)
|
||||
if stop_str is not None:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = stop_str
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _check_stop_strings(seq: Sequence, new_char_count: int,
|
||||
sampling_params: SamplingParams) -> Optional[str]:
|
||||
"""Check if any stop strings are matched and truncate sequence
|
||||
output text accordingly.
|
||||
|
||||
Returns the stop string if matched or else None.
|
||||
"""
|
||||
if not new_char_count:
|
||||
return None
|
||||
|
||||
for stop_str in sampling_params.stop:
|
||||
stop_string_len = len(stop_str)
|
||||
# Avoid searching already-searched text.
|
||||
stop_index = seq.output_text.find(
|
||||
stop_str, -new_char_count - stop_string_len)
|
||||
if stop_index == -1:
|
||||
continue
|
||||
|
||||
if sampling_params.include_stop_str_in_output:
|
||||
# Truncate to end of stop string.
|
||||
stop_index += stop_string_len
|
||||
if stop_index >= len(seq.output_text):
|
||||
# No truncation required.
|
||||
return stop_str
|
||||
|
||||
# Truncate the output text to either the beginning
|
||||
# or end of the stop string.
|
||||
seq.output_text = seq.output_text[:stop_index]
|
||||
return stop_str
|
||||
return None
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_executor.add_lora(lora_request)
|
||||
|
||||
|
0
vllm/engine/output_processor/__init__.py
Normal file
0
vllm/engine/output_processor/__init__.py
Normal file
69
vllm/engine/output_processor/interfaces.py
Normal file
69
vllm/engine/output_processor/interfaces.py
Normal file
@ -0,0 +1,69 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Iterable, List
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
|
||||
|
||||
class SequenceGroupOutputProcessor(ABC):
|
||||
"""Interface for logic that processes new token ids in sequence groups,
|
||||
managing detokenization, stop checking, and freeing/forking sequences with
|
||||
the scheduler.
|
||||
|
||||
This is highly coupled with the LLMEngine and should be seen as an extension
|
||||
of it. The logic is separated to simplify the LLMEngine class and allow
|
||||
separate implementations for single-step decoding (which supports beam
|
||||
search sequence forking) and multi-step decoding (which does not support
|
||||
beam search, but does support speculative decoding).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_output_processor(
|
||||
scheduler_config: SchedulerConfig,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: Scheduler,
|
||||
seq_counter: Iterable[int],
|
||||
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
|
||||
stop_checker: "StopChecker",
|
||||
):
|
||||
"""Create an output processor.
|
||||
|
||||
This returns a single-step output processor if num_lookahead_slots is
|
||||
zero, else returns a multi-step output processor.
|
||||
"""
|
||||
if scheduler_config.num_lookahead_slots == 0:
|
||||
# Importing here to avoid cycle.
|
||||
from vllm.engine.output_processor.single_step import (
|
||||
SingleStepOutputProcessor)
|
||||
return SingleStepOutputProcessor(
|
||||
scheduler_config,
|
||||
detokenizer,
|
||||
scheduler,
|
||||
seq_counter,
|
||||
stop_checker,
|
||||
)
|
||||
else:
|
||||
# Importing here to avoid cycle.
|
||||
from vllm.engine.output_processor.multi_step import (
|
||||
MultiStepOutputProcessor)
|
||||
return MultiStepOutputProcessor(
|
||||
detokenizer,
|
||||
scheduler,
|
||||
seq_counter,
|
||||
get_tokenizer_for_seq,
|
||||
stop_checker,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def process_outputs(self, sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
"""Process new token ids for the sequence group. Handles logic such as
|
||||
detokenization, stop checking, and freeing/forking sequences in the
|
||||
scheduler.
|
||||
"""
|
||||
pass
|
126
vllm/engine/output_processor/multi_step.py
Normal file
126
vllm/engine/output_processor/multi_step.py
Normal file
@ -0,0 +1,126 @@
|
||||
from typing import Callable, Iterable, List
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
"""SequenceGroupOutputProcessor which handles logic related to
|
||||
detokenization and stopping conditions. It specializes to "multi-step
|
||||
decoding", where vLLM's worker may generate multiple tokens per invocation.
|
||||
This is currently mutually exclusive with advanced sampling techniques like
|
||||
beam search, which motivates the separation of this logic from the single
|
||||
step output processor.
|
||||
|
||||
This class is responsible for things such as correctly appending all new
|
||||
token ids to their sequence, detokenizing new token ids, truncating new
|
||||
output tokens after an eos token, and correctly handling the case where the
|
||||
number of new output tokens per sequence differs in a single batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: Scheduler,
|
||||
seq_counter: Iterable[int],
|
||||
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
|
||||
stop_checker: StopChecker,
|
||||
):
|
||||
self.detokenizer = detokenizer
|
||||
self.scheduler = scheduler
|
||||
self.seq_counter = seq_counter
|
||||
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
||||
self.stop_checker = stop_checker
|
||||
|
||||
def process_outputs(self, sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
"""Append new tokens in the outputs to sequences in the sequence group.
|
||||
|
||||
This only supports sequence groups of size 1. It supports greater than
|
||||
one new token per sequence.
|
||||
|
||||
This applies logic like stop condition checking and detokenization,
|
||||
including freeing finished sequences. It also handles cases where there
|
||||
are tokens emitted after the EOS token.
|
||||
"""
|
||||
seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
|
||||
assert seqs, "expected running sequences"
|
||||
assert len(seqs) == 1, (
|
||||
"Beam search not supported in multi-step decoding.")
|
||||
seq = seqs[0]
|
||||
|
||||
# Since there's only one sequence per sequence group, we can take the
|
||||
# first sample.
|
||||
samples = [outputs[step].samples[0] for step in range(len(outputs))]
|
||||
|
||||
# -1 means the output token is not valid (eg. due to spec decode
|
||||
# rejecting tokens).
|
||||
valid_samples = [
|
||||
sample for sample in samples if sample.output_token != -1
|
||||
]
|
||||
assert valid_samples
|
||||
|
||||
self._process_seq_outputs(seq, valid_samples,
|
||||
sequence_group.sampling_params)
|
||||
|
||||
def _process_seq_outputs(self, seq: Sequence,
|
||||
valid_samples: List[SequenceOutput],
|
||||
sampling_params: SamplingParams) -> None:
|
||||
output_token_ids = [sample.output_token for sample in valid_samples]
|
||||
|
||||
# Truncate to max_tokens if necessary.
|
||||
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
|
||||
len(output_token_ids))
|
||||
if remaining_tokens < 0:
|
||||
valid_samples = valid_samples[:remaining_tokens]
|
||||
output_token_ids = output_token_ids[:remaining_tokens]
|
||||
|
||||
# Truncate any tokens after EOS. This is required as spec decode
|
||||
# generates a fixed number of tokens without evaluating stopping
|
||||
# conditions within the block. This can cause an eos token to be
|
||||
# unintentionally ignored.
|
||||
if not sampling_params.ignore_eos:
|
||||
eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
|
||||
# Avoiding .index calls as exception throwing in the happy path
|
||||
# is expensive.
|
||||
for i in range(len(output_token_ids)):
|
||||
if output_token_ids[i] == eos_token_id:
|
||||
output_token_ids = output_token_ids[:i + 1]
|
||||
valid_samples = valid_samples[:i + 1]
|
||||
break
|
||||
|
||||
# Incrementally append tokens to the sequence, as if we had only one new
|
||||
# token.
|
||||
for output_token_id in output_token_ids:
|
||||
seq.append_token_id(
|
||||
token_id=output_token_id,
|
||||
# TODO emit logprobs in multi-step decoding.
|
||||
logprobs={output_token_id: Logprob(0.0)},
|
||||
)
|
||||
|
||||
new_char_count = 0
|
||||
if sampling_params.detokenize:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, sampling_params)
|
||||
|
||||
self.stop_checker.maybe_stop_sequence(
|
||||
seq,
|
||||
new_char_count=new_char_count,
|
||||
sampling_params=sampling_params)
|
||||
if seq.is_finished():
|
||||
break
|
||||
|
||||
if seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
276
vllm/engine/output_processor/single_step.py
Normal file
276
vllm/engine/output_processor/single_step.py
Normal file
@ -0,0 +1,276 @@
|
||||
from typing import Iterable, List, Tuple, Union
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
|
||||
SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
"""SequenceGroupOutputProcessor which handles "output processing" logic,
|
||||
which happens after the model returns generated token ids and before
|
||||
scheduling of the next batch. Output processing logic includes
|
||||
detokenization, and determining if a sequence is finished (e.g. via max len
|
||||
or eos token).
|
||||
|
||||
The SingleStepOutputProcessor is specialized to the case where the model
|
||||
emits at most a single token per invocation, which precludes configurations
|
||||
such as speculative decoding or multi-step decoding. This enables beam
|
||||
search sampling, which requires forking/finishing/freeing sequences in a way
|
||||
that is currently difficult to schedule multiple steps ahead of time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler_config: SchedulerConfig,
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: Scheduler,
|
||||
seq_counter: Iterable[int],
|
||||
stop_checker: StopChecker,
|
||||
):
|
||||
self.scheduler_config = scheduler_config
|
||||
self.detokenizer = detokenizer
|
||||
self.scheduler = scheduler
|
||||
self.seq_counter = seq_counter
|
||||
self.stop_checker = stop_checker
|
||||
|
||||
def process_outputs(self, sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput]) -> None:
|
||||
"""Append all new tokens to sequences in the sequence group. Fork any
|
||||
surviving beam candidates; free any unsurviving ones.
|
||||
|
||||
Invokes detokenizer to detokenize new tokens, and also marks sequences
|
||||
as finished if they meet stop conditions.
|
||||
"""
|
||||
assert (len(outputs) == 1
|
||||
), f"{type(self)} does not support multiple outputs per step"
|
||||
return self._process_sequence_group_outputs(sequence_group, outputs[0])
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
outputs: SequenceGroupOutput) -> None:
|
||||
|
||||
# Process prompt logprobs
|
||||
prompt_logprobs = outputs.prompt_logprobs
|
||||
if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
|
||||
self.detokenizer.decode_prompt_logprobs_inplace(
|
||||
seq_group, prompt_logprobs)
|
||||
seq_group.prompt_logprobs = prompt_logprobs
|
||||
|
||||
# Process samples
|
||||
samples = outputs.samples
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||
parent_child_dict = {
|
||||
parent_seq.seq_id: []
|
||||
for parent_seq in parent_seqs
|
||||
}
|
||||
for sample in samples:
|
||||
parent_child_dict[sample.parent_seq_id].append(sample)
|
||||
# List of (child, parent)
|
||||
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
||||
|
||||
# Process the child samples for each parent sequence
|
||||
for parent in parent_seqs:
|
||||
child_samples: List[SequenceOutput] = parent_child_dict[
|
||||
parent.seq_id]
|
||||
if len(child_samples) == 0:
|
||||
# This parent sequence has no children samples. Remove
|
||||
# the parent sequence from the sequence group since it will
|
||||
# not be used in the future iterations.
|
||||
parent.status = SequenceStatus.FINISHED_ABORTED
|
||||
seq_group.remove(parent.seq_id)
|
||||
self.scheduler.free_seq(parent)
|
||||
continue
|
||||
# Fork the parent sequence if there are multiple child samples.
|
||||
for child_sample in child_samples[:-1]:
|
||||
new_child_seq_id = next(self.seq_counter)
|
||||
child = parent.fork(new_child_seq_id)
|
||||
child.append_token_id(child_sample.output_token,
|
||||
child_sample.logprobs)
|
||||
child_seqs.append((child, parent))
|
||||
# Continue the parent sequence for the last child sample.
|
||||
# We reuse the parent sequence here to reduce redundant memory
|
||||
# copies, especially when using non-beam search sampling methods.
|
||||
last_child_sample = child_samples[-1]
|
||||
parent.append_token_id(last_child_sample.output_token,
|
||||
last_child_sample.logprobs)
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
if seq_group.sampling_params.detokenize:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, seq_group.sampling_params)
|
||||
else:
|
||||
new_char_count = 0
|
||||
self.stop_checker.maybe_stop_sequence(seq, new_char_count,
|
||||
seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
if not seq_group.sampling_params.use_beam_search:
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
# NOTE: we need to fork the new sequences before freeing the
|
||||
# old sequences.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
return
|
||||
|
||||
# Beam search case
|
||||
# Select the child sequences to keep in the sequence group.
|
||||
selected_child_seqs = []
|
||||
unselected_child_seqs = []
|
||||
beam_width = seq_group.sampling_params.best_of
|
||||
length_penalty = seq_group.sampling_params.length_penalty
|
||||
|
||||
# Select the newly finished sequences with the highest scores
|
||||
# to replace existing finished sequences.
|
||||
# Tuple of (seq, parent, is_new)
|
||||
existing_finished_seqs = [(seq, None, False)
|
||||
for seq in existing_finished_seqs]
|
||||
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
|
||||
if seq.is_finished()]
|
||||
all_finished_seqs = existing_finished_seqs + new_finished_seqs
|
||||
# Sort the finished sequences by their scores.
|
||||
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
|
||||
reverse=True)
|
||||
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
||||
if is_new:
|
||||
# A newly generated child sequence finishes and has a high
|
||||
# score, so we will add it into the sequence group.
|
||||
selected_child_seqs.append((seq, parent))
|
||||
for seq, parent, is_new in all_finished_seqs[beam_width:]:
|
||||
if is_new:
|
||||
# A newly generated child sequence finishes but has a low
|
||||
# score, so we will not add it into the sequence group.
|
||||
# Additionally, if this sequence is a continuation of a
|
||||
# parent sequence, we will need remove the parent sequence
|
||||
# from the sequence group.
|
||||
unselected_child_seqs.append((seq, parent))
|
||||
else:
|
||||
# An existing finished sequence has a low score, so we will
|
||||
# remove it from the sequence group.
|
||||
seq_group.remove(seq.seq_id)
|
||||
|
||||
# select the top beam_width sequences from the running
|
||||
# sequences for the next iteration to continue the beam
|
||||
# search.
|
||||
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
|
||||
if not seq.is_finished()]
|
||||
# Sort the running sequences by their scores.
|
||||
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
|
||||
reverse=True)
|
||||
|
||||
# Check if we can stop the beam search.
|
||||
if len(running_child_seqs) == 0:
|
||||
# No running sequences, stop the beam search.
|
||||
stop_beam_search = True
|
||||
elif len(all_finished_seqs) < beam_width:
|
||||
# Not enough finished sequences, continue the beam search.
|
||||
stop_beam_search = False
|
||||
else:
|
||||
# Check the early stopping criteria
|
||||
best_running_seq = running_child_seqs[0][0]
|
||||
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
||||
stop_beam_search = self._check_beam_search_early_stopping(
|
||||
seq_group.sampling_params.early_stopping,
|
||||
seq_group.sampling_params, best_running_seq, current_worst_seq)
|
||||
|
||||
if stop_beam_search:
|
||||
# Stop the beam search and remove all the running sequences from
|
||||
# the sequence group.
|
||||
unselected_child_seqs.extend(running_child_seqs)
|
||||
else:
|
||||
# Continue the beam search and select the top beam_width sequences
|
||||
# to continue the beam search.
|
||||
selected_child_seqs.extend(running_child_seqs[:beam_width])
|
||||
# The remaining running sequences will not be used in the next
|
||||
# iteration. Again, if these sequences are continuations of
|
||||
# parent sequences, we will need to remove the parent sequences
|
||||
# from the sequence group.
|
||||
unselected_child_seqs.extend(running_child_seqs[beam_width:])
|
||||
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
|
||||
# Remove the unselected parent sequences from the sequence group and
|
||||
# free their memory in block manager.
|
||||
for seq, parent in unselected_child_seqs:
|
||||
if seq is parent:
|
||||
# Remove the parent sequence if it is not selected for next
|
||||
# iteration
|
||||
seq_group.remove(seq.seq_id)
|
||||
self.scheduler.free_seq(seq)
|
||||
|
||||
def _check_beam_search_early_stopping(
|
||||
self,
|
||||
early_stopping: Union[bool, str],
|
||||
sampling_params: SamplingParams,
|
||||
best_running_seq: Sequence,
|
||||
current_worst_seq: Sequence,
|
||||
) -> bool:
|
||||
assert sampling_params.use_beam_search
|
||||
length_penalty = sampling_params.length_penalty
|
||||
if early_stopping is True:
|
||||
return True
|
||||
|
||||
current_worst_score = current_worst_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=current_worst_seq.eos_token_id)
|
||||
if early_stopping is False:
|
||||
highest_attainable_score = best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=best_running_seq.eos_token_id)
|
||||
else:
|
||||
assert early_stopping == "never"
|
||||
if length_penalty > 0.0:
|
||||
# If length_penalty > 0.0, beam search will prefer longer
|
||||
# sequences. The highest attainable score calculation is
|
||||
# based on the longest possible sequence length in this case.
|
||||
max_possible_length = max(
|
||||
best_running_seq.get_prompt_len() +
|
||||
sampling_params.max_tokens,
|
||||
self.scheduler_config.max_model_len)
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=best_running_seq.eos_token_id,
|
||||
seq_len=max_possible_length))
|
||||
else:
|
||||
# Otherwise, beam search will prefer shorter sequences. The
|
||||
# highest attainable score calculation is based on the current
|
||||
# sequence length.
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=best_running_seq.eos_token_id))
|
||||
return current_worst_score >= highest_attainable_score
|
101
vllm/engine/output_processor/stop_checker.py
Normal file
101
vllm/engine/output_processor/stop_checker.py
Normal file
@ -0,0 +1,101 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Sequence, SequenceStatus
|
||||
|
||||
|
||||
class StopChecker:
|
||||
"""LLMEngine helper class which separates out the logic involving stop
|
||||
checking. This checks things such as: whether the eos token was emitted,
|
||||
whether the max_tokens has been consumed, whether a stop string has been
|
||||
emitted, or if we have exceeded the max model len.
|
||||
"""
|
||||
|
||||
def __init__(self, max_model_len: int,
|
||||
get_tokenizer_for_seq: Callable[[Sequence],
|
||||
PreTrainedTokenizer]):
|
||||
self.max_model_len = max_model_len
|
||||
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
||||
|
||||
def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Stop the finished sequences.
|
||||
|
||||
new_char_count is the number of chars added to the
|
||||
sequence's output text for the newly generated token
|
||||
"""
|
||||
|
||||
# Check if the minimum number of tokens has been generated yet;
|
||||
# skip the stop string/token checks if not
|
||||
if seq.get_output_len() < sampling_params.min_tokens:
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == seq.eos_token_id):
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if a stop token was encountered.
|
||||
# This assumes a single token produced per step.
|
||||
last_token_id = seq.get_last_token_id()
|
||||
if last_token_id in sampling_params.stop_token_ids:
|
||||
if new_char_count and (
|
||||
not sampling_params.include_stop_str_in_output):
|
||||
# Remove last token
|
||||
seq.output_text = seq.output_text[:-new_char_count]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = last_token_id
|
||||
return
|
||||
|
||||
# Check if any stop strings are matched.
|
||||
stop_str = self._check_stop_strings(seq, new_char_count,
|
||||
sampling_params)
|
||||
if stop_str is not None:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = stop_str
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.max_model_len:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _check_stop_strings(seq: Sequence, new_char_count: int,
|
||||
sampling_params: SamplingParams) -> Optional[str]:
|
||||
"""Check if any stop strings are matched and truncate sequence
|
||||
output text accordingly.
|
||||
|
||||
Returns the stop string if matched or else None.
|
||||
"""
|
||||
if not new_char_count:
|
||||
return None
|
||||
|
||||
for stop_str in sampling_params.stop:
|
||||
stop_string_len = len(stop_str)
|
||||
# Avoid searching already-searched text.
|
||||
stop_index = seq.output_text.find(
|
||||
stop_str, -new_char_count - stop_string_len)
|
||||
if stop_index == -1:
|
||||
continue
|
||||
|
||||
if sampling_params.include_stop_str_in_output:
|
||||
# Truncate to end of stop string.
|
||||
stop_index += stop_string_len
|
||||
if stop_index >= len(seq.output_text):
|
||||
# No truncation required.
|
||||
return stop_str
|
||||
|
||||
# Truncate the output text to either the beginning
|
||||
# or end of the stop string.
|
||||
seq.output_text = seq.output_text[:stop_index]
|
||||
return stop_str
|
||||
return None
|
16
vllm/engine/output_processor/util.py
Normal file
16
vllm/engine/output_processor/util.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import List
|
||||
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
|
||||
num_seq_groups: int):
|
||||
"""Helper method which transforms a 2d list organized by
|
||||
[step][sequence group] into [sequence group][step].
|
||||
"""
|
||||
output_by_sequence_group = [[] for _ in range(num_seq_groups)]
|
||||
for step in sampler_outputs:
|
||||
for i, sequence_group_output in enumerate(step):
|
||||
output_by_sequence_group[i].append(sequence_group_output)
|
||||
|
||||
return output_by_sequence_group
|
@ -74,7 +74,8 @@ class CPUExecutor(ExecutorBase):
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int) -> List[SamplerOutput]:
|
||||
output = self.driver_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
|
@ -72,8 +72,9 @@ class ExecutorBase(ABC):
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
||||
"""Executes one model step on the given sequences."""
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int) -> List[SamplerOutput]:
|
||||
"""Executes at least one model step on the given sequences."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
@ -13,13 +13,17 @@ logger = init_logger(__name__)
|
||||
class GPUExecutor(ExecutorBase):
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
assert (not self.speculative_config
|
||||
), "Speculative decoding not yet supported for GPU backend"
|
||||
"""Initialize the worker and load the model.
|
||||
|
||||
# Instantiate the worker and load the model to GPU.
|
||||
self._init_worker()
|
||||
If speculative decoding is enabled, we instead create the speculative
|
||||
worker.
|
||||
"""
|
||||
if self.speculative_config is None:
|
||||
self._init_non_spec_worker()
|
||||
else:
|
||||
self._init_spec_worker()
|
||||
|
||||
def _init_worker(self):
|
||||
def _init_non_spec_worker(self):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
from vllm.worker.worker import Worker
|
||||
@ -46,6 +50,57 @@ class GPUExecutor(ExecutorBase):
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def _init_spec_worker(self):
|
||||
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
|
||||
"""
|
||||
assert self.speculative_config is not None
|
||||
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
|
||||
target_worker = Worker(
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
draft_worker = MultiStepWorker(
|
||||
model_config=self.speculative_config.draft_model_config,
|
||||
parallel_config=self.speculative_config.draft_parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
spec_decode_worker = SpecDecodeWorker.from_workers(
|
||||
proposer_worker=draft_worker, scorer_worker=target_worker)
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"GPUExecutor only supports single GPU.")
|
||||
|
||||
self.driver_worker = spec_decode_worker
|
||||
|
||||
# Load model handled in spec decode worker.
|
||||
self.driver_worker.init_device()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
@ -63,16 +118,20 @@ class GPUExecutor(ExecutorBase):
|
||||
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int,
|
||||
) -> List[SamplerOutput]:
|
||||
output = self.driver_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
)
|
||||
return output
|
||||
|
||||
|
@ -48,10 +48,13 @@ class NeuronExecutor(ExecutorBase):
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int) -> List[SamplerOutput]:
|
||||
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
|
||||
and blocks_to_copy == {}), (
|
||||
"Cache operations are not supported for Neuron backend.")
|
||||
assert num_lookahead_slots == 0, (
|
||||
"lookahead not supported for Neuron backend.")
|
||||
|
||||
output = self.driver_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list)
|
||||
|
@ -242,7 +242,8 @@ class RayGPUExecutor(ExecutorBase):
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int = 0) -> SamplerOutput:
|
||||
all_outputs = self._run_workers(
|
||||
"execute_model",
|
||||
driver_kwargs={
|
||||
|
@ -693,3 +693,16 @@ class SamplerOutput:
|
||||
def __eq__(self, other: object):
|
||||
return isinstance(other,
|
||||
self.__class__) and self.outputs == other.outputs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Show the shape of a tensor instead of its values to reduce noise.
|
||||
"""
|
||||
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
|
||||
else self.sampled_token_probs.shape)
|
||||
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
|
||||
self.sampled_token_ids.shape)
|
||||
return (
|
||||
f"SamplerOutput(outputs={self.outputs}, "
|
||||
f"sampled_token_probs={sampled_token_probs_repr}, "
|
||||
f"sampled_token_ids={sampled_token_ids_repr}, "
|
||||
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
||||
|
@ -6,10 +6,10 @@ import torch
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||
sampler_output_to_torch,
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors,
|
||||
nvtx_range, sampler_output_to_torch,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
SeqId = int
|
||||
TargetSeqId = int
|
||||
@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
of topk/tree.
|
||||
"""
|
||||
|
||||
def __init__(self, scorer_worker: Worker, device: str, vocab_size: int):
|
||||
def __init__(self, scorer_worker: WorkerBase, device: str,
|
||||
vocab_size: int):
|
||||
self._scorer_worker = scorer_worker
|
||||
self._device = device
|
||||
self._vocab_size = vocab_size
|
||||
@ -83,7 +84,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
return_python_output=False)
|
||||
)
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
all_tokens, all_probs = self._contract_batch(
|
||||
original_bs=len(seq_group_metadata_list),
|
||||
@ -142,6 +145,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
"""
|
||||
|
||||
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
|
||||
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
||||
maybe_mock_device_tensors(
|
||||
sampler_output=target_sampler_output,
|
||||
batch_size=len(non_spec_indices) + num_scoring_tokens,
|
||||
vocab_size=self._vocab_size,
|
||||
device=self._device,
|
||||
)
|
||||
|
||||
(target_token_ids, target_probs, non_spec_target_token_ids,
|
||||
non_spec_target_probs) = self._split_scoring_output(
|
||||
target_sampler_output, num_scoring_tokens)
|
||||
|
@ -6,7 +6,8 @@ import torch
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
from vllm.spec_decode.util import (maybe_mock_device_tensors,
|
||||
sampler_output_to_torch)
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
@ -69,6 +70,9 @@ class MultiStepWorker(Worker):
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
|
||||
self._append_new_tokens(model_output,
|
||||
copied_seq_group_metadata_list)
|
||||
@ -341,6 +345,16 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
|
||||
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
|
||||
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
||||
for step_output in sampler_output:
|
||||
maybe_mock_device_tensors(
|
||||
sampler_output=step_output,
|
||||
batch_size=len(proposal_lens),
|
||||
vocab_size=self._vocab_size,
|
||||
device=self._device,
|
||||
)
|
||||
|
||||
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
||||
sampler_output)
|
||||
|
||||
|
@ -3,8 +3,9 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata,
|
||||
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceOutput)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
@ -13,8 +14,9 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
@ -45,10 +47,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_workers(cls, proposer_worker: MultiStepWorker,
|
||||
scorer_worker: WorkerBase) -> "SpecDecodeWorker":
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
# TODO(cade) disable strict mode for speedup.
|
||||
rejection_sampler=RejectionSampler(strict_mode=True),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proposer_worker: MultiStepWorker,
|
||||
scorer_worker: Worker,
|
||||
scorer_worker: WorkerBase,
|
||||
rejection_sampler: RejectionSampler,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
):
|
||||
@ -87,6 +99,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.scorer_worker.init_device()
|
||||
self.proposer_worker.init_device()
|
||||
|
||||
# NOTE(cade): load_model is not part of the WorkerBase interface.
|
||||
self.scorer_worker.load_model()
|
||||
self.proposer_worker.load_model()
|
||||
|
||||
self._metrics.init_gpu_tensors(self.rank)
|
||||
self.rejection_sampler.init_gpu_tensors(self.rank)
|
||||
self.scorer = BatchExpansionTop1Scorer(
|
||||
@ -131,7 +147,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
num_spec_tokens: int,
|
||||
num_lookahead_slots: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Perform speculative decoding on the input batch.
|
||||
"""
|
||||
@ -140,9 +156,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"speculative decoding "
|
||||
"requires non-None seq_group_metadata_list")
|
||||
|
||||
logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}")
|
||||
|
||||
# If no spec tokens, call the proposer and scorer workers normally.
|
||||
# Used for prefill.
|
||||
if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0:
|
||||
if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
|
||||
return self._run_no_spec(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
@ -155,7 +173,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
k=num_spec_tokens,
|
||||
k=num_lookahead_slots,
|
||||
)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
@ -170,20 +188,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposer and scorer model so that the KV cache is consistent between the
|
||||
two.
|
||||
"""
|
||||
logger.info("run proposer worker no spec")
|
||||
|
||||
self.proposer_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
return_python_output=False)
|
||||
)
|
||||
|
||||
logger.info("run target worker no spec")
|
||||
sampler_output = self.scorer_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
assert len(sampler_output) == 1
|
||||
sampler_output = sampler_output[0]
|
||||
|
||||
# Clear device tensors from sampler output. This reduces communication
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
@ -209,11 +231,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
sequence.
|
||||
"""
|
||||
|
||||
logger.info("get spec proposals")
|
||||
# Generate proposals using draft worker.
|
||||
proposals = self.proposer_worker.get_spec_proposals(
|
||||
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||
blocks_to_copy, k)
|
||||
|
||||
logger.info("score proposals")
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
@ -223,9 +247,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposals,
|
||||
)
|
||||
|
||||
logger.info("verify proposals")
|
||||
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
|
||||
proposal_scores, proposals, k)
|
||||
|
||||
logger.info("create output list")
|
||||
return self._create_output_sampler_list(seq_group_metadata_list,
|
||||
accepted_token_ids, k)
|
||||
|
||||
@ -311,7 +337,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
# TODO Add verifier logprobs.
|
||||
logprobs={token_id: 0.0},
|
||||
logprobs={token_id: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
|
@ -82,6 +82,32 @@ def sampler_output_to_torch(
|
||||
return sampled_token_ids, sampled_token_probs
|
||||
|
||||
|
||||
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
|
||||
vocab_size: int, device: str) -> None:
|
||||
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
|
||||
values. This will be removed in PR 7/9.
|
||||
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
||||
"""
|
||||
values = [
|
||||
sampler_output.sampled_token_probs, sampler_output.sampled_token_ids
|
||||
]
|
||||
assert all(v is None for v in values) or not any(v is None for v in values)
|
||||
if not any(v is None for v in values):
|
||||
# Do nothing if the tensors are already created (usually in unit tests).
|
||||
return
|
||||
|
||||
# Softmax to ensure valid probs.
|
||||
sampler_output.sampled_token_probs = torch.nn.functional.softmax(
|
||||
torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device),
|
||||
dim=-1)
|
||||
|
||||
sampler_output.sampled_token_ids = torch.randint(low=10,
|
||||
high=100,
|
||||
size=(batch_size, ),
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def nvtx_range(msg, *args, **kwargs):
|
||||
"""
|
||||
|
@ -251,7 +251,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
||||
) -> Optional[SamplerOutput]:
|
||||
) -> List[SamplerOutput]:
|
||||
if self.is_driver_worker:
|
||||
assert seq_group_metadata_list is not None
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
@ -274,11 +274,13 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if num_seq_groups == 0:
|
||||
return {}
|
||||
return []
|
||||
|
||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||
self.cpu_cache)
|
||||
return output
|
||||
|
||||
# CPU worker only supports single-step execution.
|
||||
return [output]
|
||||
|
||||
def init_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""A Neuron worker class."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -73,15 +73,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Optional[SamplerOutput]:
|
||||
) -> List[SamplerOutput]:
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if num_seq_groups == 0:
|
||||
return {}
|
||||
return []
|
||||
|
||||
output = self.model_runner.execute_model(seq_group_metadata_list)
|
||||
return output
|
||||
|
||||
# Neuron worker only supports single-step output. Wrap the output in a
|
||||
# list to conform to interface.
|
||||
return [output]
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Determine the size in bytes of a cache block.
|
||||
|
@ -210,7 +210,9 @@ class Worker(WorkerBase):
|
||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
||||
) -> Optional[SamplerOutput]:
|
||||
num_lookahead_slots: int = 0,
|
||||
) -> List[SamplerOutput]:
|
||||
|
||||
if self.is_driver_worker:
|
||||
assert seq_group_metadata_list is not None
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
@ -235,11 +237,14 @@ class Worker(WorkerBase):
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if num_seq_groups == 0:
|
||||
return {}
|
||||
return []
|
||||
|
||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||
self.gpu_cache)
|
||||
return output
|
||||
|
||||
# Worker only supports single-step execution. Wrap the output in a list
|
||||
# to conform to interface.
|
||||
return [output]
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
@ -40,12 +40,13 @@ class WorkerBase(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def execute_model(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
||||
"""Executes one model step on the given sequences."""
|
||||
def execute_model(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
|
||||
int],
|
||||
blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
|
||||
"""Executes at least one model step on the given sequences, unless no
|
||||
sequences are provided."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user