[Spec Decode] (1/2) Remove batch expansion (#8839)
This commit is contained in:
parent
22f5851b80
commit
1570203864
@ -208,7 +208,7 @@ steps:
|
||||
- tests/spec_decode
|
||||
commands:
|
||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||
- pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
||||
|
||||
- label: LoRA Test %N # 15min each
|
||||
mirror_hardwares: [amd]
|
||||
|
@ -434,7 +434,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens=seq_lens if seq_lens else None,
|
||||
query_lens=seq_lens if seq_lens else None,
|
||||
query_lens=seq_lens if seq_lens else [1] * batch_size,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
# the logits tensor is modified in-place by the sampler
|
||||
|
@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
||||
max_output_len=32,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify that ngram speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
|
@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"speculative_model": SPEC_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
@ -173,7 +173,6 @@ def test_same_output_for_multi_step():
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
worker = create_worker(
|
||||
|
65
tests/spec_decode/test_scorer.py
Normal file
65
tests/spec_decode/test_scorer.py
Normal file
@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
|
||||
from vllm.spec_decode.mqa_scorer import MQAScorer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
from .utils import create_batch, create_worker
|
||||
|
||||
|
||||
def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
|
||||
device: str) -> SpeculativeProposals:
|
||||
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
|
||||
device=device)
|
||||
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
|
||||
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)
|
||||
return SpeculativeProposals(proposal_token_ids, proposal_probs,
|
||||
proposal_lens)
|
||||
|
||||
|
||||
def assert_score_equal(score1: SpeculativeScores,
|
||||
score2: SpeculativeScores) -> None:
|
||||
assert torch.allclose(score1.probs, score2.probs)
|
||||
assert torch.allclose(score1.logprobs, score2.logprobs)
|
||||
assert torch.equal(score1.token_ids, score2.token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize('propose_len', [1, 3, 5])
|
||||
@pytest.mark.parametrize('device', ['cuda'])
|
||||
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
|
||||
device: str) -> None:
|
||||
"""
|
||||
Compare the batch expansion scorer and mqa scorer return the same score
|
||||
"""
|
||||
seed = 0
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
scorer_worker = create_worker(Worker, model_name, block_size,
|
||||
num_gpu_blocks, seed)
|
||||
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||
scorer_worker.model_runner.model.sampler.\
|
||||
should_modify_greedy_probs_inplace = True
|
||||
|
||||
vocab_size = scorer_worker.vocab_size
|
||||
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
|
||||
seq_group_metadatalist, _, _ = create_batch(batch_size,
|
||||
propose_len,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
requests = ExecuteModelRequest(seq_group_metadatalist,
|
||||
num_lookahead_slots=propose_len)
|
||||
|
||||
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
|
||||
vocab_size)
|
||||
batch_expansion_score = batch_expansion_scorer.score_proposals(
|
||||
requests, proposals)
|
||||
|
||||
mqa_scorer = MQAScorer(scorer_worker, device, vocab_size)
|
||||
mqa_score = mqa_scorer.score_proposals(requests, proposals)
|
||||
|
||||
assert_score_equal(batch_expansion_score, mqa_score)
|
@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_target_model(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
def test_batch_expansion_correctly_calls_target_model(
|
||||
k: int, batch_size: int, acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the target model with correct
|
||||
inputs. Everything else is mocked out.
|
||||
inputs with batch expansion. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
metrics_collector=metrics_collector,
|
||||
disable_mqa_scorer=True)
|
||||
worker.init_device()
|
||||
|
||||
vocab_size = 32_000
|
||||
|
@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts(
|
||||
for i, final_len in enumerate(final_prompt_lens)
|
||||
}
|
||||
|
||||
return [
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=len(cont_token_ids) == 0,
|
||||
seq_data={
|
||||
i: SequenceData.from_seqs(prompt_token_ids[:],
|
||||
cont_token_ids[:]),
|
||||
},
|
||||
sampling_params=SamplingParams(temperature=0.0, ),
|
||||
block_tables={i: block_allocations[i][:]},
|
||||
) for i, (prompt_token_ids,
|
||||
cont_token_ids) in enumerate(zip(prompts, continuations))
|
||||
]
|
||||
seq_grou_metadata_list = []
|
||||
for i, (prompt_token_ids,
|
||||
cont_token_ids) in enumerate(zip(prompts, continuations)):
|
||||
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
|
||||
data.update_num_computed_tokens(
|
||||
len(prompt_token_ids) + len(cont_token_ids) - 1)
|
||||
seq_data = {i: data}
|
||||
seq_grou_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=len(cont_token_ids) == 0,
|
||||
seq_data=seq_data,
|
||||
sampling_params=SamplingParams(temperature=0.0),
|
||||
block_tables={i: block_allocations[i][:]},
|
||||
))
|
||||
return seq_grou_metadata_list
|
||||
|
||||
|
||||
def assert_logprobs_dict_allclose(
|
||||
|
@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Number of query tokens for each request in the batch.
|
||||
# Currently, we require that all requests have the same number of query
|
||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||
decode_query_len: Optional[int] = None
|
||||
|
||||
_cached_prefill_metadata: Optional[
|
||||
"BlocksparseFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional[
|
||||
|
@ -245,8 +245,15 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int]
|
||||
|
||||
# Number of query tokens for each request in the batch.
|
||||
# Currently, we require that all requests have the same number of query
|
||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||
decode_query_len: Optional[int]
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
@ -303,6 +310,7 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
decode_query_len=0,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
@ -331,7 +339,8 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
decode_query_len=self.decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
@ -461,9 +470,6 @@ class FlashAttentionMetadataBuilder(
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
@ -518,6 +524,11 @@ class FlashAttentionMetadataBuilder(
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
@ -586,6 +597,7 @@ class FlashAttentionMetadataBuilder(
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
decode_query_len=decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc,
|
||||
@ -786,8 +798,12 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
_, num_head, head_dim = decode_query.shape
|
||||
decode_query = decode_query.reshape(-1,
|
||||
decode_meta.decode_query_len,
|
||||
num_head, head_dim)
|
||||
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
|
||||
decode_query.unsqueeze(1),
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_table=decode_meta.block_tables,
|
||||
@ -796,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
).squeeze(1)
|
||||
)
|
||||
|
||||
if prefill_output is None:
|
||||
assert decode_output is not None
|
||||
@ -804,5 +820,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
if decode_output is None:
|
||||
assert prefill_output is not None
|
||||
return prefill_output.view(num_prefill_tokens, hidden_size)
|
||||
|
||||
# Chunked prefill does not work with speculative decoding.
|
||||
# Therefore, the query length for decode should be 1 in chunked prefill.
|
||||
assert decode_meta is not None
|
||||
assert decode_meta.decode_query_len == 1
|
||||
decode_output = decode_output.squeeze(1)
|
||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
@ -595,7 +595,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
@ -634,7 +633,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
|
@ -116,9 +116,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Number of query tokens for each request in the batch.
|
||||
# Currently, we require that all requests have the same number of query
|
||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||
decode_query_len: Optional[int] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
|
||||
|
@ -312,7 +312,8 @@ class CommonAttentionState(AttentionState):
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=None,
|
||||
max_query_len=1,
|
||||
decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
|
@ -118,6 +118,12 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Number of query tokens for each request in the batch.
|
||||
# Currently, we require that all requests have the same number of query
|
||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||
decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
|
@ -1116,6 +1116,7 @@ class SpeculativeConfig:
|
||||
speculative_model_quantization: Optional[str],
|
||||
speculative_draft_tensor_parallel_size: Optional[int],
|
||||
num_speculative_tokens: Optional[int],
|
||||
speculative_disable_mqa_scorer: Optional[bool],
|
||||
speculative_max_model_len: Optional[int],
|
||||
enable_chunked_prefill: bool,
|
||||
use_v2_block_manager: bool,
|
||||
@ -1150,6 +1151,9 @@ class SpeculativeConfig:
|
||||
num_speculative_tokens (Optional[int]): The number of speculative
|
||||
tokens, if provided. Will default to the number in the draft
|
||||
model config if present, otherwise is required.
|
||||
speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
|
||||
scorer for the speculative model and fall back to batch
|
||||
expansion for scoring.
|
||||
speculative_max_model_len (Optional[int]): The maximum model len of
|
||||
the speculative model. Used when testing the ability to skip
|
||||
speculation for some sequences.
|
||||
@ -1304,6 +1308,7 @@ class SpeculativeConfig:
|
||||
draft_model_config,
|
||||
draft_parallel_config,
|
||||
num_speculative_tokens,
|
||||
speculative_disable_mqa_scorer,
|
||||
speculative_disable_by_batch_size,
|
||||
ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min,
|
||||
@ -1400,6 +1405,7 @@ class SpeculativeConfig:
|
||||
draft_model_config: ModelConfig,
|
||||
draft_parallel_config: ParallelConfig,
|
||||
num_speculative_tokens: int,
|
||||
speculative_disable_mqa_scorer: Optional[bool],
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
@ -1446,6 +1452,7 @@ class SpeculativeConfig:
|
||||
self.draft_model_config = draft_model_config
|
||||
self.draft_parallel_config = draft_parallel_config
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
|
||||
self.speculative_disable_by_batch_size = \
|
||||
speculative_disable_by_batch_size
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
|
||||
|
@ -162,6 +162,7 @@ class EngineArgs:
|
||||
speculative_model_quantization: Optional[str] = None
|
||||
speculative_draft_tensor_parallel_size: Optional[int] = None
|
||||
num_speculative_tokens: Optional[int] = None
|
||||
speculative_disable_mqa_scorer: Optional[bool] = False
|
||||
speculative_max_model_len: Optional[int] = None
|
||||
speculative_disable_by_batch_size: Optional[int] = None
|
||||
ngram_prompt_lookup_max: Optional[int] = None
|
||||
@ -640,6 +641,12 @@ class EngineArgs:
|
||||
default=EngineArgs.num_speculative_tokens,
|
||||
help='The number of speculative tokens to sample from '
|
||||
'the draft model in speculative decoding.')
|
||||
parser.add_argument(
|
||||
'--speculative-disable-mqa-scorer',
|
||||
action='store_true',
|
||||
help=
|
||||
'If set to True, the MQA scorer will be disabled in speculative '
|
||||
' and fall back to batch expansion')
|
||||
parser.add_argument(
|
||||
'--speculative-draft-tensor-parallel-size',
|
||||
'-spec-draft-tp',
|
||||
@ -970,6 +977,7 @@ class EngineArgs:
|
||||
speculative_draft_tensor_parallel_size = \
|
||||
self.speculative_draft_tensor_parallel_size,
|
||||
num_speculative_tokens=self.num_speculative_tokens,
|
||||
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
|
||||
speculative_disable_by_batch_size=self.
|
||||
speculative_disable_by_batch_size,
|
||||
speculative_max_model_len=self.speculative_max_model_len,
|
||||
|
@ -1110,6 +1110,8 @@ class LLMEngine:
|
||||
update_prefill_num_computed_tokens(seq_group, seq_group_meta,
|
||||
len(output),
|
||||
is_first_step_output)
|
||||
elif not is_async:
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
|
||||
if outputs:
|
||||
for o in outputs:
|
||||
@ -1133,8 +1135,16 @@ class LLMEngine:
|
||||
else:
|
||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||
if seq_group_meta.do_sample:
|
||||
self.output_processor.process_outputs(
|
||||
output_token_num = self.output_processor.process_outputs(
|
||||
seq_group, output, is_async)
|
||||
if self.speculative_config:
|
||||
# We -1 here because we always
|
||||
# (w/o speculative decoding) add the number of
|
||||
# computed tokens by one in the decoding phase.
|
||||
# Therefore, we remove that one token that
|
||||
# is already added.
|
||||
seq_group.update_num_computed_tokens(output_token_num -
|
||||
1)
|
||||
|
||||
if seq_group.is_finished():
|
||||
finished_now.append(i)
|
||||
@ -1251,11 +1261,12 @@ class LLMEngine:
|
||||
# decodes after the very first step. Therefore,
|
||||
# we skip the update to the num_computed_tokens
|
||||
# here.
|
||||
pass
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
else:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_metadata.token_chunk_size)
|
||||
|
||||
else:
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
if seq_group_metadata.do_sample:
|
||||
assert len(sequence_group_outputs.samples) == 1, (
|
||||
"Async output processor expects a single sample"
|
||||
@ -1266,7 +1277,6 @@ class LLMEngine:
|
||||
assert len(seq_group.seqs) == 1
|
||||
seq = seq_group.seqs[0]
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
|
||||
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
@ -58,10 +58,14 @@ class SequenceGroupOutputProcessor(ABC):
|
||||
@abstractmethod
|
||||
def process_outputs(self, sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput],
|
||||
is_async: bool) -> None:
|
||||
is_async: bool) -> Optional[int]:
|
||||
"""Process new token ids for the sequence group. Handles logic such as
|
||||
detokenization, stop checking, and freeing/forking sequences in the
|
||||
scheduler.
|
||||
|
||||
Return the number of new tokens generated in the sequence group.
|
||||
The returned value is optional because it is only used for
|
||||
speculative decoding mqa scorer.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import functools
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
@ -69,7 +69,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
def process_outputs(self,
|
||||
sequence_group: SequenceGroup,
|
||||
outputs: List[SequenceGroupOutput],
|
||||
is_async: bool = False) -> None:
|
||||
is_async: bool = False) -> Optional[int]:
|
||||
"""Append new tokens in the outputs to sequences in the sequence group.
|
||||
|
||||
This only supports sequence groups of size 1. It supports greater than
|
||||
@ -84,6 +84,10 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
tokens from the previous step. If this is true, then
|
||||
no tokens need to be appended since it is already done
|
||||
externally (before the next schedule() call)
|
||||
|
||||
Returns:
|
||||
The number of tokens appended to the sequence. This is optional
|
||||
because only speculative decode uses this return value.
|
||||
"""
|
||||
# Sequences can be in RUNNING or FINISHED_ABORTED state
|
||||
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
|
||||
@ -106,6 +110,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
# was already appended, so we only need to do the rest of the
|
||||
# postprocessor: Detokenization + stopping logic
|
||||
self._process_decode_and_stop(seq, sequence_group.sampling_params)
|
||||
return None
|
||||
else:
|
||||
# Standard multi-step case
|
||||
|
||||
@ -121,8 +126,8 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
]
|
||||
assert valid_samples
|
||||
|
||||
self._process_seq_outputs(seq, valid_samples,
|
||||
sequence_group.sampling_params)
|
||||
return self._process_seq_outputs(seq, valid_samples,
|
||||
sequence_group.sampling_params)
|
||||
|
||||
def _process_decode_and_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
@ -140,7 +145,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
|
||||
def _process_seq_outputs(self, seq: Sequence,
|
||||
valid_samples: List[SequenceOutput],
|
||||
sampling_params: SamplingParams) -> None:
|
||||
sampling_params: SamplingParams) -> int:
|
||||
output_token_ids = [sample.output_token for sample in valid_samples]
|
||||
output_logprobs = [sample.logprobs for sample in valid_samples]
|
||||
|
||||
@ -148,7 +153,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
|
||||
len(output_token_ids))
|
||||
if remaining_tokens < 0:
|
||||
valid_samples = valid_samples[:remaining_tokens]
|
||||
output_token_ids = output_token_ids[:remaining_tokens]
|
||||
|
||||
# Truncate any tokens after EOS. This is required as spec decode
|
||||
@ -162,7 +166,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
for i in range(len(output_token_ids)):
|
||||
if output_token_ids[i] == eos_token_id:
|
||||
output_token_ids = output_token_ids[:i + 1]
|
||||
valid_samples = valid_samples[:i + 1]
|
||||
break
|
||||
|
||||
# Incrementally append tokens to the sequence, as if we had only one new
|
||||
@ -173,9 +176,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
token_id=output_token_id,
|
||||
logprobs=output_logprob,
|
||||
)
|
||||
seq.data.update_num_computed_tokens(1)
|
||||
|
||||
self._process_decode_and_stop(seq, sampling_params)
|
||||
|
||||
if seq.is_finished():
|
||||
break
|
||||
return len(output_token_ids)
|
||||
|
@ -912,7 +912,7 @@ def get_logprobs(
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sample_results: SampleResultType,
|
||||
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
|
||||
"""Return sample lobprobs and prompt logprobs.
|
||||
"""Return sample logprobs and prompt logprobs.
|
||||
|
||||
The logic consists of 3 parts.
|
||||
- Select indices to compute logprob from, ranks of token ids, and
|
||||
|
@ -146,7 +146,7 @@ class SamplingMetadata:
|
||||
def prepare(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
seq_lens: List[int],
|
||||
query_lens: Optional[List[int]],
|
||||
query_lens: List[int],
|
||||
device: str,
|
||||
pin_memory: bool,
|
||||
generators: Optional[Dict[str, torch.Generator]] = None,
|
||||
@ -194,7 +194,7 @@ class SamplingMetadata:
|
||||
def _prepare_seq_groups(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
seq_lens: List[int],
|
||||
query_lens: Optional[List[int]],
|
||||
query_lens: List[int],
|
||||
device: str,
|
||||
generators: Optional[Dict[str, torch.Generator]] = None,
|
||||
cache: Optional[SamplingMetadataCache] = None,
|
||||
@ -284,7 +284,8 @@ def _prepare_seq_groups(
|
||||
else:
|
||||
# Decode
|
||||
prompt_logprob_len = 0
|
||||
sample_len = len(seq_ids) if do_sample else 0
|
||||
query_len = query_lens[i] if query_lens is not None else 1
|
||||
sample_len = len(seq_ids) * query_len if do_sample else 0
|
||||
|
||||
if sampling_params.seed is not None and generators is not None:
|
||||
generator = generators.get(seq_group_metadata.request_id)
|
||||
@ -440,14 +441,14 @@ class SamplingTensors:
|
||||
|
||||
if seq_group.do_sample:
|
||||
sample_lens = len(seq_group.sample_indices)
|
||||
assert sample_lens == len(seq_ids)
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
min_ps += [min_p] * len(seq_ids)
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
repetition_penalties += [r] * len(seq_ids)
|
||||
assert sample_lens >= len(seq_ids)
|
||||
temperatures += [temperature] * sample_lens
|
||||
top_ps += [top_p] * sample_lens
|
||||
top_ks += [top_k] * sample_lens
|
||||
min_ps += [min_p] * sample_lens
|
||||
presence_penalties += [p] * sample_lens
|
||||
frequency_penalties += [f] * sample_lens
|
||||
repetition_penalties += [r] * sample_lens
|
||||
|
||||
if do_penalties:
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
|
@ -12,7 +12,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
SeqId = int
|
||||
TargetSeqId = int
|
||||
@ -36,12 +35,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
of topk/tree.
|
||||
"""
|
||||
|
||||
def __init__(self, scorer_worker: WorkerBase, device: str,
|
||||
vocab_size: int):
|
||||
self._scorer_worker = scorer_worker
|
||||
self._device = device
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
||||
def score_proposals(
|
||||
self,
|
||||
|
@ -94,8 +94,6 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
assert seq_group.is_prompt is False # No prompt
|
||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||
assert seq_group.sample_indices == [i] # Simple
|
||||
assert seq_group.seq_len is None # Decode
|
||||
assert seq_group.query_len is None # Decode
|
||||
|
||||
def _gpu_advance_step(
|
||||
self, model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional, Set
|
||||
import torch
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -74,6 +75,12 @@ class SpeculativeProposer(ABC):
|
||||
|
||||
class SpeculativeScorer(ABC):
|
||||
|
||||
def __init__(self, scorer_worker: WorkerBase, device: str,
|
||||
vocab_size: int):
|
||||
self._scorer_worker = scorer_worker
|
||||
self._device = device
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
@abstractmethod
|
||||
def score_proposals(
|
||||
self,
|
||||
|
80
vllm/spec_decode/mqa_scorer.py
Normal file
80
vllm/spec_decode/mqa_scorer.py
Normal file
@ -0,0 +1,80 @@
|
||||
from vllm.sequence import (ExecuteModelRequest, SequenceData,
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
|
||||
SeqId = int
|
||||
TargetSeqId = int
|
||||
|
||||
|
||||
class MQAScorer(SpeculativeScorer):
|
||||
|
||||
def score_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
proposals: SpeculativeProposals,
|
||||
) -> SpeculativeScores:
|
||||
target_seq_group_metadata_list = []
|
||||
target_seq_id_start = max(
|
||||
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
|
||||
all_proposal_tokens = proposals.proposal_token_ids.tolist()
|
||||
for i, seq_group_metadata in enumerate(
|
||||
execute_model_req.seq_group_metadata_list):
|
||||
seq_data_dict = seq_group_metadata.seq_data
|
||||
assert len(seq_data_dict) == 1
|
||||
seq_id = next(iter(seq_data_dict.keys()))
|
||||
|
||||
seq_data: SequenceData = seq_data_dict[seq_id]
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
output_token_ids = seq_data.get_output_token_ids()
|
||||
proposal_token_ids = all_proposal_tokens[i]
|
||||
new_output_token_ids = [*output_token_ids, *proposal_token_ids]
|
||||
|
||||
target_seq_id = target_seq_id_start + i
|
||||
new_seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
output_token_ids=new_output_token_ids,
|
||||
)
|
||||
new_seq_data.update_num_computed_tokens(
|
||||
len(prompt_token_ids) + len(output_token_ids) - 1)
|
||||
|
||||
# Ensure that the new sequence has at least one token
|
||||
# because we only use mqa scorer in the decoding stage.
|
||||
assert len(output_token_ids) >= 1
|
||||
new_seq_data_dict = {target_seq_id: new_seq_data}
|
||||
|
||||
new_seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
seq_data=new_seq_data_dict,
|
||||
sampling_params=seq_group_metadata.sampling_params,
|
||||
block_tables={
|
||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||
},
|
||||
lora_request=None,
|
||||
token_chunk_size=1,
|
||||
)
|
||||
target_seq_group_metadata_list.append(new_seq_group_metadata)
|
||||
|
||||
target_sampler_output = self._scorer_worker.execute_model(
|
||||
execute_model_req=execute_model_req.clone(
|
||||
seq_group_metadata_list=target_seq_group_metadata_list))
|
||||
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
bs, k = proposals.proposal_token_ids.shape
|
||||
all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1)
|
||||
|
||||
all_probs = target_sampler_output.sampled_token_probs.reshape(
|
||||
bs, k + 1, self._vocab_size)
|
||||
all_logprobs = target_sampler_output.logprobs.reshape(
|
||||
bs, k + 1, self._vocab_size)
|
||||
|
||||
hidden_states = None
|
||||
if target_sampler_output.hidden_states is not None:
|
||||
hidden_states = target_sampler_output.hidden_states.reshape(
|
||||
bs, (k + 1), -1)
|
||||
return SpeculativeScores(probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=all_logprobs,
|
||||
hidden_states=hidden_states)
|
@ -1,6 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -24,6 +24,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
from vllm.spec_decode.medusa_worker import MedusaWorker
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
||||
from vllm.spec_decode.mqa_scorer import MQAScorer
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
@ -70,6 +71,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
spec_decode_worker = SpecDecodeWorker.create_worker(
|
||||
scorer_worker=target_worker,
|
||||
draft_worker_kwargs=draft_worker_kwargs,
|
||||
disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
|
||||
disable_by_batch_size=speculative_config.
|
||||
speculative_disable_by_batch_size,
|
||||
draft_token_acceptance_method=speculative_config.
|
||||
@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
cls,
|
||||
scorer_worker: Worker,
|
||||
draft_worker_kwargs: Dict[str, Any],
|
||||
disable_mqa_scorer: bool,
|
||||
disable_by_batch_size: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
@ -173,12 +176,43 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
||||
)
|
||||
logger.info("Configuring SpecDecodeWorker with sampler=%s",
|
||||
type(spec_decode_sampler))
|
||||
logger.info(
|
||||
"[Speculative Decoding] Configuring"
|
||||
" SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
|
||||
|
||||
if not disable_mqa_scorer:
|
||||
if scorer_worker.model_runner.attn_backend.get_name(
|
||||
) != "flash-attn":
|
||||
disable_mqa_scorer = True
|
||||
logger.info(
|
||||
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"MQA is only available with flash attn backend.")
|
||||
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
disable_mqa_scorer = True
|
||||
logger.info(
|
||||
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"NGramWorker does not support MQA scorer.")
|
||||
|
||||
if "model_config" in draft_worker_kwargs and \
|
||||
draft_worker_kwargs["model_config"].max_model_len < \
|
||||
scorer_worker.model_config.max_model_len:
|
||||
disable_mqa_scorer = True
|
||||
logger.info(
|
||||
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"draft model max_model_len is smaller than the target "
|
||||
"model max_model_len.")
|
||||
|
||||
if not scorer_worker.model_runner.model_config.enforce_eager:
|
||||
disable_mqa_scorer = True
|
||||
logger.info(
|
||||
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||
"target model is not running in eager mode.")
|
||||
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
disable_mqa_scorer=disable_mqa_scorer,
|
||||
disable_logprobs=disable_logprobs,
|
||||
disable_log_stats=disable_log_stats,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
@ -190,6 +224,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposer_worker: ProposerWorkerBase,
|
||||
scorer_worker: WorkerBase,
|
||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||
disable_mqa_scorer: bool = False,
|
||||
disable_logprobs: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
@ -211,6 +246,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
types of sampler namely RejectionSampler and
|
||||
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
|
||||
instance of RejectionSampler or TypicalAcceptanceSampler.
|
||||
disable_mqa_scorer: If set to True, disable the MQA scorer and use
|
||||
the BatchExpansionTop1Scorer instead.
|
||||
disable_logprobs: If set to True, token log probabilities will
|
||||
not be output in both the draft worker and the target worker.
|
||||
If set to False, log probabilities will be output by both.
|
||||
@ -248,6 +285,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
||||
# Lazy initialization.
|
||||
self.scorer: SpeculativeScorer
|
||||
self.disable_mqa_scorer = disable_mqa_scorer
|
||||
|
||||
# Hidden states from target model to pass to proposer
|
||||
# in the subsequent step.
|
||||
@ -270,10 +308,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self._metrics.init_gpu_tensors(self.rank)
|
||||
self.spec_decode_sampler.init_gpu_tensors(self.rank)
|
||||
|
||||
self.scorer = BatchExpansionTop1Scorer(
|
||||
scorer_worker=self.scorer_worker,
|
||||
device=self.device,
|
||||
vocab_size=self._vocab_size)
|
||||
scorer_cls: Type[SpeculativeScorer]
|
||||
if self.disable_mqa_scorer:
|
||||
scorer_cls = BatchExpansionTop1Scorer
|
||||
logger.info("[Speculative Decoding] Use batch "
|
||||
"expansion for scoring proposals.")
|
||||
else:
|
||||
scorer_cls = MQAScorer
|
||||
logger.info(
|
||||
"[Speculative Decoding] Use MQA scorer for scoring proposals.")
|
||||
|
||||
self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
|
||||
device=self.device,
|
||||
vocab_size=self._vocab_size)
|
||||
|
||||
self._configure_model_sampler_for_spec_decode()
|
||||
|
||||
|
@ -468,43 +468,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
# Compute context length (the number of tokens that are
|
||||
# already computed) and sequence length (total number of tokens).
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
if inter_data.is_prompt:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
else:
|
||||
# get_num_computed_tokens is incorrect for spec decoding.
|
||||
# So, we should have a special logic here.
|
||||
# TODO(sang): Fix it.
|
||||
seq_len = min(seq_len, context_len + token_chunk_size)
|
||||
elif self.runner.scheduler_config.is_multi_step or \
|
||||
self.runner.model_config.is_encoder_decoder_model:
|
||||
context_len = seq_len - 1
|
||||
seq_len = min(seq_len, context_len + token_chunk_size)
|
||||
else:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
|
||||
# Compute tokens.
|
||||
if inter_data.is_prompt:
|
||||
tokens = seq_data.get_token_ids()
|
||||
if context_len != 0 or seq_len < len(tokens):
|
||||
tokens = tokens[context_len:seq_len]
|
||||
else:
|
||||
# Optimization. get_token_ids requires the entire copy of
|
||||
# tokens.
|
||||
tokens = seq_data.get_last_token_id()
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
|
||||
inter_data.seq_lens[seq_idx] = seq_len
|
||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
|
||||
if isinstance(tokens, list):
|
||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||
else:
|
||||
inter_data.input_tokens[seq_idx].append(tokens)
|
||||
|
||||
if (seq_len - context_len) == 1:
|
||||
inter_data.input_positions[seq_idx].append(seq_len - 1)
|
||||
else:
|
||||
inter_data.input_positions[seq_idx].extend(
|
||||
range(context_len, seq_len))
|
||||
|
||||
inter_data.query_lens[
|
||||
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
|
||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
|
||||
inter_data.query_lens[seq_idx] = seq_len - context_len
|
||||
|
||||
if seq_data.mrope_position_delta is not None:
|
||||
if inter_data.mrope_input_positions is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user