Add more Prometheus metrics (#2764)
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
parent
9c7306ac11
commit
bf480c5302
@ -873,6 +873,289 @@
|
||||
],
|
||||
"title": "Cache Utilization",
|
||||
"type": "timeseries"
|
||||
},
|
||||
{
|
||||
"type": "heatmap",
|
||||
"title": "Request Prompt Length",
|
||||
"description": "Heatmap of request prompt length",
|
||||
"gridPos": {
|
||||
"x": 0,
|
||||
"y": 24,
|
||||
"w": 12,
|
||||
"h": 8
|
||||
},
|
||||
"datasource": {
|
||||
"uid": "prometheus",
|
||||
"type": "prometheus"
|
||||
},
|
||||
"id": 12,
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"refId": "A",
|
||||
"expr": "sum by(le) (increase(vllm:request_prompt_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))",
|
||||
"range": true,
|
||||
"instant": false,
|
||||
"editorMode": "builder",
|
||||
"legendFormat": "{{le}}",
|
||||
"useBackend": false,
|
||||
"disableTextWrap": false,
|
||||
"fullMetaSearch": false,
|
||||
"includeNullMetadata": true,
|
||||
"format": "heatmap"
|
||||
}
|
||||
],
|
||||
"options": {
|
||||
"calculate": false,
|
||||
"yAxis": {
|
||||
"axisPlacement": "left",
|
||||
"reverse": false,
|
||||
"unit": "none",
|
||||
"axisLabel": "Prompt Length"
|
||||
},
|
||||
"rowsFrame": {
|
||||
"layout": "auto",
|
||||
"value": "Request count"
|
||||
},
|
||||
"color": {
|
||||
"mode": "scheme",
|
||||
"fill": "dark-orange",
|
||||
"scale": "exponential",
|
||||
"exponent": 0.5,
|
||||
"scheme": "Spectral",
|
||||
"steps": 64,
|
||||
"reverse": false,
|
||||
"min": 0
|
||||
},
|
||||
"cellGap": 1,
|
||||
"filterValues": {
|
||||
"le": 1e-9
|
||||
},
|
||||
"tooltip": {
|
||||
"show": true,
|
||||
"yHistogram": true
|
||||
},
|
||||
"legend": {
|
||||
"show": true
|
||||
},
|
||||
"exemplars": {
|
||||
"color": "rgba(255,0,255,0.7)"
|
||||
},
|
||||
"cellValues": {
|
||||
"unit": "none"
|
||||
}
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"custom": {
|
||||
"scaleDistribution": {
|
||||
"type": "linear"
|
||||
},
|
||||
"hideFrom": {
|
||||
"tooltip": false,
|
||||
"viz": false,
|
||||
"legend": false
|
||||
}
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"pluginVersion": "10.2.0"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"uid": "prometheus",
|
||||
"type": "prometheus"
|
||||
},
|
||||
"type": "heatmap",
|
||||
"title": "Request Generation Length",
|
||||
"description": "Heatmap of request generation length",
|
||||
"gridPos": {
|
||||
"x": 12,
|
||||
"y": 24,
|
||||
"w": 12,
|
||||
"h": 8
|
||||
},
|
||||
"id": 13,
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"refId": "A",
|
||||
"expr": "sum by(le) (increase(vllm:request_generation_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))",
|
||||
"range": true,
|
||||
"instant": false,
|
||||
"editorMode": "builder",
|
||||
"legendFormat": "{{le}}",
|
||||
"useBackend": false,
|
||||
"disableTextWrap": false,
|
||||
"fullMetaSearch": false,
|
||||
"includeNullMetadata": true,
|
||||
"format": "heatmap"
|
||||
}
|
||||
],
|
||||
"options": {
|
||||
"calculate": false,
|
||||
"yAxis": {
|
||||
"axisPlacement": "left",
|
||||
"reverse": false,
|
||||
"unit": "none",
|
||||
"axisLabel": "Generation Length"
|
||||
},
|
||||
"rowsFrame": {
|
||||
"layout": "auto",
|
||||
"value": "Request count"
|
||||
},
|
||||
"color": {
|
||||
"mode": "scheme",
|
||||
"fill": "dark-orange",
|
||||
"scale": "exponential",
|
||||
"exponent": 0.5,
|
||||
"scheme": "Spectral",
|
||||
"steps": 64,
|
||||
"reverse": false,
|
||||
"min": 0
|
||||
},
|
||||
"cellGap": 1,
|
||||
"filterValues": {
|
||||
"le": 1e-9
|
||||
},
|
||||
"tooltip": {
|
||||
"show": true,
|
||||
"yHistogram": true
|
||||
},
|
||||
"legend": {
|
||||
"show": true
|
||||
},
|
||||
"exemplars": {
|
||||
"color": "rgba(255,0,255,0.7)"
|
||||
},
|
||||
"cellValues": {
|
||||
"unit": "none"
|
||||
}
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"custom": {
|
||||
"scaleDistribution": {
|
||||
"type": "linear"
|
||||
},
|
||||
"hideFrom": {
|
||||
"tooltip": false,
|
||||
"viz": false,
|
||||
"legend": false
|
||||
}
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"pluginVersion": "10.2.0"
|
||||
},
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"fieldConfig": {
|
||||
"defaults": {
|
||||
"custom": {
|
||||
"drawStyle": "line",
|
||||
"lineInterpolation": "linear",
|
||||
"barAlignment": 0,
|
||||
"lineWidth": 1,
|
||||
"fillOpacity": 0,
|
||||
"gradientMode": "none",
|
||||
"spanNulls": false,
|
||||
"insertNulls": false,
|
||||
"showPoints": "auto",
|
||||
"pointSize": 5,
|
||||
"stacking": {
|
||||
"mode": "none",
|
||||
"group": "A"
|
||||
},
|
||||
"axisPlacement": "auto",
|
||||
"axisLabel": "",
|
||||
"axisColorMode": "text",
|
||||
"axisBorderShow": false,
|
||||
"scaleDistribution": {
|
||||
"type": "linear"
|
||||
},
|
||||
"axisCenteredZero": false,
|
||||
"hideFrom": {
|
||||
"tooltip": false,
|
||||
"viz": false,
|
||||
"legend": false
|
||||
},
|
||||
"thresholdsStyle": {
|
||||
"mode": "off"
|
||||
}
|
||||
},
|
||||
"color": {
|
||||
"mode": "palette-classic"
|
||||
},
|
||||
"mappings": [],
|
||||
"thresholds": {
|
||||
"mode": "absolute",
|
||||
"steps": [
|
||||
{
|
||||
"color": "green",
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"color": "red",
|
||||
"value": 80
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"overrides": []
|
||||
},
|
||||
"gridPos": {
|
||||
"h": 8,
|
||||
"w": 12,
|
||||
"x": 0,
|
||||
"y": 32
|
||||
},
|
||||
"id": 11,
|
||||
"options": {
|
||||
"tooltip": {
|
||||
"mode": "single",
|
||||
"sort": "none"
|
||||
},
|
||||
"legend": {
|
||||
"showLegend": true,
|
||||
"displayMode": "list",
|
||||
"placement": "bottom",
|
||||
"calcs": []
|
||||
}
|
||||
},
|
||||
"targets": [
|
||||
{
|
||||
"datasource": {
|
||||
"type": "prometheus",
|
||||
"uid": "prometheus"
|
||||
},
|
||||
"disableTextWrap": false,
|
||||
"editorMode": "builder",
|
||||
"expr": "sum by(finished_reason) (increase(vllm:request_success_total{model_name=\"$model_name\"}[$__rate_interval]))",
|
||||
"fullMetaSearch": false,
|
||||
"includeNullMetadata": true,
|
||||
"instant": false,
|
||||
"interval": "",
|
||||
"legendFormat": "__auto",
|
||||
"range": true,
|
||||
"refId": "A",
|
||||
"useBackend": false
|
||||
}
|
||||
],
|
||||
"title": "Finish Reason",
|
||||
"description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.",
|
||||
"type": "timeseries"
|
||||
}
|
||||
],
|
||||
"refresh": "",
|
||||
|
@ -12,6 +12,7 @@ openai
|
||||
uvicorn[standard]
|
||||
pydantic >= 2.0 # Required for OpenAI server.
|
||||
prometheus_client >= 0.18.0
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken == 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer == 0.9.8
|
||||
outlines == 0.0.34 # Requires torch >= 2.1.0
|
||||
|
@ -320,7 +320,7 @@ class Scheduler:
|
||||
for seq_group in state_queue:
|
||||
if not request_ids:
|
||||
# Using 'break' here may add two extra iterations,
|
||||
# but is acceptable to reduce complexity .
|
||||
# but is acceptable to reduce complexity.
|
||||
break
|
||||
if seq_group.request_id in request_ids:
|
||||
# Appending aborted group into pending list.
|
||||
|
@ -22,7 +22,8 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
||||
SequenceGroup, SequenceGroupMetadata)
|
||||
SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
@ -217,7 +218,8 @@ class LLMEngine:
|
||||
if self.log_stats:
|
||||
self.stat_logger = StatLogger(
|
||||
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
||||
labels=dict(model_name=model_config.model))
|
||||
labels=dict(model_name=model_config.model),
|
||||
max_model_len=self.model_config.max_model_len)
|
||||
self.stat_logger.info("cache_config", self.cache_config)
|
||||
|
||||
# Create sequence output processor, e.g. for beam search or
|
||||
@ -619,59 +621,109 @@ class LLMEngine:
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
# KV Cache Usage in %.
|
||||
# System State
|
||||
# Scheduler State
|
||||
num_running_sys = len(self.scheduler.running)
|
||||
num_swapped_sys = len(self.scheduler.swapped)
|
||||
num_waiting_sys = len(self.scheduler.waiting)
|
||||
|
||||
# KV Cache Usage in %
|
||||
num_total_gpu = self.cache_config.num_gpu_blocks
|
||||
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
|
||||
gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
|
||||
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
|
||||
|
||||
num_total_cpu = self.cache_config.num_cpu_blocks
|
||||
cpu_cache_usage = 0.
|
||||
cpu_cache_usage_sys = 0.
|
||||
if num_total_cpu > 0:
|
||||
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
|
||||
)
|
||||
cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
|
||||
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
|
||||
|
||||
# Scheduler State
|
||||
num_running = len(self.scheduler.running)
|
||||
num_swapped = len(self.scheduler.swapped)
|
||||
num_waiting = len(self.scheduler.waiting)
|
||||
# Iteration stats
|
||||
num_prompt_tokens_iter = 0
|
||||
num_generation_tokens_iter = 0
|
||||
time_to_first_tokens_iter: List[float] = []
|
||||
time_per_output_tokens_iter: List[float] = []
|
||||
|
||||
# Iteration stats if we have scheduler output.
|
||||
num_prompt_tokens = 0
|
||||
num_generation_tokens = 0
|
||||
time_to_first_tokens = []
|
||||
time_per_output_tokens = []
|
||||
time_e2e_requests = []
|
||||
# Request stats
|
||||
# Latency
|
||||
time_e2e_requests: List[float] = []
|
||||
# Metadata
|
||||
num_prompt_tokens_requests: List[int] = []
|
||||
num_generation_tokens_requests: List[int] = []
|
||||
best_of_requests: List[int] = []
|
||||
n_requests: List[int] = []
|
||||
finished_reason_requests: List[str] = []
|
||||
|
||||
# NOTE: This loop assumes prefill seq_groups are before
|
||||
# decode seq_groups in scheduled_seq_groups.
|
||||
if scheduler_outputs is not None:
|
||||
prompt_run = scheduler_outputs.num_prefill_groups > 0
|
||||
num_generation_tokens_from_prefill_groups = 0.
|
||||
if scheduler_outputs.num_prefill_groups > 0 and len(
|
||||
scheduler_outputs.scheduled_seq_groups
|
||||
) != scheduler_outputs.num_prefill_groups:
|
||||
print("DETECTED CHUNKED")
|
||||
|
||||
# Number of Tokens.
|
||||
if prompt_run:
|
||||
num_prompt_tokens = sum(
|
||||
len(scheduled_seq_group.seq_group.prompt_token_ids)
|
||||
for scheduled_seq_group in
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
num_generation_tokens = sum(
|
||||
scheduled_seq_group.seq_group.num_seqs()
|
||||
for scheduled_seq_group in
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
else:
|
||||
num_generation_tokens = scheduler_outputs.num_batched_tokens
|
||||
|
||||
# Latency Timings.
|
||||
time_last_iters = []
|
||||
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
for idx, scheduled_seq_group in enumerate(
|
||||
scheduler_outputs.scheduled_seq_groups):
|
||||
group_was_prefill = idx < scheduler_outputs.num_prefill_groups
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
# Time since last token.
|
||||
# (n.b. updates seq_group.metrics.last_token_time)
|
||||
time_last_iters.append(seq_group.get_last_latency(now))
|
||||
# Time since arrival for all finished requests.
|
||||
|
||||
# NOTE: a seq_group that completed all of its prefill tokens
|
||||
# in the last iteration will have seq_group.is_prefill() = False
|
||||
# with group_was_prefill = True
|
||||
if group_was_prefill:
|
||||
# Number of prompt tokens.
|
||||
num_prompt_tokens_iter += (
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
|
||||
# If the seq_group just finished the prefill state
|
||||
# get TTFT.
|
||||
if not seq_group.is_prefill():
|
||||
latency = seq_group.get_last_latency(now)
|
||||
time_to_first_tokens_iter.append(latency)
|
||||
|
||||
# One generation token per finished prefill.
|
||||
num_generation_tokens_from_prefill_groups += (
|
||||
seq_group.num_seqs())
|
||||
else:
|
||||
# TPOTs.
|
||||
latency = seq_group.get_last_latency(now)
|
||||
time_per_output_tokens_iter.append(latency)
|
||||
|
||||
# Because of chunked prefill, we can have a single sequence
|
||||
# group that does multiple prompt_runs. To prevent logging
|
||||
# the same metadata more than once per request, we standardize
|
||||
# on logging request level information for finished requests,
|
||||
# which can only happen once.
|
||||
if seq_group.is_finished():
|
||||
# Latency timings
|
||||
time_e2e_requests.append(now -
|
||||
seq_group.metrics.arrival_time)
|
||||
|
||||
time_to_first_tokens = time_last_iters if prompt_run else []
|
||||
time_per_output_tokens = [] if prompt_run else time_last_iters
|
||||
# Metadata
|
||||
num_prompt_tokens_requests.append(
|
||||
len(seq_group.prompt_token_ids))
|
||||
num_generation_tokens_requests.extend([
|
||||
seq.get_output_len()
|
||||
for seq in seq_group.get_finished_seqs()
|
||||
])
|
||||
best_of_requests.append(seq_group.sampling_params.best_of)
|
||||
n_requests.append(seq_group.sampling_params.n)
|
||||
finished_reason_requests.extend([
|
||||
SequenceStatus.get_finished_reason(seq.status)
|
||||
for seq in seq_group.get_finished_seqs()
|
||||
])
|
||||
|
||||
# Number of generation tokens.
|
||||
# num_batched_tokens equals the number of prompt_tokens plus the
|
||||
# number of decode_tokens in a single iteration. So,
|
||||
# num_generation_tokens = num_batched_tokens - num_prompt_tokens
|
||||
# + num_generation_tokens_from_prefill_groups (since we generate
|
||||
# one token on prefills on iters where the prefill finishes).
|
||||
num_generation_tokens_iter = (
|
||||
scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
|
||||
num_generation_tokens_from_prefill_groups)
|
||||
|
||||
# Spec decode, if enabled, emits specialized metrics from the worker in
|
||||
# sampler output.
|
||||
@ -683,17 +735,32 @@ class LLMEngine:
|
||||
|
||||
return Stats(
|
||||
now=now,
|
||||
num_running=num_running,
|
||||
num_swapped=num_swapped,
|
||||
num_waiting=num_waiting,
|
||||
gpu_cache_usage=gpu_cache_usage,
|
||||
cpu_cache_usage=cpu_cache_usage,
|
||||
num_prompt_tokens=num_prompt_tokens,
|
||||
num_generation_tokens=num_generation_tokens,
|
||||
time_to_first_tokens=time_to_first_tokens,
|
||||
time_per_output_tokens=time_per_output_tokens,
|
||||
time_e2e_requests=time_e2e_requests,
|
||||
|
||||
# System stats
|
||||
# Scheduler State
|
||||
num_running_sys=num_running_sys,
|
||||
num_swapped_sys=num_swapped_sys,
|
||||
num_waiting_sys=num_waiting_sys,
|
||||
# KV Cache Usage in %
|
||||
gpu_cache_usage_sys=gpu_cache_usage_sys,
|
||||
cpu_cache_usage_sys=cpu_cache_usage_sys,
|
||||
|
||||
# Iteration stats
|
||||
num_prompt_tokens_iter=num_prompt_tokens_iter,
|
||||
num_generation_tokens_iter=num_generation_tokens_iter,
|
||||
time_to_first_tokens_iter=time_to_first_tokens_iter,
|
||||
time_per_output_tokens_iter=time_per_output_tokens_iter,
|
||||
spec_decode_metrics=spec_decode_metrics,
|
||||
|
||||
# Request stats
|
||||
# Latency
|
||||
time_e2e_requests=time_e2e_requests,
|
||||
# Metadata
|
||||
num_prompt_tokens_requests=num_prompt_tokens_requests,
|
||||
num_generation_tokens_requests=num_generation_tokens_requests,
|
||||
best_of_requests=best_of_requests,
|
||||
n_requests=n_requests,
|
||||
finished_reason_requests=finished_reason_requests,
|
||||
)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
|
@ -1,6 +1,8 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Protocol
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Counter as CollectionsCounter
|
||||
from typing import Dict, List, Optional, Protocol, Union
|
||||
|
||||
import numpy as np
|
||||
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
|
||||
@ -21,8 +23,9 @@ disable_created_metrics()
|
||||
|
||||
# begin-metrics-definitions
|
||||
class Metrics:
|
||||
labelname_finish_reason = "finished_reason"
|
||||
|
||||
def __init__(self, labelnames: List[str]):
|
||||
def __init__(self, labelnames: List[str], max_model_len: int):
|
||||
# Unregister any existing vLLM collectors
|
||||
for collector in list(REGISTRY._collector_to_names):
|
||||
if hasattr(collector, "_name") and "vllm" in collector._name:
|
||||
@ -34,18 +37,20 @@ class Metrics:
|
||||
documentation='information of cache_config')
|
||||
|
||||
# System stats
|
||||
# Scheduler State
|
||||
self.gauge_scheduler_running = Gauge(
|
||||
name="vllm:num_requests_running",
|
||||
documentation="Number of requests currently running on GPU.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_scheduler_swapped = Gauge(
|
||||
name="vllm:num_requests_swapped",
|
||||
documentation="Number of requests swapped to CPU.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_scheduler_waiting = Gauge(
|
||||
name="vllm:num_requests_waiting",
|
||||
documentation="Number of requests waiting to be processed.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_scheduler_swapped = Gauge(
|
||||
name="vllm:num_requests_swapped",
|
||||
documentation="Number of requests swapped to CPU.",
|
||||
labelnames=labelnames)
|
||||
# KV Cache Usage in %
|
||||
self.gauge_gpu_cache_usage = Gauge(
|
||||
name="vllm:gpu_cache_usage_perc",
|
||||
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
||||
@ -55,7 +60,7 @@ class Metrics:
|
||||
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
|
||||
labelnames=labelnames)
|
||||
|
||||
# Raw stats from last model iteration
|
||||
# Iteration stats
|
||||
self.counter_prompt_tokens = Counter(
|
||||
name="vllm:prompt_tokens_total",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
@ -80,18 +85,51 @@ class Metrics:
|
||||
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
|
||||
1.0, 2.5
|
||||
])
|
||||
self.histogram_e2e_request_latency = Histogram(
|
||||
|
||||
# Request stats
|
||||
# Latency
|
||||
self.histogram_e2e_time_request = Histogram(
|
||||
name="vllm:e2e_request_latency_seconds",
|
||||
documentation="Histogram of end to end request latency in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
|
||||
# Metadata
|
||||
self.histogram_num_prompt_tokens_request = Histogram(
|
||||
name="vllm:request_prompt_tokens",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.histogram_num_generation_tokens_request = Histogram(
|
||||
name="vllm:request_generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.histogram_best_of_request = Histogram(
|
||||
name="vllm:request_params_best_of",
|
||||
documentation="Histogram of the best_of request parameter.",
|
||||
labelnames=labelnames,
|
||||
buckets=[1, 2, 5, 10, 20],
|
||||
)
|
||||
self.histogram_n_request = Histogram(
|
||||
name="vllm:request_params_n",
|
||||
documentation="Histogram of the n request parameter.",
|
||||
labelnames=labelnames,
|
||||
buckets=[1, 2, 5, 10, 20],
|
||||
)
|
||||
self.counter_request_success = Counter(
|
||||
name="vllm:request_success",
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=labelnames + [Metrics.labelname_finish_reason])
|
||||
|
||||
# Legacy metrics
|
||||
# Deprecated in favor of vllm:prompt_tokens_total
|
||||
self.gauge_avg_prompt_throughput = Gauge(
|
||||
name="vllm:avg_prompt_throughput_toks_per_s",
|
||||
documentation="Average prefill throughput in tokens/s.",
|
||||
labelnames=labelnames,
|
||||
)
|
||||
# Deprecated in favor of vllm:generation_tokens_total
|
||||
self.gauge_avg_generation_throughput = Gauge(
|
||||
name="vllm:avg_generation_throughput_toks_per_s",
|
||||
documentation="Average generation throughput in tokens/s.",
|
||||
@ -102,24 +140,57 @@ class Metrics:
|
||||
# end-metrics-definitions
|
||||
|
||||
|
||||
def build_1_2_5_buckets(max_value: int):
|
||||
"""
|
||||
Builds a list of buckets with increasing powers of 10 multiplied by
|
||||
mantissa values (1, 2, 5) until the value exceeds the specified maximum.
|
||||
|
||||
Example:
|
||||
>>> build_1_2_5_buckets(100)
|
||||
[1, 2, 5, 10, 20, 50, 100]
|
||||
"""
|
||||
mantissa_lst = [1, 2, 5]
|
||||
exponent = 0
|
||||
buckets = []
|
||||
while True:
|
||||
for m in mantissa_lst:
|
||||
value = m * 10**exponent
|
||||
if value <= max_value:
|
||||
buckets.append(value)
|
||||
else:
|
||||
return buckets
|
||||
exponent += 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Stats:
|
||||
"""Created by LLMEngine for use by StatLogger."""
|
||||
now: float
|
||||
|
||||
# System stats.
|
||||
num_running: int
|
||||
num_waiting: int
|
||||
num_swapped: int
|
||||
gpu_cache_usage: float
|
||||
cpu_cache_usage: float
|
||||
# System stats (should have _sys suffix)
|
||||
# Scheduler State
|
||||
num_running_sys: int
|
||||
num_waiting_sys: int
|
||||
num_swapped_sys: int
|
||||
# KV Cache Usage in %
|
||||
gpu_cache_usage_sys: float
|
||||
cpu_cache_usage_sys: float
|
||||
|
||||
# Raw stats from last model iteration.
|
||||
num_prompt_tokens: int
|
||||
num_generation_tokens: int
|
||||
time_to_first_tokens: List[float]
|
||||
time_per_output_tokens: List[float]
|
||||
# Iteration stats (should have _iter suffix)
|
||||
num_prompt_tokens_iter: int
|
||||
num_generation_tokens_iter: int
|
||||
time_to_first_tokens_iter: List[float]
|
||||
time_per_output_tokens_iter: List[float]
|
||||
|
||||
# Request stats (should have _requests suffix)
|
||||
# Latency
|
||||
time_e2e_requests: List[float]
|
||||
# Metadata
|
||||
num_prompt_tokens_requests: List[int]
|
||||
num_generation_tokens_requests: List[int]
|
||||
best_of_requests: List[int]
|
||||
n_requests: List[int]
|
||||
finished_reason_requests: List[str]
|
||||
|
||||
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
|
||||
@ -133,7 +204,8 @@ class SupportsMetricsInfo(Protocol):
|
||||
class StatLogger:
|
||||
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
|
||||
|
||||
def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
|
||||
def __init__(self, local_interval: float, labels: Dict[str, str],
|
||||
max_model_len: int) -> None:
|
||||
# Metadata for logging locally.
|
||||
self.last_local_log = time.time()
|
||||
self.local_interval = local_interval
|
||||
@ -144,7 +216,8 @@ class StatLogger:
|
||||
|
||||
# Prometheus metrics
|
||||
self.labels = labels
|
||||
self.metrics = Metrics(labelnames=list(labels.keys()))
|
||||
self.metrics = Metrics(labelnames=list(labels.keys()),
|
||||
max_model_len=max_model_len)
|
||||
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
if type == "cache_config":
|
||||
@ -158,34 +231,66 @@ class StatLogger:
|
||||
return elapsed_time > self.local_interval
|
||||
|
||||
def _log_prometheus(self, stats: Stats) -> None:
|
||||
# Set system stat gauges.
|
||||
self.metrics.gauge_scheduler_running.labels(**self.labels).set(
|
||||
stats.num_running)
|
||||
self.metrics.gauge_scheduler_swapped.labels(**self.labels).set(
|
||||
stats.num_swapped)
|
||||
self.metrics.gauge_scheduler_waiting.labels(**self.labels).set(
|
||||
stats.num_waiting)
|
||||
self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set(
|
||||
stats.gpu_cache_usage)
|
||||
self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set(
|
||||
stats.cpu_cache_usage)
|
||||
# System state data
|
||||
self._log_gauge(self.metrics.gauge_scheduler_running,
|
||||
stats.num_running_sys)
|
||||
self._log_gauge(self.metrics.gauge_scheduler_swapped,
|
||||
stats.num_swapped_sys)
|
||||
self._log_gauge(self.metrics.gauge_scheduler_waiting,
|
||||
stats.num_waiting_sys)
|
||||
self._log_gauge(self.metrics.gauge_gpu_cache_usage,
|
||||
stats.gpu_cache_usage_sys)
|
||||
self._log_gauge(self.metrics.gauge_cpu_cache_usage,
|
||||
stats.cpu_cache_usage_sys)
|
||||
|
||||
# Add to token counters.
|
||||
self.metrics.counter_prompt_tokens.labels(**self.labels).inc(
|
||||
stats.num_prompt_tokens)
|
||||
self.metrics.counter_generation_tokens.labels(**self.labels).inc(
|
||||
stats.num_generation_tokens)
|
||||
# Iteration level data
|
||||
self._log_counter(self.metrics.counter_prompt_tokens,
|
||||
stats.num_prompt_tokens_iter)
|
||||
self._log_counter(self.metrics.counter_generation_tokens,
|
||||
stats.num_generation_tokens_iter)
|
||||
self._log_histogram(self.metrics.histogram_time_to_first_token,
|
||||
stats.time_to_first_tokens_iter)
|
||||
self._log_histogram(self.metrics.histogram_time_per_output_token,
|
||||
stats.time_per_output_tokens_iter)
|
||||
|
||||
# Observe request level latencies in histograms.
|
||||
for ttft in stats.time_to_first_tokens:
|
||||
self.metrics.histogram_time_to_first_token.labels(
|
||||
**self.labels).observe(ttft)
|
||||
for tpot in stats.time_per_output_tokens:
|
||||
self.metrics.histogram_time_per_output_token.labels(
|
||||
**self.labels).observe(tpot)
|
||||
for e2e in stats.time_e2e_requests:
|
||||
self.metrics.histogram_e2e_request_latency.labels(
|
||||
**self.labels).observe(e2e)
|
||||
# Request level data
|
||||
# Latency
|
||||
self._log_histogram(self.metrics.histogram_e2e_time_request,
|
||||
stats.time_e2e_requests)
|
||||
# Metadata
|
||||
finished_reason_counter = CollectionsCounter(
|
||||
stats.finished_reason_requests)
|
||||
self._log_counter_labels(self.metrics.counter_request_success,
|
||||
finished_reason_counter,
|
||||
Metrics.labelname_finish_reason)
|
||||
self._log_histogram(self.metrics.histogram_num_prompt_tokens_request,
|
||||
stats.num_prompt_tokens_requests)
|
||||
self._log_histogram(
|
||||
self.metrics.histogram_num_generation_tokens_request,
|
||||
stats.num_generation_tokens_requests)
|
||||
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
|
||||
self._log_histogram(self.metrics.histogram_best_of_request,
|
||||
stats.best_of_requests)
|
||||
|
||||
def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to gauge.
|
||||
gauge.labels(**self.labels).set(data)
|
||||
|
||||
def _log_counter(self, counter: Counter, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to counter.
|
||||
counter.labels(**self.labels).inc(data)
|
||||
|
||||
def _log_counter_labels(self, counter: Counter, data: CollectionsCounter,
|
||||
label_key: str) -> None:
|
||||
# Convenience function for collection counter of labels.
|
||||
for label, count in data.items():
|
||||
counter.labels(**{**self.labels, label_key: label}).inc(count)
|
||||
|
||||
def _log_histogram(self, histogram: Histogram,
|
||||
data: Union[List[int], List[float]]) -> None:
|
||||
# Convenience function for logging list to histogram.
|
||||
for datum in data:
|
||||
histogram.labels(**self.labels).observe(datum)
|
||||
|
||||
def _log_prometheus_interval(self, prompt_throughput: float,
|
||||
generation_throughput: float) -> None:
|
||||
@ -210,8 +315,8 @@ class StatLogger:
|
||||
self._log_prometheus(stats)
|
||||
|
||||
# Save tracked stats for token counters.
|
||||
self.num_prompt_tokens.append(stats.num_prompt_tokens)
|
||||
self.num_generation_tokens.append(stats.num_generation_tokens)
|
||||
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
|
||||
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
|
||||
|
||||
# Log locally every local_interval seconds.
|
||||
if self._local_interval_elapsed(stats.now):
|
||||
@ -234,11 +339,11 @@ class StatLogger:
|
||||
"CPU KV cache usage: %.1f%%",
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
stats.num_running,
|
||||
stats.num_swapped,
|
||||
stats.num_waiting,
|
||||
stats.gpu_cache_usage * 100,
|
||||
stats.cpu_cache_usage * 100,
|
||||
stats.num_running_sys,
|
||||
stats.num_swapped_sys,
|
||||
stats.num_waiting_sys,
|
||||
stats.gpu_cache_usage_sys * 100,
|
||||
stats.cpu_cache_usage_sys * 100,
|
||||
)
|
||||
|
||||
# Reset tracked stats for next interval.
|
||||
|
@ -442,15 +442,27 @@ class SequenceGroup:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
def get_last_latency(self, now: float) -> float:
|
||||
"""Gets last token latency for Request level timings."""
|
||||
def get_last_latency(self, now: float) -> Optional[float]:
|
||||
"""Sets the last token time for Request level timings."""
|
||||
# If still in prefill phase, raise Error.
|
||||
if self.is_prefill():
|
||||
raise ValueError(
|
||||
"seq_group.get_last_latency() should not be called "
|
||||
"if the seq_group is in prefill phase.")
|
||||
|
||||
# Otherwise return token latency.
|
||||
latency = now - self.metrics.last_token_time
|
||||
self.metrics.last_token_time = now
|
||||
return latency
|
||||
|
||||
def maybe_set_first_token_time(self, time: float) -> None:
|
||||
"""Sets the first token time for Request level timings."""
|
||||
if self.metrics.first_token_time is None:
|
||||
# Note: in a case where a sequence_group is swapped and
|
||||
# recomputed, the time between iterations is counted
|
||||
# in TPOT, rather than recalculating TTFT (since from the )
|
||||
# POV of the user, there is simply a long generation delay.
|
||||
if (self.metrics.first_token_time is None
|
||||
and self.get_seqs()[0].get_output_len() == 1):
|
||||
self.metrics.first_token_time = time
|
||||
|
||||
def maybe_set_first_scheduled_time(self, time: float) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user