diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 61fb858b00c..36c9869d13a 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -551,9 +551,9 @@ def get_req_stats(req: LlmRequest) -> RequestStats: req_stat.dis_serving_stats.kv_cache_size = req.kv_cache_size return req_stat - def get_queued_req_stats(req: RequestQueueItem) -> RequestStats: + def get_queued_req_stats(request_id: int) -> RequestStats: req_stat = RequestStats() - req_stat.id = req.id + req_stat.id = request_id req_stat.context_prefill_position = 0 req_stat.num_generated_tokens = 0 req_stat.avg_num_decoded_tokens_per_iter = 0 @@ -571,9 +571,10 @@ def get_queued_req_stats(req: RequestQueueItem) -> RequestStats: req_stats.append(req_stat) for req in list(self.request_queue.queue): - req_stat = get_queued_req_stats(req) - req_stat.stage = RequestStage.QUEUED - req_stats.append(req_stat) + if isinstance(req, RequestQueueItem): + req_stat = get_queued_req_stats(req.id) + req_stat.stage = RequestStage.QUEUED + req_stats.append(req_stat) for req in finished_requests: req_stat = get_req_stats(req) diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index 1d4842fd969..290273b5b64 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -1926,6 +1926,65 @@ def test_llm_get_stats(return_context_logits, enable_iter_req_stats): enable_iter_req_stats=enable_iter_req_stats) +def test_llm_get_queued_stats(): + + enable_iter_req_stats = True + use_overlap = False + tp_size = 1 + + num_requests = 10 + repeated_prompts = ["A B C D E F G H I J K L M"] * num_requests + + llm_args_extra = {} + sampling_args_extra = {} + + from tensorrt_llm._torch import LLM as LLM_torch + + llm_args_extra.update( + dict(enable_iter_perf_stats=True, + enable_iter_req_stats=enable_iter_req_stats, + disable_overlap_scheduler=not use_overlap)) + LLM_CLASS = LLM_torch + + llm = LLM_CLASS(model=llama_model_path, + kv_cache_config=global_kvcache_config, + tensor_parallel_size=tp_size, + fast_build=True, + max_batch_size=1, + **llm_args_extra) + + max_tokens = 10 + sampling_params = SamplingParams(max_tokens=max_tokens, + **sampling_args_extra) + + max_tries = 10 + has_queue_requests = False + + while not has_queue_requests and max_tries > 0: + max_tries -= 1 + # Generate outputs, which will queue requests + for output in llm.generate(repeated_prompts, + sampling_params=sampling_params): + print(output) + + results = llm.get_stats(2) + + for index, result in enumerate(results): + if "requestStats" in result: + for requestStat in result["requestStats"]: + if requestStat["stage"] == "QUEUED": + has_queue_requests = True + assert requestStat["numGeneratedTokens"] == 0 + + if not has_queue_requests: + print("No queued requests found, retrying...") + asyncio.sleep(1) + else: + print("Found queued requests, breaking out of the loop.") + + assert has_queue_requests + + def llm_get_stats_async_test_harness(tp_size: int = 1, return_context_logits: bool = False, pytorch_backend: bool = False,