Skip to content
Merged
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ nvidia-modelopt[torch]~=0.33.0
nvidia-nccl-cu12
nvidia-cuda-nvrtc-cu12
transformers==4.55.0
prometheus_client
prometheus_fastapi_instrumentator
pydantic>=2.9.1
pydantic-settings[yaml]
omegaconf
Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,12 @@ def deserialize(self):
self._result = tensorrt_llm.bindings.executor.deserialize_result(
self._result)

def get_result(self):
if tmp_res := tensorrt_llm.bindings.executor.deserialize_result(
self._result):
return tmp_res
return None


@dataclass
class LlmResponse:
Expand Down
15 changes: 15 additions & 0 deletions tensorrt_llm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import math
import os
import struct
import tempfile
import trace
import weakref
from contextlib import contextmanager
Expand Down Expand Up @@ -1112,3 +1113,17 @@ def is_multi_device_enable():
the number of devices
"""
return local_mpi_size() > 1


def set_prometheus_multiproc_dir() -> object:
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.10/python/sglang/srt/utils.py#L1266
global prometheus_multiproc_dir
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
logger.info("User set PROMETHEUS_MULTIPROC_DIR detected.")
prometheus_multiproc_dir = tempfile.TemporaryDirectory(
dir=os.environ["PROMETHEUS_MULTIPROC_DIR"])
else:
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
logger.info(
f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
22 changes: 15 additions & 7 deletions tensorrt_llm/executor/postproc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import deque
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional)
Optional, Union)

import zmq
import zmq.asyncio
Expand All @@ -18,7 +18,7 @@

if TYPE_CHECKING:
from .result import (DetokenizedGenerationResultBase, GenerationResult,
GenerationResultBase)
GenerationResultBase, ResponseWrapper)

__all__ = [
"PostprocWorker",
Expand Down Expand Up @@ -57,7 +57,7 @@ class PostprocWorker:

@dataclass
class Input:
rsp: "tllm.Response"
rsp: Union["tllm.Response", "ResponseWrapper"]

# The information necessary for creating a GenerationResult in the first Input for each request
sampling_params: Optional[SamplingParams] = None
Expand All @@ -69,6 +69,7 @@ class Output(NamedTuple):
res: Any
is_final: bool
error: str = ""
metrics: Optional[dict[str, float]] = None

def __init__(
self,
Expand Down Expand Up @@ -118,7 +119,9 @@ def default_record_creator(
streaming=inp.streaming,
tokenizer=tokenizer)

async def _handle_input(self, input: "PostprocWorker.Input") -> Any:
async def _handle_input(
self, input: Union["PostprocWorker.Input", "ResponseWrapper"]
) -> [Any, Optional[dict[str, float]]]:
''' Handle a single response from await_response worker. '''
if input.rsp.result.context_logits is not None or \
input.rsp.result.generation_logits is not None:
Expand All @@ -139,6 +142,7 @@ async def _handle_input(self, input: "PostprocWorker.Input") -> Any:
record._handle_response(input.rsp) # inplace
# Left the result_handler determine the final output dtype.
# NOTE: This will change the CompletionOutput._postprocess_result
metrics_dict = record.metrics_dict
if postproc_params := record.postproc_params:
result_handler, args = postproc_params.post_processor, postproc_params.postproc_args
args.tokenizer = self._tokenizer
Expand All @@ -150,7 +154,7 @@ async def _handle_input(self, input: "PostprocWorker.Input") -> Any:

# TODO: Keep only the diff token_ids and text in streaming mode when
# result_handler is not set
return out
return out, metrics_dict

async def _batched_put(self):
''' Batched IPC send. '''
Expand All @@ -173,8 +177,12 @@ async def handle_single_input(inp: PostprocWorker.Input,
client_id = inp.rsp.client_id
is_final = inp.rsp.result.is_final if is_llm_response(
inp.rsp) else True
res = await self._handle_input(inp)
batch.append(PostprocWorker.Output(client_id, res, is_final))
res, metrics = await self._handle_input(inp)
batch.append(
PostprocWorker.Output(client_id=client_id,
res=res,
is_final=is_final,
metrics=metrics))
if is_final:
self._records.pop(client_id)

Expand Down
85 changes: 79 additions & 6 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.tracer import global_tracer
from ..llmapi.utils import AsyncQueue
from ..metrics import MetricNames, MetricsCollector, RequestEventTiming
from ..sampling_params import LogprobParams, SamplingParams
from .utils import ErrorResponse, has_event_loop, is_llm_response

Expand Down Expand Up @@ -50,14 +51,18 @@ class LogProbsResult(NamedTuple):


class ResponseWrapper:
"""Wrapper of runtime response with optional outputs computed post runtime.
"""
1. Wrapper of runtime response with optional outputs computed post runtime.
2. A workaround to pass around RequestPerfMetrics.
"""

