Port metrics from aioprometheus
to prometheus_client
(#2730)
This commit is contained in:
parent
f7c1234990
commit
ef978fe411
@ -72,7 +72,7 @@ html_theme_options = {
|
|||||||
|
|
||||||
# Mock out external dependencies here.
|
# Mock out external dependencies here.
|
||||||
autodoc_mock_imports = [
|
autodoc_mock_imports = [
|
||||||
"torch", "transformers", "psutil", "aioprometheus", "sentencepiece",
|
"torch", "transformers", "psutil", "prometheus_client", "sentencepiece",
|
||||||
"vllm.cuda_utils", "vllm._C"
|
"vllm.cuda_utils", "vllm._C"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -6,4 +6,4 @@ neuronx-cc
|
|||||||
fastapi
|
fastapi
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
pydantic >= 2.0 # Required for OpenAI server.
|
pydantic >= 2.0 # Required for OpenAI server.
|
||||||
aioprometheus[starlette]
|
prometheus_client
|
||||||
|
@ -10,4 +10,4 @@ transformers >= 4.38.0 # Required for Gemma.
|
|||||||
fastapi
|
fastapi
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
pydantic >= 2.0 # Required for OpenAI server.
|
pydantic >= 2.0 # Required for OpenAI server.
|
||||||
aioprometheus[starlette]
|
prometheus_client
|
||||||
|
@ -9,7 +9,7 @@ xformers == 0.0.23.post1 # Required for CUDA 12.1.
|
|||||||
fastapi
|
fastapi
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
pydantic >= 2.0 # Required for OpenAI server.
|
pydantic >= 2.0 # Required for OpenAI server.
|
||||||
aioprometheus[starlette]
|
prometheus_client
|
||||||
pynvml == 11.5.0
|
pynvml == 11.5.0
|
||||||
triton >= 2.1.0
|
triton >= 2.1.0
|
||||||
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
|
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
|
||||||
|
@ -165,6 +165,7 @@ class VllmRunner:
|
|||||||
dtype: str = "half",
|
dtype: str = "half",
|
||||||
disable_log_stats: bool = True,
|
disable_log_stats: bool = True,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = LLM(
|
self.model = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@ -174,6 +175,7 @@ class VllmRunner:
|
|||||||
swap_space=0,
|
swap_space=0,
|
||||||
disable_log_stats=disable_log_stats,
|
disable_log_stats=disable_log_stats,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import vllm.engine.metrics
|
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
@ -16,10 +15,10 @@ def test_metric_counter_prompt_tokens(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Reset metric
|
vllm_model = vllm_runner(model,
|
||||||
vllm.engine.metrics.counter_prompt_tokens.set_value({}, 0)
|
dtype=dtype,
|
||||||
|
disable_log_stats=False,
|
||||||
vllm_model = vllm_runner(model, dtype=dtype, disable_log_stats=False)
|
gpu_memory_utilization=0.4)
|
||||||
tokenizer = vllm_model.model.get_tokenizer()
|
tokenizer = vllm_model.model.get_tokenizer()
|
||||||
prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts]
|
prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts]
|
||||||
# This test needs at least 2 prompts in a batch of different lengths to verify their token count is correct despite padding.
|
# This test needs at least 2 prompts in a batch of different lengths to verify their token count is correct despite padding.
|
||||||
@ -29,7 +28,9 @@ def test_metric_counter_prompt_tokens(
|
|||||||
vllm_prompt_token_count = sum(prompt_token_counts)
|
vllm_prompt_token_count = sum(prompt_token_counts)
|
||||||
|
|
||||||
_ = vllm_model.generate_greedy(example_prompts, max_tokens)
|
_ = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
metric_count = vllm.engine.metrics.counter_prompt_tokens.get_value({})
|
stat_logger = vllm_model.model.llm_engine.stat_logger
|
||||||
|
metric_count = stat_logger.metrics.counter_prompt_tokens.labels(
|
||||||
|
**stat_logger.labels)._value.get()
|
||||||
|
|
||||||
assert vllm_prompt_token_count == metric_count, (
|
assert vllm_prompt_token_count == metric_count, (
|
||||||
f"prompt token count: {vllm_prompt_token_count!r}\nmetric: {metric_count!r}"
|
f"prompt token count: {vllm_prompt_token_count!r}\nmetric: {metric_count!r}"
|
||||||
@ -46,13 +47,15 @@ def test_metric_counter_generation_tokens(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Reset metric
|
vllm_model = vllm_runner(model,
|
||||||
vllm.engine.metrics.counter_generation_tokens.set_value({}, 0)
|
dtype=dtype,
|
||||||
|
disable_log_stats=False,
|
||||||
vllm_model = vllm_runner(model, dtype=dtype, disable_log_stats=False)
|
gpu_memory_utilization=0.4)
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
tokenizer = vllm_model.model.get_tokenizer()
|
tokenizer = vllm_model.model.get_tokenizer()
|
||||||
metric_count = vllm.engine.metrics.counter_generation_tokens.get_value({})
|
stat_logger = vllm_model.model.llm_engine.stat_logger
|
||||||
|
metric_count = stat_logger.metrics.counter_generation_tokens.labels(
|
||||||
|
**stat_logger.labels)._value.get()
|
||||||
vllm_generation_count = 0
|
vllm_generation_count = 0
|
||||||
for i in range(len(example_prompts)):
|
for i in range(len(example_prompts)):
|
||||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||||
|
@ -128,7 +128,8 @@ class LLMEngine:
|
|||||||
# Metric Logging.
|
# Metric Logging.
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
self.stat_logger = StatLogger(
|
self.stat_logger = StatLogger(
|
||||||
local_interval=_LOCAL_LOGGING_INTERVAL_SEC)
|
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
||||||
|
labels=dict(model_name=model_config.model))
|
||||||
|
|
||||||
self.forward_dag = None
|
self.forward_dag = None
|
||||||
if USE_RAY_COMPILED_DAG:
|
if USE_RAY_COMPILED_DAG:
|
||||||
|
@ -1,66 +1,94 @@
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from aioprometheus import Counter, Gauge, Histogram
|
from prometheus_client import Counter, Gauge, Histogram, REGISTRY, disable_created_metrics
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
labels = {}
|
disable_created_metrics()
|
||||||
|
|
||||||
|
|
||||||
def add_global_metrics_labels(**kwargs):
|
|
||||||
labels.update(kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# The begin-* and end* here are used by the documentation generator
|
# The begin-* and end* here are used by the documentation generator
|
||||||
# to extract the metrics definitions.
|
# to extract the metrics definitions.
|
||||||
|
|
||||||
|
|
||||||
# begin-metrics-definitions
|
# begin-metrics-definitions
|
||||||
gauge_avg_prompt_throughput = Gauge("vllm:avg_prompt_throughput_toks_per_s",
|
class Metrics:
|
||||||
"Average prefill throughput in tokens/s.")
|
|
||||||
gauge_avg_generation_throughput = Gauge(
|
|
||||||
"vllm:avg_generation_throughput_toks_per_s",
|
|
||||||
"Average generation throughput in tokens/s.")
|
|
||||||
counter_prompt_tokens = Counter("vllm:prompt_tokens_total",
|
|
||||||
"Number of prefill tokens processed.")
|
|
||||||
counter_generation_tokens = Counter("vllm:generation_tokens_total",
|
|
||||||
"Number of generation tokens processed.")
|
|
||||||
|
|
||||||
gauge_scheduler_running = Gauge(
|
def __init__(self, labelnames: List[str]):
|
||||||
"vllm:num_requests_running",
|
# Unregister any existing vLLM collectors
|
||||||
"Number of requests currently running on GPU.")
|
for collector in list(REGISTRY._collector_to_names):
|
||||||
gauge_scheduler_swapped = Gauge("vllm:num_requests_swapped",
|
if hasattr(collector, "_name") and "vllm" in collector._name:
|
||||||
"Number of requests swapped to CPU.")
|
REGISTRY.unregister(collector)
|
||||||
gauge_scheduler_waiting = Gauge("vllm:num_requests_waiting",
|
|
||||||
"Number of requests waiting to be processed.")
|
# System stats
|
||||||
|
self.gauge_scheduler_running = Gauge(
|
||||||
|
name="vllm:num_requests_running",
|
||||||
|
documentation="Number of requests currently running on GPU.",
|
||||||
|
labelnames=labelnames)
|
||||||
|
self.gauge_scheduler_swapped = Gauge(
|
||||||
|
name="vllm:num_requests_swapped",
|
||||||
|
documentation="Number of requests swapped to CPU.",
|
||||||
|
labelnames=labelnames)
|
||||||
|
self.gauge_scheduler_waiting = Gauge(
|
||||||
|
name="vllm:num_requests_waiting",
|
||||||
|
documentation="Number of requests waiting to be processed.",
|
||||||
|
labelnames=labelnames)
|
||||||
|
self.gauge_gpu_cache_usage = Gauge(
|
||||||
|
name="vllm:gpu_cache_usage_perc",
|
||||||
|
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
||||||
|
labelnames=labelnames)
|
||||||
|
self.gauge_cpu_cache_usage = Gauge(
|
||||||
|
name="vllm:cpu_cache_usage_perc",
|
||||||
|
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
|
||||||
|
labelnames=labelnames)
|
||||||
|
|
||||||
|
# Raw stats from last model iteration
|
||||||
|
self.counter_prompt_tokens = Counter(
|
||||||
|
name="vllm:prompt_tokens_total",
|
||||||
|
documentation="Number of prefill tokens processed.",
|
||||||
|
labelnames=labelnames)
|
||||||
|
self.counter_generation_tokens = Counter(
|
||||||
|
name="vllm:generation_tokens_total",
|
||||||
|
documentation="Number of generation tokens processed.",
|
||||||
|
labelnames=labelnames)
|
||||||
|
self.histogram_time_to_first_token = Histogram(
|
||||||
|
name="vllm:time_to_first_token_seconds",
|
||||||
|
documentation="Histogram of time to first token in seconds.",
|
||||||
|
labelnames=labelnames,
|
||||||
|
buckets=[
|
||||||
|
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
|
||||||
|
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
|
||||||
|
])
|
||||||
|
self.histogram_time_per_output_token = Histogram(
|
||||||
|
name="vllm:time_per_output_token_seconds",
|
||||||
|
documentation="Histogram of time per output token in seconds.",
|
||||||
|
labelnames=labelnames,
|
||||||
|
buckets=[
|
||||||
|
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
|
||||||
|
1.0, 2.5
|
||||||
|
])
|
||||||
|
self.histogram_e2e_request_latency = Histogram(
|
||||||
|
name="vllm:e2e_request_latency_seconds",
|
||||||
|
documentation="Histogram of end to end request latency in seconds.",
|
||||||
|
labelnames=labelnames,
|
||||||
|
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
|
||||||
|
|
||||||
|
# Legacy metrics
|
||||||
|
self.gauge_avg_prompt_throughput = Gauge(
|
||||||
|
name="vllm:avg_prompt_throughput_toks_per_s",
|
||||||
|
documentation="Average prefill throughput in tokens/s.",
|
||||||
|
labelnames=labelnames,
|
||||||
|
)
|
||||||
|
self.gauge_avg_generation_throughput = Gauge(
|
||||||
|
name="vllm:avg_generation_throughput_toks_per_s",
|
||||||
|
documentation="Average generation throughput in tokens/s.",
|
||||||
|
labelnames=labelnames,
|
||||||
|
)
|
||||||
|
|
||||||
gauge_gpu_cache_usage = Gauge(
|
|
||||||
"vllm:gpu_cache_usage_perc",
|
|
||||||
"GPU KV-cache usage. 1 means 100 percent usage.")
|
|
||||||
gauge_cpu_cache_usage = Gauge(
|
|
||||||
"vllm:cpu_cache_usage_perc",
|
|
||||||
"CPU KV-cache usage. 1 means 100 percent usage.")
|
|
||||||
|
|
||||||
histogram_time_to_first_token = Histogram(
|
|
||||||
"vllm:time_to_first_token_seconds",
|
|
||||||
"Histogram of time to first token in seconds.",
|
|
||||||
buckets=[
|
|
||||||
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0,
|
|
||||||
2.5, 5.0, 7.5, 10.0
|
|
||||||
])
|
|
||||||
histogram_time_per_output_tokens = Histogram(
|
|
||||||
"vllm:time_per_output_token_seconds",
|
|
||||||
"Histogram of time per output token in seconds.",
|
|
||||||
buckets=[
|
|
||||||
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5
|
|
||||||
])
|
|
||||||
histogram_e2e_request_latency = Histogram(
|
|
||||||
"vllm:e2e_request_latency_seconds",
|
|
||||||
"Histogram of end to end request latency in seconds.",
|
|
||||||
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
|
|
||||||
# end-metrics-definitions
|
# end-metrics-definitions
|
||||||
|
|
||||||
|
|
||||||
@ -87,7 +115,7 @@ class Stats:
|
|||||||
class StatLogger:
|
class StatLogger:
|
||||||
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
|
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
|
||||||
|
|
||||||
def __init__(self, local_interval: float) -> None:
|
def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
|
||||||
# Metadata for logging locally.
|
# Metadata for logging locally.
|
||||||
self.last_local_log = time.monotonic()
|
self.last_local_log = time.monotonic()
|
||||||
self.local_interval = local_interval
|
self.local_interval = local_interval
|
||||||
@ -96,6 +124,10 @@ class StatLogger:
|
|||||||
self.num_prompt_tokens: List[int] = []
|
self.num_prompt_tokens: List[int] = []
|
||||||
self.num_generation_tokens: List[int] = []
|
self.num_generation_tokens: List[int] = []
|
||||||
|
|
||||||
|
# Prometheus metrics
|
||||||
|
self.labels = labels
|
||||||
|
self.metrics = Metrics(labelnames=list(labels.keys()))
|
||||||
|
|
||||||
def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
|
def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
|
||||||
return float(np.sum(tracked_stats) / (now - self.last_local_log))
|
return float(np.sum(tracked_stats) / (now - self.last_local_log))
|
||||||
|
|
||||||
@ -105,23 +137,33 @@ class StatLogger:
|
|||||||
|
|
||||||
def _log_prometheus(self, stats: Stats) -> None:
|
def _log_prometheus(self, stats: Stats) -> None:
|
||||||
# Set system stat gauges.
|
# Set system stat gauges.
|
||||||
gauge_scheduler_running.set(labels, stats.num_running)
|
self.metrics.gauge_scheduler_running.labels(**self.labels).set(
|
||||||
gauge_scheduler_swapped.set(labels, stats.num_swapped)
|
stats.num_running)
|
||||||
gauge_scheduler_waiting.set(labels, stats.num_waiting)
|
self.metrics.gauge_scheduler_swapped.labels(**self.labels).set(
|
||||||
gauge_gpu_cache_usage.set(labels, stats.gpu_cache_usage)
|
stats.num_swapped)
|
||||||
gauge_cpu_cache_usage.set(labels, stats.cpu_cache_usage)
|
self.metrics.gauge_scheduler_waiting.labels(**self.labels).set(
|
||||||
|
stats.num_waiting)
|
||||||
|
self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set(
|
||||||
|
stats.gpu_cache_usage)
|
||||||
|
self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set(
|
||||||
|
stats.cpu_cache_usage)
|
||||||
|
|
||||||
# Add to token counters.
|
# Add to token counters.
|
||||||
counter_prompt_tokens.add(labels, stats.num_prompt_tokens)
|
self.metrics.counter_prompt_tokens.labels(**self.labels).inc(
|
||||||
counter_generation_tokens.add(labels, stats.num_generation_tokens)
|
stats.num_prompt_tokens)
|
||||||
|
self.metrics.counter_generation_tokens.labels(**self.labels).inc(
|
||||||
|
stats.num_generation_tokens)
|
||||||
|
|
||||||
# Observe request level latencies in histograms.
|
# Observe request level latencies in histograms.
|
||||||
for ttft in stats.time_to_first_tokens:
|
for ttft in stats.time_to_first_tokens:
|
||||||
histogram_time_to_first_token.observe(labels, ttft)
|
self.metrics.histogram_time_to_first_token.labels(
|
||||||
|
**self.labels).observe(ttft)
|
||||||
for tpot in stats.time_per_output_tokens:
|
for tpot in stats.time_per_output_tokens:
|
||||||
histogram_time_per_output_tokens.observe(labels, tpot)
|
self.metrics.histogram_time_per_output_token.labels(
|
||||||
|
**self.labels).observe(tpot)
|
||||||
for e2e in stats.time_e2e_requests:
|
for e2e in stats.time_e2e_requests:
|
||||||
histogram_e2e_request_latency.observe(labels, e2e)
|
self.metrics.histogram_e2e_request_latency.labels(
|
||||||
|
**self.labels).observe(e2e)
|
||||||
|
|
||||||
def _log_prometheus_interval(self, prompt_throughput: float,
|
def _log_prometheus_interval(self, prompt_throughput: float,
|
||||||
generation_throughput: float) -> None:
|
generation_throughput: float) -> None:
|
||||||
@ -130,8 +172,10 @@ class StatLogger:
|
|||||||
# Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens
|
# Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens
|
||||||
# Which log raw data and calculate summaries using rate() on the grafana/prometheus side.
|
# Which log raw data and calculate summaries using rate() on the grafana/prometheus side.
|
||||||
# See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
|
# See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
|
||||||
gauge_avg_prompt_throughput.set(labels, prompt_throughput)
|
self.metrics.gauge_avg_prompt_throughput.labels(
|
||||||
gauge_avg_generation_throughput.set(labels, generation_throughput)
|
**self.labels).set(prompt_throughput)
|
||||||
|
self.metrics.gauge_avg_generation_throughput.labels(
|
||||||
|
**self.labels).set(generation_throughput)
|
||||||
|
|
||||||
def log(self, stats: Stats) -> None:
|
def log(self, stats: Stats) -> None:
|
||||||
"""Called by LLMEngine.
|
"""Called by LLMEngine.
|
||||||
|
@ -6,8 +6,7 @@ import os
|
|||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from aioprometheus import MetricsMiddleware
|
from prometheus_client import make_asgi_app
|
||||||
from aioprometheus.asgi.starlette import metrics
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@ -18,7 +17,6 @@ from fastapi.responses import JSONResponse, StreamingResponse, Response
|
|||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.engine.metrics import add_global_metrics_labels
|
|
||||||
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
|
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
@ -141,8 +139,9 @@ def parse_args():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics
|
# Add prometheus asgi middleware to route /metrics requests
|
||||||
app.add_route("/metrics", metrics) # Exposes HTTP metrics
|
metrics_app = make_asgi_app()
|
||||||
|
app.mount("/metrics", metrics_app)
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
@ -242,9 +241,6 @@ if __name__ == "__main__":
|
|||||||
openai_serving_completion = OpenAIServingCompletion(
|
openai_serving_completion = OpenAIServingCompletion(
|
||||||
engine, served_model, args.lora_modules)
|
engine, served_model, args.lora_modules)
|
||||||
|
|
||||||
# Register labels for metrics
|
|
||||||
add_global_metrics_labels(model_name=engine_args.model)
|
|
||||||
|
|
||||||
app.root_path = args.root_path
|
app.root_path = args.root_path
|
||||||
uvicorn.run(app,
|
uvicorn.run(app,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user