[Misc] Move print_*_once from utils to logger (#11298)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
Co-authored-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
This commit is contained in:
Cyrus Leung 2025-01-09 12:48:12 +08:00 committed by GitHub
parent 730e9592e9
commit d848800e88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 129 additions and 72 deletions

View File

@ -64,6 +64,7 @@ jobs:
run: |
export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
- name: curl test

View File

@ -13,9 +13,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import make_tensor_with_pad, print_warning_once
from vllm.logger import init_logger
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
logger = init_logger(__name__)
class TorchSDPABackend(AttentionBackend):
@ -396,8 +399,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
raise ValueError(
"Torch SPDA does not support block-sparse attention.")
if logits_soft_cap is not None:
print_warning_once("Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off.")
logger.warning_once("Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)

View File

@ -17,7 +17,9 @@ from vllm.attention.backends.utils import (
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.utils import print_warning_once
from vllm.logger import init_logger
logger = init_logger(__name__)
class XFormersBackend(AttentionBackend):
@ -385,8 +387,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
raise ValueError(
"XFormers does not support block-sparse attention.")
if logits_soft_cap is not None:
print_warning_once("XFormers does not support logits soft cap. "
"Outputs may be slightly off.")
logger.warning_once("XFormers does not support logits soft cap. "
"Outputs may be slightly off.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)

View File

@ -32,8 +32,7 @@ from vllm.transformers_utils.config import (
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, print_warning_once, random_uuid,
resolve_obj_by_qualname)
get_cpu_memory, random_uuid, resolve_obj_by_qualname)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -314,7 +313,7 @@ class ModelConfig:
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)
print_warning_once(
logger.warning_once(
f"{self.hf_text_config.model_type} has interleaved "
"attention, which is currently not supported by the "
"XFORMERS backend. Disabling sliding window and capping "
@ -2758,7 +2757,7 @@ class CompilationConfig(BaseModel):
def model_post_init(self, __context: Any) -> None:
if not self.enable_reshape and self.enable_fusion:
print_warning_once(
logger.warning_once(
"Fusion enabled but reshape elimination disabled."
"RMSNorm + quant (fp8) fusion might not work")
@ -3151,7 +3150,7 @@ class VllmConfig:
self.scheduler_config.chunked_prefill_enabled and \
self.model_config.dtype == torch.float32 and \
current_platform.get_device_capability() == (7, 5):
print_warning_once(
logger.warning_once(
"Turing devices tensor cores do not support float32 matmul. "
"To workaround this limitation, vLLM will set 'ieee' input "
"precision for chunked prefill triton kernels.")

View File

@ -35,7 +35,6 @@ from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import MediaConnector
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import print_warning_once
logger = init_logger(__name__)
@ -985,14 +984,14 @@ def apply_mistral_chat_template(
**kwargs: Any,
) -> List[int]:
if chat_template is not None:
print_warning_once(
logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs:
print_warning_once(
logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
if "continue_final_message" in kwargs:
print_warning_once(
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")

View File

@ -10,7 +10,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputsV2
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_info_once, print_warning_once
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
@ -68,21 +67,24 @@ class InputPreprocessor:
'''
if not self.model_config.is_encoder_decoder:
print_warning_once("Using None for decoder start token id because "
"this is not an encoder/decoder model.")
logger.warning_once(
"Using None for decoder start token id because "
"this is not an encoder/decoder model.")
return None
if (self.model_config is None or self.model_config.hf_config is None):
print_warning_once("Using None for decoder start token id because "
"model config is not available.")
logger.warning_once(
"Using None for decoder start token id because "
"model config is not available.")
return None
dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None)
if dec_start_token_id is None:
print_warning_once("Falling back on <BOS> for decoder start token "
"id because decoder start token id is not "
"available.")
logger.warning_once(
"Falling back on <BOS> for decoder start token "
"id because decoder start token id is not "
"available.")
dec_start_token_id = self.get_bos_token_id()
return dec_start_token_id
@ -231,7 +233,7 @@ class InputPreprocessor:
# updated to use the new multi-modal processor
can_process_multimodal = self.mm_registry.has_processor(model_config)
if not can_process_multimodal:
print_info_once(
logger.info_once(
"Your model uses the legacy input pipeline instead of the new "
"multi-modal processor. Please note that the legacy pipeline "
"will be removed in a future release. For more details, see: "

View File

@ -12,7 +12,7 @@ from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
print_warning_once, resolve_mm_processor_kwargs)
resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
@ -352,7 +352,7 @@ class InputRegistry:
num_tokens = dummy_data.seq_data.prompt_token_ids
if len(num_tokens) < seq_len:
if is_encoder_data:
print_warning_once(
logger.warning_once(
f"Expected at least {seq_len} dummy encoder tokens for "
f"profiling, but found {len(num_tokens)} tokens instead.")
else:

View File

@ -4,11 +4,12 @@ import json
import logging
import os
import sys
from functools import partial
from functools import lru_cache, partial
from logging import Logger
from logging.config import dictConfig
from os import path
from typing import Dict, Optional
from types import MethodType
from typing import Any, Optional, cast
import vllm.envs as envs
@ -49,8 +50,44 @@ DEFAULT_LOGGING_CONFIG = {
}
@lru_cache
def _print_info_once(logger: Logger, msg: str) -> None:
# Set the stacklevel to 2 to print the original caller's line info
logger.info(msg, stacklevel=2)
@lru_cache
def _print_warning_once(logger: Logger, msg: str) -> None:
# Set the stacklevel to 2 to print the original caller's line info
logger.warning(msg, stacklevel=2)
class _VllmLogger(Logger):
"""
Note:
This class is just to provide type information.
We actually patch the methods directly on the :class:`logging.Logger`
instance to avoid conflicting with other libraries such as
`intel_extension_for_pytorch.utils._logger`.
"""
def info_once(self, msg: str) -> None:
"""
As :meth:`info`, but subsequent calls with the same message
are silently dropped.
"""
_print_info_once(self, msg)
def warning_once(self, msg: str) -> None:
"""
As :meth:`warning`, but subsequent calls with the same message
are silently dropped.
"""
_print_warning_once(self, msg)
def _configure_vllm_root_logger() -> None:
logging_config: Dict = {}
logging_config = dict[str, Any]()
if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
raise RuntimeError(
@ -84,12 +121,22 @@ def _configure_vllm_root_logger() -> None:
dictConfig(logging_config)
def init_logger(name: str) -> Logger:
def init_logger(name: str) -> _VllmLogger:
"""The main purpose of this function is to ensure that loggers are
retrieved in such a way that we can be sure the root vllm logger has
already been configured."""
return logging.getLogger(name)
logger = logging.getLogger(name)
methods_to_patch = {
"info_once": _print_info_once,
"warning_once": _print_warning_once,
}
for method_name, method in methods_to_patch.items():
setattr(logger, method_name, MethodType(method, logger))
return cast(_VllmLogger, logger)
# The root logger is initialized when the module is imported.

View File

@ -4,7 +4,9 @@ import math
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union
from vllm.utils import print_info_once
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
@ -42,7 +44,7 @@ class PEFTHelper:
def __post_init__(self):
self._validate_features()
if self.use_rslora:
print_info_once("Loading LoRA weights trained with rsLoRA.")
logger.info_once("Loading LoRA weights trained with rsLoRA.")
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
else:
self.vllm_lora_scaling_factor = self.lora_alpha / self.r

View File

@ -1,19 +1,21 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import print_info_once
from .punica_base import PunicaWrapperBase
logger = init_logger(__name__)
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
if current_platform.is_cuda_alike():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
print_info_once("Using PunicaWrapperGPU.")
logger.info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)
elif current_platform.is_hpu():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
print_info_once("Using PunicaWrapperHPU.")
logger.info_once("Using PunicaWrapperHPU.")
return PunicaWrapperHPU(*args, **kwargs)
else:
raise NotImplementedError

View File

@ -5,7 +5,6 @@ import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
logger = init_logger(__name__)
@ -91,7 +90,7 @@ class CustomOp(nn.Module):
compilation_config = get_current_vllm_config().compilation_config
custom_ops = compilation_config.custom_ops
if not hasattr(cls, "name"):
print_warning_once(
logger.warning_once(
f"Custom op {cls.__name__} was not registered, "
f"which means it won't appear in the op registry. "
f"It will be enabled/disabled based on the global settings.")

View File

@ -8,6 +8,7 @@ from compressed_tensors.quantization import QuantizationStrategy
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
@ -16,7 +17,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
logger = init_logger(__name__)
class GPTQMarlinState(Enum):
@ -142,10 +144,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"activation scales are None.")
if (not all_close_1d(layer.w13_input_scale)
or not all_close_1d(layer.w2_input_scale)):
print_warning_once(
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. ")
"for each layer.")
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter(

View File

@ -28,7 +28,6 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter,
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"]
@ -539,10 +538,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"activation scales are None.")
if (not all_close_1d(layer.w13_input_scale)
or not all_close_1d(layer.w2_input_scale)):
print_warning_once(
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. ")
"for each layer.")
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter(

View File

@ -1,8 +1,10 @@
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.utils import print_warning_once
logger = init_logger(__name__)
class BaseKVCacheMethod(QuantizeMethodBase):
@ -67,7 +69,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
layer._v_scale = v_scale
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
and "e5m2" not in layer.kv_cache_dtype):
print_warning_once(
logger.warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")

View File

@ -3,11 +3,13 @@ from typing import Optional
import torch
import vllm._custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
logger = init_logger(__name__)
def is_fp8_marlin_supported():
return current_platform.has_device_capability(80)
@ -47,7 +49,7 @@ def apply_fp8_marlin_linear(
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
strategy: str = "tensor") -> None:
print_warning_once(
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "

View File

@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
from vllm.platforms import current_platform
from vllm.utils import PlaceholderModule, print_warning_once
from vllm.utils import PlaceholderModule
try:
from runai_model_streamer import SafetensorsStreamer
@ -673,7 +673,7 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
None: If the remapped name is not found in params_dict.
"""
if name.endswith(".kv_scale"):
print_warning_once(
logger.warning_once(
"DEPRECATED. Found kv_scale in the checkpoint. "
"This format is deprecated in favor of separate k_scale and "
"v_scale tensors and will be removed in a future release. "
@ -682,7 +682,7 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
# NOTE: we remap the deprecated kv_scale to k_scale
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
if remapped_name not in params_dict:
print_warning_once(
logger.warning_once(
f"Found kv_scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). kv_scale is "
@ -695,7 +695,7 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
if name.endswith(scale_name):
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict:
print_warning_once(
logger.warning_once(
f"Found {scale_name} in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). {scale_name} is "

View File

@ -11,6 +11,7 @@ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -35,13 +36,14 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
logger = init_logger(__name__)
class ChameleonImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
@ -1111,7 +1113,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
logger.warning_once(
"Found kv scale in the checkpoint (e.g. "
f"{name}), but not found the expected name in "
f"the model (e.g. {remapped_kv_scale_name}). "

View File

@ -20,6 +20,7 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -34,13 +35,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
class OlmoeMoE(nn.Module):
"""A tensor-parallel MoE implementation for Olmoe that shards each expert
@ -446,7 +448,7 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "

View File

@ -34,6 +34,7 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
@ -50,13 +51,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
from .interfaces import SupportsPP
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
class Qwen2MoeMLP(nn.Module):
@ -524,7 +526,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
logger.warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "

View File

@ -7,8 +7,10 @@ from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.utils import print_warning_once
logger = init_logger(__name__)
_C = TypeVar("_C", bound=PretrainedConfig)
@ -87,7 +89,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
print_warning_once(
logger.warning_once(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")

View File

@ -696,18 +696,6 @@ def create_kv_caches_with_random(
return key_caches, value_caches
@lru_cache
def print_info_once(msg: str) -> None:
# Set the stacklevel to 2 to print the caller's line info
logger.info(msg, stacklevel=2)
@lru_cache
def print_warning_once(msg: str) -> None:
# Set the stacklevel to 2 to print the caller's line info
logger.warning(msg, stacklevel=2)
@lru_cache(maxsize=None)
def is_pin_memory_available() -> bool:
from vllm.platforms import current_platform