Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 36 additions & 30 deletions python/mlc_llm/serve/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,12 +1056,12 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local
logprob_results[ # pylint: disable=unsupported-assignment-operation
choice.index
] += choice.logprobs.content
except (
Exception,
asyncio.CancelledError,
) as err: # pylint: disable=broad-exception-caught
except asyncio.CancelledError: # pylint: disable=try-except-raise
# for cancelled error, we can simply pass it through
raise
except Exception as err: # pylint: disable=broad-exception-caught
logger.error("Error in chat completion with request ID %s: %s", request_id, err)
raise err
raise

assert all(finish_reason is not None for finish_reason in finish_reasons)
use_function_calling, tool_calls_list = engine_base.process_function_call_output(
Expand Down Expand Up @@ -1260,12 +1260,12 @@ async def _handle_chat_completion(
response.usage.extra = None
yield response
self.state.record_event(request_id, event="finish")
except (
Exception,
asyncio.CancelledError,
) as err: # pylint: disable=broad-exception-caught
except asyncio.CancelledError: # pylint: disable=try-except-raise
# for cancelled error, we can simply pass it through
raise
except Exception as err: # pylint: disable=broad-exception-caught
logger.error("Error in _handle_chat_completion for request %s: %s", request_id, err)
raise err
raise

async def _handle_completion(
self,
Expand Down Expand Up @@ -1330,12 +1330,12 @@ async def _handle_completion(
if suffix_response is not None:
yield suffix_response
self.state.record_event(request_id, event="finish")
except (
Exception,
asyncio.CancelledError,
) as err: # pylint: disable=broad-exception-caught
except asyncio.CancelledError: # pylint: disable=try-except-raise
# for cancelled error, we can simply pass it through
raise
except Exception as err: # pylint: disable=broad-exception-caught
logger.error("Error in _handle_completion for request %s: %s", request_id, err)
raise err
raise

async def _generate(
self,
Expand Down Expand Up @@ -1396,17 +1396,22 @@ async def _generate(
)
self._ffi["add_request"](request)

# Iterate the stream asynchronously and yield the output.
try:
async for request_output in stream:
yield request_output
except (
Exception,
asyncio.CancelledError,
) as exception: # pylint: disable=broad-exception-caught
logger.error("Error in _generate for request %s: %s", request_id, exception)
await self.abort(request_id)
raise exception
def abort_request():
"""clean up"""
self._abort(request_id)
logger.info("request %s cancelled", request_id)

with engine_utils.ErrorCleanupScope(abort_request):
# Iterate the stream asynchronously and yield the output.
try:
async for request_output in stream:
yield request_output
except asyncio.CancelledError: # pylint: disable=try-except-raise
# for cancelled error, we can simply pass it through
raise
except Exception as exception: # pylint: disable=broad-exception-caught
logger.error("Exception in _generate for request %s: %s", request_id, exception)
raise

def _abort(self, request_id: str):
"""Internal implementation of request abortion."""
Expand Down Expand Up @@ -1914,8 +1919,12 @@ def _generate( # pylint: disable=too-many-locals
]
self._ffi["add_request"](request)

def abort_request():
"""clean up request if exception happens"""
self.abort(request_id)

# Iterate the stream asynchronously and yield the token.
try:
with engine_utils.ErrorCleanupScope(abort_request):
while True:
delta_outputs = self.state.sync_output_queue.get()
request_outputs, request_final_usage_json_str = self._request_stream_callback_impl(
Expand All @@ -1934,9 +1943,6 @@ def _generate( # pylint: disable=too-many-locals
)
yield [output]
break
except Exception as exception: # pylint: disable=broad-exception-caught
self.abort(request_id)
raise exception

def _request_stream_callback_impl(
self, delta_outputs: List[data.RequestStreamOutput]
Expand Down
80 changes: 80 additions & 0 deletions python/mlc_llm/serve/engine_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,83 @@ def convert_prompts_to_data(
assert isinstance(prompts, list) and all(isinstance(token_id, int) for token_id in prompts)
return [data.TokenData(prompts)] # type: ignore
return [convert_prompts_to_data(x)[0] for x in prompts] # type: ignore


class ErrorCleanupScope:
"""Scope to call cleanup when an error is thrown.

This class provides an important pattern properly cleanup
when async scope CancelledError or other exception happens.

Parameters
----------
cleanup : Callable
A callable function to trigger at scope exit during an exception.

Note
----
This helper is motivated by the need to properly
abort an async generator and trigger corresponding
cleanup functions. Naively use the try except
pattern will results in bug when we chain up
async generators.

.. code:: python

class EngineNotSafe:
async def _inner_gen(self, request):
request_id = self.get_request_id()
self.add_request(request)
try:
async for res in await producer_stream:
yield res
except asyncio.CancelledError:
self.abort(request_id)

async def generate(self, request):
async for res in await self._inner_gen(request):
# async error can he raised in here
# this will cause
res = await process(res)
yield res

The above except pattern is not safe.
This is because CancelledError may also be raised
outside _inner_gen during the process of generate
function in between iterations.

Instead, we use ErrorCleanupScope to safeguard the
generation process. The scope will always properly
cleanup in exit function when the exception is raised

.. code:: python

class EngineSafe:
async def _inner_gen(self, request):
request_id = self.get_request_id()
self.add_request(request)
with ErrorCleanupScope(lambda: self.abort(request_id))
async for res in await producer_stream:
yield res

async def generate(self, request):
async for res in await self._inner_gen(request):
# even if async error is raised here
# it will cleanup the ErrorCleanupScope
# properly during function exit
res = await process(res)
yield res
"""

cleanup: Callable

def __init__(self, cleanup: Callable):
self.cleanup = cleanup

def __enter__(self):
pass

def __exit__(self, exc_type, exc_value, traceback) -> None:
# only cleanup when exc type is not none
if exc_type is not None:
self.cleanup()
8 changes: 4 additions & 4 deletions python/mlc_llm/serve/entrypoints/openai_entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
# In non-streaming cases, the engine will not be notified
# when the request is disconnected.
# Therefore, we check if it is disconnected each time,
# and abort the request from engine if so.
await async_engine.abort(request_id)
# and explicitly return.
# Note that requesta abort is triggered when the async for and funciton scope ends.
return error_protocol.create_error_response(
HTTPStatus.BAD_REQUEST, message="The request has disconnected"
)
Expand Down Expand Up @@ -207,8 +207,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
# In non-streaming cases, the engine will not be notified
# when the request is disconnected.
# Therefore, we check if it is disconnected each time,
# and abort the request from engine if so.
await async_engine.abort(request_id)
# no need to explicitly abort, as the chat completion
# return will trigger abort call
return error_protocol.create_error_response(
HTTPStatus.BAD_REQUEST, message="The request has disconnected"
)
Expand Down