From d848800e884f581eeed9f154d6c2aeb38eac24de Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 9 Jan 2025 12:48:12 +0800 Subject: [PATCH] [Misc] Move `print_*_once` from utils to logger (#11298) Signed-off-by: DarkLight1337 Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> Co-authored-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> --- .github/workflows/lint-and-deploy.yaml | 1 + vllm/attention/backends/torch_sdpa.py | 9 ++- vllm/attention/backends/xformers.py | 8 ++- vllm/config.py | 9 ++- vllm/entrypoints/chat_utils.py | 7 +-- vllm/inputs/preprocess.py | 20 ++++--- vllm/inputs/registry.py | 4 +- vllm/logger.py | 57 +++++++++++++++++-- vllm/lora/peft_helper.py | 6 +- vllm/lora/punica_wrapper/punica_selector.py | 8 ++- vllm/model_executor/custom_op.py | 3 +- .../compressed_tensors_moe.py | 8 ++- .../model_executor/layers/quantization/fp8.py | 5 +- .../layers/quantization/kv_cache.py | 6 +- .../quantization/utils/marlin_utils_fp8.py | 6 +- .../model_loader/weight_utils.py | 8 +-- vllm/model_executor/models/chameleon.py | 6 +- vllm/model_executor/models/olmoe.py | 6 +- vllm/model_executor/models/qwen2_moe.py | 6 +- vllm/model_executor/models/vision.py | 6 +- vllm/utils.py | 12 ---- 21 files changed, 129 insertions(+), 72 deletions(-) diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml index ee768db6..556b60d2 100644 --- a/.github/workflows/lint-and-deploy.yaml +++ b/.github/workflows/lint-and-deploy.yaml @@ -64,6 +64,7 @@ jobs: run: | export AWS_ACCESS_KEY_ID=minioadmin export AWS_SECRET_ACCESS_KEY=minioadmin + sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" - name: curl test diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c14f7754..ca1c4618 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -13,9 +13,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.ipex_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttentionMetadata -from vllm.utils import make_tensor_with_pad, print_warning_once +from vllm.logger import init_logger +from vllm.utils import make_tensor_with_pad from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder +logger = init_logger(__name__) + class TorchSDPABackend(AttentionBackend): @@ -396,8 +399,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): raise ValueError( "Torch SPDA does not support block-sparse attention.") if logits_soft_cap is not None: - print_warning_once("Torch SPDA does not support logits soft cap. " - "Outputs may be slightly off.") + logger.warning_once("Torch SPDA does not support logits soft cap. " + "Outputs may be slightly off.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 694c7cc1..8c8ca852 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -17,7 +17,9 @@ from vllm.attention.backends.utils import ( is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) -from vllm.utils import print_warning_once +from vllm.logger import init_logger + +logger = init_logger(__name__) class XFormersBackend(AttentionBackend): @@ -385,8 +387,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): raise ValueError( "XFormers does not support block-sparse attention.") if logits_soft_cap is not None: - print_warning_once("XFormers does not support logits soft cap. " - "Outputs may be slightly off.") + logger.warning_once("XFormers does not support logits soft cap. " + "Outputs may be slightly off.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/config.py b/vllm/config.py index 6dabeb38..19609085 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -32,8 +32,7 @@ from vllm.transformers_utils.config import ( from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3 from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, - get_cpu_memory, print_warning_once, random_uuid, - resolve_obj_by_qualname) + get_cpu_memory, random_uuid, resolve_obj_by_qualname) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -314,7 +313,7 @@ class ModelConfig: sliding_window_len_min = get_min_sliding_window( self.hf_text_config.sliding_window) - print_warning_once( + logger.warning_once( f"{self.hf_text_config.model_type} has interleaved " "attention, which is currently not supported by the " "XFORMERS backend. Disabling sliding window and capping " @@ -2758,7 +2757,7 @@ class CompilationConfig(BaseModel): def model_post_init(self, __context: Any) -> None: if not self.enable_reshape and self.enable_fusion: - print_warning_once( + logger.warning_once( "Fusion enabled but reshape elimination disabled." "RMSNorm + quant (fp8) fusion might not work") @@ -3151,7 +3150,7 @@ class VllmConfig: self.scheduler_config.chunked_prefill_enabled and \ self.model_config.dtype == torch.float32 and \ current_platform.get_device_capability() == (7, 5): - print_warning_once( + logger.warning_once( "Turing devices tensor cores do not support float32 matmul. " "To workaround this limitation, vLLM will set 'ieee' input " "precision for chunked prefill triton kernels.") diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index a492d549..923c7459 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -35,7 +35,6 @@ from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import MediaConnector from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -985,14 +984,14 @@ def apply_mistral_chat_template( **kwargs: Any, ) -> List[int]: if chat_template is not None: - print_warning_once( + logger.warning_once( "'chat_template' cannot be overridden for mistral tokenizer.") if "add_generation_prompt" in kwargs: - print_warning_once( + logger.warning_once( "'add_generation_prompt' is not supported for mistral tokenizer, " "so it will be ignored.") if "continue_final_message" in kwargs: - print_warning_once( + logger.warning_once( "'continue_final_message' is not supported for mistral tokenizer, " "so it will be ignored.") diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 3e92d582..a738ffe1 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -10,7 +10,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputsV2 from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from vllm.utils import print_info_once, print_warning_once from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, token_inputs) @@ -68,21 +67,24 @@ class InputPreprocessor: ''' if not self.model_config.is_encoder_decoder: - print_warning_once("Using None for decoder start token id because " - "this is not an encoder/decoder model.") + logger.warning_once( + "Using None for decoder start token id because " + "this is not an encoder/decoder model.") return None if (self.model_config is None or self.model_config.hf_config is None): - print_warning_once("Using None for decoder start token id because " - "model config is not available.") + logger.warning_once( + "Using None for decoder start token id because " + "model config is not available.") return None dec_start_token_id = getattr(self.model_config.hf_config, 'decoder_start_token_id', None) if dec_start_token_id is None: - print_warning_once("Falling back on for decoder start token " - "id because decoder start token id is not " - "available.") + logger.warning_once( + "Falling back on for decoder start token " + "id because decoder start token id is not " + "available.") dec_start_token_id = self.get_bos_token_id() return dec_start_token_id @@ -231,7 +233,7 @@ class InputPreprocessor: # updated to use the new multi-modal processor can_process_multimodal = self.mm_registry.has_processor(model_config) if not can_process_multimodal: - print_info_once( + logger.info_once( "Your model uses the legacy input pipeline instead of the new " "multi-modal processor. Please note that the legacy pipeline " "will be removed in a future release. For more details, see: " diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index b22b3f15..aad0dfab 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -12,7 +12,7 @@ from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, - print_warning_once, resolve_mm_processor_kwargs) + resolve_mm_processor_kwargs) from .data import ProcessorInputs, SingletonInputs from .parse import is_encoder_decoder_inputs @@ -352,7 +352,7 @@ class InputRegistry: num_tokens = dummy_data.seq_data.prompt_token_ids if len(num_tokens) < seq_len: if is_encoder_data: - print_warning_once( + logger.warning_once( f"Expected at least {seq_len} dummy encoder tokens for " f"profiling, but found {len(num_tokens)} tokens instead.") else: diff --git a/vllm/logger.py b/vllm/logger.py index 538db0dc..cac174f7 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -4,11 +4,12 @@ import json import logging import os import sys -from functools import partial +from functools import lru_cache, partial from logging import Logger from logging.config import dictConfig from os import path -from typing import Dict, Optional +from types import MethodType +from typing import Any, Optional, cast import vllm.envs as envs @@ -49,8 +50,44 @@ DEFAULT_LOGGING_CONFIG = { } +@lru_cache +def _print_info_once(logger: Logger, msg: str) -> None: + # Set the stacklevel to 2 to print the original caller's line info + logger.info(msg, stacklevel=2) + + +@lru_cache +def _print_warning_once(logger: Logger, msg: str) -> None: + # Set the stacklevel to 2 to print the original caller's line info + logger.warning(msg, stacklevel=2) + + +class _VllmLogger(Logger): + """ + Note: + This class is just to provide type information. + We actually patch the methods directly on the :class:`logging.Logger` + instance to avoid conflicting with other libraries such as + `intel_extension_for_pytorch.utils._logger`. + """ + + def info_once(self, msg: str) -> None: + """ + As :meth:`info`, but subsequent calls with the same message + are silently dropped. + """ + _print_info_once(self, msg) + + def warning_once(self, msg: str) -> None: + """ + As :meth:`warning`, but subsequent calls with the same message + are silently dropped. + """ + _print_warning_once(self, msg) + + def _configure_vllm_root_logger() -> None: - logging_config: Dict = {} + logging_config = dict[str, Any]() if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH: raise RuntimeError( @@ -84,12 +121,22 @@ def _configure_vllm_root_logger() -> None: dictConfig(logging_config) -def init_logger(name: str) -> Logger: +def init_logger(name: str) -> _VllmLogger: """The main purpose of this function is to ensure that loggers are retrieved in such a way that we can be sure the root vllm logger has already been configured.""" - return logging.getLogger(name) + logger = logging.getLogger(name) + + methods_to_patch = { + "info_once": _print_info_once, + "warning_once": _print_warning_once, + } + + for method_name, method in methods_to_patch.items(): + setattr(logger, method_name, MethodType(method, logger)) + + return cast(_VllmLogger, logger) # The root logger is initialized when the module is imported. diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index ddd42ae9..dacfb9eb 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -4,7 +4,9 @@ import math from dataclasses import MISSING, dataclass, field, fields from typing import Literal, Optional, Union -from vllm.utils import print_info_once +from vllm.logger import init_logger + +logger = init_logger(__name__) @dataclass @@ -42,7 +44,7 @@ class PEFTHelper: def __post_init__(self): self._validate_features() if self.use_rslora: - print_info_once("Loading LoRA weights trained with rsLoRA.") + logger.info_once("Loading LoRA weights trained with rsLoRA.") self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r) else: self.vllm_lora_scaling_factor = self.lora_alpha / self.r diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index cd64878d..9791d492 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -1,19 +1,21 @@ +from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import print_info_once from .punica_base import PunicaWrapperBase +logger = init_logger(__name__) + def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: if current_platform.is_cuda_alike(): # Lazy import to avoid ImportError from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU - print_info_once("Using PunicaWrapperGPU.") + logger.info_once("Using PunicaWrapperGPU.") return PunicaWrapperGPU(*args, **kwargs) elif current_platform.is_hpu(): # Lazy import to avoid ImportError from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU - print_info_once("Using PunicaWrapperHPU.") + logger.info_once("Using PunicaWrapperHPU.") return PunicaWrapperHPU(*args, **kwargs) else: raise NotImplementedError diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index fddc8bad..401606e8 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -5,7 +5,6 @@ import torch.nn as nn from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -91,7 +90,7 @@ class CustomOp(nn.Module): compilation_config = get_current_vllm_config().compilation_config custom_ops = compilation_config.custom_ops if not hasattr(cls, "name"): - print_warning_once( + logger.warning_once( f"Custom op {cls.__name__} was not registered, " f"which means it won't appear in the op registry. " f"It will be enabled/disabled based on the global settings.") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5fd6b017..4fb8fd84 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -8,6 +8,7 @@ from compressed_tensors.quantization import QuantizationStrategy import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( @@ -16,7 +17,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import print_warning_once + +logger = init_logger(__name__) class GPTQMarlinState(Enum): @@ -142,10 +144,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): "activation scales are None.") if (not all_close_1d(layer.w13_input_scale) or not all_close_1d(layer.w2_input_scale)): - print_warning_once( + logger.warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " - "for each layer. ") + "for each layer.") layer.w13_input_scale = torch.nn.Parameter( layer.w13_input_scale.max(), requires_grad=False) layer.w2_input_scale = torch.nn.Parameter( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2fe22903..a1be45a4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -28,7 +28,6 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter, PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import print_warning_once ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -539,10 +538,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): "activation scales are None.") if (not all_close_1d(layer.w13_input_scale) or not all_close_1d(layer.w2_input_scale)): - print_warning_once( + logger.warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " - "for each layer. ") + "for each layer.") layer.w13_input_scale = torch.nn.Parameter( layer.w13_input_scale.max(), requires_grad=False) layer.w2_input_scale = torch.nn.Parameter( diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index d79536d1..a74f5415 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -1,8 +1,10 @@ import torch +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.utils import print_warning_once + +logger = init_logger(__name__) class BaseKVCacheMethod(QuantizeMethodBase): @@ -67,7 +69,7 @@ class BaseKVCacheMethod(QuantizeMethodBase): layer._v_scale = v_scale if (layer._k_scale == 1.0 and layer._v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype): - print_warning_once( + logger.warning_once( "Using KV cache scaling factor 1.0 for fp8_e4m3. This " "may cause accuracy issues. Please make sure k/v_scale " "scaling factors are available in the fp8 checkpoint.") diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 8b3dfaae..245fe923 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -3,11 +3,13 @@ from typing import Optional import torch import vllm._custom_ops as ops +from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import print_warning_once from .marlin_utils import marlin_make_workspace, marlin_permute_scales +logger = init_logger(__name__) + def is_fp8_marlin_supported(): return current_platform.has_device_capability(80) @@ -47,7 +49,7 @@ def apply_fp8_marlin_linear( def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, strategy: str = "tensor") -> None: - print_warning_once( + logger.warning_once( "Your GPU does not have native support for FP8 computation but " "FP8 quantization is being used. Weight-only FP8 compression will " "be used leveraging the Marlin kernel. This may degrade " diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index a2c991cf..11d5fd71 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) from vllm.model_executor.layers.quantization.schema import QuantParamSchema from vllm.platforms import current_platform -from vllm.utils import PlaceholderModule, print_warning_once +from vllm.utils import PlaceholderModule try: from runai_model_streamer import SafetensorsStreamer @@ -673,7 +673,7 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: None: If the remapped name is not found in params_dict. """ if name.endswith(".kv_scale"): - print_warning_once( + logger.warning_once( "DEPRECATED. Found kv_scale in the checkpoint. " "This format is deprecated in favor of separate k_scale and " "v_scale tensors and will be removed in a future release. " @@ -682,7 +682,7 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: # NOTE: we remap the deprecated kv_scale to k_scale remapped_name = name.replace(".kv_scale", ".attn.k_scale") if remapped_name not in params_dict: - print_warning_once( + logger.warning_once( f"Found kv_scale in the checkpoint (e.g. {name}), " "but not found the expected name in the model " f"(e.g. {remapped_name}). kv_scale is " @@ -695,7 +695,7 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: if name.endswith(scale_name): remapped_name = name.replace(scale_name, f".attn{scale_name}") if remapped_name not in params_dict: - print_warning_once( + logger.warning_once( f"Found {scale_name} in the checkpoint (e.g. {name}), " "but not found the expected name in the model " f"(e.g. {remapped_name}). {scale_name} is " diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index acff9268..452fe727 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -11,6 +11,7 @@ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor, from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -35,13 +36,14 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.utils import print_warning_once from .interfaces import SupportsMultiModal, SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) +logger = init_logger(__name__) + class ChameleonImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -1111,7 +1113,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, remapped_kv_scale_name = name.replace( ".kv_scale", ".attn.kv_scale") if remapped_kv_scale_name not in params_dict: - print_warning_once( + logger.warning_once( "Found kv scale in the checkpoint (e.g. " f"{name}), but not found the expected name in " f"the model (e.g. {remapped_kv_scale_name}). " diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 5d9091cf..fbe5d1ae 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -20,6 +20,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -34,13 +35,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.utils import print_warning_once from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +logger = init_logger(__name__) + class OlmoeMoE(nn.Module): """A tensor-parallel MoE implementation for Olmoe that shards each expert @@ -446,7 +448,7 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): remapped_kv_scale_name = name.replace( ".kv_scale", ".attn.kv_scale") if remapped_kv_scale_name not in params_dict: - print_warning_once( + logger.warning_once( "Found kv scale in the checkpoint " f"(e.g. {name}), but not found the expected " f"name in the model " diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index ba70243c..95de6c21 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -34,6 +34,7 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -50,13 +51,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.utils import print_warning_once from .interfaces import SupportsPP from .utils import (extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +logger = init_logger(__name__) + class Qwen2MoeMLP(nn.Module): @@ -524,7 +526,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): remapped_kv_scale_name = name.replace( ".kv_scale", ".attn.kv_scale") if remapped_kv_scale_name not in params_dict: - print_warning_once( + logger.warning_once( "Found kv scale in the checkpoint " f"(e.g. {name}), but not found the expected " f"name in the model " diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index e6a9e153..a1395982 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -7,8 +7,10 @@ from transformers import PretrainedConfig import vllm.envs as envs from vllm.attention.selector import (backend_name_to_enum, get_global_forced_attn_backend) +from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform -from vllm.utils import print_warning_once + +logger = init_logger(__name__) _C = TypeVar("_C", bound=PretrainedConfig) @@ -87,7 +89,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend: if is_flash_attn_2_available(): selected_backend = _Backend.FLASH_ATTN else: - print_warning_once( + logger.warning_once( "Current `vllm-flash-attn` has a bug inside vision module, " "so we use xformers backend instead. You can run " "`pip install flash-attn` to use flash-attention backend.") diff --git a/vllm/utils.py b/vllm/utils.py index c09cae70..a92b77ef 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -696,18 +696,6 @@ def create_kv_caches_with_random( return key_caches, value_caches -@lru_cache -def print_info_once(msg: str) -> None: - # Set the stacklevel to 2 to print the caller's line info - logger.info(msg, stacklevel=2) - - -@lru_cache -def print_warning_once(msg: str) -> None: - # Set the stacklevel to 2 to print the caller's line info - logger.warning(msg, stacklevel=2) - - @lru_cache(maxsize=None) def is_pin_memory_available() -> bool: from vllm.platforms import current_platform