Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"tenacity (>=9.1.2,<10.0.0)", # Error handling and retrying library
"fastapi[all] (>=0.115.12,<0.116.0)", # Web framework for building APIs
"sqlalchemy (>=2.0.41,<3.0.0)", # SQL toolkit
"pymysql (>=1.1.0,<2.0.0)", # MySQL Python driver
"scikit-learn (>=1.7.0,<2.0.0)", # Machine learning
"fastmcp (>=2.10.5,<3.0.0)",
"python-dateutil (>=2.9.0.post0,<3.0.0)",
Expand Down
96 changes: 0 additions & 96 deletions src/memos/api/context/context_thread.py

This file was deleted.

70 changes: 15 additions & 55 deletions src/memos/api/context/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging

from fastapi import Depends, Header, Request

from memos.api.context.context import RequestContext, set_request_context
from memos.context.context import RequestContext, get_current_context


logger = logging.getLogger(__name__)
Expand All @@ -11,56 +9,17 @@
G = RequestContext


def get_trace_id_from_header(
trace_id: str | None = Header(None, alias="trace-id"),
x_trace_id: str | None = Header(None, alias="x-trace-id"),
g_trace_id: str | None = Header(None, alias="g-trace-id"),
) -> str | None:
"""
Extract trace_id from various possible headers.

Priority: g-trace-id > x-trace-id > trace-id
"""
return g_trace_id or x_trace_id or trace_id


def get_request_context(
request: Request, trace_id: str | None = Depends(get_trace_id_from_header)
) -> RequestContext:
"""
Get request context object with trace_id and request metadata.

This function creates a RequestContext and automatically sets it
in the global context for use throughout the request lifecycle.
"""
# Create context object
ctx = RequestContext(trace_id=trace_id)

# Set the context globally for this request
set_request_context(ctx)

# Log request start
logger.info(f"Request started with trace_id: {ctx.trace_id}")

# Add request metadata to context
ctx.set("method", request.method)
ctx.set("path", request.url.path)
ctx.set("client_ip", request.client.host if request.client else None)

return ctx


def get_g_object(trace_id: str | None = Depends(get_trace_id_from_header)) -> G:
def get_g_object() -> G:
"""
Get Flask g-like object for the current request.

This creates a RequestContext and sets it globally for access
throughout the request lifecycle.
Returns the context created by middleware.
"""
g = RequestContext(trace_id=trace_id)
set_request_context(g)
logger.info(f"Request g object created with trace_id: {g.trace_id}")
return g
ctx = get_current_context()
if ctx is None:
raise RuntimeError(
"No request context available. Make sure RequestContextMiddleware is properly configured."
)
return ctx


def get_current_g() -> G | None:
Expand All @@ -70,8 +29,6 @@ def get_current_g() -> G | None:
Returns:
The current request's g object if available, None otherwise.
"""
from memos.context import get_current_context

return get_current_context()


Expand All @@ -85,6 +42,9 @@ def require_g() -> G:
Raises:
RuntimeError: If called outside of a request context.
"""
from memos.context import require_context

return require_context()
ctx = get_current_context()
if ctx is None:
raise RuntimeError(
"No request context available. This function must be called within a request handler."
)
return ctx
28 changes: 6 additions & 22 deletions src/memos/api/middleware/request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,17 @@
from starlette.requests import Request
from starlette.responses import Response

from memos.api.context.context import RequestContext, generate_trace_id, set_request_context
from memos.context.context import RequestContext, generate_trace_id, set_request_context


logger = logging.getLogger(__name__)


def extract_trace_id_from_headers(request: Request) -> str | None:
"""Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id."""
trace_id = request.headers.get("g-trace-id")
if trace_id:
return trace_id

trace_id = request.headers.get("x-trace-id")
if trace_id:
return trace_id

trace_id = request.headers.get("trace-id")
if trace_id:
return trace_id

for header in ["g-trace-id", "x-trace-id", "trace-id"]:
if trace_id := request.headers.get(header):
return trace_id
return None


Expand All @@ -45,19 +36,12 @@ class RequestContextMiddleware(BaseHTTPMiddleware):

async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Extract or generate trace_id
trace_id = extract_trace_id_from_headers(request)
if not trace_id:
trace_id = generate_trace_id()
trace_id = extract_trace_id_from_headers(request) or generate_trace_id()

# Create and set request context
context = RequestContext(trace_id=trace_id)
context = RequestContext(trace_id=trace_id, api_path=request.url.path)
set_request_context(context)

# Add request metadata to context
context.set("method", request.method)
context.set("path", request.url.path)
context.set("client_ip", request.client.host if request.client else None)

# Log request start with parameters
params_log = {}

Expand Down
Loading
Loading