[Bug fix][Core] assert num_new_tokens == 1 fails when SamplingParams.n is not 1 and max_tokens is large & Add tests for preemption (#4451)
This commit is contained in:
parent
b8afa8b95a
commit
0d62fe58db
@ -17,6 +17,7 @@ steps:
|
|||||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
|
||||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||||
|
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||||
|
|
||||||
- label: Core Test
|
- label: Core Test
|
||||||
command: pytest -v -s core
|
command: pytest -v -s core
|
||||||
|
@ -55,7 +55,6 @@ def test_models(
|
|||||||
)
|
)
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
del vllm_model
|
del vllm_model
|
||||||
print(vllm_outputs[0])
|
|
||||||
|
|
||||||
for i in range(len(example_prompts)):
|
for i in range(len(example_prompts)):
|
||||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
|
138
tests/basic_correctness/test_preemption.py
Normal file
138
tests/basic_correctness/test_preemption.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
||||||
|
|
||||||
|
VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test.
|
||||||
|
|
||||||
|
Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1
|
||||||
|
pytest tests/basic_correctness/test_preemption.py`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
|
||||||
|
ENABLE_ARTIFICIAL_PREEMPT)
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"facebook/opt-125m",
|
||||||
|
]
|
||||||
|
|
||||||
|
assert ENABLE_ARTIFICIAL_PREEMPT is True, (
|
||||||
|
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. "
|
||||||
|
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest "
|
||||||
|
"tests/basic_correctness/test_preemption.py`")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [96])
|
||||||
|
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
|
||||||
|
def test_chunked_prefill_recompute(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
chunked_prefill_token_size: int,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure that chunked prefill works with preemption."""
|
||||||
|
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||||
|
enable_chunked_prefill = False
|
||||||
|
max_num_batched_tokens = None
|
||||||
|
if chunked_prefill_token_size != -1:
|
||||||
|
enable_chunked_prefill = True
|
||||||
|
max_num_batched_tokens = chunked_prefill_token_size
|
||||||
|
|
||||||
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del hf_model
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
)
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||||
|
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
|
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||||
|
assert hf_output_str == vllm_output_str, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||||
|
assert hf_output_ids == vllm_output_ids, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [96])
|
||||||
|
def test_preemption(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""By default, recompute preemption is enabled"""
|
||||||
|
|
||||||
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del hf_model
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||||
|
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
|
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||||
|
assert hf_output_str == vllm_output_str, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||||
|
assert hf_output_ids == vllm_output_ids, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [96])
|
||||||
|
@pytest.mark.parametrize("beam_width", [4])
|
||||||
|
def test_swap(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
beam_width: int,
|
||||||
|
) -> None:
|
||||||
|
"""Use beam search enables swapping."""
|
||||||
|
example_prompts = example_prompts[:1]
|
||||||
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
|
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
|
||||||
|
max_tokens)
|
||||||
|
del hf_model
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(model, dtype=dtype, swap_space=10)
|
||||||
|
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
|
||||||
|
max_tokens)
|
||||||
|
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||||
|
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
hf_output_ids, _ = hf_outputs[i]
|
||||||
|
vllm_output_ids, _ = vllm_outputs[i]
|
||||||
|
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||||
|
for j in range(len(hf_output_ids)):
|
||||||
|
assert hf_output_ids[j] == vllm_output_ids[j], (
|
||||||
|
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||||
|
f"vLLM: {vllm_output_ids}")
|
@ -296,6 +296,7 @@ class VllmRunner:
|
|||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
block_size: int = 16,
|
block_size: int = 16,
|
||||||
enable_chunked_prefill: bool = False,
|
enable_chunked_prefill: bool = False,
|
||||||
|
swap_space=4,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = LLM(
|
self.model = LLM(
|
||||||
@ -303,7 +304,7 @@ class VllmRunner:
|
|||||||
tokenizer=tokenizer_name,
|
tokenizer=tokenizer_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
swap_space=0,
|
swap_space=swap_space,
|
||||||
disable_log_stats=disable_log_stats,
|
disable_log_stats=disable_log_stats,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
|
@ -33,7 +33,7 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
|
|||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
|
|
||||||
exception_secret = 'artifical stop'
|
exception_secret = 'artificial stop'
|
||||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||||
@ -101,7 +101,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
|||||||
proposal_probs=proposal_probs,
|
proposal_probs=proposal_probs,
|
||||||
proposal_lens=proposal_lens)
|
proposal_lens=proposal_lens)
|
||||||
|
|
||||||
exception_secret = 'artifical stop'
|
exception_secret = 'artificial stop'
|
||||||
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=exception_secret):
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
@ -197,7 +197,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
|||||||
|
|
||||||
target_worker.execute_model.return_value = [target_output[0]]
|
target_worker.execute_model.return_value = [target_output[0]]
|
||||||
|
|
||||||
exception_secret = 'artifical stop'
|
exception_secret = 'artificial stop'
|
||||||
rejection_sampler.side_effect = ValueError(exception_secret)
|
rejection_sampler.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=exception_secret):
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import enum
|
import enum
|
||||||
|
import os
|
||||||
|
import random
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@ -15,6 +17,13 @@ from vllm.utils import merge_dicts
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Test-only. If configured, decode is preempted with
|
||||||
|
# ARTIFICIAL_PREEMPTION_PROB% probability.
|
||||||
|
ENABLE_ARTIFICIAL_PREEMPT = bool(
|
||||||
|
os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
|
||||||
|
ARTIFICIAL_PREEMPTION_PROB = 0.5
|
||||||
|
ARTIFICIAL_PREEMPTION_MAX_CNT = 500
|
||||||
|
|
||||||
|
|
||||||
class PreemptionMode(enum.Enum):
|
class PreemptionMode(enum.Enum):
|
||||||
"""Preemption modes.
|
"""Preemption modes.
|
||||||
@ -286,6 +295,13 @@ class Scheduler:
|
|||||||
# Latency of the last prompt step
|
# Latency of the last prompt step
|
||||||
self.last_prompt_latency = 0.0
|
self.last_prompt_latency = 0.0
|
||||||
|
|
||||||
|
# The following field is test-only. It is used to inject artificial
|
||||||
|
# preemption.
|
||||||
|
self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
|
||||||
|
self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
|
||||||
|
if self.enable_artificial_preemption
|
||||||
|
else 0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lora_enabled(self) -> bool:
|
def lora_enabled(self) -> bool:
|
||||||
return bool(self.lora_config)
|
return bool(self.lora_config)
|
||||||
@ -386,15 +402,13 @@ class Scheduler:
|
|||||||
# groups to preempt.
|
# groups to preempt.
|
||||||
now = time.time()
|
now = time.time()
|
||||||
running_queue = policy.sort_by_priority(now, running_queue)
|
running_queue = policy.sort_by_priority(now, running_queue)
|
||||||
|
|
||||||
while running_queue:
|
while running_queue:
|
||||||
seq_group = running_queue[0]
|
seq_group = running_queue[0]
|
||||||
num_running_tokens = self._get_num_new_tokens(
|
num_running_tokens = self._get_num_new_tokens(
|
||||||
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
|
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
|
||||||
|
|
||||||
# We can have up to 1 running prefill at any given time in running
|
if num_running_tokens == 0:
|
||||||
# queue, which means we can guarantee chunk size is at least 1.
|
break
|
||||||
assert num_running_tokens != 0
|
|
||||||
|
|
||||||
running_queue.popleft()
|
running_queue.popleft()
|
||||||
while not self._can_append_slots(seq_group):
|
while not self._can_append_slots(seq_group):
|
||||||
@ -449,9 +463,6 @@ class Scheduler:
|
|||||||
if curr_loras is not None and seq_group.lora_int_id > 0:
|
if curr_loras is not None and seq_group.lora_int_id > 0:
|
||||||
curr_loras.add(seq_group.lora_int_id)
|
curr_loras.add(seq_group.lora_int_id)
|
||||||
|
|
||||||
# Make sure all queues are updated.
|
|
||||||
assert len(running_queue) == 0
|
|
||||||
|
|
||||||
return running_queue, SchedulerRunningOutputs(
|
return running_queue, SchedulerRunningOutputs(
|
||||||
decode_seq_groups=decode_seq_groups,
|
decode_seq_groups=decode_seq_groups,
|
||||||
prefill_seq_groups=prefill_seq_groups,
|
prefill_seq_groups=prefill_seq_groups,
|
||||||
@ -545,7 +556,6 @@ class Scheduler:
|
|||||||
ScheduledSequenceGroup(seq_group,
|
ScheduledSequenceGroup(seq_group,
|
||||||
token_chunk_size=num_new_tokens))
|
token_chunk_size=num_new_tokens))
|
||||||
else:
|
else:
|
||||||
assert num_new_tokens == 1
|
|
||||||
decode_seq_groups.append(
|
decode_seq_groups.append(
|
||||||
ScheduledSequenceGroup(seq_group, token_chunk_size=1))
|
ScheduledSequenceGroup(seq_group, token_chunk_size=1))
|
||||||
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
|
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
|
||||||
@ -868,6 +878,13 @@ class Scheduler:
|
|||||||
"""Determine whether or not we have enough space in the KV cache to
|
"""Determine whether or not we have enough space in the KV cache to
|
||||||
continue generation of the sequence group.
|
continue generation of the sequence group.
|
||||||
"""
|
"""
|
||||||
|
# It is True only for testing case to trigger artificial preemption.
|
||||||
|
if (self.enable_artificial_preemption
|
||||||
|
and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
|
||||||
|
and self.artificial_preempt_cnt > 0):
|
||||||
|
self.artificial_preempt_cnt -= 1
|
||||||
|
return False
|
||||||
|
|
||||||
# Appending slots only occurs in decoding.
|
# Appending slots only occurs in decoding.
|
||||||
is_prefill = False
|
is_prefill = False
|
||||||
|
|
||||||
@ -1116,11 +1133,14 @@ class Scheduler:
|
|||||||
if `enable_chunking` is True. If a sequence group has multiple
|
if `enable_chunking` is True. If a sequence group has multiple
|
||||||
sequences (e.g., running beam search), it means it is in decoding
|
sequences (e.g., running beam search), it means it is in decoding
|
||||||
phase, so chunking doesn't happen.
|
phase, so chunking doesn't happen.
|
||||||
|
|
||||||
|
Returns 0 if the new token cannot be computed due to token budget.
|
||||||
"""
|
"""
|
||||||
num_new_tokens = 0
|
num_new_tokens = 0
|
||||||
seqs = seq_group.get_seqs(status=status)
|
seqs = seq_group.get_seqs(status=status)
|
||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
num_new_tokens += seq.get_num_new_tokens()
|
num_new_tokens += seq.get_num_new_tokens()
|
||||||
|
assert num_new_tokens > 0
|
||||||
# Chunk if a running request cannot fit in.
|
# Chunk if a running request cannot fit in.
|
||||||
# If number of seq > 1, it means it is doing beam search in a
|
# If number of seq > 1, it means it is doing beam search in a
|
||||||
# decode phase. Do not chunk in that case.
|
# decode phase. Do not chunk in that case.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user