diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index f65db385d23..5bdcae38613 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -538,9 +538,9 @@ def get_req_stats(req: LlmRequest) -> RequestStats: req_stat.dis_serving_stats.kv_cache_size = req.kv_cache_size return req_stat - def get_queued_req_stats(req: LlmRequest) -> RequestStats: + def get_queued_req_stats(request_id: int) -> RequestStats: req_stat = RequestStats() - req_stat.id = req.request_id + req_stat.id = request_id req_stat.context_prefill_position = 0 req_stat.num_generated_tokens = 0 req_stat.avg_num_decoded_tokens_per_iter = 0 @@ -558,9 +558,10 @@ def get_queued_req_stats(req: LlmRequest) -> RequestStats: req_stats.append(req_stat) for req in list(self.request_queue.queue): - req_stat = get_queued_req_stats(req) - req.stage = RequestStage.QUEUED - req_stats.append(req_stat) + if isinstance(req, Tuple): + req_stat = get_queued_req_stats(req[0]) + req_stat.stage = RequestStage.QUEUED + req_stats.append(req_stat) for req in finished_requests: req_stat = get_req_stats(req) diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index 41bb462824c..98bff93596d 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -1912,6 +1912,64 @@ def test_llm_get_stats(return_context_logits, enable_iter_req_stats): enable_iter_req_stats=enable_iter_req_stats) +def test_llm_get_queued_stats(): + + enable_iter_req_stats = True + use_overlap = False + tp_size = 1 + + num_requests = 3000 + repeated_prompts = ["A B C D E F G H I J K L M"] * num_requests + + llm_args_extra = {} + sampling_args_extra = {} + + from tensorrt_llm._torch import LLM as LLM_torch + from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig + + llm_args_extra["pytorch_backend_config"] = PyTorchConfig( + enable_iter_perf_stats=True, + enable_iter_req_stats=enable_iter_req_stats, + disable_overlap_scheduler=not use_overlap, + ) + + llm = LLM_torch(model=llama_model_path, + kv_cache_config=global_kvcache_config, + tensor_parallel_size=tp_size, + fast_build=True, + **llm_args_extra) + + max_tokens = 10 + sampling_params = SamplingParams(max_tokens=max_tokens, + **sampling_args_extra) + + max_tries = 10 + has_queue_requests = False + + while not has_queue_requests and max_tries > 0: + max_tries -= 1 + # Generate outputs, which will queue requests + outputs = llm.generate(repeated_prompts, + sampling_params=sampling_params) + + results = llm.get_stats(2) + + for index, result in enumerate(results): + if "requestStats" in result: + for requestStat in result["requestStats"]: + if requestStat["stage"] == "QUEUED": + has_queue_requests = True + assert requestStat["numGeneratedTokens"] == 0 + + if not has_queue_requests: + print("No queued requests found, retrying...") + asyncio.sleep(1) + else: + print("Found queued requests, breaking out of the loop.") + + assert has_queue_requests + + def llm_get_stats_async_test_harness(tp_size: int = 1, return_context_logits: bool = False, pytorch_backend: bool = False,