[Bugfix] Fix for Spec model TP + Chunked Prefill (#10232)
Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> Signed-off-by: Sourashis Roy <sroy@roblox.com> Co-authored-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
parent
1f6584ee85
commit
db66e018ea
@ -118,7 +118,7 @@ Feature x Feature
|
||||
-
|
||||
-
|
||||
* - :ref:`SD <spec_decode>`
|
||||
- ✗
|
||||
- ✅
|
||||
- ✅
|
||||
- ✗
|
||||
- ✅
|
||||
|
@ -413,6 +413,45 @@ def test_chunked_prefill_preempt():
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_scheduler_steps", [1, 5])
|
||||
def test_chunked_prefill_spec_prefill(num_scheduler_steps):
|
||||
"""Verify that the num_lookahead_slots is set appropriately for an all"""
|
||||
"""prefill batch depending on whether multi-step scheduling is enabled"""
|
||||
"""or not"""
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
max_model_len = 200
|
||||
max_num_batched_tokens = 30
|
||||
num_lookahead_slots = 4
|
||||
scheduler_config = SchedulerConfig(
|
||||
"generate",
|
||||
max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 16
|
||||
cache_config.num_gpu_blocks = 16
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
_, seq_group = create_dummy_prompt("1",
|
||||
prompt_length=30,
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# The request is chunked.
|
||||
# prefill scheduled now.
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
print(out.num_lookahead_slots)
|
||||
assert out.num_lookahead_slots == (0 if (num_scheduler_steps == 1) else
|
||||
num_lookahead_slots)
|
||||
|
||||
|
||||
def test_chunked_prefill_max_seqs():
|
||||
block_size = 4
|
||||
max_seqs = 2
|
||||
|
@ -50,49 +50,3 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
|
||||
with pytest.raises(ValueError, match="cannot be larger than"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_llm_kwargs",
|
||||
[{
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": "True",
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"tensor_parallel_size": 2,
|
||||
"speculative_draft_tensor_parallel_size": 2,
|
||||
},
|
||||
{
|
||||
"tensor_parallel_size": 4,
|
||||
"speculative_draft_tensor_parallel_size": 4,
|
||||
},
|
||||
{
|
||||
"tensor_parallel_size": 8,
|
||||
"speculative_draft_tensor_parallel_size": 8,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_chunked_prefill_draft_model_tp_not_one(
|
||||
test_llm_generator):
|
||||
"""Verify that speculative decoding fails if chunked prefill is enabled for
|
||||
draft model with tensor parallelism of more than 1.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="with tensor parallel size 1"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
@ -115,3 +115,60 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
||||
max_output_len=32,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[[
|
||||
# Skip cuda graph recording for fast test.
|
||||
"--enforce-eager",
|
||||
"--tensor_parallel_size",
|
||||
"2",
|
||||
|
||||
# precision
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
]])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[["--enable-chunked-prefill", "False"],
|
||||
[
|
||||
"--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4",
|
||||
"--max-num-seqs", "4"
|
||||
]])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||
[("JackFram/llama-68m", [
|
||||
"--speculative-model",
|
||||
"JackFram/llama-68m",
|
||||
"--num_speculative-tokens",
|
||||
"3",
|
||||
]),
|
||||
("JackFram/llama-68m", [
|
||||
"--speculative-model",
|
||||
"JackFram/llama-68m",
|
||||
"--num_speculative-tokens",
|
||||
"3",
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
])])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, seed: int):
|
||||
"""Verify spec decode works well with same and different TP size for
|
||||
the draft model with chunked prefill.
|
||||
"""
|
||||
run_equality_correctness_test_tp(model,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
@ -867,7 +867,8 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
||||
target_group_metadata_list = prefill + decodes
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=target_group_metadata_list,
|
||||
num_lookahead_slots=k)
|
||||
# For prefill only batches we expect num_lookahead_slots = 0.
|
||||
num_lookahead_slots=k if n_decodes > 0 else 0)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
|
@ -1409,16 +1409,6 @@ class SpeculativeConfig:
|
||||
draft_hf_config
|
||||
)
|
||||
|
||||
if (enable_chunked_prefill and \
|
||||
speculative_draft_tensor_parallel_size != 1):
|
||||
# TODO - Investigate why the error reported in
|
||||
# https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258
|
||||
# is happening and re-enable it.
|
||||
raise ValueError(
|
||||
"Chunked prefill and speculative decoding can be enabled "
|
||||
"simultaneously only for draft models with tensor "
|
||||
"parallel size 1.")
|
||||
|
||||
draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len,
|
||||
|
@ -1201,15 +1201,25 @@ class Scheduler:
|
||||
# Update swapped requests.
|
||||
self.swapped.extend(running_scheduled.swapped_out)
|
||||
# Put prefills first due to Attention backend ordering assumption.
|
||||
scheduled_seq_groups = (prefills.seq_groups +
|
||||
running_scheduled.prefill_seq_groups +
|
||||
swapped_in.prefill_seq_groups +
|
||||
running_scheduled.decode_seq_groups +
|
||||
swapped_in.decode_seq_groups)
|
||||
num_prefill_groups = (len(prefills.seq_groups) +
|
||||
len(swapped_in.prefill_seq_groups) +
|
||||
len(running_scheduled.prefill_seq_groups))
|
||||
# If all prompts, then we set num_lookahead_slots to 0
|
||||
# this allows us to go through the `no_spec` path in
|
||||
# `spec_decode_worker.py`
|
||||
all_prefills = (len(scheduled_seq_groups) == num_prefill_groups)
|
||||
num_lookahead_slots = (0 if
|
||||
(all_prefills
|
||||
and not self.scheduler_config.is_multi_step)
|
||||
else running_scheduled.num_lookahead_slots)
|
||||
return SchedulerOutputs(
|
||||
scheduled_seq_groups=(prefills.seq_groups +
|
||||
running_scheduled.prefill_seq_groups +
|
||||
swapped_in.prefill_seq_groups +
|
||||
running_scheduled.decode_seq_groups +
|
||||
swapped_in.decode_seq_groups),
|
||||
num_prefill_groups=(len(prefills.seq_groups) +
|
||||
len(swapped_in.prefill_seq_groups) +
|
||||
len(running_scheduled.prefill_seq_groups)),
|
||||
scheduled_seq_groups=scheduled_seq_groups,
|
||||
num_prefill_groups=num_prefill_groups,
|
||||
num_batched_tokens=budget.num_batched_tokens +
|
||||
budget.num_cached_tokens,
|
||||
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
|
||||
@ -1218,7 +1228,7 @@ class Scheduler:
|
||||
swapped_in.blocks_to_copy,
|
||||
ignored_seq_groups=prefills.ignored_seq_groups +
|
||||
swapped_in.infeasible_seq_groups,
|
||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
running_queue_size=len(self.running),
|
||||
preempted=(len(running_scheduled.preempted) +
|
||||
len(running_scheduled.swapped_out)),
|
||||
|
@ -408,7 +408,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
disable_all_speculation = self._should_disable_all_speculation(
|
||||
execute_model_req)
|
||||
num_lookahead_slots = execute_model_req.num_lookahead_slots
|
||||
all_prompt = True
|
||||
atleast_one_prompt = False
|
||||
all_zero_spec_tokens = True
|
||||
for sgm in execute_model_req.seq_group_metadata_list:
|
||||
all_prompt = all_prompt and sgm.is_prompt
|
||||
atleast_one_prompt = atleast_one_prompt or sgm.is_prompt
|
||||
all_zero_spec_tokens = all_zero_spec_tokens and (
|
||||
sgm.num_speculative_tokens == 0)
|
||||
|
||||
if all_prompt and execute_model_req.seq_group_metadata_list:
|
||||
assert num_lookahead_slots == 0, (
|
||||
"Prompt only runs should have num_lookahead_slots equal to 0. "
|
||||
"This should never happen, please file a bug at "
|
||||
"https://github.com/vllm-project/vllm/issues")
|
||||
# Speculative decoding is disabled in the following cases:
|
||||
# 1. Prefill phase: Speculative decoding is not
|
||||
# used during the prefill phase.
|
||||
@ -419,11 +432,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# In any of these cases, the proposer and scorer workers
|
||||
# are called normally.
|
||||
# We expect `num_speculative_tokens` to be None for prefills.
|
||||
no_spec = all(
|
||||
sgm.is_prompt for sgm in execute_model_req.seq_group_metadata_list
|
||||
) or num_lookahead_slots == 0 or disable_all_speculation or all(
|
||||
sgm.num_speculative_tokens == 0
|
||||
for sgm in execute_model_req.seq_group_metadata_list)
|
||||
no_spec = (num_lookahead_slots == 0 or disable_all_speculation
|
||||
or all_zero_spec_tokens)
|
||||
|
||||
# Broadcast how many lookahead slots are scheduled for this step, and
|
||||
# whether all speculation is disabled, to all non-driver workers.
|
||||
@ -442,6 +452,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
no_spec=no_spec,
|
||||
disable_all_speculation=disable_all_speculation,
|
||||
# When both chunked prefill and speculative decoding are enabled
|
||||
# it is possible that the same batch contains both prefill
|
||||
# and decodes. If that happens in the scorer we run the batch
|
||||
# as one single forward pass. However, in the proposer we
|
||||
# run them as 2 different batches - one for prefill and
|
||||
# the other for decodes. The variable indicates to the non-driver
|
||||
# worker that there are prefills as part of the speculative batch
|
||||
# and hence it needs to run an extra prefill forward pass.
|
||||
run_spec_proposer_for_prefill=atleast_one_prompt,
|
||||
)
|
||||
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
|
||||
|
||||
@ -653,6 +672,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
if not data["no_spec"]:
|
||||
self.scorer_worker.execute_model()
|
||||
if data["run_spec_proposer_for_prefill"]:
|
||||
self.proposer_worker.execute_model()
|
||||
|
||||
return True
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user