Skip to content
34 changes: 14 additions & 20 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing
import contextlib

from threading import Lock
from anyio import Lock
from functools import partial
from typing import Iterator, List, Optional, Union, Dict

Expand Down Expand Up @@ -70,14 +70,14 @@ def set_llama_proxy(model_settings: List[ModelSettings]):
_llama_proxy = LlamaProxy(models=model_settings)


def get_llama_proxy():
async def get_llama_proxy():
# NOTE: This double lock allows the currently streaming llama model to
# check if any other requests are pending in the same thread and cancel
# the stream if so.
llama_outer_lock.acquire()
await llama_outer_lock.acquire()
release_outer_lock = True
try:
llama_inner_lock.acquire()
await llama_inner_lock.acquire()
try:
llama_outer_lock.release()
release_outer_lock = False
Expand Down Expand Up @@ -159,7 +159,7 @@ async def get_event_publisher(
request: Request,
inner_send_chan: MemoryObjectSendStream[typing.Any],
iterator: Iterator[typing.Any],
on_complete: typing.Optional[typing.Callable[[], None]] = None,
on_complete: typing.Optional[typing.Callable[[], typing.Awaitable[None]]] = None,
):
server_settings = next(get_server_settings())
interrupt_requests = (
Expand All @@ -182,7 +182,7 @@ async def get_event_publisher(
raise e
finally:
if on_complete:
on_complete()
await on_complete()


def _logit_bias_tokens_to_input_ids(
Expand Down Expand Up @@ -267,10 +267,8 @@ async def create_completion(
request: Request,
body: CreateCompletionRequest,
) -> llama_cpp.Completion:
exit_stack = contextlib.ExitStack()
llama_proxy = await run_in_threadpool(
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)())
)
exit_stack = contextlib.AsyncExitStack()
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
Expand Down Expand Up @@ -332,7 +330,6 @@ async def create_completion(
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
yield first_response
yield from iterator_or_completion
exit_stack.close()

send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
Expand All @@ -342,13 +339,13 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
on_complete=exit_stack.close,
on_complete=exit_stack.aclose,
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
exit_stack.close()
await exit_stack.aclose()
return iterator_or_completion


Expand Down Expand Up @@ -477,10 +474,8 @@ async def create_chat_completion(
# where the dependency is cleaned up before a StreamingResponse
# is complete.
# https://github.com/tiangolo/fastapi/issues/11143
exit_stack = contextlib.ExitStack()
llama_proxy = await run_in_threadpool(
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)())
)
exit_stack = contextlib.AsyncExitStack()
llama_proxy = exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
Expand Down Expand Up @@ -530,7 +525,6 @@ async def create_chat_completion(
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
yield first_response
yield from iterator_or_completion
exit_stack.close()

send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
Expand All @@ -540,13 +534,13 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
on_complete=exit_stack.close,
on_complete=exit_stack.aclose,
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
exit_stack.close()
await exit_stack.aclose()
return iterator_or_completion


Expand Down