diff --git a/requirements.txt b/requirements.txt index f6b201b1f57..e2582f50385 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,6 +29,8 @@ nvidia-modelopt[torch]~=0.33.0 nvidia-nccl-cu12 nvidia-cuda-nvrtc-cu12 transformers==4.55.0 +prometheus_client +prometheus_fastapi_instrumentator pydantic>=2.9.1 pydantic-settings[yaml] omegaconf diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index a068327b6db..db360f64a83 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -250,6 +250,12 @@ def deserialize(self): self._result = tensorrt_llm.bindings.executor.deserialize_result( self._result) + def get_result(self): + if tmp_res := tensorrt_llm.bindings.executor.deserialize_result( + self._result): + return tmp_res + return None + @dataclass class LlmResponse: diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 337caae5126..c68777b96a7 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -20,6 +20,7 @@ import math import os import struct +import tempfile import trace import weakref from contextlib import contextmanager @@ -1112,3 +1113,17 @@ def is_multi_device_enable(): the number of devices """ return local_mpi_size() > 1 + + +def set_prometheus_multiproc_dir() -> object: + # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.10/python/sglang/srt/utils.py#L1266 + global prometheus_multiproc_dir + if "PROMETHEUS_MULTIPROC_DIR" in os.environ: + logger.info("User set PROMETHEUS_MULTIPROC_DIR detected.") + prometheus_multiproc_dir = tempfile.TemporaryDirectory( + dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]) + else: + prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + logger.info( + f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}") diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index 2e5a3cd2967..7dff3289185 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -3,7 +3,7 @@ from collections import deque from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, - Optional) + Optional, Union) import zmq import zmq.asyncio @@ -18,7 +18,7 @@ if TYPE_CHECKING: from .result import (DetokenizedGenerationResultBase, GenerationResult, - GenerationResultBase) + GenerationResultBase, ResponseWrapper) __all__ = [ "PostprocWorker", @@ -57,7 +57,7 @@ class PostprocWorker: @dataclass class Input: - rsp: "tllm.Response" + rsp: Union["tllm.Response", "ResponseWrapper"] # The information necessary for creating a GenerationResult in the first Input for each request sampling_params: Optional[SamplingParams] = None @@ -69,6 +69,7 @@ class Output(NamedTuple): res: Any is_final: bool error: str = "" + metrics: Optional[dict[str, float]] = None def __init__( self, @@ -118,7 +119,9 @@ def default_record_creator( streaming=inp.streaming, tokenizer=tokenizer) - async def _handle_input(self, input: "PostprocWorker.Input") -> Any: + async def _handle_input( + self, input: Union["PostprocWorker.Input", "ResponseWrapper"] + ) -> [Any, Optional[dict[str, float]]]: ''' Handle a single response from await_response worker. ''' if input.rsp.result.context_logits is not None or \ input.rsp.result.generation_logits is not None: @@ -139,6 +142,7 @@ async def _handle_input(self, input: "PostprocWorker.Input") -> Any: record._handle_response(input.rsp) # inplace # Left the result_handler determine the final output dtype. # NOTE: This will change the CompletionOutput._postprocess_result + metrics_dict = record.metrics_dict if postproc_params := record.postproc_params: result_handler, args = postproc_params.post_processor, postproc_params.postproc_args args.tokenizer = self._tokenizer @@ -150,7 +154,7 @@ async def _handle_input(self, input: "PostprocWorker.Input") -> Any: # TODO: Keep only the diff token_ids and text in streaming mode when # result_handler is not set - return out + return out, metrics_dict async def _batched_put(self): ''' Batched IPC send. ''' @@ -173,8 +177,12 @@ async def handle_single_input(inp: PostprocWorker.Input, client_id = inp.rsp.client_id is_final = inp.rsp.result.is_final if is_llm_response( inp.rsp) else True - res = await self._handle_input(inp) - batch.append(PostprocWorker.Output(client_id, res, is_final)) + res, metrics = await self._handle_input(inp) + batch.append( + PostprocWorker.Output(client_id=client_id, + res=res, + is_final=is_final, + metrics=metrics)) if is_final: self._records.pop(client_id) diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 0408a6c757c..2566a699aa4 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -15,6 +15,7 @@ from ..disaggregated_params import DisaggregatedParams from ..llmapi.tracer import global_tracer from ..llmapi.utils import AsyncQueue +from ..metrics import MetricNames, MetricsCollector, RequestEventTiming from ..sampling_params import LogprobParams, SamplingParams from .utils import ErrorResponse, has_event_loop, is_llm_response @@ -50,14 +51,18 @@ class LogProbsResult(NamedTuple): class ResponseWrapper: - """Wrapper of runtime response with optional outputs computed post runtime. + """ + 1. Wrapper of runtime response with optional outputs computed post runtime. + 2. A workaround to pass around RequestPerfMetrics. """ def __init__(self, response: Union["PostprocWorker.Output", tllm.Response], - logprobs: Optional[LogProbsResult] = None): + logprobs: Optional[LogProbsResult] = None, + request_perf_metrics: Optional[dict[str, float]] = None): self._response = response self.logprobs = logprobs + self.request_perf_metrics = request_perf_metrics @property def _is_llm_response(self): @@ -68,6 +73,14 @@ def __getattr__(self, name): response = object.__getattribute__(self, '_response') return getattr(response, name) + def __getstate__(self): + return (self._response, self.logprobs, self.request_perf_metrics) + + def __setstate__(self, state): + self._response = state[0] + self.logprobs = state[1] + self.request_perf_metrics = state[2] + @dataclass(slots=True) class CompletionOutput: @@ -146,6 +159,7 @@ def __init__(self, self.disaggregated_params = None self.decoding_iter = 0 self._done = False + self.metrics_dict = {} if has_event_loop(): self.aqueue = AsyncQueue() @@ -201,7 +215,9 @@ def _handle_sequence(self, finish_reasons, response_tensors, sequence_index, - logprobs_result=None): + logprobs_result=None, + req_perf_metrics_dict: Optional[dict[str, + float]] = None): """ Handle a single sequence in the response. """ seq_idx = sequence_index @@ -271,6 +287,7 @@ def _handle_sequence(self, else: raise ValueError( f"Unknown finish reason: {finish_reasons[src_idx]}") + self.record_stats(output, req_perf_metrics_dict) @nvtx_range_debug("handle_response", color="red", @@ -278,7 +295,9 @@ def _handle_sequence(self, def _handle_response(self, response: Union["PostprocWorker.Output", tllm.Response, ResponseWrapper, ErrorResponse]): + req_perf_metrics_dict = None if isinstance(response, ResponseWrapper): + req_perf_metrics_dict = response.request_perf_metrics logprobs_result = response.logprobs response = response._response else: @@ -291,6 +310,8 @@ def _handle_response(self, self._outputs[0] = response.res else: self._outputs[0]._postprocess_result = response.res + if response.metrics: + self.metrics_dict = response.metrics if response.error: if self._background_error_handler is not None and ( @@ -303,7 +324,8 @@ def _handle_response(self, handler(response.error_msg) response_result = response.result - if hasattr(response_result, "_result"): + if hasattr(response_result, "_result") and isinstance( + response_result._result, bytes): response_result.deserialize() self._done = response_result.is_final @@ -322,11 +344,12 @@ def _handle_response(self, if self.sampling_params.use_beam_search: for beam_idx, _ in enumerate(response_result.output_token_ids): self._handle_sequence(finish_reasons, response_result, - beam_idx, logprobs_result) + beam_idx, logprobs_result, + req_perf_metrics_dict) else: self._handle_sequence(finish_reasons, response_result, response_result.sequence_index, - logprobs_result) + logprobs_result, req_perf_metrics_dict) if response_result.context_logits is not None: self._context_logits = response_result.context_logits @@ -342,6 +365,29 @@ def _handle_response(self, else: raise ValueError(f"Unknown response type: {response}") + def record_stats(self, + output: CompletionOutput, + stats: Optional[dict[str, float]] = None) -> None: + """Record the stats of the generation result. + + Args: + output (CompletionOutput): The output of the generation result. + stats (Optional[dict[str, float]]): The stats of the generation result. Defaults to None. + """ + if not stats: + return + metrics_stats = {} + if output.finish_reason: + metrics_stats.update({ + MetricsCollector.labelname_finish_reason: + output.finish_reason + }) + processed_metrics_stat = _process_req_perf_metrics( + stats, len(output.token_ids), self.sampling_params.n > 1) + if processed_metrics_stat: + metrics_stats.update(processed_metrics_stat) + self.metrics_dict = metrics_stats + class DetokenizedGenerationResultBase(GenerationResultBase): ''' The base class for the generation result with detokenization support. ''' @@ -688,3 +734,30 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int, return LogProbsResult(prompt=prompt_logprobs, generation=generation_logprobs) + + +def _process_req_perf_metrics( + req_perf_metrics_dict: Optional[dict[str, float]], + output_length: int, + is_multiple_response: bool = False) -> dict[MetricNames, float]: + stat = {} + if not req_perf_metrics_dict: + return stat + ttft = req_perf_metrics_dict.get(RequestEventTiming.FIRST_TOKEN_TIME, 0) - \ + req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0) + e2e = req_perf_metrics_dict.get(RequestEventTiming.LAST_TOKEN_TIME, 0) - \ + req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0) + request_queue_time = req_perf_metrics_dict.get(RequestEventTiming.FIRST_SCHEDULED_TIME, 0) - \ + req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0) + stat = { + MetricNames.TTFT: ttft, + MetricNames.E2E: e2e, + MetricNames.REQUEST_QUEUE_TIME: request_queue_time + } + if output_length > 1 and not is_multiple_response: + tpot = (req_perf_metrics_dict.get( + RequestEventTiming.LAST_TOKEN_TIME, 0) - req_perf_metrics_dict.get( + RequestEventTiming.FIRST_TOKEN_TIME, 0)) / (output_length - 1) + stat.update({MetricNames.TPOT: tpot}) + stat = dict(filter(lambda item: item[1] > 0, stat.items())) + return stat diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index db8d84fcc89..c3a827bb00b 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -25,6 +25,7 @@ clear_sched_affinity, print_colored_debug, print_traceback_on_error) from ..lora_manager import LoraConfig, LoraManager +from ..metrics import RequestEventTiming from ..prompt_adapter_manager import PromptAdapterManager from ..runtime import ModelConfig from ..runtime.model_runner import _engine_config_to_model_config @@ -899,10 +900,8 @@ def handle_for_worker(self, responses: List[tllm.Response]) -> None: assert response is not None queue = self.worker.return_queue(response.client_id) - logprobs_result = _get_logprobs(self.worker, response, + response = _maybe_wrap_response(self.worker, response, self.worker._is_pytorch_backend) - if logprobs_result: - response = ResponseWrapper(response, logprobs_result) # For AsyncQueue.sync_q, we will batch the events to avoid too many # event notifications, thus put without wait here. @@ -940,10 +939,8 @@ def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None: response = ErrorResponse(response.client_id, response.error_msg, response.request_id) else: - logprobs_result = _get_logprobs(self.worker, response, + response = _maybe_wrap_response(self.worker, response, self.worker._is_pytorch_backend) - if logprobs_result: - response = ResponseWrapper(response, logprobs_result) _send_rsp(self.worker, response, @@ -1051,3 +1048,41 @@ def _send_rsp( worker._pop_result(response.client_id) else: raise ValueError(f"Unknown response type: {response}") + + +def _get_metrics_dict( + response: tllm.Response) -> dict[RequestEventTiming, float]: + req_perf_metrics, metrics_dict = None, {} + res = response.result + if res: + if hasattr(res, '_result'): + if result := res.get_result(): + req_perf_metrics = result.request_perf_metrics + else: + req_perf_metrics = res.request_perf_metrics + if req_perf_metrics and req_perf_metrics.timing_metrics: + metrics_dict = { + RequestEventTiming.ARRIVAL_TIME: + req_perf_metrics.timing_metrics.arrival_time.total_seconds(), + RequestEventTiming.FIRST_TOKEN_TIME: + req_perf_metrics.timing_metrics.first_token_time.total_seconds( + ), + RequestEventTiming.FIRST_SCHEDULED_TIME: + req_perf_metrics.timing_metrics.first_scheduled_time. + total_seconds(), + RequestEventTiming.LAST_TOKEN_TIME: + req_perf_metrics.timing_metrics.last_token_time.total_seconds() + } + return metrics_dict + + +def _maybe_wrap_response( + worker, + response: tllm.Response, + is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]: + + logprobs_result = _get_logprobs(worker, response, is_pytorch_backend) + req_perf_metrics = _get_metrics_dict(response) + if logprobs_result or req_perf_metrics: + response = ResponseWrapper(response, logprobs_result, req_perf_metrics) + return response diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 12bb079eaf5..b578ba07211 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -548,7 +548,7 @@ def _prepare_sampling_params( if sampling_params._stream_interval is None: sampling_params._stream_interval = getattr(self.args, "stream_interval", 1) - + sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics return sampling_params def _check_arguments(self, prompt_len: int, query_len: int, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 1169a779be6..279d26999b2 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1311,6 +1311,10 @@ class BaseLlmArgs(StrictBaseModel): status="deprecated", ) + return_perf_metrics: bool = Field(default=False, + description="Return perf metrics.", + status="prototype") + _parallel_config: Optional[object] = PrivateAttr(default=None) _model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None) _speculative_model: Optional[str] = PrivateAttr(default=None) diff --git a/tensorrt_llm/metrics/__init__.py b/tensorrt_llm/metrics/__init__.py new file mode 100644 index 00000000000..f68d9f698ac --- /dev/null +++ b/tensorrt_llm/metrics/__init__.py @@ -0,0 +1,4 @@ +from .collector import * +from .enums import * + +__all__ = ["MetricsCollector", "MetricNames", "RequestEventTiming"] diff --git a/tensorrt_llm/metrics/collector.py b/tensorrt_llm/metrics/collector.py new file mode 100644 index 00000000000..952529393c6 --- /dev/null +++ b/tensorrt_llm/metrics/collector.py @@ -0,0 +1,105 @@ +"""Utilities for Prometheus Metrics Collection.""" + +import time +from typing import Dict, Optional, Union + +from .enums import MetricNames + + +# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0rc1/vllm/engine/metrics.py#L30 +class MetricsCollector: + labelname_finish_reason = "finished_reason" + + def __init__(self, labels: Dict[str, str]) -> None: + from prometheus_client import Counter, Histogram + self.last_log_time = time.time() + self.labels = labels + + self.finish_reason_label = { + MetricsCollector.labelname_finish_reason: "unknown" + } + self.labels_with_finished_reason = { + **self.labels, + **self.finish_reason_label + } + + self.counter_request_success = Counter( + name="request_success_total", + documentation="Count of successfully processed requests.", + labelnames=self.labels_with_finished_reason.keys()) + + self.histogram_e2e_time_request = Histogram( + name="e2e_request_latency_seconds", + documentation="Histogram of end to end request latency in seconds.", + buckets=[ + 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, + 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 + ], + labelnames=self.labels.keys()) + + self.histogram_time_to_first_token = Histogram( + name="time_to_first_token_seconds", + documentation="Histogram of time to first token in seconds.", + buckets=[ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, + 2560.0 + ], + labelnames=self.labels.keys()) + + self.histogram_time_per_output_token = Histogram( + name="time_per_output_token_seconds", + documentation="Histogram of time per output token in seconds.", + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + ], + labelnames=self.labels.keys()) + + self.histogram_queue_time_request = Histogram( + name="request_queue_time_seconds", + documentation= + "Histogram of time spent in WAITING phase for request.", + buckets=[ + 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, + 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 + ], + labelnames=self.labels.keys()) + + def _label_merge(self, labels: Dict[str, str]) -> Dict[str, str]: + if labels is None or len(labels) == 0: + return self.labels + return {**self.labels, **labels} + + def _log_counter(self, counter, labels: Dict[str, str], + data: Union[int, float]) -> None: + # Convenience function for logging to counter. + counter.labels(**self._label_merge(labels)).inc(data) + + def _log_histogram(self, histogram, data: Union[int, float]) -> None: + # Convenience function for logging to histogram. + histogram.labels(**self.labels).observe(data) + + def log_request_success(self, data: Union[int, float], + labels: Dict[str, str]) -> None: + self._log_counter(self.counter_request_success, labels, data) + self.last_log_time = time.time() + + def log_histogram(self, data: Optional[dict[str, float]]) -> None: + if e2e := data.get(MetricNames.E2E, 0): + self._log_histogram(self.histogram_e2e_time_request, e2e) + if ttft := data.get(MetricNames.TTFT, 0): + self._log_histogram(self.histogram_time_to_first_token, ttft) + if tpot := data.get(MetricNames.TPOT, 0): + self._log_histogram(self.histogram_time_per_output_token, tpot) + if request_queue_time := data.get(MetricNames.REQUEST_QUEUE_TIME, 0): + self._log_histogram(self.histogram_queue_time_request, + request_queue_time) + self.last_log_time = time.time() + + def log_metrics_dict(self, metrics_dict: dict[str, float]) -> None: + if finish_reason := metrics_dict.get( + MetricsCollector.labelname_finish_reason): + self.log_request_success( + 1, {MetricsCollector.labelname_finish_reason: finish_reason}) + self.log_histogram(metrics_dict) diff --git a/tensorrt_llm/metrics/enums.py b/tensorrt_llm/metrics/enums.py new file mode 100644 index 00000000000..5ce982281bc --- /dev/null +++ b/tensorrt_llm/metrics/enums.py @@ -0,0 +1,15 @@ +from enum import Enum + + +class MetricNames(Enum): + TTFT = "ttft" + TPOT = "tpot" + E2E = "e2e" + REQUEST_QUEUE_TIME = "request_queue_time" + + +class RequestEventTiming(Enum): + ARRIVAL_TIME = "arrival_time" + FIRST_TOKEN_TIME = "first_token_time" # nosec: B105 + FIRST_SCHEDULED_TIME = "first_scheduled_time" + LAST_TOKEN_TIME = "last_token_time" # nosec: B105 diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index d90578ce36b..1b1e15ec625 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import asyncio import os +import re import signal import traceback from contextlib import asynccontextmanager @@ -13,6 +14,7 @@ from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.routing import Mount from transformers import AutoConfig, AutoProcessor from tensorrt_llm._tensorrt_engine import LLM @@ -25,6 +27,7 @@ from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig, ServerRole from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.logger import logger +from tensorrt_llm.metrics.collector import MetricsCollector from tensorrt_llm.serve.chat_utils import (check_multiple_response, parse_chat_messages_coroutines) from tensorrt_llm.serve.metadata_server import create_metadata_server @@ -42,7 +45,7 @@ completion_stream_post_processor) from tensorrt_llm.version import __version__ as VERSION -from .._utils import nvtx_mark +from .._utils import nvtx_mark, set_prometheus_multiproc_dir # yapf: enale TIMEOUT_KEEP_ALIVE = 5 # seconds. @@ -78,6 +81,13 @@ def __init__(self, self.model = model_dir.name else: self.model = model + self.metrics_collector = None + if self.llm.args.return_perf_metrics: + set_prometheus_multiproc_dir() + self.metrics_collector = MetricsCollector({ + "model_name": "undefined", + "engine_type": "undefined" + }) @asynccontextmanager async def lifespan(app: FastAPI): @@ -151,6 +161,32 @@ def register_routes(self): self.app.add_api_route("/v1/chat/completions", self.openai_chat, methods=["POST"]) + if self.llm.args.return_perf_metrics: + # register /prometheus/metrics + self.mount_metrics() + + def mount_metrics(self): + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import (CollectorRegistry, make_asgi_app, + multiprocess) + from prometheus_fastapi_instrumentator import Instrumentator + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + Instrumentator( + should_group_status_codes=False, + should_respect_env_var=True, + excluded_handlers=[ + ".*" + ], + registry=registry, + ).add().instrument(self.app).expose(self.app) + metrics_app = make_asgi_app(registry=registry) + metrics_route = Mount("/prometheus/metrics", metrics_app) + metrics_route.path_regex = re.compile("^/prometheus/metrics(?P.*)$") + self.app.routes.append(metrics_route) async def health(self) -> Response: return Response(status_code=200) @@ -228,6 +264,8 @@ async def chat_stream_generator( post_processor, args = postproc_params.post_processor, postproc_params.postproc_args async for res in promise: pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) + if res.finished and self.metrics_collector: + self.metrics_collector.log_metrics_dict(res.metrics_dict) for pp_res in pp_results: yield pp_res yield "data: [DONE]\n\n" @@ -245,6 +283,8 @@ async def create_chat_response( # Add prompt_tokens_ids to the response if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": chat_response.prompt_token_ids = promise.prompt_token_ids + if promise.finished and self.metrics_collector: + self.metrics_collector.log_metrics_dict(promise.metrics_dict) return chat_response try: @@ -337,6 +377,8 @@ async def completion_response(promise: RequestOutput, if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": # Include prompt token ids for context-only requests pp_result.prompt_token_ids = response.prompt_token_ids + if response.finished and self.metrics_collector: + self.metrics_collector.log_metrics_dict(response.metrics_dict) return pp_result def merge_completion_responses(responses: List[CompletionResponse]) -> CompletionResponse: @@ -372,6 +414,8 @@ async def completion_generator(promise: RequestOutput, params: Optional[Postproc pp_result = post_processor(output, args) else: pp_result = output.outputs[0]._postprocess_result + if output.finished and self.metrics_collector: + self.metrics_collector.log_metrics_dict(output.metrics_dict) for pp_res in pp_result: yield pp_res diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 33e3f55765e..7c6a203d0b4 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1497,6 +1497,13 @@ def test_openai_chat_with_logit_bias(llm_root, llm_venv, sampler: str): ]) +def test_openai_prometheus(llm_root, llm_venv): + test_root = unittest_path() / "llmapi" / "apps" + llm_venv.run_cmd( + ["-m", "pytest", + str(test_root / "_test_openai_prometheus.py")]) + + def test_openai_lora(llm_root, llm_venv): test_root = unittest_path() / "llmapi" / "apps" llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_lora.py")]) diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 891649e5b9f..ce285faa799 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -25,6 +25,7 @@ l0_a10: - test_e2e.py::test_openai_chat_structural_tag_example - test_e2e.py::test_openai_chat_json_example - test_e2e.py::test_openai_chat_multimodal_example + - test_e2e.py::test_openai_prometheus - test_e2e.py::test_openai_lora - test_e2e.py::test_trtllm_serve_multimodal_example - test_e2e.py::test_trtllm_serve_lora_example diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 5ac588dc911..5a846dd7869 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -27,6 +27,10 @@ methods: annotation: Optional[int] default: null status: prototype + return_perf_metrics: + annotation: bool + default: False + status: prototype # Bindings and mirrored configs peft_cache_config: annotation: Optional[tensorrt_llm.llmapi.llm_args.PeftCacheConfig] diff --git a/tests/unittest/api_stability/references/request_output.yaml b/tests/unittest/api_stability/references/request_output.yaml index 52e499dd147..7e3054cd5ef 100644 --- a/tests/unittest/api_stability/references/request_output.yaml +++ b/tests/unittest/api_stability/references/request_output.yaml @@ -11,4 +11,13 @@ methods: clear_logprob_params: parameters: {} return_annotation: None + record_stats: + parameters: + output: + annotation: tensorrt_llm.executor.result.CompletionOutput + default: inspect._empty + stats: + annotation: Optional[dict[str, float]] + default: None + return_annotation: None properties: {} diff --git a/tests/unittest/llmapi/apps/_test_openai_prometheus.py b/tests/unittest/llmapi/apps/_test_openai_prometheus.py new file mode 100644 index 00000000000..8a360668fd5 --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_openai_prometheus.py @@ -0,0 +1,67 @@ +import logging +import os +import tempfile +from urllib.request import urlopen + +import pytest +import yaml + +from ..test_llm import get_model_path +from .openai_server import RemoteOpenAIServer + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module", ids=["TinyLlama-1.1B-Chat"]) +def model_name(): + return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + + +@pytest.fixture(scope="module") +def temp_extra_llm_api_options_file(request): + temp_dir = tempfile.gettempdir() + temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") + try: + extra_llm_api_options_dict = {"return_perf_metrics": True} + + with open(temp_file_path, 'w') as f: + yaml.dump(extra_llm_api_options_dict, f) + + yield temp_file_path + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + +@pytest.fixture(scope="module") +def server(model_name: str, + temp_extra_llm_api_options_file: str) -> RemoteOpenAIServer: + model_path = get_model_path(model_name) + args = ["--backend", "pytorch", "--tp_size", "1"] + args.extend(["--extra_llm_api_options", temp_extra_llm_api_options_file]) + logger.info(f"Starting server, model: {model_name}, args: {args}") + with RemoteOpenAIServer(model_path, args) as remote_server: + yield remote_server + logger.info("Tests completed, shutting down server") + + +def test_metrics_endpoint(server: RemoteOpenAIServer): + + client = server.get_client() + client.completions.create( + model="Server", + prompt="Hello, my name is", + max_tokens=25, + stream=False, + ) + + response = urlopen(f'{server.url_root}/prometheus/metrics') + assert response.status is 200 + + data = response.read().decode("utf-8") + assert "request_success_total" in data + assert "e2e_request_latency_seconds" in data + assert "time_to_first_token_seconds" in data + assert "request_queue_time_seconds" in data diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index bb5d028b556..541965b588f 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -6,6 +6,7 @@ from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.llmapi.llm_args import PeftCacheConfig from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer +from tensorrt_llm.metrics import MetricNames from tensorrt_llm.sampling_params import SamplingParams # isort: off @@ -195,6 +196,27 @@ def test_llm_perf_metrics(): assert perf_metrics.last_iter == perf_metrics.iter +def test_llm_prometheus(): + test_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(max_tokens=10, temperature=0.8, top_p=0.95) + llm = LLM(model=llama_model_path, + return_perf_metrics=True, + kv_cache_config=global_kvcache_config) + for test_prompt in test_prompts: + request_output = llm.generate(test_prompt, sampling_params) + assert request_output.metrics_dict is not None + assert MetricNames.REQUEST_QUEUE_TIME in request_output.metrics_dict + assert MetricNames.TPOT in request_output.metrics_dict + assert MetricNames.TTFT in request_output.metrics_dict + assert MetricNames.E2E in request_output.metrics_dict + assert request_output.outputs is not None + + @pytest.mark.parametrize("streaming", [True, False]) def test_llm_with_postprocess_parallel_and_result_handler(streaming): run_llm_with_postprocess_parallel_and_result_handler(streaming,