Add more Prometheus metrics (#2764)

Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Ronen Schaffer 2024-04-29 01:59:33 +03:00 committed by GitHub
parent 9c7306ac11
commit bf480c5302
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 576 additions and 108 deletions

View File

@ -873,6 +873,289 @@
],
"title": "Cache Utilization",
"type": "timeseries"
},
{
"type": "heatmap",
"title": "Request Prompt Length",
"description": "Heatmap of request prompt length",
"gridPos": {
"x": 0,
"y": 24,
"w": 12,
"h": 8
},
"datasource": {
"uid": "prometheus",
"type": "prometheus"
},
"id": 12,
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"refId": "A",
"expr": "sum by(le) (increase(vllm:request_prompt_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))",
"range": true,
"instant": false,
"editorMode": "builder",
"legendFormat": "{{le}}",
"useBackend": false,
"disableTextWrap": false,
"fullMetaSearch": false,
"includeNullMetadata": true,
"format": "heatmap"
}
],
"options": {
"calculate": false,
"yAxis": {
"axisPlacement": "left",
"reverse": false,
"unit": "none",
"axisLabel": "Prompt Length"
},
"rowsFrame": {
"layout": "auto",
"value": "Request count"
},
"color": {
"mode": "scheme",
"fill": "dark-orange",
"scale": "exponential",
"exponent": 0.5,
"scheme": "Spectral",
"steps": 64,
"reverse": false,
"min": 0
},
"cellGap": 1,
"filterValues": {
"le": 1e-9
},
"tooltip": {
"show": true,
"yHistogram": true
},
"legend": {
"show": true
},
"exemplars": {
"color": "rgba(255,0,255,0.7)"
},
"cellValues": {
"unit": "none"
}
},
"fieldConfig": {
"defaults": {
"custom": {
"scaleDistribution": {
"type": "linear"
},
"hideFrom": {
"tooltip": false,
"viz": false,
"legend": false
}
}
},
"overrides": []
},
"pluginVersion": "10.2.0"
},
{
"datasource": {
"uid": "prometheus",
"type": "prometheus"
},
"type": "heatmap",
"title": "Request Generation Length",
"description": "Heatmap of request generation length",
"gridPos": {
"x": 12,
"y": 24,
"w": 12,
"h": 8
},
"id": 13,
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"refId": "A",
"expr": "sum by(le) (increase(vllm:request_generation_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))",
"range": true,
"instant": false,
"editorMode": "builder",
"legendFormat": "{{le}}",
"useBackend": false,
"disableTextWrap": false,
"fullMetaSearch": false,
"includeNullMetadata": true,
"format": "heatmap"
}
],
"options": {
"calculate": false,
"yAxis": {
"axisPlacement": "left",
"reverse": false,
"unit": "none",
"axisLabel": "Generation Length"
},
"rowsFrame": {
"layout": "auto",
"value": "Request count"
},
"color": {
"mode": "scheme",
"fill": "dark-orange",
"scale": "exponential",
"exponent": 0.5,
"scheme": "Spectral",
"steps": 64,
"reverse": false,
"min": 0
},
"cellGap": 1,
"filterValues": {
"le": 1e-9
},
"tooltip": {
"show": true,
"yHistogram": true
},
"legend": {
"show": true
},
"exemplars": {
"color": "rgba(255,0,255,0.7)"
},
"cellValues": {
"unit": "none"
}
},
"fieldConfig": {
"defaults": {
"custom": {
"scaleDistribution": {
"type": "linear"
},
"hideFrom": {
"tooltip": false,
"viz": false,
"legend": false
}
}
},
"overrides": []
},
"pluginVersion": "10.2.0"
},
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"fieldConfig": {
"defaults": {
"custom": {
"drawStyle": "line",
"lineInterpolation": "linear",
"barAlignment": 0,
"lineWidth": 1,
"fillOpacity": 0,
"gradientMode": "none",
"spanNulls": false,
"insertNulls": false,
"showPoints": "auto",
"pointSize": 5,
"stacking": {
"mode": "none",
"group": "A"
},
"axisPlacement": "auto",
"axisLabel": "",
"axisColorMode": "text",
"axisBorderShow": false,
"scaleDistribution": {
"type": "linear"
},
"axisCenteredZero": false,
"hideFrom": {
"tooltip": false,
"viz": false,
"legend": false
},
"thresholdsStyle": {
"mode": "off"
}
},
"color": {
"mode": "palette-classic"
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 32
},
"id": 11,
"options": {
"tooltip": {
"mode": "single",
"sort": "none"
},
"legend": {
"showLegend": true,
"displayMode": "list",
"placement": "bottom",
"calcs": []
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "prometheus"
},
"disableTextWrap": false,
"editorMode": "builder",
"expr": "sum by(finished_reason) (increase(vllm:request_success_total{model_name=\"$model_name\"}[$__rate_interval]))",
"fullMetaSearch": false,
"includeNullMetadata": true,
"instant": false,
"interval": "",
"legendFormat": "__auto",
"range": true,
"refId": "A",
"useBackend": false
}
],
"title": "Finish Reason",
"description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.",
"type": "timeseries"
}
],
"refresh": "",

