diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 03ad7df27b1..e8e72aafba9 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -106,7 +106,7 @@ class CompletionResponse(OpenAIBaseModel): usage: UsageInfo # Add prompt_tokens_ids to the response to remove the tokenization # in the generation server in disaggreated serving - prompt_token_ids: Optional[List[List[int]]] = None + prompt_token_ids: Optional[Union[List[List[int]], List[int]]] = None class CompletionResponseStreamChoice(OpenAIBaseModel): diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 9223c9ddd7b..bf9fdb2859d 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -6,7 +6,7 @@ from datetime import datetime from http import HTTPStatus from pathlib import Path -from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple +from typing import Any, AsyncGenerator, AsyncIterator, List, Optional import uvicorn from fastapi import FastAPI, Request @@ -127,7 +127,7 @@ def postproc_worker_enabled(self) -> bool: def create_error_response( message: str, err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> Response: error_response = ErrorResponse(message=message, type=err_type, code=status_code.value) @@ -316,76 +316,77 @@ async def create_chat_response( async def openai_completion(self, request: CompletionRequest, raw_request: Request) -> Response: - def merge_promises( - promises: List[RequestOutput], - postproc_params_collections: List[Optional[PostprocParams]] - ) -> AsyncIterator[Tuple[RequestOutput, Optional[PostprocParams]]]: - outputs = asyncio.Queue() - finished = [False] * len(promises) - - async def producer(i: int, promise: RequestOutput, postproc_params: Optional[PostprocParams]): - async for output in promise: - await outputs.put((output, postproc_params)) - finished[i] = True - - _tasks = [ - asyncio.create_task(producer(i, promise, postproc_params)) - for i, (promise, postproc_params) in enumerate(zip(promises, postproc_params_collections)) - ] - - async def consumer(): - while not all(finished) or not outputs.empty(): - item = await outputs.get() - yield item - await asyncio.gather(*_tasks) - - return consumer() - - async def create_completion_generator( - generator: AsyncIterator[Tuple[RequestOutput, Optional[PostprocParams]]]): - async for request_output, postproc_params in generator: - if not self.postproc_worker_enabled: - post_processor, args = postproc_params.post_processor, postproc_params.postproc_args - pp_result = post_processor(request_output, args) - else: - pp_result = request_output.outputs[0]._postprocess_result - for pp_res in pp_result: - yield pp_res - yield "data: [DONE]\n\n" + async def completion_response(promise: RequestOutput, + postproc_params: Optional[PostprocParams]) -> CompletionResponse: + response = await promise + if not self.postproc_worker_enabled: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + pp_result = post_processor(response, args) + else: + pp_result = response.outputs[0]._postprocess_result + if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": + # Include prompt token ids for context-only requests + pp_result.prompt_token_ids = response.prompt_token_ids + return pp_result - async def create_completion_response( - generator: AsyncIterator[Tuple[RequestOutput, Optional[PostprocParams]]], disaggregated_params: Optional[LlmDisaggregatedParams] = None) -> CompletionResponse: + def merge_completion_responses(responses: List[CompletionResponse]) -> CompletionResponse: all_choices: List[CompletionResponseChoice] = [] all_prompt_token_ids: List[List[int]] = [] num_prompt_tokens = num_gen_tokens = 0 - async for request_output, postproc_params in generator: - pp_result: CompletionResponse - if not self.postproc_worker_enabled: - post_processor, args = postproc_params.post_processor, postproc_params.postproc_args - pp_result = post_processor(request_output, args) - else: - pp_result = request_output.outputs[0]._postprocess_result - - choices, usage = pp_result.choices, pp_result.usage + for rsp in responses: + choices, usage = rsp.choices, rsp.usage all_choices.extend(choices) num_prompt_tokens += usage.prompt_tokens num_gen_tokens += usage.completion_tokens - #Include prompt token ids for context-only requests - if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": - all_prompt_token_ids.append(request_output.prompt_token_ids) + # Aggregate prompt token ids for context-only requests + if rsp.prompt_token_ids is not None: + all_prompt_token_ids.append(rsp.prompt_token_ids) usage_info = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_gen_tokens, total_tokens=num_gen_tokens + num_prompt_tokens, ) - response = CompletionResponse( + merged_rsp = CompletionResponse( model=self.model, choices=all_choices, usage=usage_info, prompt_token_ids=all_prompt_token_ids, ) - return response + return merged_rsp + + async def completion_generator(promise: RequestOutput, params: Optional[PostprocParams]): + async for output in promise: + if not self.postproc_worker_enabled: + post_processor, args = params.post_processor, params.postproc_args + pp_result = post_processor(output, args) + else: + pp_result = output.outputs[0]._postprocess_result + for pp_res in pp_result: + yield pp_res + + async def merge_generators(generators: List[AsyncIterator[Any]]): + result_queue = asyncio.Queue() + finished = [False] * len(generators) + + async def producer(generator: AsyncIterator[Any], idx: int): + async for output in generator: + await result_queue.put(output) + finished[idx] = True + + tasks = [ + asyncio.create_task(producer(generator, idx)) for idx, generator in enumerate(generators) + ] + + while not all(finished) or not result_queue.empty(): + output = await result_queue.get() + yield output + await asyncio.gather(*tasks) + + async def generator_wrapper(generator: AsyncIterator[Any]): + async for output in generator: + yield output + yield "data: [DONE]\n\n" try: check_multiple_response(request.n, self.llm.args.backend) @@ -423,15 +424,16 @@ async def create_completion_response( promises.append(promise) postproc_params_collection.append(None if self.postproc_worker_enabled else postproc_params) - generator = merge_promises(promises, postproc_params_collection) if request.stream: - response_generator = create_completion_generator( - generator) - return StreamingResponse(content=response_generator, + generators = [completion_generator(promise, params) + for promise, params in zip(promises, postproc_params_collection)] + response_generator = merge_generators(generators) if len(promises) > 1 else generators[0] + return StreamingResponse(content=generator_wrapper(response_generator), media_type="text/event-stream") else: - response = await create_completion_response( - generator, disaggregated_params) + rsps = await asyncio.gather(*[completion_response(promise, params) + for promise, params in zip(promises, postproc_params_collection)]) + response = merge_completion_responses(rsps) if len(rsps) > 1 else rsps[0] return JSONResponse(content=response.model_dump()) except CppExecutorError: # If internal executor error is raised, shutdown the server diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 53518619658..56b90c74706 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -43,7 +43,6 @@ examples/test_multimodal.py::test_llm_multimodal_general[video-neva-pp:1-tp:1-bf examples/test_whisper.py::test_llm_whisper_general[large-v3-enable_gemm_plugin-enable_attention_plugin-disable_weight_only-float16-nb:1-use_python_runtime] SKIP (https://nvbugs/4866931) examples/test_nemotron.py::test_llm_nemotron_3_8b_1gpu[bfloat16-fp8] SKIP (https://nvbugs/4961624) examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-chunked_summarization_long] SKIP (https://nvbugs/5321371) -test_e2e.py::test_openai_completions_example SKIP (https://nvbugspro.nvidia.com/bug/5004744) cpp/test_e2e.py::test_model[fp8-chatglm-90] SKIP (https://nvbugs/5034830) full:B200_PCIe/examples/test_mamba.py::test_llm_mamba_1gpu[mamba2-130m-float16-enable_gemm_plugin] SKIP (Disable for Blackwell) full:B200_PCIe/examples/test_mamba.py::test_llm_mamba_1gpu[mamba2-130m-float16-disable_gemm_plugin] SKIP (Disable for Blackwell)