[Misc] Add OpenTelemetry support (#4687)
This PR adds basic support for OpenTelemetry distributed tracing. It includes changes to enable tracing functionality and improve monitoring capabilities. I've also added a markdown with print-screens to guide users how to use this feature. You can find it here
This commit is contained in:
parent
13db4369d9
commit
7879f24dcc
@ -159,6 +159,15 @@ steps:
|
|||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
command: pytest -v -s quantization
|
command: pytest -v -s quantization
|
||||||
|
|
||||||
|
- label: Tracing Test
|
||||||
|
commands:
|
||||||
|
- "pip install \
|
||||||
|
opentelemetry-sdk \
|
||||||
|
opentelemetry-api \
|
||||||
|
opentelemetry-exporter-otlp \
|
||||||
|
opentelemetry-semantic-conventions-ai"
|
||||||
|
- pytest -v -s tracing
|
||||||
|
|
||||||
- label: Benchmarks
|
- label: Benchmarks
|
||||||
working_dir: "/vllm-workspace/.buildkite"
|
working_dir: "/vllm-workspace/.buildkite"
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
|
@ -20,26 +20,29 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||||
# the engine will automatically process the request in multiple batches.
|
# the engine will automatically process the request in multiple batches.
|
||||||
llm = LLM(model=args.model,
|
llm = LLM(
|
||||||
speculative_model=args.speculative_model,
|
model=args.model,
|
||||||
num_speculative_tokens=args.num_speculative_tokens,
|
speculative_model=args.speculative_model,
|
||||||
tokenizer=args.tokenizer,
|
num_speculative_tokens=args.num_speculative_tokens,
|
||||||
quantization=args.quantization,
|
tokenizer=args.tokenizer,
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
quantization=args.quantization,
|
||||||
trust_remote_code=args.trust_remote_code,
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
dtype=args.dtype,
|
trust_remote_code=args.trust_remote_code,
|
||||||
enforce_eager=args.enforce_eager,
|
dtype=args.dtype,
|
||||||
kv_cache_dtype=args.kv_cache_dtype,
|
enforce_eager=args.enforce_eager,
|
||||||
quantization_param_path=args.quantization_param_path,
|
kv_cache_dtype=args.kv_cache_dtype,
|
||||||
device=args.device,
|
quantization_param_path=args.quantization_param_path,
|
||||||
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
device=args.device,
|
||||||
use_v2_block_manager=args.use_v2_block_manager,
|
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
||||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
use_v2_block_manager=args.use_v2_block_manager,
|
||||||
download_dir=args.download_dir,
|
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||||
block_size=args.block_size,
|
download_dir=args.download_dir,
|
||||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
block_size=args.block_size,
|
||||||
load_format=args.load_format,
|
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||||
distributed_executor_backend=args.distributed_executor_backend)
|
load_format=args.load_format,
|
||||||
|
distributed_executor_backend=args.distributed_executor_backend,
|
||||||
|
otlp_traces_endpoint=args.otlp_traces_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
n=args.n,
|
n=args.n,
|
||||||
@ -254,5 +257,10 @@ if __name__ == '__main__':
|
|||||||
help='Backend to use for distributed serving. When more than 1 GPU '
|
help='Backend to use for distributed serving. When more than 1 GPU '
|
||||||
'is used, will be automatically set to "ray" if installed '
|
'is used, will be automatically set to "ray" if installed '
|
||||||
'or "mp" (multiprocessing) otherwise.')
|
'or "mp" (multiprocessing) otherwise.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--otlp-traces-endpoint',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Target URL to which OpenTelemetry traces will be sent.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
82
examples/production_monitoring/Otel.md
Normal file
82
examples/production_monitoring/Otel.md
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
# Setup OpenTelemetry POC
|
||||||
|
|
||||||
|
1. Install OpenTelemetry packages:
|
||||||
|
```
|
||||||
|
pip install \
|
||||||
|
opentelemetry-sdk \
|
||||||
|
opentelemetry-api \
|
||||||
|
opentelemetry-exporter-otlp \
|
||||||
|
opentelemetry-semantic-conventions-ai
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Start Jaeger in a docker container:
|
||||||
|
```
|
||||||
|
# From: https://www.jaegertracing.io/docs/1.57/getting-started/
|
||||||
|
docker run --rm --name jaeger \
|
||||||
|
-e COLLECTOR_ZIPKIN_HOST_PORT=:9411 \
|
||||||
|
-p 6831:6831/udp \
|
||||||
|
-p 6832:6832/udp \
|
||||||
|
-p 5778:5778 \
|
||||||
|
-p 16686:16686 \
|
||||||
|
-p 4317:4317 \
|
||||||
|
-p 4318:4318 \
|
||||||
|
-p 14250:14250 \
|
||||||
|
-p 14268:14268 \
|
||||||
|
-p 14269:14269 \
|
||||||
|
-p 9411:9411 \
|
||||||
|
jaegertracing/all-in-one:1.57
|
||||||
|
```
|
||||||
|
|
||||||
|
1. In a new shell, export Jaeger IP:
|
||||||
|
```
|
||||||
|
export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger)
|
||||||
|
export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317
|
||||||
|
```
|
||||||
|
Then set vLLM's service name for OpenTelemetry, enable insecure connections to Jaeger and run vLLM:
|
||||||
|
```
|
||||||
|
export OTEL_SERVICE_NAME="vllm-server"
|
||||||
|
export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true
|
||||||
|
python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"
|
||||||
|
```
|
||||||
|
|
||||||
|
1. In a new shell, send requests with trace context from a dummy client
|
||||||
|
```
|
||||||
|
export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger)
|
||||||
|
export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317
|
||||||
|
export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true
|
||||||
|
export OTEL_SERVICE_NAME="client-service"
|
||||||
|
python dummy_client.py
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Open Jaeger webui: http://localhost:16686/
|
||||||
|
|
||||||
|
In the search pane, select `vllm-server` service and hit `Find Traces`. You should get a list of traces, one for each request.
|
||||||
|

