[CI] Disable non-lazy string operation on logging (#4326)

Co-authored-by: Danny Guinther <dguinther@neuralmagic.com>
This commit is contained in:
SangBin Cho 2024-04-26 16:16:58 +09:00 committed by GitHub
parent 2f30e7c72f
commit a88081bf76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 176 additions and 149 deletions

View File

@ -98,9 +98,10 @@ autodoc_mock_imports = [
for mock_target in autodoc_mock_imports: for mock_target in autodoc_mock_imports:
if mock_target in sys.modules: if mock_target in sys.modules:
logger.info( logger.info(
f"Potentially problematic mock target ({mock_target}) found; " "Potentially problematic mock target (%s) found; "
"autodoc_mock_imports cannot mock modules that have already " "autodoc_mock_imports cannot mock modules that have already "
"been loaded into sys.modules when the sphinx build starts.") "been loaded into sys.modules when the sphinx build starts.",
mock_target)
class MockedClassDocumenter(autodoc.ClassDocumenter): class MockedClassDocumenter(autodoc.ClassDocumenter):

View File

@ -32,6 +32,7 @@ select = [
"SIM", "SIM",
# isort # isort
# "I", # "I",
"G",
] ]
ignore = [ ignore = [
# star imports # star imports

View File

@ -63,7 +63,7 @@ class cmake_build_ext(build_ext):
num_jobs = os.environ.get("MAX_JOBS", None) num_jobs = os.environ.get("MAX_JOBS", None)
if num_jobs is not None: if num_jobs is not None:
num_jobs = int(num_jobs) num_jobs = int(num_jobs)
logger.info(f"Using MAX_JOBS={num_jobs} as the number of jobs.") logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
else: else:
try: try:
# os.sched_getaffinity() isn't universally available, so fall # os.sched_getaffinity() isn't universally available, so fall
@ -81,8 +81,9 @@ class cmake_build_ext(build_ext):
nvcc_threads = os.getenv("NVCC_THREADS", None) nvcc_threads = os.getenv("NVCC_THREADS", None)
if nvcc_threads is not None: if nvcc_threads is not None:
nvcc_threads = int(nvcc_threads) nvcc_threads = int(nvcc_threads)
logger.info(f"Using NVCC_THREADS={nvcc_threads} as the number" logger.info(
" of nvcc threads.") "Using NVCC_THREADS=%d as the number of nvcc threads.",
nvcc_threads)
else: else:
nvcc_threads = 1 nvcc_threads = 1
num_jobs = max(1, num_jobs // nvcc_threads) num_jobs = max(1, num_jobs // nvcc_threads)

View File

@ -167,9 +167,9 @@ class ModelConfig:
f"supported in ROCm.") f"supported in ROCm.")
if self.quantization != "marlin": if self.quantization != "marlin":
logger.warning( logger.warning(
f"{self.quantization} quantization is not fully " "%s quantization is not fully "
"optimized yet. The speed can be slower than " "optimized yet. The speed can be slower than "
"non-quantized models.") "non-quantized models.", self.quantization)
def _verify_cuda_graph(self) -> None: def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None: if self.max_context_len_to_capture is None:
@ -360,7 +360,7 @@ class CacheConfig:
if cpu_memory_usage > 0.7 * total_cpu_memory: if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg) raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory: elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warning("Possibly too large swap space. " + msg) logger.warning("Possibly too large swap space. %s", msg)
@dataclass @dataclass
@ -898,8 +898,8 @@ class LoRAConfig:
"awq", "gptq" "awq", "gptq"
]: ]:
# TODO support marlin and squeezellm # TODO support marlin and squeezellm
logger.warning(f"{model_config.quantization} quantization is not " logger.warning("%s quantization is not tested with LoRA yet.",
"tested with LoRA yet.") model_config.quantization)
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528: if scheduler_config.max_num_batched_tokens > 65528:
@ -1008,7 +1008,7 @@ def _get_and_verify_dtype(
pass pass
else: else:
# Casting between float16 and bfloat16 is allowed with a warning. # Casting between float16 and bfloat16 is allowed with a warning.
logger.warning(f"Casting {config_dtype} to {torch_dtype}.") logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
return torch_dtype return torch_dtype
@ -1051,8 +1051,8 @@ def _get_and_verify_max_len(
logger.warning( logger.warning(
"The model's config.json does not contain any of the following " "The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: " "keys to determine the original maximum length of the model: "
f"{possible_keys}. Assuming the model's maximum length is " "%d. Assuming the model's maximum length is %d.", possible_keys,
f"{default_max_len}.") default_max_len)
derived_max_model_len = default_max_len derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None) rope_scaling = getattr(hf_config, "rope_scaling", None)

View File

@ -617,8 +617,9 @@ class Scheduler:
if num_new_tokens > self.prompt_limit: if num_new_tokens > self.prompt_limit:
logger.warning( logger.warning(
f"Input prompt ({num_new_tokens} tokens) is too long" "Input prompt (%d tokens) is too long"
f" and exceeds limit of {self.prompt_limit}") " and exceeds limit of %d", num_new_tokens,
self.prompt_limit)
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
@ -631,8 +632,9 @@ class Scheduler:
break break
elif can_allocate == AllocStatus.NEVER: elif can_allocate == AllocStatus.NEVER:
logger.warning( logger.warning(
f"Input prompt ({num_new_tokens} tokens) is too long" "Input prompt (%d tokens) is too long"
f" and exceeds the capacity of block_manager") " and exceeds the capacity of block_manager",
num_new_tokens)
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)