def __init__(self,
response: Union["PostprocWorker.Output", tllm.Response],
logprobs: Optional[LogProbsResult] = None):
logprobs: Optional[LogProbsResult] = None,
request_perf_metrics: Optional[dict[str, float]] = None):
self._response = response
self.logprobs = logprobs
self.request_perf_metrics = request_perf_metrics

@property
def _is_llm_response(self):
Expand All @@ -68,6 +73,14 @@ def __getattr__(self, name):
response = object.__getattribute__(self, '_response')
return getattr(response, name)

def __getstate__(self):
return (self._response, self.logprobs, self.request_perf_metrics)

def __setstate__(self, state):
self._response = state[0]
self.logprobs = state[1]
self.request_perf_metrics = state[2]


@dataclass(slots=True)
class CompletionOutput:
Expand Down Expand Up @@ -146,6 +159,7 @@ def __init__(self,
self.disaggregated_params = None
self.decoding_iter = 0
self._done = False
self.metrics_dict = {}

if has_event_loop():
self.aqueue = AsyncQueue()
Expand Down Expand Up @@ -201,7 +215,9 @@ def _handle_sequence(self,
finish_reasons,
response_tensors,
sequence_index,
logprobs_result=None):
logprobs_result=None,
req_perf_metrics_dict: Optional[dict[str,
float]] = None):
""" Handle a single sequence in the response. """

