[Spec Decode] (1/2) Remove batch expansion (#8839)

This commit is contained in:
Lily Liu 2024-10-01 16:04:42 -07:00 committed by GitHub
parent 22f5851b80
commit 1570203864
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 531 additions and 99 deletions

View File

@ -208,7 +208,7 @@ steps:
- tests/spec_decode
commands:
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
- pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
- label: LoRA Test %N # 15min each
mirror_hardwares: [amd]

View File

@ -434,7 +434,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens=seq_lens if seq_lens else None,
query_lens=seq_lens if seq_lens else None,
query_lens=seq_lens if seq_lens else [1] * batch_size,
device=device,
pin_memory=is_pin_memory_available())
# the logits tensor is modified in-place by the sampler

View File

@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
max_output_len=32,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@ -173,7 +173,6 @@ def test_same_output_for_multi_step():
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(

View File

@ -0,0 +1,65 @@
import pytest
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
from vllm.spec_decode.mqa_scorer import MQAScorer
from vllm.worker.worker import Worker
from .utils import create_batch, create_worker
def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
device: str) -> SpeculativeProposals:
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
device=device)
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)
return SpeculativeProposals(proposal_token_ids, proposal_probs,
proposal_lens)
def assert_score_equal(score1: SpeculativeScores,
score2: SpeculativeScores) -> None:
assert torch.allclose(score1.probs, score2.probs)
assert torch.allclose(score1.logprobs, score2.logprobs)
assert torch.equal(score1.token_ids, score2.token_ids)
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
@pytest.mark.parametrize('propose_len', [1, 3, 5])
@pytest.mark.parametrize('device', ['cuda'])
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
device: str) -> None:
"""
Compare the batch expansion scorer and mqa scorer return the same score
"""
seed = 0
block_size = 32
num_gpu_blocks = 2048 // block_size
scorer_worker = create_worker(Worker, model_name, block_size,
num_gpu_blocks, seed)
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
scorer_worker.model_runner.model.sampler.\
should_modify_greedy_probs_inplace = True
vocab_size = scorer_worker.vocab_size
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
seq_group_metadatalist, _, _ = create_batch(batch_size,
propose_len,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
requests = ExecuteModelRequest(seq_group_metadatalist,
num_lookahead_slots=propose_len)
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
vocab_size)
batch_expansion_score = batch_expansion_scorer.score_proposals(
requests, proposals)
mqa_scorer = MQAScorer(scorer_worker, device, vocab_size)
mqa_score = mqa_scorer.score_proposals(requests, proposals)
assert_score_equal(batch_expansion_score, mqa_score)

View File

@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_correctly_calls_target_model(k: int, batch_size: int,
acceptance_sampler_method: str):
def test_batch_expansion_correctly_calls_target_model(
k: int, batch_size: int, acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out.
inputs with batch expansion. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
metrics_collector=metrics_collector,
disable_mqa_scorer=True)
worker.init_device()
vocab_size = 32_000

View File