View File

@ -37,7 +37,7 @@ def init_custom_ar() -> None:
return return
if world_size not in _SUPPORTED_WORLD_SIZES: if world_size not in _SUPPORTED_WORLD_SIZES:
logger.warn( logger.warning(
"Custom allreduce is disabled due to an unsupported world size: " "Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify" "%d. Supported world sizes: %s. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.", world_size, " disable_custom_all_reduce=True explicitly.", world_size,
@ -47,7 +47,7 @@ def init_custom_ar() -> None:
# note: num dev can be larger than world_size if we're only using # note: num dev can be larger than world_size if we're only using
# first few GPUs # first few GPUs
if num_dev < world_size: if num_dev < world_size:
logger.warn( logger.warning(
"Cannot test GPU P2P because not all GPUs are visible to the " "Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.") " is set.")
@ -62,7 +62,7 @@ def init_custom_ar() -> None:
# this checks hardware and driver support for NVLink # this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(device_ids) full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink: if world_size > 2 and not full_nvlink:
logger.warn( logger.warning(
"Custom allreduce is disabled because it's not supported on more" "Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify" " than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.") " disable_custom_all_reduce=True explicitly.")
@ -71,7 +71,7 @@ def init_custom_ar() -> None:
# this is expensive to compute at the first time # this is expensive to compute at the first time
# then we cache the result # then we cache the result
if not _can_p2p(rank, world_size): if not _can_p2p(rank, world_size):
logger.warn( logger.warning(
"Custom allreduce is disabled because your platform lacks GPU P2P" "Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify" " capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.") " disable_custom_all_reduce=True explicitly.")

View File

@ -43,15 +43,16 @@ try:
nccl = ctypes.CDLL(so_file) nccl = ctypes.CDLL(so_file)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to load NCCL library from {so_file} ." "Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs." "It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted " "Otherwise, the nccl library might not exist, be corrupted "
f"or it does not support the current platform {platform.platform()}." "or it does not support the current platform %s."
f"One solution is to download libnccl2 version 2.18 from " "One solution is to download libnccl2 version 2.18 from "
f"https://developer.download.nvidia.com/compute/cuda/repos/ " "https://developer.download.nvidia.com/compute/cuda/repos/ "
f"and extract the libnccl.so.2 file. If you already have the " "and extract the libnccl.so.2 file. If you already have the "
f"library, please set the environment variable VLLM_NCCL_SO_PATH" "library, please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.") " to point to the correct nccl library path.", so_file,
platform.platform())
raise e raise e
# === export types and functions from nccl to Python === # === export types and functions from nccl to Python ===