seq_idx = sequence_index
Expand Down Expand Up @@ -271,14 +287,17 @@ def _handle_sequence(self,
else:
raise ValueError(
f"Unknown finish reason: {finish_reasons[src_idx]}")
self.record_stats(output, req_perf_metrics_dict)

@nvtx_range_debug("handle_response",
color="red",
category="GenerationResultBase")
def _handle_response(self,
response: Union["PostprocWorker.Output", tllm.Response,
ResponseWrapper, ErrorResponse]):
req_perf_metrics_dict = None
if isinstance(response, ResponseWrapper):
req_perf_metrics_dict = response.request_perf_metrics
logprobs_result = response.logprobs
response = response._response
else:
Expand All @@ -291,6 +310,8 @@ def _handle_response(self,
self._outputs[0] = response.res
else:
self._outputs[0]._postprocess_result = response.res
if response.metrics:
self.metrics_dict = response.metrics

if response.error:
if self._background_error_handler is not None and (
Expand All @@ -303,7 +324,8 @@ def _handle_response(self,
handler(response.error_msg)

response_result = response.result
if hasattr(response_result, "_result"):
if hasattr(response_result, "_result") and isinstance(
response_result._result, bytes):
response_result.deserialize()

self._done = response_result.is_final
Expand All @@ -322,11 +344,12 @@ def _handle_response(self,
if self.sampling_params.use_beam_search:
for beam_idx, _ in enumerate(response_result.output_token_ids):
self._handle_sequence(finish_reasons, response_result,
beam_idx, logprobs_result)
beam_idx, logprobs_result,
req_perf_metrics_dict)
else:
self._handle_sequence(finish_reasons, response_result,
response_result.sequence_index,
logprobs_result)
logprobs_result, req_perf_metrics_dict)

if response_result.context_logits is not None:
self._context_logits = response_result.context_logits
Expand All @@ -342,6 +365,29 @@ def _handle_response(self,
else:
raise ValueError(f"Unknown response type: {response}")

def record_stats(self,
output: CompletionOutput,
stats: Optional[dict[str, float]] = None) -> None:
"""Record the stats of the generation result.

Args:
output (CompletionOutput): The output of the generation result.
stats (Optional[dict[str, float]]): The stats of the generation result. Defaults to None.
"""
if not stats:
return
metrics_stats = {}
if output.finish_reason:
metrics_stats.update({
MetricsCollector.labelname_finish_reason:
output.finish_reason
})
processed_metrics_stat = _process_req_perf_metrics(
stats, len(output.token_ids), self.sampling_params.n > 1)
if processed_metrics_stat:
metrics_stats.update(processed_metrics_stat)
self.metrics_dict = metrics_stats


class DetokenizedGenerationResultBase(GenerationResultBase):
''' The base class for the generation result with detokenization support. '''
Expand Down Expand Up @@ -688,3 +734,30 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int,

return LogProbsResult(prompt=prompt_logprobs,
generation=generation_logprobs)


def _process_req_perf_metrics(
req_perf_metrics_dict: Optional[dict[str, float]],
output_length: int,
is_multiple_response: bool = False) -> dict[MetricNames, float]:
stat = {}
if not req_perf_metrics_dict:
return stat
ttft = req_perf_metrics_dict.get(RequestEventTiming.FIRST_TOKEN_TIME, 0) - \
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
e2e = req_perf_metrics_dict.get(RequestEventTiming.LAST_TOKEN_TIME, 0) - \
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
request_queue_time = req_perf_metrics_dict.get(RequestEventTiming.FIRST_SCHEDULED_TIME, 0) - \
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
stat = {
MetricNames.TTFT: ttft,
MetricNames.E2E: e2e,
MetricNames.REQUEST_QUEUE_TIME: request_queue_time
}
if output_length > 1 and not is_multiple_response:
tpot = (req_perf_metrics_dict.get(
RequestEventTiming.LAST_TOKEN_TIME, 0) - req_perf_metrics_dict.get(
RequestEventTiming.FIRST_TOKEN_TIME, 0)) / (output_length - 1)
stat.update({MetricNames.TPOT: tpot})
stat = dict(filter(lambda item: item[1] > 0, stat.items()))
return stat
47 changes: 41 additions & 6 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
clear_sched_affinity, print_colored_debug,
print_traceback_on_error)
from ..lora_manager import LoraConfig, LoraManager
from ..metrics import RequestEventTiming
from ..prompt_adapter_manager import PromptAdapterManager
from ..runtime import ModelConfig
from ..runtime.model_runner import _engine_config_to_model_config
Expand Down Expand Up @@ -899,10 +900,8 @@ def handle_for_worker(self, responses: List[tllm.Response]) -> None:
assert response is not None
queue = self.worker.return_queue(response.client_id)

logprobs_result = _get_logprobs(self.worker, response,
response = _maybe_wrap_response(self.worker, response,
self.worker._is_pytorch_backend)
if logprobs_result:
response = ResponseWrapper(response, logprobs_result)

# For AsyncQueue.sync_q, we will batch the events to avoid too many
# event notifications, thus put without wait here.
Expand Down Expand Up @@ -940,10 +939,8 @@ def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None:
response = ErrorResponse(response.client_id, response.error_msg,
response.request_id)
else:
logprobs_result = _get_logprobs(self.worker, response,
response = _maybe_wrap_response(self.worker, response,
self.worker._is_pytorch_backend)
if logprobs_result:
response = ResponseWrapper(response, logprobs_result)

_send_rsp(self.worker,
response,
Expand Down Expand Up @@ -1051,3 +1048,41 @@ def _send_rsp(
worker._pop_result(response.client_id)
else:
raise ValueError(f"Unknown response type: {response}")


def _get_metrics_dict(
response: tllm.Response) -> dict[RequestEventTiming, float]:
req_perf_metrics, metrics_dict = None, {}
res = response.result
if res:
if hasattr(res, '_result'):
if result := res.get_result():
req_perf_metrics = result.request_perf_metrics
else:
req_perf_metrics = res.request_perf_metrics
if req_perf_metrics and req_perf_metrics.timing_metrics:
metrics_dict = {
RequestEventTiming.ARRIVAL_TIME:
req_perf_metrics.timing_metrics.arrival_time.total_seconds(),
RequestEventTiming.FIRST_TOKEN_TIME:
req_perf_metrics.timing_metrics.first_token_time.total_seconds(
),
RequestEventTiming.FIRST_SCHEDULED_TIME:
req_perf_metrics.timing_metrics.first_scheduled_time.
total_seconds(),
RequestEventTiming.LAST_TOKEN_TIME:
req_perf_metrics.timing_metrics.last_token_time.total_seconds()
}
return metrics_dict


def _maybe_wrap_response(
worker,
response: tllm.Response,
is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]:

logprobs_result = _get_logprobs(worker, response, is_pytorch_backend)
req_perf_metrics = _get_metrics_dict(response)
if logprobs_result or req_perf_metrics:
response = ResponseWrapper(response, logprobs_result, req_perf_metrics)
return response
2 changes: 1 addition & 1 deletion tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def _prepare_sampling_params(
if sampling_params._stream_interval is None:
sampling_params._stream_interval = getattr(self.args,
"stream_interval", 1)

sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics
return sampling_params

def _check_arguments(self, prompt_len: int, query_len: int,
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,10 @@ class BaseLlmArgs(StrictBaseModel):
status="deprecated",
)

return_perf_metrics: bool = Field(default=False,
description="Return perf metrics.",
status="prototype")

_parallel_config: Optional[object] = PrivateAttr(default=None)
_model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None)
_speculative_model: Optional[str] = PrivateAttr(default=None)
Expand Down
Loading