Skip to content

Commit a7414f7

Browse files
authored
Merge pull request vllm-project#1 from RichardoMrMu/feat-trace-v1-aftermerge
feat:trace v1
2 parents 0c600b9 + e0bb716 commit a7414f7

File tree

6 files changed

+94
-11
lines changed

6 files changed

+94
-11
lines changed

vllm/tracing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ class SpanAttributes:
119119
# forward, block/sync across workers, cpu-gpu sync time and sampling time.
120120
GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = (
121121
"gen_ai.latency.time_in_model_execute")
122+
GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = \
123+
"gen_ai.latency.time_in_model_prefill"
124+
GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode"
125+
GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = \
126+
"gen_ai.latency.time_in_model_inference"
122127

123128

124129
def contains_trace_headers(headers: Mapping[str, str]) -> bool:

vllm/v1/core/sched/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,9 +860,9 @@ def update_from_output(
860860
stop_reason=request.stop_reason,
861861
events=request.take_events(),
862862
kv_transfer_params=kv_transfer_params,
863+
trace_headers=request.trace_headers,
863864
num_cached_tokens=request.num_cached_tokens,
864865
))
865-
866866
else:
867867
# Invariant: EngineCore returns no partial prefill outputs.
868868
assert not prompt_logprobs_tensors

vllm/v1/engine/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import enum
55
import time
6-
from collections.abc import Sequence
6+
from collections.abc import Mapping, Sequence
77
from typing import Any, Optional, Union
88

99
import msgspec
@@ -70,6 +70,8 @@ class EngineCoreRequest(
7070
current_wave: int = 0
7171
priority: int = 0
7272

73+
trace_headers: Optional[Mapping[str, str]] = None
74+
7375

7476
class EngineCoreEventType(enum.IntEnum):
7577
"""The type of engine core request event."""
@@ -115,6 +117,7 @@ class EngineCoreOutput(
115117
events: Optional[list[EngineCoreEvent]] = None
116118
kv_transfer_params: Optional[dict[str, Any]] = None
117119

120+
trace_headers: Optional[Mapping[str, str]] = None
118121
# The number of tokens with prefix cache hits.
119122
num_cached_tokens: int = 0
120123

@@ -141,7 +144,7 @@ class EngineCoreOutputs(
141144
omit_defaults=True, # type: ignore[call-arg]
142145
gc=False): # type: ignore[call-arg]
143146

144-
#NOTE(Nick): We could consider ways to make this more compact,
147+
# NOTE(Nick): We could consider ways to make this more compact,
145148
# e.g. columnwise layout
146149

147150
engine_index: int = 0

vllm/v1/engine/output_processor.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import asyncio
5+
import time
56
from collections.abc import Iterable
67
from dataclasses import dataclass
78
from typing import Any, Optional, Union, cast
89

910
import torch
1011

12+
from vllm.config import ObservabilityConfig
1113
from vllm.outputs import (CompletionOutput, PoolingOutput,
1214
PoolingRequestOutput, RequestOutput)
1315
from vllm.sampling_params import RequestOutputKind
16+
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
17+
init_tracer)
1418
from vllm.transformers_utils.tokenizer import AnyTokenizer
1519
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
1620
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
@@ -274,16 +278,26 @@ def _new_pooling_output(
274278
class OutputProcessor:
275279
"""Process EngineCoreOutputs into RequestOutputs."""
276280

277-
def __init__(
278-
self,
279-
tokenizer: TokenizerGroup,
280-
log_stats: bool,
281-
):
281+
def __init__(self,
282+
tokenizer: TokenizerGroup,
283+
log_stats: bool,
284+
observability_config: Optional[ObservabilityConfig] = None):
282285
self.log_stats = log_stats
283286
self.tokenizer = tokenizer
284287
self.request_states: dict[str, RequestState] = {}
285288
self.parent_requests: dict[str, ParentRequest] = {}
286289
self.lora_states = LoRARequestStates()
290+
self.observability_config = observability_config
291+
292+
self.tracer = None
293+
if (self.observability_config is not None
294+
and self.observability_config.otlp_traces_endpoint):
295+
self.tracer = init_tracer(
296+
"vllm.llm_engine",
297+
self.observability_config.otlp_traces_endpoint)
298+
299+
def is_tracing_enabled(self) -> bool:
300+
return self.tracer is not None
287301

288302
def get_num_unfinished_requests(self):
289303
return len(self.request_states)
@@ -440,6 +454,65 @@ def process_outputs(
440454
reqs_to_abort=reqs_to_abort,
441455
)
442456

457+
def do_tracing(self, engine_core_output: EngineCoreOutput,
458+
req_state: RequestState,
459+
iteration_stats: Optional[IterationStats]):
460+
if (engine_core_output.finish_reason is None or iteration_stats is None
461+
or req_state is None or req_state.stats is None
462+
or self.tracer is None):
463+
return
464+
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
465+
466+
trace_context = extract_trace_context(engine_core_output.trace_headers)
467+
with self.tracer.start_as_current_span(
468+
"llm_request",
469+
kind=SpanKind.SERVER,
470+
context=trace_context,
471+
start_time=arrival_time_nano_seconds) as span:
472+
metrics = req_state.stats
473+
ttft = metrics.first_token_ts - metrics.arrival_time
474+
e2e_time = time.time() - metrics.arrival_time
475+
# Queued interval is from first QUEUED event to first SCHEDULED
476+
queued_time = metrics.scheduled_ts - metrics.queued_ts
477+
478+
# Prefill interval is from first SCHEDULED to first NEW_TOKEN
479+
# Any preemptions during prefill is included in the interval
480+
prefill_time = metrics.first_token_ts - metrics.scheduled_ts
481+
482+
# Decode interval is from first NEW_TOKEN to last NEW_TOKEN
483+
# Any preemptions during decode are included
484+
decode_time = metrics.last_token_ts - metrics.first_token_ts
485+
486+
# Inference interval is from first SCHEDULED to last NEW_TOKEN
487+
# Any preemptions during prefill or decode are included
488+
inference_time = metrics.last_token_ts - metrics.scheduled_ts
489+
span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
490+
self.tokenizer.tokenizer_id)
491+
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
492+
req_state.request_id)
493+
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
494+
req_state.max_tokens_param)
495+
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
496+
len(req_state.prompt_token_ids))
497+
span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
498+
metrics.num_generation_tokens)
499+
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
500+
metrics.queued_ts - metrics.arrival_time)
501+
span.set_attribute(
502+
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
503+
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
504+
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
505+
queued_time)
506+
span.set_attribute(
507+
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL,
508+
prefill_time)
509+
span.set_attribute(
510+
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE,
511+
decode_time)
512+
span.set_attribute(
513+
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE,
514+
inference_time)
515+
443516
def _update_stats_from_output(self, req_state: RequestState,
444517
engine_core_output: EngineCoreOutput,
445518
engine_core_timestamp: Optional[float],

vllm/v1/engine/processor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,6 @@ def process_inputs(
225225
# TODO(woosuk): Support encoder-decoder models.
226226
self._validate_lora(lora_request)
227227
self._validate_params(params, lora_request)
228-
if trace_headers is not None:
229-
raise ValueError("V1 does not support tracing yet.")
230228
if prompt_adapter_request is not None:
231229
raise ValueError("V1 does not support prompt_adapter_request.")
232230

vllm/v1/request.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import enum
55
import time
6+
from collections.abc import Mapping
67
from typing import TYPE_CHECKING, Any, Optional, Union
78

89
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
@@ -36,6 +37,7 @@ def __init__(
3637
structured_output_request: Optional["StructuredOutputRequest"] = None,
3738
cache_salt: Optional[str] = None,
3839
priority: int = 0,
40+
trace_headers: Optional[Mapping[str, str]] = None,
3941
) -> None:
4042
self.request_id = request_id
4143
self.client_index = client_index
@@ -98,7 +100,8 @@ def __init__(
98100
# they should also be updated simultaneously.
99101
self.output_token_ids = ConstantList(self._output_token_ids)
100102
self.all_token_ids = ConstantList(self._all_token_ids)
101-
103+
# trace_headers
104+
self.trace_headers = trace_headers
102105
# State
103106
# The number of tokens with prefix cache hits.
104107
self.num_cached_tokens = -1
@@ -131,6 +134,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
131134
if request.sampling_params else None,
132135
cache_salt=request.cache_salt,
133136
priority=request.priority,
137+
trace_headers=request.trace_headers,
134138
)
135139

136140
def append_output_token_ids(

0 commit comments

Comments
 (0)