@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts(
for i, final_len in enumerate(final_prompt_lens)
}
return [
SequenceGroupMetadata(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data={
i: SequenceData.from_seqs(prompt_token_ids[:],
cont_token_ids[:]),
},
sampling_params=SamplingParams(temperature=0.0, ),
block_tables={i: block_allocations[i][:]},
) for i, (prompt_token_ids,
cont_token_ids) in enumerate(zip(prompts, continuations))
]
seq_grou_metadata_list = []
for i, (prompt_token_ids,
cont_token_ids) in enumerate(zip(prompts, continuations)):
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
data.update_num_computed_tokens(
len(prompt_token_ids) + len(cont_token_ids) - 1)
seq_data = {i: data}
seq_grou_metadata_list.append(
SequenceGroupMetadata(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations[i][:]},
))
return seq_grou_metadata_list
def assert_logprobs_dict_allclose(

View File

@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None
_cached_prefill_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional[

View File

@ -245,8 +245,15 @@ class FlashAttentionMetadata(AttentionMetadata):
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
# Maximum query length in the batch.
max_query_len: Optional[int]
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
@ -303,6 +310,7 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
decode_query_len=0,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
@ -331,7 +339,8 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
decode_query_len=self.decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
@ -461,9 +470,6 @@ class FlashAttentionMetadataBuilder(
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
@ -518,6 +524,11 @@ class FlashAttentionMetadataBuilder(
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
decode_query_len = max(decode_query_lens)
else:
decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
@ -586,6 +597,7 @@ class FlashAttentionMetadataBuilder(
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
decode_query_len=decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
@ -786,8 +798,12 @@ class FlashAttentionImpl(AttentionImpl):
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
_, num_head, head_dim = decode_query.shape
decode_query = decode_query.reshape(-1,
decode_meta.decode_query_len,
num_head, head_dim)
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
decode_query,
key_cache,
value_cache,
block_table=decode_meta.block_tables,
@ -796,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
causal=True,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
)
if prefill_output is None:
assert decode_output is not None
@ -804,5 +820,11 @@ class FlashAttentionImpl(AttentionImpl):
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_tokens, hidden_size)
# Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill.
assert decode_meta is not None
assert decode_meta.decode_query_len == 1
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)

View File

@ -595,7 +595,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
@ -634,7 +633,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,

View File

@ -116,9 +116,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None

View File

@ -312,7 +312,8 @@ class CommonAttentionState(AttentionState):
slot_mapping=self._graph_slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=None,
max_query_len=1,
decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,

View File

@ -118,6 +118,12 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].

View File

@ -1116,6 +1116,7 @@ class SpeculativeConfig:
speculative_model_quantization: Optional[str],
speculative_draft_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int],
speculative_disable_mqa_scorer: Optional[bool],
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
@ -1150,6 +1151,9 @@ class SpeculativeConfig:
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required.
speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
scorer for the speculative model and fall back to batch
expansion for scoring.
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
speculation for some sequences.
@ -1304,6 +1308,7 @@ class SpeculativeConfig:
draft_model_config,
draft_parallel_config,
num_speculative_tokens,
speculative_disable_mqa_scorer,
speculative_disable_by_batch_size,
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
@ -1400,6 +1405,7 @@ class SpeculativeConfig:
draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig,
num_speculative_tokens: int,
speculative_disable_mqa_scorer: Optional[bool],
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
@ -1446,6 +1452,7 @@ class SpeculativeConfig:
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens
self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
self.speculative_disable_by_batch_size = \
speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0

View File

@ -162,6 +162,7 @@ class EngineArgs:
speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_disable_mqa_scorer: Optional[bool] = False
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
@ -640,6 +641,12 @@ class EngineArgs:
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-disable-mqa-scorer',
action='store_true',
help=
'If set to True, the MQA scorer will be disabled in speculative '
' and fall back to batch expansion')
parser.add_argument(
'--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
@ -970,6 +977,7 @@ class EngineArgs:
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len,

View File

@ -1110,6 +1110,8 @@ class LLMEngine:
update_prefill_num_computed_tokens(seq_group, seq_group_meta,
len(output),
is_first_step_output)
elif not is_async:
seq_group.update_num_computed_tokens(1)
if outputs:
for o in outputs:
@ -1133,8 +1135,16 @@ class LLMEngine:
else:
self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(
output_token_num = self.output_processor.process_outputs(
seq_group, output, is_async)
if self.speculative_config:
# We -1 here because we always
# (w/o speculative decoding) add the number of
# computed tokens by one in the decoding phase.
# Therefore, we remove that one token that
# is already added.
seq_group.update_num_computed_tokens(output_token_num -
1)
if seq_group.is_finished():
finished_now.append(i)
@ -1251,11 +1261,12 @@ class LLMEngine:
# decodes after the very first step. Therefore,
# we skip the update to the num_computed_tokens
# here.
pass
seq_group.update_num_computed_tokens(1)
else:
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
else:
seq_group.update_num_computed_tokens(1)
if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (
"Async output processor expects a single sample"
@ -1266,7 +1277,6 @@ class LLMEngine:
assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)
seq_group.update_num_computed_tokens(1)
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, List
from typing import Callable, List, Optional
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
@ -58,10 +58,14 @@ class SequenceGroupOutputProcessor(ABC):
@abstractmethod
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool) -> None:
is_async: bool) -> Optional[int]:
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
Return the number of new tokens generated in the sequence group.
The returned value is optional because it is only used for
speculative decoding mqa scorer.
"""
pass

View File

@ -1,5 +1,5 @@
import functools
from typing import Callable, List
from typing import Callable, List, Optional
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
@ -69,7 +69,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def process_outputs(self,
sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool = False) -> None:
is_async: bool = False) -> Optional[int]:
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
@ -84,6 +84,10 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
Returns:
The number of tokens appended to the sequence. This is optional
because only speculative decode uses this return value.
"""
# Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
@ -106,6 +110,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic
self._process_decode_and_stop(seq, sequence_group.sampling_params)
return None
else:
# Standard multi-step case
@ -121,8 +126,8 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
]
assert valid_samples
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
return self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
def _process_decode_and_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
@ -140,7 +145,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput],
sampling_params: SamplingParams) -> None:
sampling_params: SamplingParams) -> int:
output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples]
@ -148,7 +153,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
len(output_token_ids))
if remaining_tokens < 0:
valid_samples = valid_samples[:remaining_tokens]
output_token_ids = output_token_ids[:remaining_tokens]
# Truncate any tokens after EOS. This is required as spec decode
@ -162,7 +166,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
for i in range(len(output_token_ids)):
if output_token_ids[i] == eos_token_id:
output_token_ids = output_token_ids[:i + 1]
valid_samples = valid_samples[:i + 1]
break
# Incrementally append tokens to the sequence, as if we had only one new
@ -173,9 +176,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
token_id=output_token_id,
logprobs=output_logprob,
)
seq.data.update_num_computed_tokens(1)
self._process_decode_and_stop(seq, sampling_params)
if seq.is_finished():
break
return len(output_token_ids)

