[Bug fix][Core] assert num_new_tokens == 1 fails when SamplingParams.n is not 1 and max_tokens is large & Add tests for preemption (#4451)

This commit is contained in:
SangBin Cho 2024-05-02 11:24:13 +09:00 committed by GitHub
parent b8afa8b95a
commit 0d62fe58db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 172 additions and 13 deletions

View File

@ -17,6 +17,7 @@ steps:
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
- label: Core Test - label: Core Test
command: pytest -v -s core command: pytest -v -s core

View File

@ -55,7 +55,6 @@ def test_models(
) )
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model del vllm_model
print(vllm_outputs[0])
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i] hf_output_ids, hf_output_str = hf_outputs[i]

View File

@ -0,0 +1,138 @@
"""Compare the short outputs of HF and vLLM when using greedy sampling.
VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test.
Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1
pytest tests/basic_correctness/test_preemption.py`.
"""
import pytest
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
ENABLE_ARTIFICIAL_PREEMPT)
MODELS = [
"facebook/opt-125m",
]
assert ENABLE_ARTIFICIAL_PREEMPT is True, (
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. "
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest "
"tests/basic_correctness/test_preemption.py`")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
def test_chunked_prefill_recompute(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
) -> None:
"""Ensure that chunked prefill works with preemption."""
max_num_seqs = min(chunked_prefill_token_size, 256)
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(
model,
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
max_num_seqs=max_num_seqs,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
del vllm_model
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_preemption(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
"""By default, recompute preemption is enabled"""
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(
model,
dtype=dtype,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
del vllm_model
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
"""Use beam search enables swapping."""
example_prompts = example_prompts[:1]
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype, swap_space=10)
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
del vllm_model
for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
vllm_output_ids, _ = vllm_outputs[i]
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")

View File

@ -296,6 +296,7 @@ class VllmRunner:
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
block_size: int = 16, block_size: int = 16,
enable_chunked_prefill: bool = False, enable_chunked_prefill: bool = False,
swap_space=4,
**kwargs, **kwargs,
) -> None: ) -> None:
self.model = LLM( self.model = LLM(
@ -303,7 +304,7 @@ class VllmRunner:
tokenizer=tokenizer_name, tokenizer=tokenizer_name,
trust_remote_code=True, trust_remote_code=True,
dtype=dtype, dtype=dtype,
swap_space=0, swap_space=swap_space,
disable_log_stats=disable_log_stats, disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len, max_model_len=max_model_len,

View File

@ -33,7 +33,7 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector) metrics_collector)
exception_secret = 'artifical stop' exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
execute_model_data, _, _ = create_batch(batch_size, k) execute_model_data, _, _ = create_batch(batch_size, k)
@ -101,7 +101,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
proposal_probs=proposal_probs, proposal_probs=proposal_probs,
proposal_lens=proposal_lens) proposal_lens=proposal_lens)
exception_secret = 'artifical stop' exception_secret = 'artificial stop'
target_worker.execute_model.side_effect = ValueError(exception_secret) target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
@ -197,7 +197,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
target_worker.execute_model.return_value = [target_output[0]] target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artifical stop' exception_secret = 'artificial stop'
rejection_sampler.side_effect = ValueError(exception_secret) rejection_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):

View File

@ -1,4 +1,6 @@
import enum import enum
import os
import random
import time import time
from collections import deque from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -15,6 +17,13 @@ from vllm.utils import merge_dicts
logger = init_logger(__name__) logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with
# ARTIFICIAL_PREEMPTION_PROB% probability.
ENABLE_ARTIFICIAL_PREEMPT = bool(
os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa
ARTIFICIAL_PREEMPTION_PROB = 0.5
ARTIFICIAL_PREEMPTION_MAX_CNT = 500
class PreemptionMode(enum.Enum): class PreemptionMode(enum.Enum):
"""Preemption modes. """Preemption modes.
@ -286,6 +295,13 @@ class Scheduler:
# Latency of the last prompt step # Latency of the last prompt step
self.last_prompt_latency = 0.0 self.last_prompt_latency = 0.0
# The following field is test-only. It is used to inject artificial
# preemption.
self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT
self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT
if self.enable_artificial_preemption
else 0)
@property @property
def lora_enabled(self) -> bool: def lora_enabled(self) -> bool:
return bool(self.lora_config) return bool(self.lora_config)
@ -386,15 +402,13 @@ class Scheduler:
# groups to preempt. # groups to preempt.
now = time.time() now = time.time()
running_queue = policy.sort_by_priority(now, running_queue) running_queue = policy.sort_by_priority(now, running_queue)
while running_queue: while running_queue:
seq_group = running_queue[0] seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens( num_running_tokens = self._get_num_new_tokens(
seq_group, SequenceStatus.RUNNING, enable_chunking, budget) seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
# We can have up to 1 running prefill at any given time in running if num_running_tokens == 0:
# queue, which means we can guarantee chunk size is at least 1. break
assert num_running_tokens != 0
running_queue.popleft() running_queue.popleft()
while not self._can_append_slots(seq_group): while not self._can_append_slots(seq_group):
@ -449,9 +463,6 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0: if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id) curr_loras.add(seq_group.lora_int_id)
# Make sure all queues are updated.
assert len(running_queue) == 0
return running_queue, SchedulerRunningOutputs( return running_queue, SchedulerRunningOutputs(
decode_seq_groups=decode_seq_groups, decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups, prefill_seq_groups=prefill_seq_groups,
@ -545,7 +556,6 @@ class Scheduler:
ScheduledSequenceGroup(seq_group, ScheduledSequenceGroup(seq_group,
token_chunk_size=num_new_tokens)) token_chunk_size=num_new_tokens))
else: else:
assert num_new_tokens == 1
decode_seq_groups.append( decode_seq_groups.append(
ScheduledSequenceGroup(seq_group, token_chunk_size=1)) ScheduledSequenceGroup(seq_group, token_chunk_size=1))
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
@ -868,6 +878,13 @@ class Scheduler:
"""Determine whether or not we have enough space in the KV cache to """Determine whether or not we have enough space in the KV cache to
continue generation of the sequence group. continue generation of the sequence group.
""" """
# It is True only for testing case to trigger artificial preemption.
if (self.enable_artificial_preemption
and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB
and self.artificial_preempt_cnt > 0):
self.artificial_preempt_cnt -= 1
return False
# Appending slots only occurs in decoding. # Appending slots only occurs in decoding.
is_prefill = False is_prefill = False
@ -1116,11 +1133,14 @@ class Scheduler:
if `enable_chunking` is True. If a sequence group has multiple if `enable_chunking` is True. If a sequence group has multiple
sequences (e.g., running beam search), it means it is in decoding sequences (e.g., running beam search), it means it is in decoding
phase, so chunking doesn't happen. phase, so chunking doesn't happen.
Returns 0 if the new token cannot be computed due to token budget.
""" """
num_new_tokens = 0 num_new_tokens = 0
seqs = seq_group.get_seqs(status=status) seqs = seq_group.get_seqs(status=status)
for seq in seqs: for seq in seqs:
num_new_tokens += seq.get_num_new_tokens() num_new_tokens += seq.get_num_new_tokens()
assert num_new_tokens > 0
# Chunk if a running request cannot fit in. # Chunk if a running request cannot fit in.
# If number of seq > 1, it means it is doing beam search in a # If number of seq > 1, it means it is doing beam search in a
# decode phase. Do not chunk in that case. # decode phase. Do not chunk in that case.