[Bug fix][Core] assert num_new_tokens == 1 fails when SamplingParams.n is not 1 and max_tokens is large & Add tests for preemption (#4451)
This commit is contained in:
parent
b8afa8b95a
commit
0d62fe58db
@ -17,6 +17,7 @@ steps:
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
|
||||
- label: Core Test
|
||||
command: pytest -v -s core
|
||||
|
@ -55,7 +55,6 @@ def test_models(
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
print(vllm_outputs[0])
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
|
138
tests/basic_correctness/test_preemption.py
Normal file
138
tests/basic_correctness/test_preemption.py
Normal file
@ -0,0 +1,138 @@
|
||||
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test.
|
||||
|
||||
Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1
|
||||
pytest tests/basic_correctness/test_preemption.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
|
||||
ENABLE_ARTIFICIAL_PREEMPT)
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
]
|
||||
|
||||
assert ENABLE_ARTIFICIAL_PREEMPT is True, (
|
||||
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. "
|
||||
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest "
|
||||
"tests/basic_correctness/test_preemption.py`")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
|
||||
def test_chunked_prefill_recompute(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
) -> None:
|
||||
"""Ensure that chunked prefill works with preemption."""
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_seqs=max_num_seqs,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
def test_preemption(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
"""By default, recompute preemption is enabled"""
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
@pytest.mark.parametrize("beam_width", [4])
|
||||
def test_swap(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
) -> None:
|
||||
"""Use beam search enables swapping."""
|
||||
example_prompts = example_prompts[:1]
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
|
||||
max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype, swap_space=10)
|
||||
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
|
||||
max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, _ = hf_outputs[i]
|
||||
vllm_output_ids, _ = vllm_outputs[i]
|
||||
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||
for j in range(len(hf_output_ids)):
|
||||
assert hf_output_ids[j] == vllm_output_ids[j], (
|
||||
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||
f"vLLM: {vllm_output_ids}")
|
@ -296,6 +296,7 @@ class VllmRunner:
|
||||
tensor_parallel_size: int = 1,
|
||||
block_size: int = 16,
|
||||
enable_chunked_prefill: bool = False,
|
||||
swap_space=4,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.model = LLM(
|
||||
@ -303,7 +304,7 @@ class VllmRunner:
|
||||
tokenizer=tokenizer_name,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
swap_space=0,
|
||||
swap_space=swap_space,
|
||||
disable_log_stats=disable_log_stats,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
max_model_len=max_model_len,
|
||||
|
@ -33,7 +33,7 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
exception_secret = 'artifical stop'
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
@ -101,7 +101,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
exception_secret = 'artifical stop'
|
||||
exception_secret = 'artificial stop'
|
||||
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
@ -197,7 +197,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
exception_secret = 'artifical stop'
|
||||
exception_secret = 'artificial stop'
|
||||
rejection_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
|
@ -1,4 +1,6 @@
|
||||
import enum
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
@ -15,6 +17,13 @@ from vllm.utils import merge_dicts
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Test-only. If configured, decode is preempted with
|
||||
# ARTIFICIAL_PREEMPTION_PROB% probability.
|
||||
ENABLE_ARTIFICIAL_PREEMPT = bool(
|
||||
os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
|
||||
ARTIFICIAL_PREEMPTION_PROB = 0.5
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT = 500
|
||||
|
||||
|
||||
class PreemptionMode(enum.Enum):
|
||||
"""Preemption modes.
|
||||
@ -286,6 +295,13 @@ class Scheduler:
|
||||
# Latency of the last prompt step
|
||||
self.last_prompt_latency = 0.0
|
||||
|
||||
# The following field is test-only. It is used to inject artificial
|
||||
# preemption.
|
||||
self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
|
||||
self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
|
||||
if self.enable_artificial_preemption
|
||||
else 0)
|
||||
|
||||
@property
|
||||
def lora_enabled(self) -> bool:
|
||||
return bool(self.lora_config)
|
||||
@ -386,15 +402,13 @@ class Scheduler:
|
||||
# groups to preempt.
|
||||
now = time.time()
|
||||
running_queue = policy.sort_by_priority(now, running_queue)
|
||||
|
||||
while running_queue:
|
||||
seq_group = running_queue[0]
|
||||
num_running_tokens = self._get_num_new_tokens(
|
||||
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
|
||||
|
||||
# We can have up to 1 running prefill at any given time in running
|
||||
# queue, which means we can guarantee chunk size is at least 1.
|
||||
assert num_running_tokens != 0
|
||||
if num_running_tokens == 0:
|
||||
break
|
||||
|
||||
running_queue.popleft()
|
||||
while not self._can_append_slots(seq_group):
|
||||
@ -449,9 +463,6 @@ class Scheduler:
|
||||
if curr_loras is not None and seq_group.lora_int_id > 0:
|
||||
curr_loras.add(seq_group.lora_int_id)
|
||||
|
||||
# Make sure all queues are updated.
|
||||
assert len(running_queue) == 0
|
||||
|
||||
return running_queue, SchedulerRunningOutputs(
|
||||
decode_seq_groups=decode_seq_groups,
|
||||
prefill_seq_groups=prefill_seq_groups,
|
||||
@ -545,7 +556,6 @@ class Scheduler:
|
||||
ScheduledSequenceGroup(seq_group,
|
||||
token_chunk_size=num_new_tokens))
|
||||
else:
|
||||
assert num_new_tokens == 1
|
||||
decode_seq_groups.append(
|
||||
ScheduledSequenceGroup(seq_group, token_chunk_size=1))
|
||||
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
|
||||
@ -868,6 +878,13 @@ class Scheduler:
|
||||
"""Determine whether or not we have enough space in the KV cache to
|
||||
continue generation of the sequence group.
|
||||
"""
|
||||
# It is True only for testing case to trigger artificial preemption.
|
||||
if (self.enable_artificial_preemption
|
||||
and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
|
||||
and self.artificial_preempt_cnt > 0):
|
||||
self.artificial_preempt_cnt -= 1
|
||||
return False
|
||||
|
||||
# Appending slots only occurs in decoding.
|
||||
is_prefill = False
|
||||
|
||||
@ -1116,11 +1133,14 @@ class Scheduler:
|
||||
if `enable_chunking` is True. If a sequence group has multiple
|
||||
sequences (e.g., running beam search), it means it is in decoding
|
||||
phase, so chunking doesn't happen.
|
||||
|
||||
Returns 0 if the new token cannot be computed due to token budget.
|
||||
"""
|
||||
num_new_tokens = 0
|
||||
seqs = seq_group.get_seqs(status=status)
|
||||
for seq in seqs:
|
||||
num_new_tokens += seq.get_num_new_tokens()
|
||||
assert num_new_tokens > 0
|
||||
# Chunk if a running request cannot fit in.
|
||||
# If number of seq > 1, it means it is doing beam search in a
|
||||
# decode phase. Do not chunk in that case.
|
||||
|
Loading…
x
Reference in New Issue
Block a user