View File

@ -12,6 +12,7 @@ openai
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.9.8
outlines == 0.0.34 # Requires torch >= 2.1.0

View File

@ -320,7 +320,7 @@ class Scheduler:
for seq_group in state_queue:
if not request_ids:
# Using 'break' here may add two extra iterations,
# but is acceptable to reduce complexity .
# but is acceptable to reduce complexity.
break
if seq_group.request_id in request_ids:
# Appending aborted group into pending list.

View File

@ -22,7 +22,8 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata)
SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
@ -217,7 +218,8 @@ class LLMEngine:
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.model))
labels=dict(model_name=model_config.model),
max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config)
# Create sequence output processor, e.g. for beam search or
@ -619,59 +621,109 @@ class LLMEngine:
"""
now = time.time()
# KV Cache Usage in %.
# System State
# Scheduler State
num_running_sys = len(self.scheduler.running)
num_swapped_sys = len(self.scheduler.swapped)
num_waiting_sys = len(self.scheduler.waiting)
# KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage = 0.
cpu_cache_usage_sys = 0.
if num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
)
cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
# Scheduler State
num_running = len(self.scheduler.running)
num_swapped = len(self.scheduler.swapped)
num_waiting = len(self.scheduler.waiting)
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []
# Iteration stats if we have scheduler output.
num_prompt_tokens = 0
num_generation_tokens = 0
time_to_first_tokens = []
time_per_output_tokens = []
time_e2e_requests = []
# Request stats
# Latency
time_e2e_requests: List[float] = []
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
best_of_requests: List[int] = []
n_requests: List[int] = []
finished_reason_requests: List[str] = []
# NOTE: This loop assumes prefill seq_groups are before
# decode seq_groups in scheduled_seq_groups.
if scheduler_outputs is not None:
prompt_run = scheduler_outputs.num_prefill_groups > 0
num_generation_tokens_from_prefill_groups = 0.
if scheduler_outputs.num_prefill_groups > 0 and len(
scheduler_outputs.scheduled_seq_groups
) != scheduler_outputs.num_prefill_groups:
print("DETECTED CHUNKED")
# Number of Tokens.
if prompt_run:
num_prompt_tokens = sum(
len(scheduled_seq_group.seq_group.prompt_token_ids)
for scheduled_seq_group in
scheduler_outputs.scheduled_seq_groups)
num_generation_tokens = sum(
scheduled_seq_group.seq_group.num_seqs()
for scheduled_seq_group in
scheduler_outputs.scheduled_seq_groups)
else:
num_generation_tokens = scheduler_outputs.num_batched_tokens
# Latency Timings.
time_last_iters = []
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
for idx, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
group_was_prefill = idx < scheduler_outputs.num_prefill_groups
seq_group = scheduled_seq_group.seq_group
# Time since last token.
# (n.b. updates seq_group.metrics.last_token_time)
time_last_iters.append(seq_group.get_last_latency(now))
# Time since arrival for all finished requests.
# NOTE: a seq_group that completed all of its prefill tokens
# in the last iteration will have seq_group.is_prefill() = False
# with group_was_prefill = True
if group_was_prefill:
# Number of prompt tokens.
num_prompt_tokens_iter += (
scheduled_seq_group.token_chunk_size)
# If the seq_group just finished the prefill state
# get TTFT.
if not seq_group.is_prefill():
latency = seq_group.get_last_latency(now)
time_to_first_tokens_iter.append(latency)
# One generation token per finished prefill.
num_generation_tokens_from_prefill_groups += (
seq_group.num_seqs())
else:
# TPOTs.
latency = seq_group.get_last_latency(now)
time_per_output_tokens_iter.append(latency)
# Because of chunked prefill, we can have a single sequence
# group that does multiple prompt_runs. To prevent logging
# the same metadata more than once per request, we standardize
# on logging request level information for finished requests,
# which can only happen once.
if seq_group.is_finished():
# Latency timings
time_e2e_requests.append(now -
seq_group.metrics.arrival_time)
time_to_first_tokens = time_last_iters if prompt_run else []
time_per_output_tokens = [] if prompt_run else time_last_iters
# Metadata
num_prompt_tokens_requests.append(
len(seq_group.prompt_token_ids))
num_generation_tokens_requests.extend([
seq.get_output_len()
for seq in seq_group.get_finished_seqs()
])
best_of_requests.append(seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n)
finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status)
for seq in seq_group.get_finished_seqs()
])
# Number of generation tokens.
# num_batched_tokens equals the number of prompt_tokens plus the
# number of decode_tokens in a single iteration. So,
# num_generation_tokens = num_batched_tokens - num_prompt_tokens
# + num_generation_tokens_from_prefill_groups (since we generate
# one token on prefills on iters where the prefill finishes).
num_generation_tokens_iter = (
scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
num_generation_tokens_from_prefill_groups)
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
@ -683,17 +735,32 @@ class LLMEngine:
return Stats(
now=now,
num_running=num_running,
num_swapped=num_swapped,
num_waiting=num_waiting,
gpu_cache_usage=gpu_cache_usage,
cpu_cache_usage=cpu_cache_usage,
num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=num_generation_tokens,
time_to_first_tokens=time_to_first_tokens,
time_per_output_tokens=time_per_output_tokens,
time_e2e_requests=time_e2e_requests,
# System stats
# Scheduler State
num_running_sys=num_running_sys,
num_swapped_sys=num_swapped_sys,
num_waiting_sys=num_waiting_sys,
# KV Cache Usage in %
gpu_cache_usage_sys=gpu_cache_usage_sys,
cpu_cache_usage_sys=cpu_cache_usage_sys,
# Iteration stats
num_prompt_tokens_iter=num_prompt_tokens_iter,
num_generation_tokens_iter=num_generation_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
spec_decode_metrics=spec_decode_metrics,
# Request stats
# Latency
time_e2e_requests=time_e2e_requests,
# Metadata
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
best_of_requests=best_of_requests,
n_requests=n_requests,
finished_reason_requests=finished_reason_requests,
)
def add_lora(self, lora_request: LoRARequest) -> bool:

View File

@ -1,6 +1,8 @@
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Protocol
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Protocol, Union
import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
@ -21,8 +23,9 @@ disable_created_metrics()
# begin-metrics-definitions
class Metrics:
labelname_finish_reason = "finished_reason"
def __init__(self, labelnames: List[str]):
def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors
for collector in list(REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
@ -34,18 +37,20 @@ class Metrics:
documentation='information of cache_config')
# System stats
# Scheduler State
self.gauge_scheduler_running = Gauge(
name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.",
labelnames=labelnames)
self.gauge_scheduler_swapped = Gauge(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
self.gauge_scheduler_waiting = Gauge(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
self.gauge_scheduler_swapped = Gauge(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
# KV Cache Usage in %
self.gauge_gpu_cache_usage = Gauge(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
@ -55,7 +60,7 @@ class Metrics:
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
# Raw stats from last model iteration
# Iteration stats
self.counter_prompt_tokens = Counter(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
@ -80,18 +85,51 @@ class Metrics:
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5
])
self.histogram_e2e_request_latency = Histogram(
# Request stats
# Latency
self.histogram_e2e_time_request = Histogram(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Metadata
self.histogram_num_prompt_tokens_request = Histogram(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_num_generation_tokens_request = Histogram(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_best_of_request = Histogram(
name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = Histogram(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.counter_request_success = Counter(
name="vllm:request_success",
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])
# Legacy metrics
# Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = Gauge(
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
)
# Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = Gauge(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
@ -102,24 +140,57 @@ class Metrics:
# end-metrics-definitions
def build_1_2_5_buckets(max_value: int):
"""
Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values (1, 2, 5) until the value exceeds the specified maximum.
Example:
>>> build_1_2_5_buckets(100)
[1, 2, 5, 10, 20, 50, 100]
"""
mantissa_lst = [1, 2, 5]
exponent = 0
buckets = []
while True:
for m in mantissa_lst:
value = m * 10**exponent
if value <= max_value:
buckets.append(value)
else:
return buckets
exponent += 1
@dataclass
class Stats:
"""Created by LLMEngine for use by StatLogger."""
now: float
# System stats.
num_running: int
num_waiting: int
num_swapped: int
gpu_cache_usage: float
cpu_cache_usage: float
# System stats (should have _sys suffix)
# Scheduler State
num_running_sys: int
num_waiting_sys: int
num_swapped_sys: int
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
# Raw stats from last model iteration.
num_prompt_tokens: int
num_generation_tokens: int
time_to_first_tokens: List[float]
time_per_output_tokens: List[float]
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int
num_generation_tokens_iter: int
time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float]
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float]
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
@ -133,7 +204,8 @@ class SupportsMetricsInfo(Protocol):
class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
def __init__(self, local_interval: float, labels: Dict[str, str],
max_model_len: int) -> None:
# Metadata for logging locally.
self.last_local_log = time.time()
self.local_interval = local_interval
@ -144,7 +216,8 @@ class StatLogger:
# Prometheus metrics
self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()))
self.metrics = Metrics(labelnames=list(labels.keys()),
max_model_len=max_model_len)
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config":
@ -158,34 +231,66 @@ class StatLogger:
return elapsed_time > self.local_interval
def _log_prometheus(self, stats: Stats) -> None:
# Set system stat gauges.
self.metrics.gauge_scheduler_running.labels(**self.labels).set(
stats.num_running)
self.metrics.gauge_scheduler_swapped.labels(**self.labels).set(
stats.num_swapped)
self.metrics.gauge_scheduler_waiting.labels(**self.labels).set(
stats.num_waiting)
self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set(
stats.gpu_cache_usage)
self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set(
stats.cpu_cache_usage)
# System state data
self._log_gauge(self.metrics.gauge_scheduler_running,
stats.num_running_sys)
self._log_gauge(self.metrics.gauge_scheduler_swapped,
stats.num_swapped_sys)
self._log_gauge(self.metrics.gauge_scheduler_waiting,
stats.num_waiting_sys)
self._log_gauge(self.metrics.gauge_gpu_cache_usage,
stats.gpu_cache_usage_sys)
self._log_gauge(self.metrics.gauge_cpu_cache_usage,
stats.cpu_cache_usage_sys)
# Add to token counters.
self.metrics.counter_prompt_tokens.labels(**self.labels).inc(
stats.num_prompt_tokens)
self.metrics.counter_generation_tokens.labels(**self.labels).inc(
stats.num_generation_tokens)
# Iteration level data
self._log_counter(self.metrics.counter_prompt_tokens,
stats.num_prompt_tokens_iter)
self._log_counter(self.metrics.counter_generation_tokens,
stats.num_generation_tokens_iter)
self._log_histogram(self.metrics.histogram_time_to_first_token,
stats.time_to_first_tokens_iter)
self._log_histogram(self.metrics.histogram_time_per_output_token,
stats.time_per_output_tokens_iter)
# Observe request level latencies in histograms.
for ttft in stats.time_to_first_tokens:
self.metrics.histogram_time_to_first_token.labels(
**self.labels).observe(ttft)
for tpot in stats.time_per_output_tokens:
self.metrics.histogram_time_per_output_token.labels(
**self.labels).observe(tpot)
for e2e in stats.time_e2e_requests:
self.metrics.histogram_e2e_request_latency.labels(
**self.labels).observe(e2e)
# Request level data
# Latency
self._log_histogram(self.metrics.histogram_e2e_time_request,
stats.time_e2e_requests)
# Metadata
finished_reason_counter = CollectionsCounter(
stats.finished_reason_requests)
self._log_counter_labels(self.metrics.counter_request_success,
finished_reason_counter,
Metrics.labelname_finish_reason)
self._log_histogram(self.metrics.histogram_num_prompt_tokens_request,
stats.num_prompt_tokens_requests)
self._log_histogram(
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests)
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
self._log_histogram(self.metrics.histogram_best_of_request,
stats.best_of_requests)
def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
def _log_counter(self, counter: Counter, data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self.labels).inc(data)
def _log_counter_labels(self, counter: Counter, data: CollectionsCounter,
label_key: str) -> None:
# Convenience function for collection counter of labels.
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram: Histogram,
data: Union[List[int], List[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
@ -210,8 +315,8 @@ class StatLogger:
self._log_prometheus(stats)
# Save tracked stats for token counters.
self.num_prompt_tokens.append(stats.num_prompt_tokens)
self.num_generation_tokens.append(stats.num_generation_tokens)
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Log locally every local_interval seconds.
if self._local_interval_elapsed(stats.now):
@ -234,11 +339,11 @@ class StatLogger:
"CPU KV cache usage: %.1f%%",
prompt_throughput,
generation_throughput,
stats.num_running,
stats.num_swapped,
stats.num_waiting,
stats.gpu_cache_usage * 100,
stats.cpu_cache_usage * 100,
stats.num_running_sys,
stats.num_swapped_sys,
stats.num_waiting_sys,
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
# Reset tracked stats for next interval.

View File

@ -442,15 +442,27 @@ class SequenceGroup:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
def get_last_latency(self, now: float) -> float:
"""Gets last token latency for Request level timings."""
def get_last_latency(self, now: float) -> Optional[float]:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
if self.is_prefill():
raise ValueError(
"seq_group.get_last_latency() should not be called "
"if the seq_group is in prefill phase.")
# Otherwise return token latency.
latency = now - self.metrics.last_token_time
self.metrics.last_token_time = now
return latency
def maybe_set_first_token_time(self, time: float) -> None:
"""Sets the first token time for Request level timings."""
if self.metrics.first_token_time is None:
# Note: in a case where a sequence_group is swapped and
# recomputed, the time between iterations is counted
# in TPOT, rather than recalculating TTFT (since from the )
# POV of the user, there is simply a long generation delay.
if (self.metrics.first_token_time is None
and self.get_seqs()[0].get_output_len() == 1):
self.metrics.first_token_time = time
def maybe_set_first_scheduled_time(self, time: float) -> None: