Skip to content

Commit b459a6e

Browse files
author
Zhang Haotong
committed
fix
Signed-off-by: Zhang Haotong <[email protected]>
1 parent 6e9fe8b commit b459a6e

File tree

7 files changed

+88
-36
lines changed

7 files changed

+88
-36
lines changed

tensorrt_llm/executor/result.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def _handle_response(self,
316316
else:
317317
self._outputs[0]._postprocess_result = response.res
318318
if response.metrics:
319-
self.metrics_dict = response.metrics
319+
self.metrics_dict.update(response.metrics)
320320

321321
if response.error:
322322
if self._background_error_handler is not None and (
@@ -391,7 +391,7 @@ def record_stats(self,
391391
stats, len(output.token_ids), self.sampling_params.n > 1)
392392
if processed_metrics_stat:
393393
metrics_stats.update(processed_metrics_stat)
394-
self.metrics_dict = metrics_stats
394+
self.metrics_dict.update(metrics_stats)
395395

396396
def do_tracing(
397397
self,
@@ -410,20 +410,29 @@ def do_tracing(
410410
trace_context = tracing.extract_trace_context(self.trace_headers)
411411
sampling_params = self.sampling_params
412412

413-
# TODO: Add request arrival time
414-
arrival_time = time.time() - metrics_dict.get(MetricNames.E2E, -1)
413+
# Since arrival_time and other timing metrics are based on different time origins,
414+
# we need to apply corrections to align them with absolute timestamps
415+
time_correction = 0
416+
arrival_timestamp = metrics_dict.get(MetricNames.ARRIVAL_TIMESTAMP, 0)
417+
arrival_time = req_perf_metrics_dict.get(
418+
RequestEventTiming.ARRIVAL_TIME, 0)
419+
if arrival_timestamp > 0:
420+
time_correction = arrival_timestamp - arrival_time
421+
else:
422+
time_correction = time.time() - metrics_dict.get(
423+
MetricNames.E2E, -1) - arrival_time
424+
415425
with tracing.global_otlp_tracer().start_as_current_span(
416426
"llm_request",
417427
kind=tracing.SpanKind.SERVER,
418428
context=trace_context,
419-
start_time=int(arrival_time * 1e9),
429+
start_time=int((arrival_time + time_correction) * 1e9),
420430
) as span:
421431

422432
def safe_set_attr(span, attr, value):
423433
if value is not None:
424434
span.set_attribute(attr, value)
425435

426-
e2e_time = metrics_dict.get(MetricNames.E2E, -1)
427436
safe_set_attr(span,
428437
tracing.SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
429438
sampling_params.temperature)
@@ -451,14 +460,36 @@ def safe_set_attr(span, attr, value):
451460
span, tracing.SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
452461
metrics_dict.get(MetricNames.TTFT, -1))
453462
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_LATENCY_E2E,
454-
e2e_time)
463+
metrics_dict.get(MetricNames.E2E, -1))
455464
safe_set_attr(span,
456465
tracing.SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
457466
metrics_dict.get(MetricNames.REQUEST_QUEUE_TIME, -1))
458467
safe_set_attr(
459468
span, tracing.SpanAttributes.GEN_AI_RESPONSE_FINISH_REASONS,
460469
json.dumps([output.finish_reason])
461470
if output.finish_reason else None)
471+
safe_set_attr(
472+
span,
473+
tracing.SpanAttributes.GEN_AI_LATENCY_KV_CACHE_TRANSFER_TIME,
474+
req_perf_metrics_dict.get(
475+
RequestEventTiming.KV_CACHE_TRANSFER_END, 0.0) -
476+
req_perf_metrics_dict.get(
477+
RequestEventTiming.KV_CACHE_TRANSFER_START, 0.0))
478+
479+
if req_perf_metrics_dict.get(
480+
RequestEventTiming.KV_CACHE_TRANSFER_START,
481+
0) and req_perf_metrics_dict.get(
482+
RequestEventTiming.KV_CACHE_TRANSFER_END, 0):
483+
tracing.add_event(
484+
tracing.SpanEvents.KV_CACHE_TRANSFER_START,
485+
timestamp=int((req_perf_metrics_dict.get(
486+
RequestEventTiming.KV_CACHE_TRANSFER_START, 0.0) +
487+
time_correction) * 1e9))
488+
tracing.add_event(
489+
tracing.SpanEvents.KV_CACHE_TRANSFER_END,
490+
timestamp=int((req_perf_metrics_dict.get(
491+
RequestEventTiming.KV_CACHE_TRANSFER_END, 0.0) +
492+
time_correction) * 1e9))
462493

463494

464495
class DetokenizedGenerationResultBase(GenerationResultBase):

tensorrt_llm/executor/worker.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,15 @@ def _get_metrics_dict(
10621062
req_perf_metrics.timing_metrics.first_scheduled_time.
10631063
total_seconds(),
10641064
RequestEventTiming.LAST_TOKEN_TIME:
1065-
req_perf_metrics.timing_metrics.last_token_time.total_seconds()
1065+
req_perf_metrics.timing_metrics.last_token_time.total_seconds(),
1066+
RequestEventTiming.KV_CACHE_TRANSFER_START:
1067+
req_perf_metrics.timing_metrics.kv_cache_transfer_start.
1068+
total_seconds(),
1069+
RequestEventTiming.KV_CACHE_TRANSFER_END:
1070+
req_perf_metrics.timing_metrics.kv_cache_transfer_end.
1071+
total_seconds(),
1072+
RequestEventTiming.KV_CACHE_SIZE:
1073+
req_perf_metrics.timing_metrics.kv_cache_size,
10661074
}
10671075
return metrics_dict
10681076

tensorrt_llm/llmapi/llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from tensorrt_llm.inputs.multimodal import MultimodalParams
1818
from tensorrt_llm.inputs.registry import DefaultInputProcessor
1919
from tensorrt_llm.llmapi import tracing
20+
from tensorrt_llm.metrics.enums import MetricNames
2021

2122
from .._utils import nvtx_range_debug
2223
from ..bindings import executor as tllm
@@ -449,6 +450,10 @@ def generate_async(
449450
scheduling_params=scheduling_params,
450451
)
451452

453+
if sampling_params.return_perf_metrics:
454+
result.metrics_dict.update(
455+
{MetricNames.ARRIVAL_TIMESTAMP: time.time()})
456+
452457
return RequestOutput._from_generation_result(result, prompt,
453458
self.tokenizer)
454459

tensorrt_llm/llmapi/tracing.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
__all__ = [
44
'SpanAttributes', 'SpanKind', 'contains_trace_headers',
5-
'extract_trace_context', 'extract_trace_headers', 'get_span_exporter',
6-
'global_otlp_tracer', 'init_tracer', 'insufficient_request_metrics_warning',
7-
'is_otel_available', 'is_tracing_enabled', 'log_tracing_disabled_warning',
8-
'set_global_otlp_tracer'
5+
'extract_trace_context', 'get_span_exporter', 'global_otlp_tracer',
6+
'init_tracer', 'insufficient_request_metrics_warning', 'is_otel_available',
7+
'is_tracing_enabled', 'log_tracing_disabled_warning',
8+
'set_global_otlp_tracer', 'extract_trace_headers'
99
]
1010

1111
import functools
@@ -98,16 +98,23 @@ def extract_trace_context(
9898
return None
9999

100100

101-
def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]:
102-
# Return only recognized trace headers with normalized lowercase keys
103-
lower_map = {k.lower(): v for k, v in headers.items()}
104-
return {h: lower_map[h] for h in TRACE_HEADERS if h in lower_map}
101+
def extract_trace_headers(
102+
headers: Mapping[str, str]) -> Optional[Mapping[str, str]]:
103+
if is_tracing_enabled():
104+
# Return only recognized trace headers with normalized lowercase keys
105+
lower_map = {k.lower(): v for k, v in headers.items()}
106+
return {h: lower_map[h] for h in TRACE_HEADERS if h in lower_map}
107+
if contains_trace_headers(headers):
108+
log_tracing_disabled_warning()
109+
return None
105110

106111

107112
def inject_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]:
108-
trace_headers = extract_trace_headers(headers) if not headers else {}
109-
TraceContextTextMapPropagator().inject(trace_headers)
110-
return trace_headers
113+
if is_tracing_enabled():
114+
trace_headers = extract_trace_headers(headers) if not headers else {}
115+
TraceContextTextMapPropagator().inject(trace_headers)
116+
return trace_headers
117+
return None
111118

112119

113120
def global_otlp_tracer() -> Tracer:
@@ -138,9 +145,17 @@ class SpanAttributes:
138145
GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token"
139146
GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e"
140147
GEN_AI_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
148+
GEN_AI_LATENCY_KV_CACHE_TRANSFER_TIME = "gen_ai.latency.kv_cache_transfer_time"
141149
GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons"
142150

143151

152+
class SpanEvents:
153+
KV_CACHE_TRANSFER_START = "kv_cache_transfer_start"
154+
KV_CACHE_TRANSFER_END = "kv_cache_transfer_end"
155+
CTX_SERVER_SELECTED = "ctx_server.selected"
156+
GEN_SERVER_SELECTED = "gen_server.selected"
157+
158+
144159
def contains_trace_headers(headers: Mapping[str, str]) -> bool:
145160
lower_keys = {k.lower() for k in headers.keys()}
146161
return any(h in lower_keys for h in TRACE_HEADERS)

tensorrt_llm/metrics/enums.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ class MetricNames(Enum):
66
TPOT = "tpot"
77
E2E = "e2e"
88
REQUEST_QUEUE_TIME = "request_queue_time"
9+
ARRIVAL_TIMESTAMP = 'arrival_timestamp'
910

1011

1112
class RequestEventTiming(Enum):
1213
ARRIVAL_TIME = "arrival_time"
1314
FIRST_TOKEN_TIME = "first_token_time" # nosec: B105
1415
FIRST_SCHEDULED_TIME = "first_scheduled_time"
1516
LAST_TOKEN_TIME = "last_token_time" # nosec: B105
17+
KV_CACHE_TRANSFER_START = "kv_cache_transfer_start"
18+
KV_CACHE_TRANSFER_END = "kv_cache_transfer_end"
19+
KV_CACHE_SIZE = "kv_cache_size"

tensorrt_llm/serve/openai_disagg_server.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletio
257257
if need_ctx:
258258
ctx_req = copy.deepcopy(req)
259259
ctx_server, _ = await self.ctx_router.get_next_server(ctx_req)
260+
#todo: rename event to something more descriptive
261+
tracing.add_event(tracing.SpanEvents.CTX_SERVER_SELECTED, attributes={"server": str(ctx_server),})
262+
260263
# TODO: add ctx_server info into generation request for pre-registration
261264
ctx_response = await self._send_context_request(ctx_server, ctx_req, trace_headers)
262265

@@ -277,13 +280,11 @@ async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletio
277280
else:
278281
ctx_response = None
279282

280-
#todo: rename event to something more descriptive
281-
tracing.add_event('picking generation server')
282-
283283
# Pick a generation server if haven't reserved one, and send request
284284
if gen_server is None:
285285
gen_server, _ = await self.gen_router.get_next_server(req)
286286
logger.debug("Sending request to gen server: %s", gen_server)
287+
tracing.add_event(tracing.SpanEvents.GEN_SERVER_SELECTED,attributes={"server": str(gen_server),})
287288

288289
if not req.stream:
289290
try:

tensorrt_llm/serve/openai_server.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import re
55
import signal
66
import traceback
7-
from collections.abc import Mapping
87
from contextlib import asynccontextmanager
98
from datetime import datetime
109
from http import HTTPStatus
@@ -15,7 +14,6 @@
1514
from fastapi import FastAPI, Request
1615
from fastapi.exceptions import RequestValidationError
1716
from fastapi.responses import JSONResponse, Response, StreamingResponse
18-
from starlette.datastructures import Headers
1917
from starlette.routing import Mount
2018
from transformers import AutoConfig, AutoProcessor
2119

@@ -339,7 +337,7 @@ async def create_chat_response(
339337
postproc_args=postproc_args,
340338
)
341339

342-
trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers))
340+
trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers))
343341

344342
promise = self.llm.generate_async(
345343
inputs=prompt,
@@ -476,7 +474,7 @@ async def generator_wrapper(generator: AsyncIterator[Any]):
476474
if request.stream else completion_response_post_processor,
477475
postproc_args=postproc_args,
478476
)
479-
trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers))
477+
trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers))
480478
promise = self.llm.generate_async(
481479
inputs=prompt,
482480
sampling_params=sampling_params,
@@ -521,13 +519,3 @@ async def __call__(self, host, port):
521519
log_level="info",
522520
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
523521
await uvicorn.Server(config).serve()
524-
525-
async def _get_trace_headers(
526-
self,
527-
headers: Headers,
528-
) -> Optional[Mapping[str, str]]:
529-
if tracing.is_tracing_enabled():
530-
return tracing.extract_trace_headers(headers)
531-
if tracing.contains_trace_headers(headers):
532-
tracing.log_tracing_disabled_warning()
533-
return None

0 commit comments

Comments
 (0)