From 827beb0613f846567509a77ffcc2eaaecfefe05d Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 26 Jul 2024 16:37:38 -0400 Subject: [PATCH] [ASYNC] Properly abort cleanup in async handling This PR adds a context manager to properly cleanup during async for exception. Naively use the try except pattern will results in bug when we chain up async generators and exception get raised not inside the try except in between iterations. --- python/mlc_llm/serve/engine.py | 66 ++++++++------- python/mlc_llm/serve/engine_utils.py | 80 +++++++++++++++++++ .../serve/entrypoints/openai_entrypoints.py | 8 +- 3 files changed, 120 insertions(+), 34 deletions(-) diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index fa67c7a81c..24bf2b9b4f 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -1056,12 +1056,12 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local logprob_results[ # pylint: disable=unsupported-assignment-operation choice.index ] += choice.logprobs.content - except ( - Exception, - asyncio.CancelledError, - ) as err: # pylint: disable=broad-exception-caught + except asyncio.CancelledError: # pylint: disable=try-except-raise + # for cancelled error, we can simply pass it through + raise + except Exception as err: # pylint: disable=broad-exception-caught logger.error("Error in chat completion with request ID %s: %s", request_id, err) - raise err + raise assert all(finish_reason is not None for finish_reason in finish_reasons) use_function_calling, tool_calls_list = engine_base.process_function_call_output( @@ -1260,12 +1260,12 @@ async def _handle_chat_completion( response.usage.extra = None yield response self.state.record_event(request_id, event="finish") - except ( - Exception, - asyncio.CancelledError, - ) as err: # pylint: disable=broad-exception-caught + except asyncio.CancelledError: # pylint: disable=try-except-raise + # for cancelled error, we can simply pass it through + raise + except Exception as err: # pylint: disable=broad-exception-caught logger.error("Error in _handle_chat_completion for request %s: %s", request_id, err) - raise err + raise async def _handle_completion( self, @@ -1330,12 +1330,12 @@ async def _handle_completion( if suffix_response is not None: yield suffix_response self.state.record_event(request_id, event="finish") - except ( - Exception, - asyncio.CancelledError, - ) as err: # pylint: disable=broad-exception-caught + except asyncio.CancelledError: # pylint: disable=try-except-raise + # for cancelled error, we can simply pass it through + raise + except Exception as err: # pylint: disable=broad-exception-caught logger.error("Error in _handle_completion for request %s: %s", request_id, err) - raise err + raise async def _generate( self, @@ -1396,17 +1396,22 @@ async def _generate( ) self._ffi["add_request"](request) - # Iterate the stream asynchronously and yield the output. - try: - async for request_output in stream: - yield request_output - except ( - Exception, - asyncio.CancelledError, - ) as exception: # pylint: disable=broad-exception-caught - logger.error("Error in _generate for request %s: %s", request_id, exception) - await self.abort(request_id) - raise exception + def abort_request(): + """clean up""" + self._abort(request_id) + logger.info("request %s cancelled", request_id) + + with engine_utils.ErrorCleanupScope(abort_request): + # Iterate the stream asynchronously and yield the output. + try: + async for request_output in stream: + yield request_output + except asyncio.CancelledError: # pylint: disable=try-except-raise + # for cancelled error, we can simply pass it through + raise + except Exception as exception: # pylint: disable=broad-exception-caught + logger.error("Exception in _generate for request %s: %s", request_id, exception) + raise def _abort(self, request_id: str): """Internal implementation of request abortion.""" @@ -1914,8 +1919,12 @@ def _generate( # pylint: disable=too-many-locals ] self._ffi["add_request"](request) + def abort_request(): + """clean up request if exception happens""" + self.abort(request_id) + # Iterate the stream asynchronously and yield the token. - try: + with engine_utils.ErrorCleanupScope(abort_request): while True: delta_outputs = self.state.sync_output_queue.get() request_outputs, request_final_usage_json_str = self._request_stream_callback_impl( @@ -1934,9 +1943,6 @@ def _generate( # pylint: disable=too-many-locals ) yield [output] break - except Exception as exception: # pylint: disable=broad-exception-caught - self.abort(request_id) - raise exception def _request_stream_callback_impl( self, delta_outputs: List[data.RequestStreamOutput] diff --git a/python/mlc_llm/serve/engine_utils.py b/python/mlc_llm/serve/engine_utils.py index 6ccbc0e621..68cb501a22 100644 --- a/python/mlc_llm/serve/engine_utils.py +++ b/python/mlc_llm/serve/engine_utils.py @@ -167,3 +167,83 @@ def convert_prompts_to_data( assert isinstance(prompts, list) and all(isinstance(token_id, int) for token_id in prompts) return [data.TokenData(prompts)] # type: ignore return [convert_prompts_to_data(x)[0] for x in prompts] # type: ignore + + +class ErrorCleanupScope: + """Scope to call cleanup when an error is thrown. + + This class provides an important pattern properly cleanup + when async scope CancelledError or other exception happens. + + Parameters + ---------- + cleanup : Callable + A callable function to trigger at scope exit during an exception. + + Note + ---- + This helper is motivated by the need to properly + abort an async generator and trigger corresponding + cleanup functions. Naively use the try except + pattern will results in bug when we chain up + async generators. + + .. code:: python + + class EngineNotSafe: + async def _inner_gen(self, request): + request_id = self.get_request_id() + self.add_request(request) + try: + async for res in await producer_stream: + yield res + except asyncio.CancelledError: + self.abort(request_id) + + async def generate(self, request): + async for res in await self._inner_gen(request): + # async error can he raised in here + # this will cause + res = await process(res) + yield res + + The above except pattern is not safe. + This is because CancelledError may also be raised + outside _inner_gen during the process of generate + function in between iterations. + + Instead, we use ErrorCleanupScope to safeguard the + generation process. The scope will always properly + cleanup in exit function when the exception is raised + + .. code:: python + + class EngineSafe: + async def _inner_gen(self, request): + request_id = self.get_request_id() + self.add_request(request) + with ErrorCleanupScope(lambda: self.abort(request_id)) + async for res in await producer_stream: + yield res + + async def generate(self, request): + async for res in await self._inner_gen(request): + # even if async error is raised here + # it will cleanup the ErrorCleanupScope + # properly during function exit + res = await process(res) + yield res + """ + + cleanup: Callable + + def __init__(self, cleanup: Callable): + self.cleanup = cleanup + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback) -> None: + # only cleanup when exc type is not none + if exc_type is not None: + self.cleanup() diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index 7f62c2ad3f..6e19d34df5 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -95,8 +95,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: # In non-streaming cases, the engine will not be notified # when the request is disconnected. # Therefore, we check if it is disconnected each time, - # and abort the request from engine if so. - await async_engine.abort(request_id) + # and explicitly return. + # Note that requesta abort is triggered when the async for and funciton scope ends. return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" ) @@ -207,8 +207,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: # In non-streaming cases, the engine will not be notified # when the request is disconnected. # Therefore, we check if it is disconnected each time, - # and abort the request from engine if so. - await async_engine.abort(request_id) + # no need to explicitly abort, as the chat completion + # return will trigger abort call return error_protocol.create_error_response( HTTPStatus.BAD_REQUEST, message="The request has disconnected" )