View File

@ -14,7 +14,7 @@ try:
except Exception as e: except Exception as e:
# in non-NVIDIA environments, we can't import the nccl module # in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs # e.g. when running on machines with AMD GPUs
logger.info(f"Failed to import NCCL library: {e}") logger.info("Failed to import NCCL library: %s", e)
logger.info("It is expected if you are not running on NVIDIA GPUs.") logger.info("It is expected if you are not running on NVIDIA GPUs.")
pass pass
@ -40,7 +40,7 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
def init_process_group(group: Optional[ProcessGroup] = None) -> None: def init_process_group(group: Optional[ProcessGroup] = None) -> None:
assert not is_initialized() assert not is_initialized()
global comm global comm
logger.info(f"vLLM is using nccl=={ncclGetVersion()}") logger.info("vLLM is using nccl==%s", ncclGetVersion())
comm = NCCLCommunicator(group=group) comm = NCCLCommunicator(group=group)

View File

@ -57,8 +57,10 @@ def init_distributed_environment(
local_rank: int = -1, local_rank: int = -1,
backend: str = "nccl", backend: str = "nccl",
): ):
logger.debug(f"{world_size=} {rank=} {local_rank=} " logger.debug(
f"{distributed_init_method=} {backend=}") "world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
distributed_init_method, backend)
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
assert distributed_init_method is not None, ( assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing " "distributed_init_method must be provided when initializing "

View File

@ -112,7 +112,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
and (not os.path.exists(path)): and (not os.path.exists(path)):
# only the local master process (with local_rank == 0) can # only the local master process (with local_rank == 0) can
# enter this block to calculate the cache # enter this block to calculate the cache
logger.info(f"generating GPU P2P access cache for in {path}") logger.info("generating GPU P2P access cache for in %s", path)
cache = {} cache = {}
for _i in range(num_dev): for _i in range(num_dev):
for _j in range(num_dev): for _j in range(num_dev):
@ -126,7 +126,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
if is_distributed: if is_distributed:
cpu_world_group = get_cpu_world_group() cpu_world_group = get_cpu_world_group()
dist.barrier(cpu_world_group) dist.barrier(cpu_world_group)
logger.info(f"reading GPU P2P access cache from {path}") logger.info("reading GPU P2P access cache from %s", path)
with open(path, "r") as f: with open(path, "r") as f:
cache = json.load(f) cache = json.load(f)
_gpu_p2p_access_cache = cache _gpu_p2p_access_cache = cache

View File

@ -117,7 +117,7 @@ class RequestTracker:
self._request_streams[request_id].put(request_output) self._request_streams[request_id].put(request_output)
if request_output.finished: if request_output.finished:
if verbose: if verbose:
logger.info(f"Finished request {request_id}.") logger.info("Finished request %s.", request_id)
self.abort_request(request_id) self.abort_request(request_id)
def process_exception(self, def process_exception(self,
@ -128,7 +128,7 @@ class RequestTracker:
"""Propagate an exception from the engine.""" """Propagate an exception from the engine."""
self._request_streams[request_id].put(exception) self._request_streams[request_id].put(exception)
if verbose: if verbose:
logger.info(f"Finished request {request_id}.") logger.info("Finished request %s.", request_id)
self.abort_request(request_id) self.abort_request(request_id)
def add_request(self, request_id: str, def add_request(self, request_id: str,
@ -151,7 +151,7 @@ class RequestTracker:
def abort_request(self, request_id: str, *, verbose: bool = False) -> None: def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration.""" """Abort a request during next background loop iteration."""
if verbose: if verbose:
logger.info(f"Aborted request {request_id}.") logger.info("Aborted request %s.", request_id)
self._finished_requests.put_nowait(request_id) self._finished_requests.put_nowait(request_id)
@ -521,11 +521,11 @@ class AsyncLLMEngine:
if shortened_token_ids is not None: if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self. shortened_token_ids = shortened_token_ids[:self.
max_log_len] max_log_len]
logger.info(f"Received request {request_id}: " logger.info(
f"prompt: {shortened_prompt!r}, " "Received request %s: prompt: %r, "
f"sampling_params: {sampling_params}, " "sampling_params: %s, prompt_token_ids: %s, "
f"prompt_token_ids: {shortened_token_ids}, " "lora_request: %s.", request_id, shortened_prompt,
f"lora_request: {lora_request}.") sampling_params, shortened_token_ids, lora_request)
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
@ -717,4 +717,4 @@ class AsyncLLMEngine:
raise RuntimeError("Engine is dead.") from e raise RuntimeError("Engine is dead.") from e
else: else:
await self.engine.check_health_async() await self.engine.check_health_async()
logger.debug(f"Health check took {time.perf_counter()-t}s") logger.debug("Health check took %fs", time.perf_counter() - t)

View File

@ -96,29 +96,38 @@ class LLMEngine:
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> None: ) -> None:
logger.info( logger.info(
f"Initializing an LLM engine (v{vllm.__version__}) with config: " "Initializing an LLM engine (v%s) with config: "
f"model={model_config.model!r}, " "model=%r, speculative_config=%r, tokenizer=%r, "
f"speculative_config={speculative_config!r}, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
f"tokenizer={model_config.tokenizer!r}, " "tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
f"skip_tokenizer_init={model_config.skip_tokenizer_init}, " "max_seq_len=%d, download_dir=%r, load_format=%s, "
f"tokenizer_mode={model_config.tokenizer_mode}, " "tensor_parallel_size=%d, disable_custom_all_reduce=%s"
f"revision={model_config.revision}, " "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
f"tokenizer_revision={model_config.tokenizer_revision}, " "quantization_param_path=%s, device_config=%s, "
f"trust_remote_code={model_config.trust_remote_code}, " "decoding_config=%r, seed=%d)",
f"dtype={model_config.dtype}, " vllm.__version__,
f"max_seq_len={model_config.max_model_len}, " model_config.model,
f"download_dir={load_config.download_dir!r}, " speculative_config,
f"load_format={load_config.load_format}, " model_config.tokenizer,
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " model_config.skip_tokenizer_init,
f"disable_custom_all_reduce=" model_config.tokenizer_mode,
f"{parallel_config.disable_custom_all_reduce}, " model_config.revision,
f"quantization={model_config.quantization}, " model_config.tokenizer_revision,
f"enforce_eager={model_config.enforce_eager}, " model_config.trust_remote_code,
f"kv_cache_dtype={cache_config.cache_dtype}, " model_config.dtype,
f"quantization_param_path={model_config.quantization_param_path}, " model_config.max_model_len,
f"device_config={device_config.device}, " load_config.download_dir,
f"decoding_config={decoding_config!r}, " load_config.load_format,
f"seed={model_config.seed})") parallel_config.tensor_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
device_config.device,
decoding_config,
model_config.seed,
)
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config self.model_config = model_config
@ -237,8 +246,10 @@ class LLMEngine:
if self.cache_config.num_gpu_blocks_override is not None: if self.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
logger.info(f"Overriding {num_gpu_blocks=} with " logger.info(
f"{num_gpu_blocks_override=}") "Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_gpu_blocks,
num_gpu_blocks_override)
num_gpu_blocks = num_gpu_blocks_override num_gpu_blocks = num_gpu_blocks_override
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks

View File

@ -227,14 +227,19 @@ class StatLogger:
# Log to stdout. # Log to stdout.
logger.info( logger.info(
f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, " "Avg prompt throughput: %.1f tokens/s, "
f"Avg generation throughput: " "Avg generation throughput: %.1f tokens/s, "
f"{generation_throughput:.1f} tokens/s, " "Running: %d reqs, Swapped: %d reqs, "
f"Running: {stats.num_running} reqs, " "Pending: %d reqs, GPU KV cache usage: %.1f%, "
f"Swapped: {stats.num_swapped} reqs, " "CPU KV cache usage: %.1f%",
f"Pending: {stats.num_waiting} reqs, " prompt_throughput,
f"GPU KV cache usage: {stats.gpu_cache_usage * 100:.1f}%, " generation_throughput,
f"CPU KV cache usage: {stats.cpu_cache_usage * 100:.1f}%") stats.num_running,
stats.num_swapped,
stats.num_waiting,
stats.gpu_cache_usage * 100,
stats.cpu_cache_usage * 100,
)
# Reset tracked stats for next interval. # Reset tracked stats for next interval.
self.num_prompt_tokens = [] self.num_prompt_tokens = []

View File

@ -148,8 +148,8 @@ if __name__ == "__main__":
raise ValueError(f"Invalid middleware {middleware}. " raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.") f"Must be a function or a class.")
logger.info(f"vLLM API server version {vllm.__version__}") logger.info("vLLM API server version %s", vllm.__version__)
logger.info(f"args: {args}") logger.info("args: %s", args)
if args.served_model_name is not None: if args.served_model_name is not None:
served_model_names = args.served_model_name served_model_names = args.served_model_name

View File

@ -57,8 +57,7 @@ class OpenAIServingChat(OpenAIServing):
tokenize=False, tokenize=False,
add_generation_prompt=request.add_generation_prompt) add_generation_prompt=request.add_generation_prompt)
except Exception as e: except Exception as e:
logger.error( logger.error("Error in applying chat template from request: %s", e)
f"Error in applying chat template from request: {str(e)}")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
@ -338,11 +337,11 @@ class OpenAIServingChat(OpenAIServing):
tokenizer.chat_template = codecs.decode( tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape") chat_template, "unicode_escape")
logger.info( logger.info("Using supplied chat template:\n%s",
f"Using supplied chat template:\n{tokenizer.chat_template}") tokenizer.chat_template)
elif tokenizer.chat_template is not None: elif tokenizer.chat_template is not None:
logger.info( logger.info("Using default chat template:\n%s",
f"Using default chat template:\n{tokenizer.chat_template}") tokenizer.chat_template)
else: else:
logger.warning( logger.warning(
"No chat template provided. Chat API will not work.") "No chat template provided. Chat API will not work.")

View File

@ -69,7 +69,7 @@ class CPUExecutor(ExecutorBase):
# NOTE: `cpu block` for CPU backend is located on CPU memory but is # NOTE: `cpu block` for CPU backend is located on CPU memory but is
# referred as `gpu block`. Because we want to reuse the existing block # referred as `gpu block`. Because we want to reuse the existing block
# management procedure. # management procedure.
logger.info(f"# CPU blocks: {num_gpu_blocks}") logger.info("# CPU blocks: %d", num_gpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self, def execute_model(self,

View File

@ -116,8 +116,8 @@ class GPUExecutor(ExecutorBase):
# NOTE: This is logged in the executor because there can be >1 worker # NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work # with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations. # remains to abstract away the device for non-GPU configurations.
logger.info(f"# GPU blocks: {num_gpu_blocks}, " logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
f"# CPU blocks: {num_cpu_blocks}") num_cpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

View File

@ -214,8 +214,8 @@ class RayGPUExecutor(ExecutorBase):
# NOTE: We log here to avoid multiple logs when number of workers is # NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors # greater than one. We could log in the engine, but not all executors
# have GPUs. # have GPUs.
logger.info(f"# GPU blocks: {num_gpu_blocks}, " logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
f"# CPU blocks: {num_cpu_blocks}") num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks

View File

@ -43,9 +43,9 @@ try:
return output return output
except ImportError as e: except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. " logger.warning(
"For distributed inference, please install Ray with " "Failed to import Ray with %r. For distributed inference, "
"`pip install ray`.") "please install Ray with `pip install ray`.", e)
ray = None # type: ignore ray = None # type: ignore
RayWorkerWrapper = None # type: ignore RayWorkerWrapper = None # type: ignore

View File

@ -126,7 +126,7 @@ def enable_trace_function_call(log_file_path: str,
"VLLM_TRACE_FUNCTION is enabled. It will record every" "VLLM_TRACE_FUNCTION is enabled. It will record every"
" function executed by Python. This will slow down the code. It " " function executed by Python. This will slow down the code. It "
"is suggested to be used for debugging hang or crashes only.") "is suggested to be used for debugging hang or crashes only.")
logger.info(f"Trace frame log is saved to {log_file_path}") logger.info("Trace frame log is saved to %s", log_file_path)
if root_dir is None: if root_dir is None:
# by default, this is the vllm root directory # by default, this is the vllm root directory
root_dir = os.path.dirname(os.path.dirname(__file__)) root_dir = os.path.dirname(os.path.dirname(__file__))

View File

@ -345,8 +345,8 @@ class LoRAModelManager:
index, _ = first_free_slot index, _ = first_free_slot
self._active_loras[lora_id] = None self._active_loras[lora_id] = None
lora_model = self._registered_loras[lora_id] lora_model = self._registered_loras[lora_id]
logger.debug( logger.debug("Activating LoRA. int id: %d, slot index: %d",
f"Activating LoRA. int id: {lora_model.id}, slot index: {index}") lora_model.id, index)
self.lora_index_to_id[index] = lora_model.id self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items(): for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name) module_lora = lora_model.get_lora(module_name)
@ -567,7 +567,7 @@ class LoRALRUCache(LRUCache[LoRAModel]):
self.deactivate_lora_fn = deactivate_lora_fn self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: int, value: LoRAModel): def _on_remove(self, key: int, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}") logger.debug("Removing LoRA. int id: %d", key)
self.deactivate_lora_fn(key) self.deactivate_lora_fn(key)
return super()._on_remove(key, value) return super()._on_remove(key, value)

View File

@ -296,8 +296,8 @@ def get_moe_configs(E: int, N: int,
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path) as f: with open(config_file_path) as f:
logger.info( logger.info("Using configuration from %s for MoE layer.",
f"Using configuration from {config_file_path} for MoE layer.") config_file_path)
# If a configuration has been found, return it # If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()} return {int(key): val for key, val in json.load(f).items()}

View File

@ -334,10 +334,10 @@ class TensorizerAgent:
per_second = convert_bytes(deserializer.total_tensor_bytes / duration) per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage() after_mem = get_mem_usage()
deserializer.close() deserializer.close()
logger.info(f"Deserialized {total_bytes_str} in " logger.info("Deserialized %s in %0.2fs, %f/s", total_bytes_str,
f"{end - start:0.2f}s, {per_second}/s") end - start, per_second)
logger.info(f"Memory usage before: {before_mem}") logger.info("Memory usage before: %s", before_mem)
logger.info(f"Memory usage after: {after_mem}") logger.info("Memory usage after: %s", after_mem)
self._check_tensors_on_meta_device() self._check_tensors_on_meta_device()
self._resize_lora_embeddings() self._resize_lora_embeddings()

View File

@ -190,7 +190,7 @@ def download_weights_from_hf(model_name_or_path: str,
allow_patterns = [pattern] allow_patterns = [pattern]
break break
logger.info(f"Using model weights format {allow_patterns}") logger.info("Using model weights format %s", allow_patterns)
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time. # downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir): with get_lock(model_name_or_path, cache_dir):
@ -310,17 +310,17 @@ def kv_cache_scales_loader(
return layer_scales_map.items() return layer_scales_map.items()
except FileNotFoundError: except FileNotFoundError:
logger.error(f"File or directory '{filename}' not found.") logger.error("File or directory '%s' not found.", filename)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error(f"Error decoding JSON in file '{filename}'.") logger.error("Error decoding JSON in file '%s'.", filename)
except Exception as e: except Exception as e:
logger.error(f"An error occurred while reading '{filename}': {e}") logger.error("An error occurred while reading '%s': %s", filename, e)
# This section is reached if and only if any of the excepts are hit # This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded # Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales # which ultimately defaults to 1.0 scales
logger.warning("Defaulting to KV cache scaling factors = 1.0 " logger.warning(
f"for all layers in TP rank {tp_rank} " "Defaulting to KV cache scaling factors = 1.0 for all "
"as an error occurred during loading.") "layers in TP rank %d as an error occurred during loading.", tp_rank)
return [] return []

View File

@ -91,8 +91,8 @@ class ModelRegistry:
"ROCm for now.") "ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning( logger.warning(
f"Model architecture {model_arch} is partially supported " "Model architecture %s is partially supported by ROCm: %s",
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
module_name, model_cls_name = _MODELS[model_arch] module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module( module = importlib.import_module(
@ -107,9 +107,9 @@ class ModelRegistry:
def register_model(model_arch: str, model_cls: Type[nn.Module]): def register_model(model_arch: str, model_cls: Type[nn.Module]):
if model_arch in _MODELS: if model_arch in _MODELS:
logger.warning( logger.warning(
f"Model architecture {model_arch} is already registered, " "Model architecture %s is already registered, and will be "
"and will be overwritten by the new model " "overwritten by the new model class %s.", model_arch,
f"class {model_cls.__name__}.") model_cls.__name__)
global _OOT_MODELS global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls _OOT_MODELS[model_arch] = model_cls

View File

@ -55,10 +55,10 @@ def _get_gemma_act_fn(
"in the config JSON file when it was initially released. " "in the config JSON file when it was initially released. "
"Changing the activation function to approximate GeLU " "Changing the activation function to approximate GeLU "
"(`gelu_pytorch_tanh`). If you want to use the legacy " "(`gelu_pytorch_tanh`). If you want to use the legacy "
f"`{hidden_act}`, edit the config JSON to set " "`%s`, edit the config JSON to set "
f"`hidden_activation={hidden_act}` instead of `hidden_act`. " "`hidden_activation=%s` instead of `hidden_act`. "
"See https://github.com/huggingface/transformers/pull/29402 " "See https://github.com/huggingface/transformers/pull/29402 "
"for more details.") "for more details.", hidden_act, hidden_act)
return GeluAndMul(approximate="tanh") return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu_pytorch_tanh": elif hidden_activation == "gelu_pytorch_tanh":
return GeluAndMul(approximate="tanh") return GeluAndMul(approximate="tanh")

View File

@ -183,7 +183,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"speculative decoding " "speculative decoding "
"requires non-None seq_group_metadata_list") "requires non-None seq_group_metadata_list")
logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}") logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
num_lookahead_slots)
# If no spec tokens, call the proposer and scorer workers normally. # If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill. # Used for prefill.

View File

@ -72,9 +72,10 @@ class DbrxAttentionConfig(PretrainedConfig):
and config_dict["model_type"] != cls.model_type and config_dict["model_type"] != cls.model_type
): ):
logger.warning( logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " "You are using a model of type %s to instantiate a model of "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." "type %s. This is not supported for all configurations of "
) "models and can yield errors.",
config_dict["model_type"], cls.model_type)
return cls.from_dict(config_dict, **kwargs) return cls.from_dict(config_dict, **kwargs)
@ -151,9 +152,9 @@ class DbrxFFNConfig(PretrainedConfig):
and config_dict["model_type"] != cls.model_type and config_dict["model_type"] != cls.model_type
): ):
logger.warning( logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " "You are using a model of type %s to instantiate a model of "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." "type %s. This is not supported for all "
) "configurations of models and can yield errors.", config_dict["model_type"], cls.model_type)
return cls.from_dict(config_dict, **kwargs) return cls.from_dict(config_dict, **kwargs)

View File

@ -138,9 +138,8 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
# No tokenizer was found in the LoRA folder, # No tokenizer was found in the LoRA folder,
# use base model tokenizer # use base model tokenizer
logger.warning( logger.warning(
f"No tokenizer found in {lora_request.lora_local_path}, " "No tokenizer found in %s, using base model tokenizer instead. "
"using base model tokenizer instead. " "(Exception: %s)", lora_request.lora_local_path, e)
f"(Exception: {str(e)})")
tokenizer = None tokenizer = None
return tokenizer return tokenizer

View File

@ -289,8 +289,9 @@ def get_open_port() -> int:
def update_environment_variables(envs: Dict[str, str]): def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items(): for k, v in envs.items():
if k in os.environ and os.environ[k] != v: if k in os.environ and os.environ[k] != v:
logger.warning(f"Overwriting environment variable {k} " logger.warning(
f"from '{os.environ[k]}' to '{v}'") "Overwriting environment variable %s "
"from '%s' to '%s'", k, os.environ[k], v)
os.environ[k] = v os.environ[k] = v
@ -310,11 +311,12 @@ def get_nvcc_cuda_version() -> Optional[Version]:
if not cuda_home: if not cuda_home:
cuda_home = '/usr/local/cuda' cuda_home = '/usr/local/cuda'
if os.path.isfile(cuda_home + '/bin/nvcc'): if os.path.isfile(cuda_home + '/bin/nvcc'):
logger.info(f'CUDA_HOME is not found in the environment. ' logger.info(
f'Using {cuda_home} as CUDA_HOME.') 'CUDA_HOME is not found in the environment. '
'Using %s as CUDA_HOME.', cuda_home)
else: else:
logger.warning( logger.warning('Not found nvcc in %s. Skip cuda version check!',
f'Not found nvcc in {cuda_home}. Skip cuda version check!') cuda_home)
return None return None
nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
universal_newlines=True) universal_newlines=True)
@ -599,8 +601,8 @@ def find_nccl_library():
# manually load the nccl library # manually load the nccl library
if so_file: if so_file:
logger.info( logger.info(
f"Found nccl from environment variable VLLM_NCCL_SO_PATH={so_file}" "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s",
) so_file)
else: else:
if torch.version.cuda is not None: if torch.version.cuda is not None:
so_file = vllm_nccl_path or find_library("libnccl.so.2") so_file = vllm_nccl_path or find_library("libnccl.so.2")
@ -608,7 +610,7 @@ def find_nccl_library():
so_file = find_library("librccl.so.1") so_file = find_library("librccl.so.1")
else: else:
raise ValueError("NCCL only supports CUDA and ROCm backends.") raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Found nccl from library {so_file}") logger.info("Found nccl from library %s", so_file)
return so_file return so_file

View File

@ -170,8 +170,8 @@ class ModelRunner:
) )
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info(f"Loading model weights took " logger.info("Loading model weights took %.4f GB",
f"{self.model_memory_usage / float(2**30):.4f} GB") self.model_memory_usage / float(2**30))
if self.lora_config: if self.lora_config:
assert hasattr(self.model, "supported_lora_modules" assert hasattr(self.model, "supported_lora_modules"
@ -196,18 +196,19 @@ class ModelRunner:
self.model.load_kv_cache_scales( self.model.load_kv_cache_scales(
self.model_config.quantization_param_path) self.model_config.quantization_param_path)
else: else:
raise RuntimeError("Using FP8 KV cache and scaling " raise RuntimeError(
"factors provided but model " "Using FP8 KV cache and scaling factors provided but "
f"{self.model.__class__} does not " "model %s does not support loading scaling factors.",
"support loading scaling factors.") self.model.__class__)
else: else:
logger.warn("Using FP8 KV cache but no scaling factors " logger.warning(
"provided. Defaulting to scaling factors of 1.0. " "Using FP8 KV cache but no scaling factors "
"This may lead to less accurate results!") "provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
elif self.model_config.quantization_param_path is not None: elif self.model_config.quantization_param_path is not None:
logger.warn("KV cache scaling factors provided, " logger.warning("KV cache scaling factors provided, "
"but the KV cache data type is not FP8. " "but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.") "KV cache scaling factors will not be used.")
def set_block_size(self, block_size: int) -> None: def set_block_size(self, block_size: int) -> None:
self.block_size = block_size self.block_size = block_size
@ -1054,7 +1055,7 @@ class ModelRunner:
end_time = time.perf_counter() end_time = time.perf_counter()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
# This usually takes < 10 seconds. # This usually takes < 10 seconds.
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
def __del__(self) -> None: def __del__(self) -> None:
# Delete the CUDA graphs before deleting the pynccl communicator. # Delete the CUDA graphs before deleting the pynccl communicator.