[V1] Scheduler Refactoring [1/N] - Add Scheduler Interface (#15250)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
06dd08256f
commit
0c6f5023c3
@ -6,7 +6,7 @@ from vllm.core.scheduler import Scheduler
|
|||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.core.scheduler import Scheduler as V1Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,8 @@ import pytest
|
|||||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
@ -3,8 +3,8 @@ import pytest
|
|||||||
|
|
||||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
|
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||||
SchedulerOutput)
|
SchedulerOutput)
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
|
@ -1695,7 +1695,7 @@ class EngineArgs:
|
|||||||
# V1 should use the new scheduler by default.
|
# V1 should use the new scheduler by default.
|
||||||
# Swap it only if this arg is set to the original V0 default
|
# Swap it only if this arg is set to the original V0 default
|
||||||
if self.scheduler_cls == EngineArgs.scheduler_cls:
|
if self.scheduler_cls == EngineArgs.scheduler_cls:
|
||||||
self.scheduler_cls = "vllm.v1.core.scheduler.Scheduler"
|
self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
|
||||||
|
|
||||||
# When no user override, set the default values based on the usage
|
# When no user override, set the default values based on the usage
|
||||||
# context.
|
# context.
|
||||||
|
@ -17,7 +17,7 @@ from vllm.utils import get_ip
|
|||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.scheduler import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
@ -17,7 +17,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.scheduler_output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
|
@ -212,7 +212,7 @@ except ImportError:
|
|||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.scheduler_output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
|
0
vllm/v1/core/sched/__init__.py
Normal file
0
vllm/v1/core/sched/__init__.py
Normal file
139
vllm/v1/core/sched/interface.py
Normal file
139
vllm/v1/core/sched/interface.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.engine import EngineCoreOutputs
|
||||||
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerInterface(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def schedule(self) -> "SchedulerOutput":
|
||||||
|
"""Schedule the requests to process in this scheduling step.
|
||||||
|
|
||||||
|
The scheduling decision is made at the iteration level. Each scheduling
|
||||||
|
step corresponds to a single forward pass of the model. Therefore, this
|
||||||
|
method is called repeatedly by a busy loop in the engine.
|
||||||
|
|
||||||
|
Essentially, the scheduler produces a dictionary of {req_id: num_tokens}
|
||||||
|
that specifies how many tokens to process for each request in this
|
||||||
|
scheduling step. For example, num_tokens can be as large as the number
|
||||||
|
of prompt tokens for new requests, or it can be 1 for the requests that
|
||||||
|
are auto-regressively generating new tokens one by one. Otherwise, it
|
||||||
|
can be somewhere in between in case of chunked prefills, prefix caching,
|
||||||
|
speculative decoding, etc.
|
||||||
|
|
||||||
|
Additionally, the scheduler also returns useful data about each request
|
||||||
|
or the batch as a whole. The model runner will use this information in
|
||||||
|
preparing inputs to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A SchedulerOutput object containing information about the scheduled
|
||||||
|
requests.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_from_output(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
model_runner_output: "ModelRunnerOutput",
|
||||||
|
) -> "EngineCoreOutputs":
|
||||||
|
"""Update the scheduler state based on the model runner output.
|
||||||
|
|
||||||
|
This method is called after the model runner has processed the scheduled
|
||||||
|
requests. The model runner output includes generated token ids, draft
|
||||||
|
token ids for next step, etc. The scheduler uses this information to
|
||||||
|
update its states, checks the finished requests, and returns the output
|
||||||
|
for each request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A EngineCoreOutputs object containing the outputs for each request.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_request(self, request: "Request") -> None:
|
||||||
|
"""Add a new request to the scheduler's internal queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The new request being added.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def finish_requests(
|
||||||
|
self,
|
||||||
|
request_ids: Union[str, Iterable[str]],
|
||||||
|
finished_status: "RequestStatus",
|
||||||
|
) -> None:
|
||||||
|
"""Finish the requests in the scheduler's internal queue. If the request
|
||||||
|
is not in the queue, this method will do nothing.
|
||||||
|
|
||||||
|
This method is called in two cases:
|
||||||
|
1. When the request is aborted by the client.
|
||||||
|
2. When the frontend process detects a stop string of the request after
|
||||||
|
de-tokenizing its generated tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_ids: A single or a list of request IDs.
|
||||||
|
finished_status: The finished status of the given requests.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_num_unfinished_requests(self) -> int:
|
||||||
|
"""Number of unfinished requests in the scheduler's internal queue."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def has_unfinished_requests(self) -> bool:
|
||||||
|
"""Returns True if there are unfinished requests in the scheduler's
|
||||||
|
internal queue."""
|
||||||
|
return self.get_num_unfinished_requests() > 0
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def has_finished_requests(self) -> bool:
|
||||||
|
"""Returns True if there are finished requests that need to be cleared.
|
||||||
|
NOTE: This is different from `not self.has_unfinished_requests()`.
|
||||||
|
|
||||||
|
The scheduler maintains an internal list of the requests finished in the
|
||||||
|
previous step. This list is returned from the next call to schedule(),
|
||||||
|
to be sent to the model runner in the next step to clear cached states
|
||||||
|
for these finished requests.
|
||||||
|
|
||||||
|
This method checks if this internal list of finished requests is
|
||||||
|
non-empty. This information is useful for DP attention.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def has_requests(self) -> bool:
|
||||||
|
"""Returns True if there are unfinished requests, or finished requests
|
||||||
|
not yet returned in SchedulerOutputs."""
|
||||||
|
return self.has_unfinished_requests() or self.has_finished_requests()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_num_unscheduled_requests(self) -> int:
|
||||||
|
"""Number of requests that are not being processed by the executor."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_prefix_cache(self) -> bool:
|
||||||
|
"""Reset the prefix cache for KV cache.
|
||||||
|
|
||||||
|
This is particularly required when the model weights are live-updated.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def make_stats(self) -> Optional["SchedulerStats"]:
|
||||||
|
"""Make a SchedulerStats object for logging.
|
||||||
|
|
||||||
|
The SchedulerStats object is created for every scheduling step.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
@ -13,8 +13,10 @@ from vllm.logger import init_logger
|
|||||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||||
compute_encoder_budget)
|
compute_encoder_budget)
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||||
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
|
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||||
SchedulerOutput)
|
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||||
|
SchedulerOutput)
|
||||||
|
from vllm.v1.core.sched.utils import check_stop
|
||||||
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
||||||
EngineCoreOutputs)
|
EngineCoreOutputs)
|
||||||
from vllm.v1.metrics.stats import SchedulerStats
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
@ -25,7 +27,7 @@ from vllm.v1.structured_output import StructuredOutputManager
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler(SchedulerInterface):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -602,7 +604,7 @@ class Scheduler:
|
|||||||
|
|
||||||
# Check for stop and update request state.
|
# Check for stop and update request state.
|
||||||
# This must be called before we make the EngineCoreOutput.
|
# This must be called before we make the EngineCoreOutput.
|
||||||
stopped = self._check_stop(request)
|
stopped = check_stop(request, self.max_model_len)
|
||||||
if stopped:
|
if stopped:
|
||||||
self._free_request(request)
|
self._free_request(request)
|
||||||
break
|
break
|
||||||
@ -648,25 +650,6 @@ class Scheduler:
|
|||||||
scheduler_stats=self.make_stats(),
|
scheduler_stats=self.make_stats(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_stop(self, request: Request) -> bool:
|
|
||||||
if (request.num_tokens >= self.max_model_len
|
|
||||||
or request.num_output_tokens >= request.max_tokens):
|
|
||||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
|
||||||
return True
|
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
|
||||||
last_token_id = request.output_token_ids[-1]
|
|
||||||
if (not sampling_params.ignore_eos
|
|
||||||
and last_token_id == request.eos_token_id):
|
|
||||||
request.status = RequestStatus.FINISHED_STOPPED
|
|
||||||
return True
|
|
||||||
|
|
||||||
if last_token_id in (sampling_params.stop_token_ids or ()):
|
|
||||||
request.status = RequestStatus.FINISHED_STOPPED
|
|
||||||
request.stop_reason = last_token_id
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def add_request(self, request: Request) -> None:
|
def add_request(self, request: Request) -> None:
|
||||||
self.waiting.append(request)
|
self.waiting.append(request)
|
||||||
self.requests[request.request_id] = request
|
self.requests[request.request_id] = request
|
||||||
@ -715,17 +698,9 @@ class Scheduler:
|
|||||||
def get_num_unfinished_requests(self) -> int:
|
def get_num_unfinished_requests(self) -> int:
|
||||||
return len(self.waiting) + len(self.running)
|
return len(self.waiting) + len(self.running)
|
||||||
|
|
||||||
def has_unfinished_requests(self) -> bool:
|
|
||||||
return self.get_num_unfinished_requests() > 0
|
|
||||||
|
|
||||||
def has_finished_requests(self) -> bool:
|
def has_finished_requests(self) -> bool:
|
||||||
return len(self.finished_req_ids) > 0
|
return len(self.finished_req_ids) > 0
|
||||||
|
|
||||||
def has_requests(self):
|
|
||||||
"""Returns True if there are unfinished requests, or finished requests
|
|
||||||
not yet returned in SchedulerOutputs."""
|
|
||||||
return self.has_unfinished_requests() or self.has_finished_requests()
|
|
||||||
|
|
||||||
def get_num_unscheduled_requests(self) -> int:
|
def get_num_unscheduled_requests(self) -> int:
|
||||||
"""Number of requests that are not being processed by the executor."""
|
"""Number of requests that are not being processed by the executor."""
|
||||||
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)
|
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)
|
22
vllm/v1/core/sched/utils.py
Normal file
22
vllm/v1/core/sched/utils.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
|
||||||
|
|
||||||
|
def check_stop(request: Request, max_model_len: int) -> bool:
|
||||||
|
if (request.num_tokens >= max_model_len
|
||||||
|
or request.num_output_tokens >= request.max_tokens):
|
||||||
|
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||||
|
return True
|
||||||
|
|
||||||
|
sampling_params = request.sampling_params
|
||||||
|
last_token_id = request.output_token_ids[-1]
|
||||||
|
if (not sampling_params.ignore_eos
|
||||||
|
and last_token_id == request.eos_token_id):
|
||||||
|
request.status = RequestStatus.FINISHED_STOPPED
|
||||||
|
return True
|
||||||
|
|
||||||
|
if last_token_id in (sampling_params.stop_token_ids or ()):
|
||||||
|
request.status = RequestStatus.FINISHED_STOPPED
|
||||||
|
request.stop_reason = last_token_id
|
||||||
|
return True
|
||||||
|
return False
|
@ -22,8 +22,8 @@ from vllm.transformers_utils.config import (
|
|||||||
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
|
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
|
||||||
zmq_socket_ctx)
|
zmq_socket_ctx)
|
||||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||||
from vllm.v1.core.scheduler import Scheduler as V1Scheduler
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.core.scheduler import SchedulerOutput
|
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||||
EngineCoreRequestType, UtilityOutput)
|
EngineCoreRequestType, UtilityOutput)
|
||||||
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
||||||
|
@ -45,7 +45,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
|
|
||||||
from vllm.v1.core.scheduler_output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
else:
|
else:
|
||||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ from vllm.v1.worker.worker_base import WorkerBase
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.scheduler_output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
class Worker(WorkerBase):
|
class Worker(WorkerBase):
|
||||||
|
@ -37,7 +37,7 @@ from vllm.v1.utils import bind_kv_cache
|
|||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.scheduler import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
from vllm.v1.core.scheduler import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
Loading…
x
Reference in New Issue
Block a user