[V1] Scheduler Refactoring [1/N] - Add Scheduler Interface (#15250)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Woosuk Kwon 2025-03-20 17:50:43 -07:00 committed by GitHub
parent 06dd08256f
commit 0c6f5023c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 182 additions and 45 deletions

View File

@ -6,7 +6,7 @@ from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import Scheduler as V1Scheduler
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine

View File

@ -6,7 +6,8 @@ import pytest
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager

View File

@ -3,8 +3,8 @@ import pytest
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

View File

@ -1695,7 +1695,7 @@ class EngineArgs:
# V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default
if self.scheduler_cls == EngineArgs.scheduler_cls:
self.scheduler_cls = "vllm.v1.core.scheduler.Scheduler"
self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
# When no user override, set the default values based on the usage
# context.

View File

@ -17,7 +17,7 @@ from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)

View File

@ -17,7 +17,7 @@ from vllm.platforms import current_platform
from vllm.utils import cdiv
if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

View File

@ -212,7 +212,7 @@ except ImportError:
from flash_attn import flash_attn_varlen_func
if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

View File

View File

@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
class SchedulerInterface(ABC):
@abstractmethod
def schedule(self) -> "SchedulerOutput":
"""Schedule the requests to process in this scheduling step.
The scheduling decision is made at the iteration level. Each scheduling
step corresponds to a single forward pass of the model. Therefore, this
method is called repeatedly by a busy loop in the engine.
Essentially, the scheduler produces a dictionary of {req_id: num_tokens}
that specifies how many tokens to process for each request in this
scheduling step. For example, num_tokens can be as large as the number
of prompt tokens for new requests, or it can be 1 for the requests that
are auto-regressively generating new tokens one by one. Otherwise, it
can be somewhere in between in case of chunked prefills, prefix caching,
speculative decoding, etc.
Additionally, the scheduler also returns useful data about each request
or the batch as a whole. The model runner will use this information in
preparing inputs to the model.
Returns:
A SchedulerOutput object containing information about the scheduled
requests.
"""
raise NotImplementedError
@abstractmethod
def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> "EngineCoreOutputs":
"""Update the scheduler state based on the model runner output.
This method is called after the model runner has processed the scheduled
requests. The model runner output includes generated token ids, draft
token ids for next step, etc. The scheduler uses this information to
update its states, checks the finished requests, and returns the output
for each request.
Returns:
A EngineCoreOutputs object containing the outputs for each request.
"""
raise NotImplementedError
@abstractmethod
def add_request(self, request: "Request") -> None:
"""Add a new request to the scheduler's internal queue.
Args:
request: The new request being added.
"""
raise NotImplementedError
@abstractmethod
def finish_requests(
self,
request_ids: Union[str, Iterable[str]],
finished_status: "RequestStatus",
) -> None:
"""Finish the requests in the scheduler's internal queue. If the request
is not in the queue, this method will do nothing.
This method is called in two cases:
1. When the request is aborted by the client.
2. When the frontend process detects a stop string of the request after
de-tokenizing its generated tokens.
Args:
request_ids: A single or a list of request IDs.
finished_status: The finished status of the given requests.
"""
raise NotImplementedError
@abstractmethod
def get_num_unfinished_requests(self) -> int:
"""Number of unfinished requests in the scheduler's internal queue."""
raise NotImplementedError
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests in the scheduler's
internal queue."""
return self.get_num_unfinished_requests() > 0
@abstractmethod
def has_finished_requests(self) -> bool:
"""Returns True if there are finished requests that need to be cleared.
NOTE: This is different from `not self.has_unfinished_requests()`.
The scheduler maintains an internal list of the requests finished in the
previous step. This list is returned from the next call to schedule(),
to be sent to the model runner in the next step to clear cached states
for these finished requests.
This method checks if this internal list of finished requests is
non-empty. This information is useful for DP attention.
"""
raise NotImplementedError
def has_requests(self) -> bool:
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()
@abstractmethod
def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
raise NotImplementedError
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset the prefix cache for KV cache.
This is particularly required when the model weights are live-updated.
"""
raise NotImplementedError
@abstractmethod
def make_stats(self) -> Optional["SchedulerStats"]:
"""Make a SchedulerStats object for logging.
The SchedulerStats object is created for every scheduling step.
"""
raise NotImplementedError

View File

@ -13,8 +13,10 @@ from vllm.logger import init_logger
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs)
from vllm.v1.metrics.stats import SchedulerStats
@ -25,7 +27,7 @@ from vllm.v1.structured_output import StructuredOutputManager
logger = init_logger(__name__)
class Scheduler:
class Scheduler(SchedulerInterface):
def __init__(
self,
@ -602,7 +604,7 @@ class Scheduler:
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = self._check_stop(request)
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
break
@ -648,25 +650,6 @@ class Scheduler:
scheduler_stats=self.make_stats(),
)
def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
return True
return False
def add_request(self, request: Request) -> None:
self.waiting.append(request)
self.requests[request.request_id] = request
@ -715,17 +698,9 @@ class Scheduler:
def get_num_unfinished_requests(self) -> int:
return len(self.waiting) + len(self.running)
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0
def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0
def has_requests(self):
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()
def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)

View File

@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.v1.request import Request, RequestStatus
def check_stop(request: Request, max_model_len: int) -> bool:
if (request.num_tokens >= max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
return True
return False

View File

@ -22,8 +22,8 @@ from vllm.transformers_utils.config import (
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx)
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.scheduler import Scheduler as V1Scheduler
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer

View File

@ -45,7 +45,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
if TYPE_CHECKING:
import xgrammar as xgr
from vllm.v1.core.scheduler_output import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")

View File

@ -28,7 +28,7 @@ from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
class Worker(WorkerBase):

View File

@ -37,7 +37,7 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)

View File

@ -17,7 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput