|
1 | 1 | import asyncio |
| 2 | +from collections.abc import Mapping |
2 | 3 | import json |
3 | 4 | import weakref |
4 | 5 | from dataclasses import dataclass, field |
|
7 | 8 | Optional, TypeAlias, Union) |
8 | 9 | from weakref import WeakMethod |
9 | 10 |
|
| 11 | +from tensorrt_llm.llmapi.otel_tracing import SpanAttributes, SpanKind, extract_trace_context, global_otlp_tracer |
10 | 12 | import torch |
11 | 13 | import torch.nn.functional as F |
12 | 14 |
|
@@ -160,6 +162,7 @@ def __init__(self, |
160 | 162 | self.decoding_iter = 0 |
161 | 163 | self._done = False |
162 | 164 | self.metrics_dict = {} |
| 165 | + self.trace_headers = None |
163 | 166 |
|
164 | 167 | if has_event_loop(): |
165 | 168 | self.aqueue = AsyncQueue() |
@@ -288,6 +291,7 @@ def _handle_sequence(self, |
288 | 291 | raise ValueError( |
289 | 292 | f"Unknown finish reason: {finish_reasons[src_idx]}") |
290 | 293 | self.record_stats(output, req_perf_metrics_dict) |
| 294 | + self.do_tracing(output, req_perf_metrics_dict,req_perf_metrics_dict) |
291 | 295 |
|
292 | 296 | @nvtx_range_debug("handle_response", |
293 | 297 | color="red", |
@@ -388,6 +392,70 @@ def record_stats(self, |
388 | 392 | metrics_stats.update(processed_metrics_stat) |
389 | 393 | self.metrics_dict = metrics_stats |
390 | 394 |
|
| 395 | + def do_tracing( |
| 396 | + self, |
| 397 | + output: CompletionOutput, |
| 398 | + req_perf_metrics_dict: Optional[dict[str, float]] = None, |
| 399 | + ): |
| 400 | + if not global_otlp_tracer(): |
| 401 | + return |
| 402 | + |
| 403 | + metrics_dict = self.metrics_dict |
| 404 | + if not metrics_dict: |
| 405 | + # Insufficient request metrics available; trace generation aborted. |
| 406 | + return |
| 407 | + |
| 408 | + trace_context = extract_trace_context(self.trace_headers) |
| 409 | + sampling_params = self.sampling_params |
| 410 | + with global_otlp_tracer().start_as_current_span( |
| 411 | + "llm_request", |
| 412 | + kind=SpanKind.SERVER, |
| 413 | + context=trace_context, |
| 414 | + start_time=int( |
| 415 | + req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0) |
| 416 | + ), |
| 417 | + ) as span: |
| 418 | + |
| 419 | + def safe_set_attr(span, attr, value): |
| 420 | + if value is not None: |
| 421 | + span.set_attribute(attr, value) |
| 422 | + |
| 423 | + e2e_time = metrics_dict.get(SupportedMetricNames.E2E, -1) |
| 424 | + safe_set_attr( |
| 425 | + span, |
| 426 | + SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, |
| 427 | + sampling_params.temperature, |
| 428 | + ) |
| 429 | + safe_set_attr( |
| 430 | + span, SpanAttributes.GEN_AI_REQUEST_TOP_P, sampling_params.top_p |
| 431 | + ) |
| 432 | + safe_set_attr( |
| 433 | + span, |
| 434 | + SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, |
| 435 | + sampling_params.max_tokens, |
| 436 | + ) |
| 437 | + safe_set_attr(span, SpanAttributes.GEN_AI_REQUEST_N, sampling_params.n) |
| 438 | + safe_set_attr( |
| 439 | + span, |
| 440 | + SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, |
| 441 | + self.postproc_params.postproc_args.num_prompt_tokens, |
| 442 | + ) |
| 443 | + safe_set_attr( |
| 444 | + span, SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, output.length |
| 445 | + ) |
| 446 | + safe_set_attr( |
| 447 | + span, |
| 448 | + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, |
| 449 | + metrics_dict.get(SupportedMetricNames.TTFT, -1), |
| 450 | + ) |
| 451 | + safe_set_attr(span, SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time) |
| 452 | + safe_set_attr(span, SpanAttributes.GEN_AI_REQUEST_ID, self.id) |
| 453 | + safe_set_attr( |
| 454 | + span, |
| 455 | + SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, |
| 456 | + metrics_dict.get(SupportedMetricNames.REQUEST_QUEUE_TIME, -1), |
| 457 | + ) |
| 458 | + |
391 | 459 |
|
392 | 460 | class DetokenizedGenerationResultBase(GenerationResultBase): |
393 | 461 | ''' The base class for the generation result with detokenization support. ''' |
@@ -498,6 +566,7 @@ def __init__( |
498 | 566 | self.disaggregated_params = disaggregated_params |
499 | 567 | # minimal sampling params needed for logprob calculation |
500 | 568 | self._logprob_params = logprob_params |
| 569 | + self.trace_headers = generation_request.trace_headers |
501 | 570 |
|
502 | 571 | # for aborting the request |
503 | 572 | self._executor: Optional[weakref.ReferenceType[ |
|
0 commit comments