diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index fa67c7a81c..24bf2b9b4f 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -1056,12 +1056,12 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local logprob_results[ # pylint: disable=unsupported-assignment-operation choice.index ] += choice.logprobs.content - except ( - Exception, - asyncio.CancelledError, - ) as err: # pylint: disable=broad-exception-caught + except asyncio.CancelledError: # pylint: disable=try-except-raise + # for cancelled error, we can simply pass it through + raise + except Exception as err: # pylint: disable=broad-exception-caught logger.error("Error in chat completion with request ID %s: %s", request_id, err) - raise err + raise assert all(finish_reason is not None for finish_reason in finish_reasons) use_function_calling, tool_calls_list = engine_base.process_function_call_output( @@ -1260,12 +1260,12 @@ async def _handle_chat_completion( response.usage.extra = None yield response self.state.record_event(request_id, event="finish") - except ( - Exception, - asyncio.CancelledError, - ) as err: # pylint: disable=broad-exception-caught + except asyncio.CancelledError: # pylint: disable=try-except-raise + # for cancelled error, we can simply pass it through + raise + except Exception as err: # pylint: disable=broad-exception-caught logger.error("Error in _handle_chat_completion for request %s: %s", request_id, err) - raise err + raise async def _handle_completion( self, @@ -1330,12 +1330,12 @@ async def _handle_completion( if suffix_response is not None: yield suffix_response self.state.record_event(request_id, event="finish") - except ( - Exception, - asyncio.CancelledError, - ) as err: # pylint: disable=broad-exception-caught + except asyncio.CancelledError: # pylint: disable=try-except-raise + # for cancelled error, we can simply pass it through + raise + except Exception as err: # pylint: disable=broad-exception-caught logger.error("Error in _handle_completion for request %s: %s", request_id, err) - raise err + raise async def _generate( self, @@ -1396,17 +1396,22 @@ async def _generate( ) self._ffi["add_request"](request) - # Iterate the stream asynchronously and yield the output. - try: - async for request_output in stream: - yield request_output - except ( - Exception, - asyncio.CancelledError, - ) as exception: # pylint: disable=broad-exception-caught - logger.error("Error in _generate for request %s: %s", request_id, exception) - await self.abort(request_id) - raise exception + def abort_request(): + """clean up""" + self._abort(request_id) + logger.info("request %s cancelled", request_id) + + with engine_utils.ErrorCleanupScope(abort_request): + # Iterate the stream asynchronously and yield the output. + try: + async for request_output in stream: + yield request_output + except asyncio.CancelledError: # pylint: disable=try-except-raise + # for cancelled error, we can simply pass it through + raise + except Exception as exception: # pylint: disable=broad-exception-caught + logger.error("Exception in _generate for request %s: %s", request_id, exception) + raise def _abort(self, request_id: str): """Internal implementation of request abortion.""" @@ -1914,8 +1919,12 @@ def _generate( # pylint: disable=too-many-locals ] self._ffi["add_request"](request) + def abort_request(): + """clean up request if exception happens""" + self.abort(request_id) + # Iterate the stream asynchronously and yield the token. - try: + with engine_utils.ErrorCleanupScope(abort_request): while True: delta_outputs = self.state.sync_output_queue.get() request_outputs, request_final_usage_json_str = self._request_stream_callback_impl( @@ -1934,9 +1943,6 @@ def _generate( # pylint: disable=too-many-locals ) yield [output] break - except Exception as exception: # pylint: disable=broad-exception-caught - self.abort(request_id) - raise exception def _request_stream_callback_impl( self, delta_outputs: List[data.RequestStreamOutput] diff --git a/python/mlc_llm/serve/engine_utils.py b/python/mlc_llm/serve/engine_utils.py index 6ccbc0e621..68cb501a22 100644 --- a/python/mlc_llm/serve/engine_utils.py +++ b/python/mlc_llm/serve/engine_utils.py @@ -167,3 +167,83 @@ def convert_prompts_to_data( assert isinstance(prompts, list) and all(isinstance(token_id, int) for token_id in prompts) return [data.TokenData(prompts)] # type: ignore return [convert_prompts_to_data(x)[0] for x in prompts] # type: ignore + + +class ErrorCleanupScope: + """Scope to call cleanup when an error is thrown. + + This class provides an important pattern properly cleanup + when async scope CancelledError or other exception happens. + + Parameters + ---------- + cleanup : Callable + A callable function to trigger at scope exit during an exception. + + Note + ---- + This helper is motivated by the need to properly + abort an async generator and trigger corresponding + cleanup functions. Naively use the try except + pattern will results in bug when we chain up + async generators. + + .. code:: python + + class EngineNotSafe: + async def _inner_gen(self, request): + request_id = self.get_request_id() + self.add_request(request) + try: + async for res in await producer_stream: + yield res + except asyncio.CancelledError: + self.abort(request_id) + + async def generate(self, request): + async for res in await self._inner_gen(request): + # async error can he raised in here + # this will cause + res = await process(res) + yield res + + The above except pattern is not safe. + This is because CancelledError may also be raised + outside _inner_gen during the process of generate + function in between iterations. + + Instead, we use ErrorCleanupScope to safeguard the + generation process. The scope will always properly + cleanup in exit function when the exception is raised + + .. code:: python + + class EngineSafe: + async def _inner_gen(self, request): + request_id = self.get_request_id() + self.add_request(request) + with ErrorCleanupScope(lambda: self.abort(request_id)) + async for res in await producer_stream: + yield res + + async def generate(self, request): + async for res in await self._inner_gen(request): + # even if async error is raised here + # it will cleanup the ErrorCleanupScope + # properly during function exit + res = await process(res) + yield res + """ + + cleanup: Callable + + def __init__(self, cleanup: Callable): + self.cleanup = cleanup + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback) -> None: + # only cleanup when exc type is not none + if exc_type is not None: + self.cleanup() diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index 7f62c2ad3f..6e19d34df5 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -95,8 +95,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: # In non-streaming cases, the engine will not be notified # when the request is disconnected. # Therefore, we check if it is disconnected each time, - # and abort the request from engine if so. - await async_engine.abort(request_id) + # and explicitly return. + # Note that requesta abort is triggered when the async for and funciton scope ends. return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" ) @@ -207,8 +207,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: # In non-streaming cases, the engine will not be notified # when the request is disconnected. # Therefore, we check if it is disconnected each time, - # and abort the request from engine if so. - await async_engine.abort(request_id) + # no need to explicitly abort, as the chat completion + # return will trigger abort call return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" )