View File

@ -912,7 +912,7 @@ def get_logprobs(
sampling_metadata: SamplingMetadata,
sample_results: SampleResultType,
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
"""Return sample lobprobs and prompt logprobs.
"""Return sample logprobs and prompt logprobs.
The logic consists of 3 parts.
- Select indices to compute logprob from, ranks of token ids, and

View File

@ -146,7 +146,7 @@ class SamplingMetadata:
def prepare(
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int],
query_lens: Optional[List[int]],
query_lens: List[int],
device: str,
pin_memory: bool,
generators: Optional[Dict[str, torch.Generator]] = None,
@ -194,7 +194,7 @@ class SamplingMetadata:
def _prepare_seq_groups(
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int],
query_lens: Optional[List[int]],
query_lens: List[int],
device: str,
generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
@ -284,7 +284,8 @@ def _prepare_seq_groups(
else:
# Decode
prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0
query_len = query_lens[i] if query_lens is not None else 1
sample_len = len(seq_ids) * query_len if do_sample else 0
if sampling_params.seed is not None and generators is not None:
generator = generators.get(seq_group_metadata.request_id)
@ -440,14 +441,14 @@ class SamplingTensors:
if seq_group.do_sample:
sample_lens = len(seq_group.sample_indices)
assert sample_lens == len(seq_ids)
temperatures += [temperature] * len(seq_ids)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
min_ps += [min_p] * len(seq_ids)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
assert sample_lens >= len(seq_ids)
temperatures += [temperature] * sample_lens
top_ps += [top_p] * sample_lens
top_ks += [top_k] * sample_lens
min_ps += [min_p] * sample_lens
presence_penalties += [p] * sample_lens
frequency_penalties += [f] * sample_lens
repetition_penalties += [r] * sample_lens
if do_penalties:
for seq_group in sampling_metadata.seq_groups:

View File

@ -12,7 +12,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
from vllm.worker.worker_base import WorkerBase
SeqId = int
TargetSeqId = int
@ -36,12 +35,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
of topk/tree.
"""
def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
self._scorer_worker = scorer_worker
self._device = device
self._vocab_size = vocab_size
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,

View File

@ -94,8 +94,6 @@ class TP1DraftModelRunner(ModelRunner):
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode
def _gpu_advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata,

View File

@ -5,6 +5,7 @@ from typing import Optional, Set
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.worker.worker_base import WorkerBase
@dataclass
@ -74,6 +75,12 @@ class SpeculativeProposer(ABC):
class SpeculativeScorer(ABC):
def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
self._scorer_worker = scorer_worker
self._device = device
self._vocab_size = vocab_size
@abstractmethod
def score_proposals(
self,

View File

@ -0,0 +1,80 @@
from vllm.sequence import (ExecuteModelRequest, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
SeqId = int
TargetSeqId = int
class MQAScorer(SpeculativeScorer):
def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
target_seq_group_metadata_list = []
target_seq_id_start = max(
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
all_proposal_tokens = proposals.proposal_token_ids.tolist()
for i, seq_group_metadata in enumerate(
execute_model_req.seq_group_metadata_list):
seq_data_dict = seq_group_metadata.seq_data
assert len(seq_data_dict) == 1
seq_id = next(iter(seq_data_dict.keys()))
seq_data: SequenceData = seq_data_dict[seq_id]
prompt_token_ids = seq_data.get_prompt_token_ids()
output_token_ids = seq_data.get_output_token_ids()
proposal_token_ids = all_proposal_tokens[i]
new_output_token_ids = [*output_token_ids, *proposal_token_ids]
target_seq_id = target_seq_id_start + i
new_seq_data = SequenceData.from_seqs(
prompt_token_ids=prompt_token_ids,
output_token_ids=new_output_token_ids,
)
new_seq_data.update_num_computed_tokens(
len(prompt_token_ids) + len(output_token_ids) - 1)
# Ensure that the new sequence has at least one token
# because we only use mqa scorer in the decoding stage.
assert len(output_token_ids) >= 1
new_seq_data_dict = {target_seq_id: new_seq_data}
new_seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data=new_seq_data_dict,
sampling_params=seq_group_metadata.sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
token_chunk_size=1,
)
target_seq_group_metadata_list.append(new_seq_group_metadata)
target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list))
target_sampler_output = target_sampler_output[0]
bs, k = proposals.proposal_token_ids.shape
all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1)
all_probs = target_sampler_output.sampled_token_probs.reshape(
bs, k + 1, self._vocab_size)
all_logprobs = target_sampler_output.logprobs.reshape(
bs, k + 1, self._vocab_size)
hidden_states = None
if target_sampler_output.hidden_states is not None:
hidden_states = target_sampler_output.hidden_states.reshape(
bs, (k + 1), -1)
return SpeculativeScores(probs=all_probs,
token_ids=all_tokens,
logprobs=all_logprobs,
hidden_states=hidden_states)

View File

@ -1,6 +1,6 @@
from collections import defaultdict
from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Type
import torch
@ -24,6 +24,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
from vllm.spec_decode.medusa_worker import MedusaWorker
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.mqa_scorer import MQAScorer
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
@ -70,6 +71,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
disable_by_batch_size=speculative_config.
speculative_disable_by_batch_size,
draft_token_acceptance_method=speculative_config.
@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
cls,
scorer_worker: Worker,
draft_worker_kwargs: Dict[str, Any],
disable_mqa_scorer: bool,
disable_by_batch_size: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
@ -173,12 +176,43 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_threshold,
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
)
logger.info("Configuring SpecDecodeWorker with sampler=%s",
type(spec_decode_sampler))
logger.info(
"[Speculative Decoding] Configuring"
" SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
if not disable_mqa_scorer:
if scorer_worker.model_runner.attn_backend.get_name(
) != "flash-attn":
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend.")
if ngram_prompt_lookup_max > 0:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"NGramWorker does not support MQA scorer.")
if "model_config" in draft_worker_kwargs and \
draft_worker_kwargs["model_config"].max_model_len < \
scorer_worker.model_config.max_model_len:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"draft model max_model_len is smaller than the target "
"model max_model_len.")
if not scorer_worker.model_runner.model_config.enforce_eager:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode.")
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
@ -190,6 +224,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase,
spec_decode_sampler: SpecDecodeBaseSampler,
disable_mqa_scorer: bool = False,
disable_logprobs: bool = False,
disable_log_stats: bool = False,
metrics_collector: Optional[AsyncMetricsCollector] = None,
@ -211,6 +246,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_mqa_scorer: If set to True, disable the MQA scorer and use
the BatchExpansionTop1Scorer instead.
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
@ -248,6 +285,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
# Lazy initialization.
self.scorer: SpeculativeScorer
self.disable_mqa_scorer = disable_mqa_scorer
# Hidden states from target model to pass to proposer
# in the subsequent step.
@ -270,10 +308,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._metrics.init_gpu_tensors(self.rank)
self.spec_decode_sampler.init_gpu_tensors(self.rank)
self.scorer = BatchExpansionTop1Scorer(
scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer:
scorer_cls = BatchExpansionTop1Scorer
logger.info("[Speculative Decoding] Use batch "
"expansion for scoring proposals.")
else:
scorer_cls = MQAScorer
logger.info(
"[Speculative Decoding] Use MQA scorer for scoring proposals.")
self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
self._configure_model_sampler_for_spec_decode()

View File

@ -468,43 +468,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens).
seq_len = seq_data.get_len()
if inter_data.is_prompt:
context_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
seq_len = min(seq_len, context_len + token_chunk_size)
elif self.runner.scheduler_config.is_multi_step or \
self.runner.model_config.is_encoder_decoder_model:
context_len = seq_len - 1
seq_len = min(seq_len, context_len + token_chunk_size)
else:
context_len = seq_data.get_num_computed_tokens()
# Compute tokens.
if inter_data.is_prompt:
tokens = seq_data.get_token_ids()
if context_len != 0 or seq_len < len(tokens):
tokens = tokens[context_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = seq_data.get_last_token_id()
tokens = seq_data.get_token_ids()[context_len:seq_len]
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len
if isinstance(tokens, list):
inter_data.input_tokens[seq_idx].extend(tokens)
else:
inter_data.input_tokens[seq_idx].append(tokens)
if (seq_len - context_len) == 1:
inter_data.input_positions[seq_idx].append(seq_len - 1)
else:
inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len))
inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.query_lens[seq_idx] = seq_len - context_len
if seq_data.mrope_position_delta is not None:
if inter_data.mrope_input_positions is None: