Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/memos/api/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
and request isolation.
"""

import os
import uuid

from collections.abc import Callable
Expand Down Expand Up @@ -117,6 +118,11 @@ def require_context() -> RequestContext:
_trace_id_getter: TraceIdGetter | None = None


def generate_trace_id() -> str:
"""Generate a random trace_id."""
return os.urandom(16).hex()


def set_trace_id_getter(getter: TraceIdGetter) -> None:
"""
Set a custom trace_id getter function.
Expand Down
8 changes: 1 addition & 7 deletions src/memos/api/middleware/request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,19 @@
"""

import logging
import os

from collections.abc import Callable

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response

from memos.api.context.context import RequestContext, set_request_context
from memos.api.context.context import RequestContext, generate_trace_id, set_request_context


logger = logging.getLogger(__name__)


def generate_trace_id() -> str:
"""Generate a random trace_id."""
return os.urandom(16).hex()


def extract_trace_id_from_headers(request: Request) -> str | None:
"""Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id."""
trace_id = request.headers.get("g-trace-id")
Expand Down
108 changes: 7 additions & 101 deletions src/memos/log.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
import atexit
import logging
import os
import threading

from logging.config import dictConfig
from pathlib import Path
from sys import stdout

import requests

from dotenv import load_dotenv

from memos import settings
from memos.api.context.context import get_current_trace_id
from memos.api.context.context_thread import ContextThreadPoolExecutor
from memos.api.context.context import generate_trace_id, get_current_trace_id


# Load environment variables
Expand All @@ -39,95 +33,12 @@ class TraceIDFilter(logging.Filter):
def filter(self, record):
try:
trace_id = get_current_trace_id()
record.trace_id = trace_id if trace_id else "no-trace-id"
record.trace_id = trace_id if trace_id else generate_trace_id()
except Exception:
record.trace_id = "no-trace-id"
record.trace_id = generate_trace_id()
return True


class CustomLoggerRequestHandler(logging.Handler):
_instance = None
_lock = threading.Lock()

def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
cls._instance._executor = None
cls._instance._session = None
cls._instance._is_shutting_down = None
return cls._instance

def __init__(self):
"""Initialize handler with minimal setup"""
if not self._initialized:
super().__init__()
workers = int(os.getenv("CUSTOM_LOGGER_WORKERS", "2"))
self._executor = ContextThreadPoolExecutor(
max_workers=workers, thread_name_prefix="log_sender"
)
self._is_shutting_down = threading.Event()
self._session = requests.Session()
self._initialized = True
atexit.register(self._cleanup)

def emit(self, record):
"""Process log records of INFO or ERROR level (non-blocking)"""
if os.getenv("CUSTOM_LOGGER_URL") is None or self._is_shutting_down.is_set():
return

try:
trace_id = get_current_trace_id() or "no-trace-id"
self._executor.submit(self._send_log_sync, record.getMessage(), trace_id)
except Exception as e:
if not self._is_shutting_down.is_set():
print(f"Error sending log: {e}")

def _send_log_sync(self, message, trace_id):
"""Send log message synchronously in a separate thread"""
try:
logger_url = os.getenv("CUSTOM_LOGGER_URL")
token = os.getenv("CUSTOM_LOGGER_TOKEN")

headers = {"Content-Type": "application/json"}
post_content = {"message": message, "trace_id": trace_id}

# Add auth token if exists
if token:
headers["Authorization"] = f"Bearer {token}"

# Add traceId to headers for consistency
headers["traceId"] = trace_id

# Add custom attributes from env
for key, value in os.environ.items():
if key.startswith("CUSTOM_LOGGER_ATTRIBUTE_"):
attribute_key = key[len("CUSTOM_LOGGER_ATTRIBUTE_") :].lower()
post_content[attribute_key] = value

self._session.post(logger_url, headers=headers, json=post_content, timeout=5)
except Exception:
# Silently ignore errors to avoid affecting main application
pass

def _cleanup(self):
"""Clean up resources during program exit"""
if not self._initialized:
return

self._is_shutting_down.set()
try:
self._executor.shutdown(wait=False)
self._session.close()
except Exception as e:
print(f"Error during cleanup: {e}")

def close(self):
"""Override close to prevent premature shutdown"""


LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
Expand All @@ -151,7 +62,7 @@ def close(self):
"level": selected_log_level,
"class": "logging.StreamHandler",
"stream": stdout,
"formatter": "simplified",
"formatter": "no_datetime",
"filters": ["package_tree_filter", "trace_id_filter"],
},
"file": {
Expand All @@ -160,18 +71,13 @@ def close(self):
"filename": _setup_logfile(),
"maxBytes": 1024**2 * 10,
"backupCount": 10,
"formatter": "simplified",
"formatter": "standard",
"filters": ["trace_id_filter"],
},
"custom_logger": {
"level": selected_log_level,
"class": "memos.log.CustomLoggerRequestHandler",
"formatter": "simplified",
},
},
"root": { # Root logger handles all logs
"level": selected_log_level,
"handlers": ["console", "file", "custom_logger"],
"level": logging.DEBUG if settings.DEBUG else logging.INFO,
"handlers": ["console", "file"],
},
"loggers": {
"memos": {
Expand Down
Loading