|
||||||
|
|
||||||
|
1. Clicking on a trace will show its spans and their tags. In this demo, each trace has 2 spans. One from the dummy client containing the prompt text and one from vLLM containing metadata about the request.
|
||||||
|

|
||||||
|
|
||||||
|
## Exporter Protocol
|
||||||
|
OpenTelemetry supports either `grpc` or `http/protobuf` as the transport protocol for trace data in the exporter.
|
||||||
|
By default, `grpc` is used. To set `http/protobuf` as the protocol, configure the `OTEL_EXPORTER_OTLP_TRACES_PROTOCOL` environment variable as follows:
|
||||||
|
```
|
||||||
|
export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf
|
||||||
|
export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=http://$JAEGER_IP:4318/v1/traces
|
||||||
|
python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Instrumentation of FastAPI
|
||||||
|
OpenTelemetry allows automatic instrumentation of FastAPI.
|
||||||
|
1. Install the instrumentation library
|
||||||
|
```
|
||||||
|
pip install opentelemetry-instrumentation-fastapi
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Run vLLM with `opentelemetry-instrument`
|
||||||
|
```
|
||||||
|
opentelemetry-instrument python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m"
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Send a request to vLLM and find its trace in Jaeger. It should contain spans from FastAPI.
|
||||||
|
|
||||||
|

|
35
examples/production_monitoring/dummy_client.py
Normal file
35
examples/production_monitoring/dummy_client.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import requests
|
||||||
|
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
|
||||||
|
OTLPSpanExporter)
|
||||||
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from opentelemetry.sdk.trace.export import (BatchSpanProcessor,
|
||||||
|
ConsoleSpanExporter)
|
||||||
|
from opentelemetry.trace import SpanKind, set_tracer_provider
|
||||||
|
from opentelemetry.trace.propagation.tracecontext import (
|
||||||
|
TraceContextTextMapPropagator)
|
||||||
|
|
||||||
|
trace_provider = TracerProvider()
|
||||||
|
set_tracer_provider(trace_provider)
|
||||||
|
|
||||||
|
trace_provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter()))
|
||||||
|
trace_provider.add_span_processor(BatchSpanProcessor(ConsoleSpanExporter()))
|
||||||
|
|
||||||
|
tracer = trace_provider.get_tracer("dummy-client")
|
||||||
|
|
||||||
|
url = "http://localhost:8000/v1/completions"
|
||||||
|
with tracer.start_as_current_span("client-span", kind=SpanKind.CLIENT) as span:
|
||||||
|
prompt = "San Francisco is a"
|
||||||
|
span.set_attribute("prompt", prompt)
|
||||||
|
headers = {}
|
||||||
|
TraceContextTextMapPropagator().inject(headers)
|
||||||
|
payload = {
|
||||||
|
"model": "facebook/opt-125m",
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": 10,
|
||||||
|
"best_of": 20,
|
||||||
|
"n": 3,
|
||||||
|
"use_beam_search": "true",
|
||||||
|
"temperature": 0.0,
|
||||||
|
# "stream": True,
|
||||||
|
}
|
||||||
|
response = requests.post(url, headers=headers, json=payload)
|
0
tests/tracing/__init__.py
Normal file
0
tests/tracing/__init__.py
Normal file
116
tests/tracing/test_tracing.py
Normal file
116
tests/tracing/test_tracing.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from concurrent import futures
|
||||||
|
from typing import Callable, Dict, Iterable, Literal
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
import pytest
|
||||||
|
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
||||||
|
ExportTraceServiceResponse)
|
||||||
|
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
|
||||||
|
TraceServiceServicer, add_TraceServiceServicer_to_server)
|
||||||
|
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
|
||||||
|
from opentelemetry.sdk.environment_variables import (
|
||||||
|
OTEL_EXPORTER_OTLP_TRACES_INSECURE)
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.tracing import SpanAttributes
|
||||||
|
|
||||||
|
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
|
||||||
|
|
||||||
|
FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
|
||||||
|
'array_value']
|
||||||
|
|
||||||
|
|
||||||
|
def decode_value(value: AnyValue):
|
||||||
|
field_decoders: Dict[FieldName, Callable] = {
|
||||||
|
"bool_value": (lambda v: v.bool_value),
|
||||||
|
"string_value": (lambda v: v.string_value),
|
||||||
|
"int_value": (lambda v: v.int_value),
|
||||||
|
"double_value": (lambda v: v.double_value),
|
||||||
|
"array_value":
|
||||||
|
(lambda v: [decode_value(item) for item in v.array_value.values]),
|
||||||
|
}
|
||||||
|
for field, decoder in field_decoders.items():
|
||||||
|
if value.HasField(field):
|
||||||
|
return decoder(value)
|
||||||
|
raise ValueError(f"Couldn't decode value: {value}")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attributes(attributes: Iterable[KeyValue]):
|
||||||
|
return {kv.key: decode_value(kv.value) for kv in attributes}
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTraceService(TraceServiceServicer):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.request = None
|
||||||
|
self.evt = threading.Event()
|
||||||
|
|
||||||
|
def Export(self, request, context):
|
||||||
|
self.request = request
|
||||||
|
self.evt.set()
|
||||||
|
return ExportTraceServiceResponse()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def trace_service():
|
||||||
|
"""Fixture to set up a fake gRPC trace service"""
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||||
|
service = FakeTraceService()
|
||||||
|
add_TraceServiceServicer_to_server(service, server)
|
||||||
|
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
yield service
|
||||||
|
|
||||||
|
server.stop(None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_traces(trace_service):
|
||||||
|
os.environ[OTEL_EXPORTER_OTLP_TRACES_INSECURE] = "true"
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.01,
|
||||||
|
top_p=0.1,
|
||||||
|
max_tokens=256)
|
||||||
|
model = "facebook/opt-125m"
|
||||||
|
llm = LLM(
|
||||||
|
model=model,
|
||||||
|
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
|
||||||
|
)
|
||||||
|
prompts = ["This is a short prompt"]
|
||||||
|
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
||||||
|
|
||||||
|
timeout = 5
|
||||||
|
if not trace_service.evt.wait(timeout):
|
||||||
|
raise TimeoutError(
|
||||||
|
f"The fake trace service didn't receive a trace within "
|
||||||
|
f"the {timeout} seconds timeout")
|
||||||
|
|
||||||
|
attributes = decode_attributes(trace_service.request.resource_spans[0].
|
||||||
|
scope_spans[0].spans[0].attributes)
|
||||||
|
assert attributes.get(SpanAttributes.LLM_RESPONSE_MODEL) == model
|
||||||
|
assert attributes.get(
|
||||||
|
SpanAttributes.LLM_REQUEST_ID) == outputs[0].request_id
|
||||||
|
assert attributes.get(
|
||||||
|
SpanAttributes.LLM_REQUEST_TEMPERATURE) == sampling_params.temperature
|
||||||
|
assert attributes.get(
|
||||||
|
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
|
||||||
|
assert attributes.get(
|
||||||
|
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
|
||||||
|
assert attributes.get(
|
||||||
|
SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of
|
||||||
|
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
|
||||||
|
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
|
||||||
|
outputs[0].prompt_token_ids)
|
||||||
|
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
|
||||||
|
assert attributes.get(
|
||||||
|
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS) == completion_tokens
|
||||||
|
metrics = outputs[0].metrics
|
||||||
|
assert attributes.get(
|
||||||
|
SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue
|
||||||
|
ttft = metrics.first_token_time - metrics.arrival_time
|
||||||
|
assert attributes.get(
|
||||||
|
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
|
||||||
|
e2e_time = metrics.finished_time - metrics.arrival_time
|
||||||
|
assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time
|
@ -10,6 +10,7 @@ from transformers import PretrainedConfig, PreTrainedTokenizerBase
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.tracing import is_otel_installed
|
||||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||||
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
||||||
is_hip, is_neuron, is_tpu, is_xpu)
|
is_hip, is_neuron, is_tpu, is_xpu)
|
||||||
@ -1371,6 +1372,17 @@ class DecodingConfig:
|
|||||||
f"must be one of {valid_guided_backends}")
|
f"must be one of {valid_guided_backends}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ObservabilityConfig:
|
||||||
|
"""Configuration for observability."""
|
||||||
|
otlp_traces_endpoint: Optional[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not is_otel_installed() and self.otlp_traces_endpoint is not None:
|
||||||
|
raise ValueError("OpenTelemetry packages must be installed before "
|
||||||
|
"configuring 'otlp_traces_endpoint'")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class EngineConfig:
|
class EngineConfig:
|
||||||
"""Dataclass which contains all engine-related configuration. This
|
"""Dataclass which contains all engine-related configuration. This
|
||||||
@ -1387,6 +1399,7 @@ class EngineConfig:
|
|||||||
vision_language_config: Optional[VisionLanguageConfig]
|
vision_language_config: Optional[VisionLanguageConfig]
|
||||||
speculative_config: Optional[SpeculativeConfig]
|
speculative_config: Optional[SpeculativeConfig]
|
||||||
decoding_config: Optional[DecodingConfig]
|
decoding_config: Optional[DecodingConfig]
|
||||||
|
observability_config: Optional[ObservabilityConfig]
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Verify configs are valid & consistent with each other.
|
"""Verify configs are valid & consistent with each other.
|
||||||
|
@ -7,8 +7,9 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ObservabilityConfig, ParallelConfig, SchedulerConfig,
|
||||||
TokenizerPoolConfig, VisionLanguageConfig)
|
SpeculativeConfig, TokenizerPoolConfig,
|
||||||
|
VisionLanguageConfig)
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.utils import str_to_int_tuple
|
from vllm.utils import str_to_int_tuple
|
||||||
|
|
||||||
@ -101,6 +102,8 @@ class EngineArgs:
|
|||||||
|
|
||||||
qlora_adapter_name_or_path: Optional[str] = None
|
qlora_adapter_name_or_path: Optional[str] = None
|
||||||
|
|
||||||
|
otlp_traces_endpoint: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
self.tokenizer = self.model
|
self.tokenizer = self.model
|
||||||
@ -599,6 +602,13 @@ class EngineArgs:
|
|||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help='Name or path of the QLoRA adapter.')
|
help='Name or path of the QLoRA adapter.')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--otlp-traces-endpoint',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Target URL to which OpenTelemetry traces will be sent.')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -757,6 +767,9 @@ class EngineArgs:
|
|||||||
decoding_config = DecodingConfig(
|
decoding_config = DecodingConfig(
|
||||||
guided_decoding_backend=self.guided_decoding_backend)
|
guided_decoding_backend=self.guided_decoding_backend)
|
||||||
|
|
||||||
|
observability_config = ObservabilityConfig(
|
||||||
|
otlp_traces_endpoint=self.otlp_traces_endpoint)
|
||||||
|
|
||||||
if (model_config.get_sliding_window() is not None
|
if (model_config.get_sliding_window() is not None
|
||||||
and scheduler_config.chunked_prefill_enabled
|
and scheduler_config.chunked_prefill_enabled
|
||||||
and not scheduler_config.use_v2_block_manager):
|
and not scheduler_config.use_v2_block_manager):
|
||||||
@ -764,16 +777,19 @@ class EngineArgs:
|
|||||||
"Chunked prefill is not supported with sliding window. "
|
"Chunked prefill is not supported with sliding window. "
|
||||||
"Set --disable-sliding-window to disable sliding window.")
|
"Set --disable-sliding-window to disable sliding window.")
|
||||||
|
|
||||||
return EngineConfig(model_config=model_config,
|
return EngineConfig(
|
||||||
cache_config=cache_config,
|
model_config=model_config,
|
||||||
parallel_config=parallel_config,
|
cache_config=cache_config,
|
||||||
scheduler_config=scheduler_config,
|
parallel_config=parallel_config,
|
||||||
device_config=device_config,
|
scheduler_config=scheduler_config,
|
||||||
lora_config=lora_config,
|
device_config=device_config,
|
||||||
vision_language_config=vision_language_config,
|
lora_config=lora_config,
|
||||||
speculative_config=speculative_config,
|
vision_language_config=vision_language_config,
|
||||||
load_config=load_config,
|
speculative_config=speculative_config,
|
||||||
decoding_config=decoding_config)
|
load_config=load_config,
|
||||||
|
decoding_config=decoding_config,
|
||||||
|
observability_config=observability_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -244,6 +244,9 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, output)
|
self.do_log_stats(scheduler_outputs, output)
|
||||||
|
|
||||||
|
# Tracing
|
||||||
|
self.do_tracing(scheduler_outputs)
|
||||||
|
|
||||||
if not request_outputs:
|
if not request_outputs:
|
||||||
# Stop the execute model loop in parallel workers until there are
|
# Stop the execute model loop in parallel workers until there are
|
||||||
# more requests to process. This avoids waiting indefinitely in
|
# more requests to process. This avoids waiting indefinitely in
|
||||||
@ -285,6 +288,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Dict[str, str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if lora_request is not None and not self.lora_config:
|
if lora_request is not None and not self.lora_config:
|
||||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||||
@ -301,6 +305,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
params=params,
|
params=params,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def check_health_async(self) -> None:
|
async def check_health_async(self) -> None:
|
||||||
@ -556,6 +561,7 @@ class AsyncLLMEngine:
|
|||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Dict[str, str]] = None,
|
||||||
) -> AsyncStream:
|
) -> AsyncStream:
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
if isinstance(inputs, str):
|
if isinstance(inputs, str):
|
||||||
@ -597,6 +603,7 @@ class AsyncLLMEngine:
|
|||||||
params=params,
|
params=params,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
@ -607,6 +614,7 @@ class AsyncLLMEngine:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Dict[str, str]] = None,
|
||||||
) -> AsyncIterator[RequestOutput]:
|
) -> AsyncIterator[RequestOutput]:
|
||||||
"""Generate outputs for a request.
|
"""Generate outputs for a request.
|
||||||
|
|
||||||
@ -621,6 +629,7 @@ class AsyncLLMEngine:
|
|||||||
sampling_params: The sampling parameters of the request.
|
sampling_params: The sampling parameters of the request.
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
lora_request: LoRA request to use for generation, if any.
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
trace_headers: OpenTelemetry trace headers.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
The output `RequestOutput` objects from the LLMEngine
|
The output `RequestOutput` objects from the LLMEngine
|
||||||
@ -674,6 +683,7 @@ class AsyncLLMEngine:
|
|||||||
inputs,
|
inputs,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
):
|
):
|
||||||
yield LLMEngine.validate_output(output, RequestOutput)
|
yield LLMEngine.validate_output(output, RequestOutput)
|
||||||
|
|
||||||
@ -683,6 +693,7 @@ class AsyncLLMEngine:
|
|||||||
pooling_params: PoolingParams,
|
pooling_params: PoolingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Dict[str, str]] = None,
|
||||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||||
"""Generate outputs for a request from an embedding model.
|
"""Generate outputs for a request from an embedding model.
|
||||||
|
|
||||||
@ -697,6 +708,7 @@ class AsyncLLMEngine:
|
|||||||
pooling_params: The pooling parameters of the request.
|
pooling_params: The pooling parameters of the request.
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
lora_request: LoRA request to use for generation, if any.
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
trace_headers: OpenTelemetry trace headers.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||||
@ -748,6 +760,7 @@ class AsyncLLMEngine:
|
|||||||
inputs,
|
inputs,
|
||||||
pooling_params,
|
pooling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
):
|
):
|
||||||
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
||||||
|
|
||||||
@ -758,6 +771,7 @@ class AsyncLLMEngine:
|
|||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
*,
|
*,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Dict[str, str]] = None,
|
||||||
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||||
"""Common logic to process requests with SamplingParams or
|
"""Common logic to process requests with SamplingParams or
|
||||||
PoolingParams."""
|
PoolingParams."""
|
||||||
@ -769,6 +783,7 @@ class AsyncLLMEngine:
|
|||||||
params,
|
params,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -848,3 +863,10 @@ class AsyncLLMEngine:
|
|||||||
else:
|
else:
|
||||||
await self.engine.check_health_async()
|
await self.engine.check_health_async()
|
||||||
logger.debug("Health check took %fs", time.perf_counter() - t)
|
logger.debug("Health check took %fs", time.perf_counter() - t)
|
||||||
|
|
||||||
|
async def is_tracing_enabled(self) -> bool:
|
||||||
|
if self.engine_use_ray:
|
||||||
|
return await self.engine.is_tracing_enabled.remote( # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.engine.is_tracing_enabled()
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
|
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Set, Type, TypeVar, Union
|
from typing import Set, Type, TypeVar, Union
|
||||||
|
|
||||||
from transformers import GenerationConfig, PreTrainedTokenizer
|
from transformers import GenerationConfig, PreTrainedTokenizer
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||||
LoRAConfig, ModelConfig, ParallelConfig,
|
LoRAConfig, ModelConfig, ObservabilityConfig,
|
||||||
SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
|
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
|
||||||
SchedulerOutputs)
|
SchedulerOutputs)
|
||||||
@ -31,6 +31,8 @@ from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
|||||||
PoolerOutput, SamplerOutput, Sequence,
|
PoolerOutput, SamplerOutput, Sequence,
|
||||||
SequenceGroup, SequenceGroupMetadata,
|
SequenceGroup, SequenceGroupMetadata,
|
||||||
SequenceStatus)
|
SequenceStatus)
|
||||||
|
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||||
|
init_tracer)
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||||
get_tokenizer_group)
|
get_tokenizer_group)
|
||||||
@ -154,6 +156,7 @@ class LLMEngine:
|
|||||||
vision_language_config: Optional[VisionLanguageConfig],
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
speculative_config: Optional[SpeculativeConfig],
|
speculative_config: Optional[SpeculativeConfig],
|
||||||
decoding_config: Optional[DecodingConfig],
|
decoding_config: Optional[DecodingConfig],
|
||||||
|
observability_config: Optional[ObservabilityConfig],
|
||||||
executor_class: Type[ExecutorBase],
|
executor_class: Type[ExecutorBase],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
@ -168,7 +171,8 @@ class LLMEngine:
|
|||||||
"disable_custom_all_reduce=%s, quantization=%s, "
|
"disable_custom_all_reduce=%s, quantization=%s, "
|
||||||
"enforce_eager=%s, kv_cache_dtype=%s, "
|
"enforce_eager=%s, kv_cache_dtype=%s, "
|
||||||
"quantization_param_path=%s, device_config=%s, "
|
"quantization_param_path=%s, device_config=%s, "
|
||||||
"decoding_config=%r, seed=%d, served_model_name=%s)",
|
"decoding_config=%r, observability_config=%r, "
|
||||||
|
"seed=%d, served_model_name=%s)",
|
||||||
VLLM_VERSION,
|
VLLM_VERSION,
|
||||||
model_config.model,
|
model_config.model,
|
||||||
speculative_config,
|
speculative_config,
|
||||||
@ -192,6 +196,7 @@ class LLMEngine:
|
|||||||
model_config.quantization_param_path,
|
model_config.quantization_param_path,
|
||||||
device_config.device,
|
device_config.device,
|
||||||
decoding_config,
|
decoding_config,
|
||||||
|
observability_config,
|
||||||
model_config.seed,
|
model_config.seed,
|
||||||
model_config.served_model_name,
|
model_config.served_model_name,
|
||||||
)
|
)
|
||||||
@ -207,6 +212,8 @@ class LLMEngine:
|
|||||||
self.speculative_config = speculative_config
|
self.speculative_config = speculative_config
|
||||||
self.load_config = load_config
|
self.load_config = load_config
|
||||||
self.decoding_config = decoding_config or DecodingConfig()
|
self.decoding_config = decoding_config or DecodingConfig()
|
||||||
|
self.observability_config = observability_config or ObservabilityConfig(
|
||||||
|
)
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
|
|
||||||
if not self.model_config.skip_tokenizer_init:
|
if not self.model_config.skip_tokenizer_init:
|
||||||
@ -288,6 +295,12 @@ class LLMEngine:
|
|||||||
max_model_len=self.model_config.max_model_len)
|
max_model_len=self.model_config.max_model_len)
|
||||||
self.stat_logger.info("cache_config", self.cache_config)
|
self.stat_logger.info("cache_config", self.cache_config)
|
||||||
|
|
||||||
|
self.tracer = None
|
||||||
|
if self.observability_config.otlp_traces_endpoint:
|
||||||
|
self.tracer = init_tracer(
|
||||||
|
"vllm.llm_engine",
|
||||||
|
self.observability_config.otlp_traces_endpoint)
|
||||||
|
|
||||||
# Create sequence output processor, e.g. for beam search or
|
# Create sequence output processor, e.g. for beam search or
|
||||||
# speculative decoding.
|
# speculative decoding.
|
||||||
self.output_processor = (
|
self.output_processor = (
|
||||||
@ -444,6 +457,7 @@ class LLMEngine:
|
|||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
lora_request: Optional[LoRARequest],
|
lora_request: Optional[LoRARequest],
|
||||||
|
trace_headers: Optional[Dict[str, str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Create the sequences.
|
# Create the sequences.
|
||||||
block_size = self.cache_config.block_size
|
block_size = self.cache_config.block_size
|
||||||
@ -461,6 +475,7 @@ class LLMEngine:
|
|||||||
params,
|
params,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
)
|
)
|
||||||
elif isinstance(params, PoolingParams):
|
elif isinstance(params, PoolingParams):
|
||||||
seq_group = self._create_sequence_group_with_pooling(
|
seq_group = self._create_sequence_group_with_pooling(
|
||||||
@ -507,6 +522,7 @@ class LLMEngine:
|
|||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Dict[str, str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a request to the engine's request pool.
|
"""Add a request to the engine's request pool.
|
||||||
|
|
||||||
@ -524,6 +540,7 @@ class LLMEngine:
|
|||||||
:class:`~vllm.PoolingParams` for pooling.
|
:class:`~vllm.PoolingParams` for pooling.
|
||||||
arrival_time: The arrival time of the request. If None, we use
|
arrival_time: The arrival time of the request. If None, we use
|
||||||
the current monotonic time.
|
the current monotonic time.
|
||||||
|
trace_headers: OpenTelemetry trace headers.
|
||||||
|
|
||||||
Details:
|
Details:
|
||||||
- Set arrival_time to the current time if it is None.
|
- Set arrival_time to the current time if it is None.
|
||||||
@ -565,6 +582,7 @@ class LLMEngine:
|
|||||||
params=params,
|
params=params,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_sequence_group_with_sampling(
|
def _create_sequence_group_with_sampling(
|
||||||
@ -574,6 +592,7 @@ class LLMEngine:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
lora_request: Optional[LoRARequest],
|
lora_request: Optional[LoRARequest],
|
||||||
|
trace_headers: Optional[Dict[str, str]] = None,
|
||||||
) -> SequenceGroup:
|
) -> SequenceGroup:
|
||||||
"""Creates a SequenceGroup with SamplingParams."""
|
"""Creates a SequenceGroup with SamplingParams."""
|
||||||
max_logprobs = self.get_model_config().max_logprobs
|
max_logprobs = self.get_model_config().max_logprobs
|
||||||
@ -595,11 +614,14 @@ class LLMEngine:
|
|||||||
self.generation_config_fields)
|
self.generation_config_fields)
|
||||||
|
|
||||||
# Create the sequence group.
|
# Create the sequence group.
|
||||||
seq_group = SequenceGroup(request_id=request_id,
|
seq_group = SequenceGroup(
|
||||||
seqs=[seq],
|
request_id=request_id,
|
||||||
arrival_time=arrival_time,
|
seqs=[seq],
|
||||||
sampling_params=sampling_params,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request)
|
sampling_params=sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
)
|
||||||
|
|
||||||
return seq_group
|
return seq_group
|
||||||
|
|
||||||
@ -793,6 +815,9 @@ class LLMEngine:
|
|||||||
# Log stats.
|
# Log stats.
|
||||||
self.do_log_stats(scheduler_outputs, output)
|
self.do_log_stats(scheduler_outputs, output)
|
||||||
|
|
||||||
|
# Tracing
|
||||||
|
self.do_tracing(scheduler_outputs)
|
||||||
|
|
||||||
if not request_outputs:
|
if not request_outputs:
|
||||||
# Stop the execute model loop in parallel workers until there are
|
# Stop the execute model loop in parallel workers until there are
|
||||||
# more requests to process. This avoids waiting indefinitely in
|
# more requests to process. This avoids waiting indefinitely in
|
||||||
@ -986,3 +1011,62 @@ class LLMEngine:
|
|||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
self.model_executor.check_health()
|
self.model_executor.check_health()
|
||||||
|
|
||||||
|
def is_tracing_enabled(self) -> bool:
|
||||||
|
return self.tracer is not None
|
||||||
|
|
||||||
|
def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None:
|
||||||
|
if self.tracer is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||||
|
seq_group = scheduled_seq_group.seq_group
|
||||||
|
if seq_group.is_finished():
|
||||||
|
self.create_trace_span(seq_group)
|
||||||
|
|
||||||
|
def create_trace_span(self, seq_group: SequenceGroup) -> None:
|
||||||
|
if self.tracer is None or seq_group.sampling_params is None:
|
||||||
|
return
|
||||||
|
arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)
|
||||||
|
|
||||||
|
trace_context = extract_trace_context(seq_group.trace_headers)
|
||||||
|
|
||||||
|
with self.tracer.start_as_current_span(
|
||||||
|
"llm_request",
|
||||||
|
kind=SpanKind.SERVER,
|
||||||
|
context=trace_context,
|
||||||
|
start_time=arrival_time_nano_seconds) as seq_span:
|
||||||
|
metrics = seq_group.metrics
|
||||||
|
ttft = metrics.first_token_time - metrics.arrival_time
|
||||||
|
e2e_time = metrics.finished_time - metrics.arrival_time
|
||||||
|
# attribute names are based on
|
||||||
|
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
|
||||||
|
seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
|
||||||
|
self.model_config.model)
|
||||||
|
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
|
||||||
|
seq_group.request_id)
|
||||||
|
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
|
||||||
|
seq_group.sampling_params.temperature)
|
||||||
|
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
|
||||||
|
seq_group.sampling_params.top_p)
|
||||||
|
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
|
||||||
|
seq_group.sampling_params.max_tokens)
|
||||||
|
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
|
||||||
|
seq_group.sampling_params.best_of)
|
||||||
|
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
|
||||||
|
seq_group.sampling_params.n)
|
||||||
|
seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
|
||||||
|
seq_group.num_seqs())
|
||||||
|
seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
|
||||||
|
len(seq_group.prompt_token_ids))
|
||||||
|
seq_span.set_attribute(
|
||||||
|
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
|
||||||
|
sum([
|
||||||
|
seq.get_output_len()
|
||||||
|
for seq in seq_group.get_finished_seqs()
|
||||||
|
]))
|
||||||
|
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
|
||||||
|
metrics.time_in_queue)
|
||||||
|
seq_span.set_attribute(
|
||||||
|
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
|
||||||
|
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
|
||||||
|
@ -31,6 +31,8 @@ from vllm.multimodal.utils import (async_get_and_parse_image,
|
|||||||
get_full_image_text_prompt)
|
get_full_image_text_prompt)
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
|
log_tracing_disabled_warning)
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -267,11 +269,20 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
if image_data is not None:
|
if image_data is not None:
|
||||||
inputs["multi_modal_data"] = image_data
|
inputs["multi_modal_data"] = image_data
|
||||||
|
|
||||||
|
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||||
|
trace_headers = None
|
||||||
|
if is_tracing_enabled and raw_request:
|
||||||
|
trace_headers = extract_trace_headers(raw_request.headers)
|
||||||
|
if not is_tracing_enabled and raw_request and contains_trace_headers(
|
||||||
|
raw_request.headers):
|
||||||
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
result_generator = self.engine.generate(
|
result_generator = self.engine.generate(
|
||||||
inputs,
|
inputs,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
request_id,
|
request_id,
|
||||||
lora_request,
|
lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
)
|
)
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if request.stream:
|
if request.stream:
|
||||||
|
@ -24,6 +24,8 @@ from vllm.model_executor.guided_decoding import (
|
|||||||
get_guided_decoding_logits_processor)
|
get_guided_decoding_logits_processor)
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
|
log_tracing_disabled_warning)
|
||||||
from vllm.utils import merge_async_iterators, random_uuid
|
from vllm.utils import merge_async_iterators, random_uuid
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -125,6 +127,14 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
truncate_prompt_tokens)
|
truncate_prompt_tokens)
|
||||||
prompt_ids, prompt_text = prompt_formats
|
prompt_ids, prompt_text = prompt_formats
|
||||||
|
|
||||||
|
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||||
|
trace_headers = None
|
||||||
|
if is_tracing_enabled:
|
||||||
|
trace_headers = extract_trace_headers(raw_request.headers)
|
||||||
|
if not is_tracing_enabled and contains_trace_headers(
|
||||||
|
raw_request.headers):
|
||||||
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
generator = self.engine.generate(
|
generator = self.engine.generate(
|
||||||
{
|
{
|
||||||
"prompt": prompt_text,
|
"prompt": prompt_text,
|
||||||
@ -133,6 +143,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
sampling_params,
|
sampling_params,
|
||||||
f"{request_id}-{i}",
|
f"{request_id}-{i}",
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
|
@ -414,6 +414,7 @@ class SequenceGroup:
|
|||||||
for an embedding model.
|
for an embedding model.
|
||||||
encoder_seq: Optional, the single encoder sequence. Should be None
|
encoder_seq: Optional, the single encoder sequence. Should be None
|
||||||
unless you are working with an encoder/decoder model.
|
unless you are working with an encoder/decoder model.
|
||||||
|
trace_headers: OpenTelemetry trace headers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -426,6 +427,7 @@ class SequenceGroup:
|
|||||||
embeddings: Optional[List[float]] = None,
|
embeddings: Optional[List[float]] = None,
|
||||||
pooling_params: Optional[PoolingParams] = None,
|
pooling_params: Optional[PoolingParams] = None,
|
||||||
encoder_seq: Optional[Sequence] = None,
|
encoder_seq: Optional[Sequence] = None,
|
||||||
|
trace_headers: Optional[Dict[str, str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||||
@ -441,6 +443,7 @@ class SequenceGroup:
|
|||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
self.pooling_params = pooling_params
|
self.pooling_params = pooling_params
|
||||||
self.encoder_seq = encoder_seq
|
self.encoder_seq = encoder_seq
|
||||||
|
self.trace_headers = trace_headers
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt(self) -> Optional[str]:
|
def prompt(self) -> Optional[str]:
|
||||||
|
104
vllm/tracing.py
Normal file
104
vllm/tracing.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
from typing import Mapping, Optional
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import run_once
|
||||||
|
|
||||||
|
TRACE_HEADERS = ["traceparent", "tracestate"]
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_is_otel_installed = False
|
||||||
|
try:
|
||||||
|
from opentelemetry.context.context import Context
|
||||||
|
from opentelemetry.sdk.environment_variables import (
|
||||||
|
OTEL_EXPORTER_OTLP_TRACES_PROTOCOL)
|
||||||
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
|
from opentelemetry.semconv.ai import SpanAttributes as BaseSpanAttributes
|
||||||
|
from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider
|
||||||
|
from opentelemetry.trace.propagation.tracecontext import (
|
||||||
|
TraceContextTextMapPropagator)
|
||||||
|
_is_otel_installed = True
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
class Context: # type: ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BaseSpanAttributes: # type: ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
class SpanKind: # type: ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Tracer: # type: ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def is_otel_installed() -> bool:
|
||||||
|
return _is_otel_installed
|
||||||
|
|
||||||
|
|
||||||
|
def init_tracer(instrumenting_module_name: str,
|
||||||
|
otlp_traces_endpoint: str) -> Optional[Tracer]:
|
||||||
|
assert is_otel_installed(), ("OpenTelemetry packages must be installed "
|
||||||
|
"prior to initializing a tracer")
|
||||||
|
trace_provider = TracerProvider()
|
||||||
|
|
||||||
|
span_exporter = get_span_exporter(otlp_traces_endpoint)
|
||||||
|
trace_provider.add_span_processor(BatchSpanProcessor(span_exporter))
|
||||||
|
set_tracer_provider(trace_provider)
|
||||||
|
|
||||||
|
tracer = trace_provider.get_tracer(instrumenting_module_name)
|
||||||
|
return tracer
|
||||||
|
|
||||||
|
|
||||||
|
def get_span_exporter(endpoint):
|
||||||
|
protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc")
|
||||||
|
if protocol == "grpc":
|
||||||
|
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
|
||||||
|
OTLPSpanExporter)
|
||||||
|
elif protocol == "http/protobuf":
|
||||||
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||||
|
OTLPSpanExporter)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported OTLP protocol '{protocol}' is configured")
|
||||||
|
|
||||||
|
return OTLPSpanExporter(endpoint=endpoint)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_trace_context(
|
||||||
|
headers: Optional[Mapping[str, str]]) -> Optional[Context]:
|
||||||
|
if is_otel_installed():
|
||||||
|
headers = headers or {}
|
||||||
|
return TraceContextTextMapPropagator().extract(headers)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]:
|
||||||
|
|
||||||
|
return {h: headers[h] for h in TRACE_HEADERS if h in headers}
|
||||||
|
|
||||||
|
|
||||||
|
class SpanAttributes(BaseSpanAttributes):
|
||||||
|
# The following span attribute names are added here because they are missing
|
||||||
|
# from the Semantic Conventions for LLM.
|
||||||
|
LLM_REQUEST_ID = "gen_ai.request.id"
|
||||||
|
LLM_REQUEST_BEST_OF = "gen_ai.request.best_of"
|
||||||
|
LLM_REQUEST_N = "gen_ai.request.n"
|
||||||
|
LLM_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences"
|
||||||
|
LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
|
||||||
|
LLM_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token"
|
||||||
|
LLM_LATENCY_E2E = "gen_ai.latency.e2e"
|
||||||
|
|
||||||
|
|
||||||
|
def contains_trace_headers(headers: Mapping[str, str]) -> bool:
|
||||||
|
return any(h in headers for h in TRACE_HEADERS)
|
||||||
|
|
||||||
|
|
||||||
|
@run_once
|
||||||
|
def log_tracing_disabled_warning() -> None:
|
||||||
|
logger.warning(
|
||||||
|
"Received a request with trace context but tracing is disabled")
|
@ -763,3 +763,15 @@ def cuda_device_count_stateless() -> int:
|
|||||||
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
||||||
|
|
||||||
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
||||||
|
|
||||||
|
|
||||||
|
#From: https://stackoverflow.com/a/4104188/2749989
|
||||||
|
def run_once(f):
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
if not wrapper.has_run: # type: ignore[attr-defined]
|
||||||
|
wrapper.has_run = True # type: ignore[attr-defined]
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
wrapper.has_run = False # type: ignore[attr-defined]
|
||||||
|
return wrapper
|
||||||
|
Loading…
x
Reference in New Issue
Block a user