[Misc] Move print_*_once
from utils to logger (#11298)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> Co-authored-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
This commit is contained in:
parent
730e9592e9
commit
d848800e88
1
.github/workflows/lint-and-deploy.yaml
vendored
1
.github/workflows/lint-and-deploy.yaml
vendored
@ -64,6 +64,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
export AWS_ACCESS_KEY_ID=minioadmin
|
export AWS_ACCESS_KEY_ID=minioadmin
|
||||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||||
|
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
|
||||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||||
|
|
||||||
- name: curl test
|
- name: curl test
|
||||||
|
@ -13,9 +13,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
from vllm.attention.ops.ipex_attn import PagedAttention
|
||||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||||
from vllm.utils import make_tensor_with_pad, print_warning_once
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import make_tensor_with_pad
|
||||||
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TorchSDPABackend(AttentionBackend):
|
class TorchSDPABackend(AttentionBackend):
|
||||||
|
|
||||||
@ -396,7 +399,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Torch SPDA does not support block-sparse attention.")
|
"Torch SPDA does not support block-sparse attention.")
|
||||||
if logits_soft_cap is not None:
|
if logits_soft_cap is not None:
|
||||||
print_warning_once("Torch SPDA does not support logits soft cap. "
|
logger.warning_once("Torch SPDA does not support logits soft cap. "
|
||||||
"Outputs may be slightly off.")
|
"Outputs may be slightly off.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
@ -17,7 +17,9 @@ from vllm.attention.backends.utils import (
|
|||||||
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
|
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
|
||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
PagedAttentionMetadata)
|
PagedAttentionMetadata)
|
||||||
from vllm.utils import print_warning_once
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class XFormersBackend(AttentionBackend):
|
class XFormersBackend(AttentionBackend):
|
||||||
@ -385,7 +387,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"XFormers does not support block-sparse attention.")
|
"XFormers does not support block-sparse attention.")
|
||||||
if logits_soft_cap is not None:
|
if logits_soft_cap is not None:
|
||||||
print_warning_once("XFormers does not support logits soft cap. "
|
logger.warning_once("XFormers does not support logits soft cap. "
|
||||||
"Outputs may be slightly off.")
|
"Outputs may be slightly off.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
@ -32,8 +32,7 @@ from vllm.transformers_utils.config import (
|
|||||||
from vllm.transformers_utils.s3_utils import S3Model
|
from vllm.transformers_utils.s3_utils import S3Model
|
||||||
from vllm.transformers_utils.utils import is_s3
|
from vllm.transformers_utils.utils import is_s3
|
||||||
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||||
get_cpu_memory, print_warning_once, random_uuid,
|
get_cpu_memory, random_uuid, resolve_obj_by_qualname)
|
||||||
resolve_obj_by_qualname)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
@ -314,7 +313,7 @@ class ModelConfig:
|
|||||||
sliding_window_len_min = get_min_sliding_window(
|
sliding_window_len_min = get_min_sliding_window(
|
||||||
self.hf_text_config.sliding_window)
|
self.hf_text_config.sliding_window)
|
||||||
|
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
f"{self.hf_text_config.model_type} has interleaved "
|
f"{self.hf_text_config.model_type} has interleaved "
|
||||||
"attention, which is currently not supported by the "
|
"attention, which is currently not supported by the "
|
||||||
"XFORMERS backend. Disabling sliding window and capping "
|
"XFORMERS backend. Disabling sliding window and capping "
|
||||||
@ -2758,7 +2757,7 @@ class CompilationConfig(BaseModel):
|
|||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None:
|
||||||
if not self.enable_reshape and self.enable_fusion:
|
if not self.enable_reshape and self.enable_fusion:
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"Fusion enabled but reshape elimination disabled."
|
"Fusion enabled but reshape elimination disabled."
|
||||||
"RMSNorm + quant (fp8) fusion might not work")
|
"RMSNorm + quant (fp8) fusion might not work")
|
||||||
|
|
||||||
@ -3151,7 +3150,7 @@ class VllmConfig:
|
|||||||
self.scheduler_config.chunked_prefill_enabled and \
|
self.scheduler_config.chunked_prefill_enabled and \
|
||||||
self.model_config.dtype == torch.float32 and \
|
self.model_config.dtype == torch.float32 and \
|
||||||
current_platform.get_device_capability() == (7, 5):
|
current_platform.get_device_capability() == (7, 5):
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"Turing devices tensor cores do not support float32 matmul. "
|
"Turing devices tensor cores do not support float32 matmul. "
|
||||||
"To workaround this limitation, vLLM will set 'ieee' input "
|
"To workaround this limitation, vLLM will set 'ieee' input "
|
||||||
"precision for chunked prefill triton kernels.")
|
"precision for chunked prefill triton kernels.")
|
||||||
|
@ -35,7 +35,6 @@ from vllm.logger import init_logger
|
|||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.multimodal.utils import MediaConnector
|
from vllm.multimodal.utils import MediaConnector
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -985,14 +984,14 @@ def apply_mistral_chat_template(
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
if chat_template is not None:
|
if chat_template is not None:
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"'chat_template' cannot be overridden for mistral tokenizer.")
|
"'chat_template' cannot be overridden for mistral tokenizer.")
|
||||||
if "add_generation_prompt" in kwargs:
|
if "add_generation_prompt" in kwargs:
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"'add_generation_prompt' is not supported for mistral tokenizer, "
|
"'add_generation_prompt' is not supported for mistral tokenizer, "
|
||||||
"so it will be ignored.")
|
"so it will be ignored.")
|
||||||
if "continue_final_message" in kwargs:
|
if "continue_final_message" in kwargs:
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"'continue_final_message' is not supported for mistral tokenizer, "
|
"'continue_final_message' is not supported for mistral tokenizer, "
|
||||||
"so it will be ignored.")
|
"so it will be ignored.")
|
||||||
|
|
||||||
|
@ -10,7 +10,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|||||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputsV2
|
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputsV2
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
from vllm.utils import print_info_once, print_warning_once
|
|
||||||
|
|
||||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
|
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
|
||||||
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
|
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
|
||||||
@ -68,19 +67,22 @@ class InputPreprocessor:
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
if not self.model_config.is_encoder_decoder:
|
if not self.model_config.is_encoder_decoder:
|
||||||
print_warning_once("Using None for decoder start token id because "
|
logger.warning_once(
|
||||||
|
"Using None for decoder start token id because "
|
||||||
"this is not an encoder/decoder model.")
|
"this is not an encoder/decoder model.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if (self.model_config is None or self.model_config.hf_config is None):
|
if (self.model_config is None or self.model_config.hf_config is None):
|
||||||
print_warning_once("Using None for decoder start token id because "
|
logger.warning_once(
|
||||||
|
"Using None for decoder start token id because "
|
||||||
"model config is not available.")
|
"model config is not available.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
dec_start_token_id = getattr(self.model_config.hf_config,
|
dec_start_token_id = getattr(self.model_config.hf_config,
|
||||||
'decoder_start_token_id', None)
|
'decoder_start_token_id', None)
|
||||||
if dec_start_token_id is None:
|
if dec_start_token_id is None:
|
||||||
print_warning_once("Falling back on <BOS> for decoder start token "
|
logger.warning_once(
|
||||||
|
"Falling back on <BOS> for decoder start token "
|
||||||
"id because decoder start token id is not "
|
"id because decoder start token id is not "
|
||||||
"available.")
|
"available.")
|
||||||
dec_start_token_id = self.get_bos_token_id()
|
dec_start_token_id = self.get_bos_token_id()
|
||||||
@ -231,7 +233,7 @@ class InputPreprocessor:
|
|||||||
# updated to use the new multi-modal processor
|
# updated to use the new multi-modal processor
|
||||||
can_process_multimodal = self.mm_registry.has_processor(model_config)
|
can_process_multimodal = self.mm_registry.has_processor(model_config)
|
||||||
if not can_process_multimodal:
|
if not can_process_multimodal:
|
||||||
print_info_once(
|
logger.info_once(
|
||||||
"Your model uses the legacy input pipeline instead of the new "
|
"Your model uses the legacy input pipeline instead of the new "
|
||||||
"multi-modal processor. Please note that the legacy pipeline "
|
"multi-modal processor. Please note that the legacy pipeline "
|
||||||
"will be removed in a future release. For more details, see: "
|
"will be removed in a future release. For more details, see: "
|
||||||
|
@ -12,7 +12,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
|
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
|
||||||
print_warning_once, resolve_mm_processor_kwargs)
|
resolve_mm_processor_kwargs)
|
||||||
|
|
||||||
from .data import ProcessorInputs, SingletonInputs
|
from .data import ProcessorInputs, SingletonInputs
|
||||||
from .parse import is_encoder_decoder_inputs
|
from .parse import is_encoder_decoder_inputs
|
||||||
@ -352,7 +352,7 @@ class InputRegistry:
|
|||||||
num_tokens = dummy_data.seq_data.prompt_token_ids
|
num_tokens = dummy_data.seq_data.prompt_token_ids
|
||||||
if len(num_tokens) < seq_len:
|
if len(num_tokens) < seq_len:
|
||||||
if is_encoder_data:
|
if is_encoder_data:
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
f"Expected at least {seq_len} dummy encoder tokens for "
|
f"Expected at least {seq_len} dummy encoder tokens for "
|
||||||
f"profiling, but found {len(num_tokens)} tokens instead.")
|
f"profiling, but found {len(num_tokens)} tokens instead.")
|
||||||
else:
|
else:
|
||||||
|
@ -4,11 +4,12 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from functools import partial
|
from functools import lru_cache, partial
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Dict, Optional
|
from types import MethodType
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
|
||||||
@ -49,8 +50,44 @@ DEFAULT_LOGGING_CONFIG = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _print_info_once(logger: Logger, msg: str) -> None:
|
||||||
|
# Set the stacklevel to 2 to print the original caller's line info
|
||||||
|
logger.info(msg, stacklevel=2)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _print_warning_once(logger: Logger, msg: str) -> None:
|
||||||
|
# Set the stacklevel to 2 to print the original caller's line info
|
||||||
|
logger.warning(msg, stacklevel=2)
|
||||||
|
|
||||||
|
|
||||||
|
class _VllmLogger(Logger):
|
||||||
|
"""
|
||||||
|
Note:
|
||||||
|
This class is just to provide type information.
|
||||||
|
We actually patch the methods directly on the :class:`logging.Logger`
|
||||||
|
instance to avoid conflicting with other libraries such as
|
||||||
|
`intel_extension_for_pytorch.utils._logger`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def info_once(self, msg: str) -> None:
|
||||||
|
"""
|
||||||
|
As :meth:`info`, but subsequent calls with the same message
|
||||||
|
are silently dropped.
|
||||||
|
"""
|
||||||
|
_print_info_once(self, msg)
|
||||||
|
|
||||||
|
def warning_once(self, msg: str) -> None:
|
||||||
|
"""
|
||||||
|
As :meth:`warning`, but subsequent calls with the same message
|
||||||
|
are silently dropped.
|
||||||
|
"""
|
||||||
|
_print_warning_once(self, msg)
|
||||||
|
|
||||||
|
|
||||||
def _configure_vllm_root_logger() -> None:
|
def _configure_vllm_root_logger() -> None:
|
||||||
logging_config: Dict = {}
|
logging_config = dict[str, Any]()
|
||||||
|
|
||||||
if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
|
if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -84,12 +121,22 @@ def _configure_vllm_root_logger() -> None:
|
|||||||
dictConfig(logging_config)
|
dictConfig(logging_config)
|
||||||
|
|
||||||
|
|
||||||
def init_logger(name: str) -> Logger:
|
def init_logger(name: str) -> _VllmLogger:
|
||||||
"""The main purpose of this function is to ensure that loggers are
|
"""The main purpose of this function is to ensure that loggers are
|
||||||
retrieved in such a way that we can be sure the root vllm logger has
|
retrieved in such a way that we can be sure the root vllm logger has
|
||||||
already been configured."""
|
already been configured."""
|
||||||
|
|
||||||
return logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
|
|
||||||
|
methods_to_patch = {
|
||||||
|
"info_once": _print_info_once,
|
||||||
|
"warning_once": _print_warning_once,
|
||||||
|
}
|
||||||
|
|
||||||
|
for method_name, method in methods_to_patch.items():
|
||||||
|
setattr(logger, method_name, MethodType(method, logger))
|
||||||
|
|
||||||
|
return cast(_VllmLogger, logger)
|
||||||
|
|
||||||
|
|
||||||
# The root logger is initialized when the module is imported.
|
# The root logger is initialized when the module is imported.
|
||||||
|
@ -4,7 +4,9 @@ import math
|
|||||||
from dataclasses import MISSING, dataclass, field, fields
|
from dataclasses import MISSING, dataclass, field, fields
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from vllm.utils import print_info_once
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -42,7 +44,7 @@ class PEFTHelper:
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._validate_features()
|
self._validate_features()
|
||||||
if self.use_rslora:
|
if self.use_rslora:
|
||||||
print_info_once("Loading LoRA weights trained with rsLoRA.")
|
logger.info_once("Loading LoRA weights trained with rsLoRA.")
|
||||||
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
|
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
|
||||||
else:
|
else:
|
||||||
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
|
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
|
||||||
|
@ -1,19 +1,21 @@
|
|||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import print_info_once
|
|
||||||
|
|
||||||
from .punica_base import PunicaWrapperBase
|
from .punica_base import PunicaWrapperBase
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
|
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
# Lazy import to avoid ImportError
|
# Lazy import to avoid ImportError
|
||||||
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
|
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
|
||||||
print_info_once("Using PunicaWrapperGPU.")
|
logger.info_once("Using PunicaWrapperGPU.")
|
||||||
return PunicaWrapperGPU(*args, **kwargs)
|
return PunicaWrapperGPU(*args, **kwargs)
|
||||||
elif current_platform.is_hpu():
|
elif current_platform.is_hpu():
|
||||||
# Lazy import to avoid ImportError
|
# Lazy import to avoid ImportError
|
||||||
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
|
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
|
||||||
print_info_once("Using PunicaWrapperHPU.")
|
logger.info_once("Using PunicaWrapperHPU.")
|
||||||
return PunicaWrapperHPU(*args, **kwargs)
|
return PunicaWrapperHPU(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -5,7 +5,6 @@ import torch.nn as nn
|
|||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -91,7 +90,7 @@ class CustomOp(nn.Module):
|
|||||||
compilation_config = get_current_vllm_config().compilation_config
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
custom_ops = compilation_config.custom_ops
|
custom_ops = compilation_config.custom_ops
|
||||||
if not hasattr(cls, "name"):
|
if not hasattr(cls, "name"):
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
f"Custom op {cls.__name__} was not registered, "
|
f"Custom op {cls.__name__} was not registered, "
|
||||||
f"which means it won't appear in the op registry. "
|
f"which means it won't appear in the op registry. "
|
||||||
f"It will be enabled/disabled based on the global settings.")
|
f"It will be enabled/disabled based on the global settings.")
|
||||||
|
@ -8,6 +8,7 @@ from compressed_tensors.quantization import QuantizationStrategy
|
|||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe # noqa
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
@ -16,7 +17,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GPTQMarlinState(Enum):
|
class GPTQMarlinState(Enum):
|
||||||
@ -142,7 +144,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
"activation scales are None.")
|
"activation scales are None.")
|
||||||
if (not all_close_1d(layer.w13_input_scale)
|
if (not all_close_1d(layer.w13_input_scale)
|
||||||
or not all_close_1d(layer.w2_input_scale)):
|
or not all_close_1d(layer.w2_input_scale)):
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"Found input_scales that are not equal for "
|
"Found input_scales that are not equal for "
|
||||||
"fp8 MoE layer. Using the maximum across experts "
|
"fp8 MoE layer. Using the maximum across experts "
|
||||||
"for each layer.")
|
"for each layer.")
|
||||||
|
@ -28,7 +28,6 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
|||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
@ -539,7 +538,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
"activation scales are None.")
|
"activation scales are None.")
|
||||||
if (not all_close_1d(layer.w13_input_scale)
|
if (not all_close_1d(layer.w13_input_scale)
|
||||||
or not all_close_1d(layer.w2_input_scale)):
|
or not all_close_1d(layer.w2_input_scale)):
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"Found input_scales that are not equal for "
|
"Found input_scales that are not equal for "
|
||||||
"fp8 MoE layer. Using the maximum across experts "
|
"fp8 MoE layer. Using the maximum across experts "
|
||||||
"for each layer.")
|
"for each layer.")
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaseKVCacheMethod(QuantizeMethodBase):
|
class BaseKVCacheMethod(QuantizeMethodBase):
|
||||||
@ -67,7 +69,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|||||||
layer._v_scale = v_scale
|
layer._v_scale = v_scale
|
||||||
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
|
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
|
||||||
and "e5m2" not in layer.kv_cache_dtype):
|
and "e5m2" not in layer.kv_cache_dtype):
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
||||||
"may cause accuracy issues. Please make sure k/v_scale "
|
"may cause accuracy issues. Please make sure k/v_scale "
|
||||||
"scaling factors are available in the fp8 checkpoint.")
|
"scaling factors are available in the fp8 checkpoint.")
|
||||||
|
@ -3,11 +3,13 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
|
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def is_fp8_marlin_supported():
|
def is_fp8_marlin_supported():
|
||||||
return current_platform.has_device_capability(80)
|
return current_platform.has_device_capability(80)
|
||||||
@ -47,7 +49,7 @@ def apply_fp8_marlin_linear(
|
|||||||
|
|
||||||
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||||
strategy: str = "tensor") -> None:
|
strategy: str = "tensor") -> None:
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"Your GPU does not have native support for FP8 computation but "
|
"Your GPU does not have native support for FP8 computation but "
|
||||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||||
"be used leveraging the Marlin kernel. This may degrade "
|
"be used leveraging the Marlin kernel. This may degrade "
|
||||||
|
@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
|||||||
get_quantization_config)
|
get_quantization_config)
|
||||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import PlaceholderModule, print_warning_once
|
from vllm.utils import PlaceholderModule
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from runai_model_streamer import SafetensorsStreamer
|
from runai_model_streamer import SafetensorsStreamer
|
||||||
@ -673,7 +673,7 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|||||||
None: If the remapped name is not found in params_dict.
|
None: If the remapped name is not found in params_dict.
|
||||||
"""
|
"""
|
||||||
if name.endswith(".kv_scale"):
|
if name.endswith(".kv_scale"):
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"DEPRECATED. Found kv_scale in the checkpoint. "
|
"DEPRECATED. Found kv_scale in the checkpoint. "
|
||||||
"This format is deprecated in favor of separate k_scale and "
|
"This format is deprecated in favor of separate k_scale and "
|
||||||
"v_scale tensors and will be removed in a future release. "
|
"v_scale tensors and will be removed in a future release. "
|
||||||
@ -682,7 +682,7 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|||||||
# NOTE: we remap the deprecated kv_scale to k_scale
|
# NOTE: we remap the deprecated kv_scale to k_scale
|
||||||
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
|
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
|
||||||
if remapped_name not in params_dict:
|
if remapped_name not in params_dict:
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
f"Found kv_scale in the checkpoint (e.g. {name}), "
|
f"Found kv_scale in the checkpoint (e.g. {name}), "
|
||||||
"but not found the expected name in the model "
|
"but not found the expected name in the model "
|
||||||
f"(e.g. {remapped_name}). kv_scale is "
|
f"(e.g. {remapped_name}). kv_scale is "
|
||||||
@ -695,7 +695,7 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|||||||
if name.endswith(scale_name):
|
if name.endswith(scale_name):
|
||||||
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
||||||
if remapped_name not in params_dict:
|
if remapped_name not in params_dict:
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
f"Found {scale_name} in the checkpoint (e.g. {name}), "
|
f"Found {scale_name} in the checkpoint (e.g. {name}), "
|
||||||
"but not found the expected name in the model "
|
"but not found the expected name in the model "
|
||||||
f"(e.g. {remapped_name}). {scale_name} is "
|
f"(e.g. {remapped_name}). {scale_name} is "
|
||||||
|
@ -11,6 +11,7 @@ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
|
|||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
@ -35,13 +36,14 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement)
|
BaseProcessingInfo, PromptReplacement)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChameleonImagePixelInputs(TypedDict):
|
class ChameleonImagePixelInputs(TypedDict):
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
@ -1111,7 +1113,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
remapped_kv_scale_name = name.replace(
|
remapped_kv_scale_name = name.replace(
|
||||||
".kv_scale", ".attn.kv_scale")
|
".kv_scale", ".attn.kv_scale")
|
||||||
if remapped_kv_scale_name not in params_dict:
|
if remapped_kv_scale_name not in params_dict:
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"Found kv scale in the checkpoint (e.g. "
|
"Found kv scale in the checkpoint (e.g. "
|
||||||
f"{name}), but not found the expected name in "
|
f"{name}), but not found the expected name in "
|
||||||
f"the model (e.g. {remapped_kv_scale_name}). "
|
f"the model (e.g. {remapped_kv_scale_name}). "
|
||||||
|
@ -20,6 +20,7 @@ from vllm.attention import Attention, AttentionMetadata
|
|||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
@ -34,13 +35,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OlmoeMoE(nn.Module):
|
class OlmoeMoE(nn.Module):
|
||||||
"""A tensor-parallel MoE implementation for Olmoe that shards each expert
|
"""A tensor-parallel MoE implementation for Olmoe that shards each expert
|
||||||
@ -446,7 +448,7 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
remapped_kv_scale_name = name.replace(
|
remapped_kv_scale_name = name.replace(
|
||||||
".kv_scale", ".attn.kv_scale")
|
".kv_scale", ".attn.kv_scale")
|
||||||
if remapped_kv_scale_name not in params_dict:
|
if remapped_kv_scale_name not in params_dict:
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"Found kv scale in the checkpoint "
|
"Found kv scale in the checkpoint "
|
||||||
f"(e.g. {name}), but not found the expected "
|
f"(e.g. {name}), but not found the expected "
|
||||||
f"name in the model "
|
f"name in the model "
|
||||||
|
@ -34,6 +34,7 @@ from vllm.config import CacheConfig, VllmConfig
|
|||||||
from vllm.distributed import (get_pp_group,
|
from vllm.distributed import (get_pp_group,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -50,13 +51,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsPP
|
||||||
from .utils import (extract_layer_index, is_pp_missing_parameter,
|
from .utils import (extract_layer_index, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeMLP(nn.Module):
|
class Qwen2MoeMLP(nn.Module):
|
||||||
|
|
||||||
@ -524,7 +526,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
remapped_kv_scale_name = name.replace(
|
remapped_kv_scale_name = name.replace(
|
||||||
".kv_scale", ".attn.kv_scale")
|
".kv_scale", ".attn.kv_scale")
|
||||||
if remapped_kv_scale_name not in params_dict:
|
if remapped_kv_scale_name not in params_dict:
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"Found kv scale in the checkpoint "
|
"Found kv scale in the checkpoint "
|
||||||
f"(e.g. {name}), but not found the expected "
|
f"(e.g. {name}), but not found the expected "
|
||||||
f"name in the model "
|
f"name in the model "
|
||||||
|
@ -7,8 +7,10 @@ from transformers import PretrainedConfig
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.selector import (backend_name_to_enum,
|
from vllm.attention.selector import (backend_name_to_enum,
|
||||||
get_global_forced_attn_backend)
|
get_global_forced_attn_backend)
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||||
|
|
||||||
@ -87,7 +89,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
|
|||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
selected_backend = _Backend.FLASH_ATTN
|
selected_backend = _Backend.FLASH_ATTN
|
||||||
else:
|
else:
|
||||||
print_warning_once(
|
logger.warning_once(
|
||||||
"Current `vllm-flash-attn` has a bug inside vision module, "
|
"Current `vllm-flash-attn` has a bug inside vision module, "
|
||||||
"so we use xformers backend instead. You can run "
|
"so we use xformers backend instead. You can run "
|
||||||
"`pip install flash-attn` to use flash-attention backend.")
|
"`pip install flash-attn` to use flash-attention backend.")
|
||||||
|
@ -696,18 +696,6 @@ def create_kv_caches_with_random(
|
|||||||
return key_caches, value_caches
|
return key_caches, value_caches
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def print_info_once(msg: str) -> None:
|
|
||||||
# Set the stacklevel to 2 to print the caller's line info
|
|
||||||
logger.info(msg, stacklevel=2)
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def print_warning_once(msg: str) -> None:
|
|
||||||
# Set the stacklevel to 2 to print the caller's line info
|
|
||||||
logger.warning(msg, stacklevel=2)
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def is_pin_memory_available() -> bool:
|
def is_pin_memory_available() -> bool:
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
Loading…
x
Reference in New Issue
Block a user