diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 11cda053..641f366d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -17,6 +17,7 @@ steps: - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - label: Core Test command: pytest -v -s core diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index d83416eb..47d582c7 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -55,7 +55,6 @@ def test_models( ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model - print(vllm_outputs[0]) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py new file mode 100644 index 00000000..1adfc7dd --- /dev/null +++ b/tests/basic_correctness/test_preemption.py @@ -0,0 +1,138 @@ +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test. + +Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 +pytest tests/basic_correctness/test_preemption.py`. +""" +import pytest + +from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, + ENABLE_ARTIFICIAL_PREEMPT) + +MODELS = [ + "facebook/opt-125m", +] + +assert ENABLE_ARTIFICIAL_PREEMPT is True, ( + "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. " + "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest " + "tests/basic_correctness/test_preemption.py`") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_chunked_prefill_recompute( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, +) -> None: + """Ensure that chunked prefill works with preemption.""" + max_num_seqs = min(chunked_prefill_token_size, 256) + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + max_num_seqs=max_num_seqs, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_preemption( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + """By default, recompute preemption is enabled""" + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("beam_width", [4]) +def test_swap( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + """Use beam search enables swapping.""" + example_prompts = example_prompts[:1] + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, + max_tokens) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype, swap_space=10) + vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, + max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, _ = hf_outputs[i] + vllm_output_ids, _ = vllm_outputs[i] + assert len(hf_output_ids) == len(vllm_output_ids) + for j in range(len(hf_output_ids)): + assert hf_output_ids[j] == vllm_output_ids[j], ( + f"Test{i} output{j}:\nHF: {hf_output_ids}\n" + f"vLLM: {vllm_output_ids}") diff --git a/tests/conftest.py b/tests/conftest.py index 5c50fc2d..67132691 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -296,6 +296,7 @@ class VllmRunner: tensor_parallel_size: int = 1, block_size: int = 16, enable_chunked_prefill: bool = False, + swap_space=4, **kwargs, ) -> None: self.model = LLM( @@ -303,7 +304,7 @@ class VllmRunner: tokenizer=tokenizer_name, trust_remote_code=True, dtype=dtype, - swap_space=0, + swap_space=swap_space, disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, max_model_len=max_model_len, diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index d24d726c..91315df9 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -33,7 +33,7 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - exception_secret = 'artifical stop' + exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) execute_model_data, _, _ = create_batch(batch_size, k) @@ -101,7 +101,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): proposal_probs=proposal_probs, proposal_lens=proposal_lens) - exception_secret = 'artifical stop' + exception_secret = 'artificial stop' target_worker.execute_model.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): @@ -197,7 +197,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): target_worker.execute_model.return_value = [target_output[0]] - exception_secret = 'artifical stop' + exception_secret = 'artificial stop' rejection_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 024b7e70..b17b6cc7 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,4 +1,6 @@ import enum +import os +import random import time from collections import deque from dataclasses import dataclass, field @@ -15,6 +17,13 @@ from vllm.utils import merge_dicts logger = init_logger(__name__) +# Test-only. If configured, decode is preempted with +# ARTIFICIAL_PREEMPTION_PROB% probability. +ENABLE_ARTIFICIAL_PREEMPT = bool( + os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa +ARTIFICIAL_PREEMPTION_PROB = 0.5 +ARTIFICIAL_PREEMPTION_MAX_CNT = 500 + class PreemptionMode(enum.Enum): """Preemption modes. @@ -286,6 +295,13 @@ class Scheduler: # Latency of the last prompt step self.last_prompt_latency = 0.0 + # The following field is test-only. It is used to inject artificial + # preemption. + self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT + self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT + if self.enable_artificial_preemption + else 0) + @property def lora_enabled(self) -> bool: return bool(self.lora_config) @@ -386,15 +402,13 @@ class Scheduler: # groups to preempt. now = time.time() running_queue = policy.sort_by_priority(now, running_queue) - while running_queue: seq_group = running_queue[0] num_running_tokens = self._get_num_new_tokens( seq_group, SequenceStatus.RUNNING, enable_chunking, budget) - # We can have up to 1 running prefill at any given time in running - # queue, which means we can guarantee chunk size is at least 1. - assert num_running_tokens != 0 + if num_running_tokens == 0: + break running_queue.popleft() while not self._can_append_slots(seq_group): @@ -449,9 +463,6 @@ class Scheduler: if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) - # Make sure all queues are updated. - assert len(running_queue) == 0 - return running_queue, SchedulerRunningOutputs( decode_seq_groups=decode_seq_groups, prefill_seq_groups=prefill_seq_groups, @@ -545,7 +556,6 @@ class Scheduler: ScheduledSequenceGroup(seq_group, token_chunk_size=num_new_tokens)) else: - assert num_new_tokens == 1 decode_seq_groups.append( ScheduledSequenceGroup(seq_group, token_chunk_size=1)) budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) @@ -868,6 +878,13 @@ class Scheduler: """Determine whether or not we have enough space in the KV cache to continue generation of the sequence group. """ + # It is True only for testing case to trigger artificial preemption. + if (self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0): + self.artificial_preempt_cnt -= 1 + return False + # Appending slots only occurs in decoding. is_prefill = False @@ -1116,11 +1133,14 @@ class Scheduler: if `enable_chunking` is True. If a sequence group has multiple sequences (e.g., running beam search), it means it is in decoding phase, so chunking doesn't happen. + + Returns 0 if the new token cannot be computed due to token budget. """ num_new_tokens = 0 seqs = seq_group.get_seqs(status=status) for seq in seqs: num_new_tokens += seq.get_num_new_tokens() + assert num_new_tokens > 0 # Chunk if a running request cannot fit in. # If number of seq > 1, it means it is doing beam search in a # decode phase. Do not chunk in that case.