[Misc] Avoid direct access of global mm_registry in compute_encoder_budget (#15621)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-28 01:52:00 +08:00 committed by GitHub
parent 66aa4c0bf4
commit 13ac9cab21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 19 additions and 7 deletions

View File

@ -3,7 +3,7 @@
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal import MultiModalRegistry
from vllm.v1.request import Request
if TYPE_CHECKING:
@ -67,6 +67,7 @@ class EncoderCacheManager:
def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
@ -74,6 +75,7 @@ def compute_encoder_budget(
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
@ -89,7 +91,11 @@ def compute_encoder_budget(
(
encoder_compute_budget,
encoder_cache_size,
) = _compute_encoder_budget_multimodal(model_config, scheduler_config)
) = _compute_encoder_budget_multimodal(
model_config,
scheduler_config,
mm_registry,
)
return encoder_compute_budget, encoder_cache_size
@ -97,6 +103,7 @@ def compute_encoder_budget(
def _compute_encoder_budget_multimodal(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
@ -104,6 +111,7 @@ def _compute_encoder_budget_multimodal(
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
@ -112,8 +120,8 @@ def _compute_encoder_budget_multimodal(
in the input sequence.
"""
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
model_config)
max_tokens_by_modality_dict = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config)
if not max_tokens_by_modality_dict:
logger.warning(

View File

@ -10,6 +10,7 @@ from typing import Optional, Union
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
@ -38,6 +39,7 @@ class Scheduler(SchedulerInterface):
speculative_config: Optional[SpeculativeConfig],
log_stats: bool,
structured_output_manager: StructuredOutputManager,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
@ -93,6 +95,7 @@ class Scheduler(SchedulerInterface):
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
mm_registry=mm_registry,
)
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and

View File

@ -137,6 +137,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
mm_registry=self.mm_registry,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
@ -1439,9 +1440,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
max_tokens_by_modality_dict = (
MULTIMODAL_REGISTRY.
get_max_tokens_per_item_by_nonzero_modality(self.model_config))
max_tokens_by_modality_dict = self.mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(self.model_config)
dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1])

View File

@ -109,6 +109,7 @@ class TPUModelRunner:
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
mm_registry=self.mm_registry,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size