From 39d555acfe8ae26ec71379d1a86773f3e5b21277 Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Wed, 10 Sep 2025 19:33:13 +0800 Subject: [PATCH 1/9] feat: custom_logger support info and error --- src/memos/api/context/context.py | 6 ++++++ src/memos/api/routers/product_router.py | 7 +++++++ src/memos/log.py | 12 ++++++++---- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/memos/api/context/context.py b/src/memos/api/context/context.py index 8aee2cfe..d360471c 100644 --- a/src/memos/api/context/context.py +++ b/src/memos/api/context/context.py @@ -6,6 +6,7 @@ and request isolation. """ +import os import uuid from collections.abc import Callable @@ -117,6 +118,11 @@ def require_context() -> RequestContext: _trace_id_getter: TraceIdGetter | None = None +def generate_trace_id() -> str: + """Generate a random trace_id.""" + return os.urandom(16).hex() + + def set_trace_id_getter(getter: TraceIdGetter) -> None: """ Set a custom trace_id getter function. diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py index a27e4e48..11ac6f0f 100644 --- a/src/memos/api/routers/product_router.py +++ b/src/memos/api/routers/product_router.py @@ -70,6 +70,13 @@ def get_mos_product_instance(): get_mos_product_instance() +@router.post("/test", summary="Test MOSProduct", response_model=SimpleResponse) +def test(): + """Test MOSProduct.""" + logger.info("Test trace_id") + return SimpleResponse(message="Test completed successfully") + + @router.post("/configure", summary="Configure MOSProduct", response_model=SimpleResponse) def set_config(config): """Set MOSProduct configuration.""" diff --git a/src/memos/log.py b/src/memos/log.py index a5b6648f..b6f889e9 100644 --- a/src/memos/log.py +++ b/src/memos/log.py @@ -12,7 +12,7 @@ from dotenv import load_dotenv from memos import settings -from memos.api.context.context import get_current_trace_id +from memos.api.context.context import generate_trace_id, get_current_trace_id from memos.api.context.context_thread import ContextThreadPoolExecutor @@ -39,9 +39,9 @@ class TraceIDFilter(logging.Filter): def filter(self, record): try: trace_id = get_current_trace_id() - record.trace_id = trace_id if trace_id else "no-trace-id" + record.trace_id = trace_id if trace_id else generate_trace_id() except Exception: - record.trace_id = "no-trace-id" + record.trace_id = generate_trace_id() return True @@ -78,6 +78,10 @@ def emit(self, record): if os.getenv("CUSTOM_LOGGER_URL") is None or self._is_shutting_down.is_set(): return + # Only process INFO and ERROR level logs + if record.levelno < logging.INFO: # Skip DEBUG and lower + return + try: trace_id = get_current_trace_id() or "no-trace-id" self._executor.submit(self._send_log_sync, record.getMessage(), trace_id) @@ -164,7 +168,7 @@ def close(self): "filters": ["trace_id_filter"], }, "custom_logger": { - "level": selected_log_level, + "level": "INFO", "class": "memos.log.CustomLoggerRequestHandler", "formatter": "simplified", }, From 0a6ae5b4f9e23db4312455a125ffe59f764642ed Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Wed, 10 Sep 2025 20:20:13 +0800 Subject: [PATCH 2/9] feat: customLogger support info and error log --- src/memos/api/context/context.py | 87 +++++++++++++++++++++- src/memos/api/context/context_thread.py | 96 ------------------------- src/memos/api/routers/product_router.py | 7 -- src/memos/log.py | 7 +- 4 files changed, 91 insertions(+), 106 deletions(-) delete mode 100644 src/memos/api/context/context_thread.py diff --git a/src/memos/api/context/context.py b/src/memos/api/context/context.py index d360471c..f7bbe380 100644 --- a/src/memos/api/context/context.py +++ b/src/memos/api/context/context.py @@ -6,14 +6,19 @@ and request isolation. """ +import functools import os +import threading import uuid from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor from contextvars import ContextVar -from typing import Any +from typing import Any, TypeVar +T = TypeVar("T") + # Global context variable for request-scoped data _request_context: ContextVar[dict[str, Any] | None] = ContextVar("request_context", default=None) @@ -111,6 +116,86 @@ def require_context() -> RequestContext: return context +class ContextThread(threading.Thread): + """ + Thread class that automatically propagates the main thread's trace_id to child threads. + """ + + def __init__(self, target, args=(), kwargs=None, **thread_kwargs): + super().__init__(**thread_kwargs) + self.target = target + self.args = args + self.kwargs = kwargs or {} + + self.main_trace_id = get_current_trace_id() + self.main_context = get_current_context() + + def run(self): + # Create a new RequestContext with the main thread's trace_id + if self.main_context: + # Copy the context data + child_context = RequestContext(trace_id=self.main_trace_id) + child_context._data = self.main_context._data.copy() + + # Set the context in the child thread + set_request_context(child_context) + + # Run the target function + self.target(*self.args, **self.kwargs) + + +class ContextThreadPoolExecutor(ThreadPoolExecutor): + """ + ThreadPoolExecutor that automatically propagates the main thread's trace_id to worker threads. + """ + + def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any: + """ + Submit a callable to be executed with the given arguments. + Automatically propagates the current thread's context to the worker thread. + """ + main_trace_id = get_current_trace_id() + main_context = get_current_context() + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if main_context: + # Create and set new context in worker thread + child_context = RequestContext(trace_id=main_trace_id) + child_context._data = main_context._data.copy() + set_request_context(child_context) + + return fn(*args, **kwargs) + + return super().submit(wrapper, *args, **kwargs) + + def map( + self, + fn: Callable[..., T], + *iterables: Any, + timeout: float | None = None, + chunksize: int = 1, + ) -> Any: + """ + Returns an iterator equivalent to map(fn, iter). + Automatically propagates the current thread's context to worker threads. + """ + main_trace_id = get_current_trace_id() + main_context = get_current_context() + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if main_context: + # Create and set new context in worker thread + child_context = RequestContext(trace_id=main_trace_id) + child_context._data = main_context._data.copy() + set_request_context(child_context) + + return fn(*args, **kwargs) + + return super().map(wrapper, *iterables, timeout=timeout, chunksize=chunksize) + + # Type for trace_id getter function TraceIdGetter = Callable[[], str | None] diff --git a/src/memos/api/context/context_thread.py b/src/memos/api/context/context_thread.py deleted file mode 100644 index 41de13a6..00000000 --- a/src/memos/api/context/context_thread.py +++ /dev/null @@ -1,96 +0,0 @@ -import functools -import threading - -from collections.abc import Callable -from concurrent.futures import ThreadPoolExecutor -from typing import Any, TypeVar - -from memos.api.context.context import ( - RequestContext, - get_current_context, - get_current_trace_id, - set_request_context, -) - - -T = TypeVar("T") - - -class ContextThread(threading.Thread): - """ - Thread class that automatically propagates the main thread's trace_id to child threads. - """ - - def __init__(self, target, args=(), kwargs=None, **thread_kwargs): - super().__init__(**thread_kwargs) - self.target = target - self.args = args - self.kwargs = kwargs or {} - - self.main_trace_id = get_current_trace_id() - self.main_context = get_current_context() - - def run(self): - # Create a new RequestContext with the main thread's trace_id - if self.main_context: - # Copy the context data - child_context = RequestContext(trace_id=self.main_trace_id) - child_context._data = self.main_context._data.copy() - - # Set the context in the child thread - set_request_context(child_context) - - # Run the target function - self.target(*self.args, **self.kwargs) - - -class ContextThreadPoolExecutor(ThreadPoolExecutor): - """ - ThreadPoolExecutor that automatically propagates the main thread's trace_id to worker threads. - """ - - def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any: - """ - Submit a callable to be executed with the given arguments. - Automatically propagates the current thread's context to the worker thread. - """ - main_trace_id = get_current_trace_id() - main_context = get_current_context() - - @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> Any: - if main_context: - # Create and set new context in worker thread - child_context = RequestContext(trace_id=main_trace_id) - child_context._data = main_context._data.copy() - set_request_context(child_context) - - return fn(*args, **kwargs) - - return super().submit(wrapper, *args, **kwargs) - - def map( - self, - fn: Callable[..., T], - *iterables: Any, - timeout: float | None = None, - chunksize: int = 1, - ) -> Any: - """ - Returns an iterator equivalent to map(fn, iter). - Automatically propagates the current thread's context to worker threads. - """ - main_trace_id = get_current_trace_id() - main_context = get_current_context() - - @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> Any: - if main_context: - # Create and set new context in worker thread - child_context = RequestContext(trace_id=main_trace_id) - child_context._data = main_context._data.copy() - set_request_context(child_context) - - return fn(*args, **kwargs) - - return super().map(wrapper, *iterables, timeout=timeout, chunksize=chunksize) diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py index 11ac6f0f..a27e4e48 100644 --- a/src/memos/api/routers/product_router.py +++ b/src/memos/api/routers/product_router.py @@ -70,13 +70,6 @@ def get_mos_product_instance(): get_mos_product_instance() -@router.post("/test", summary="Test MOSProduct", response_model=SimpleResponse) -def test(): - """Test MOSProduct.""" - logger.info("Test trace_id") - return SimpleResponse(message="Test completed successfully") - - @router.post("/configure", summary="Configure MOSProduct", response_model=SimpleResponse) def set_config(config): """Set MOSProduct configuration.""" diff --git a/src/memos/log.py b/src/memos/log.py index b6f889e9..c597dcef 100644 --- a/src/memos/log.py +++ b/src/memos/log.py @@ -12,8 +12,11 @@ from dotenv import load_dotenv from memos import settings -from memos.api.context.context import generate_trace_id, get_current_trace_id -from memos.api.context.context_thread import ContextThreadPoolExecutor +from memos.api.context.context import ( + ContextThreadPoolExecutor, + generate_trace_id, + get_current_trace_id, +) # Load environment variables From b1d1c2037baccb18f947a2a448edadaeb0eeaf84 Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Thu, 11 Sep 2025 21:29:23 +0800 Subject: [PATCH 3/9] feat: replace all Thread to contextThread --- poetry.lock | 82 +++++++++++++++++-- pyproject.toml | 3 + src/memos/api/context/dependencies.py | 70 ++++------------ src/memos/api/middleware/request_context.py | 28 ++----- src/memos/{api => }/context/context.py | 35 ++++++-- src/memos/graph_dbs/nebular.py | 12 +-- src/memos/log.py | 33 ++++---- src/memos/mem_os/main.py | 9 +- src/memos/mem_os/product.py | 8 +- src/memos/mem_reader/simple_struct.py | 7 +- .../general_modules/dispatcher.py | 4 +- .../monitors/dispatcher_monitor.py | 6 +- .../tree_text_memory/organize/manager.py | 7 +- .../tree_text_memory/organize/reorganizer.py | 7 +- .../tree_text_memory/retrieve/bochasearch.py | 5 +- .../tree_text_memory/retrieve/recall.py | 5 +- .../tree_text_memory/retrieve/searcher.py | 8 +- .../tree_text_memory/retrieve/xinyusearch.py | 5 +- src/memos/settings.py | 1 + src/memos/utils.py | 4 +- 20 files changed, 190 insertions(+), 149 deletions(-) rename src/memos/{api => }/context/context.py (87%) diff --git a/poetry.lock b/poetry.lock index c6b6a0eb..e99d7d17 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. [[package]] name = "absl-py" @@ -1093,6 +1093,19 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard ; python_version < \"3.14\""] tqdm = ["tqdm"] +[[package]] +name = "future" +version = "1.0.0" +description = "Clean single-source support for Python 3 and 2" +optional = true +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +groups = ["main"] +markers = "extra == \"tree-mem\" or extra == \"all\"" +files = [ + {file = "future-1.0.0-py3-none-any.whl", hash = "sha256:929292d34f5872e70396626ef385ec22355a1fae8ad29e1a734c3e43f9fbc216"}, + {file = "future-1.0.0.tar.gz", hash = "sha256:bd2968309307861edae1458a4f8a4f3598c03be43b97521076aebf5d94c07b05"}, +] + [[package]] name = "greenlet" version = "3.2.3" @@ -1250,7 +1263,7 @@ files = [ {file = "h2-4.2.0-py3-none-any.whl", hash = "sha256:479a53ad425bb29af087f3458a61d30780bc818e4ebcf01f0b536ba916462ed0"}, {file = "h2-4.2.0.tar.gz", hash = "sha256:c8a52129695e88b1a0578d8d2cc6842bbd79128ac685463b887ee278126ad01f"}, ] -markers = {main = "extra == \"all\""} +markers = {main = "extra == \"tree-mem\" or extra == \"all\""} [package.dependencies] hpack = ">=4.1,<5" @@ -1289,7 +1302,7 @@ files = [ {file = "hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496"}, {file = "hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca"}, ] -markers = {main = "extra == \"all\""} +markers = {main = "extra == \"tree-mem\" or extra == \"all\""} [[package]] name = "httpcore" @@ -1313,6 +1326,22 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<1.0)"] +[[package]] +name = "httplib2" +version = "0.30.2" +description = "A comprehensive HTTP client library." +optional = true +python-versions = ">=3.6" +groups = ["main"] +markers = "extra == \"tree-mem\" or extra == \"all\"" +files = [ + {file = "httplib2-0.30.2-py3-none-any.whl", hash = "sha256:62a665905c1f1d1069c34f933787d2a4435c67c0bc2b323645dcfbb64661b5ec"}, + {file = "httplib2-0.30.2.tar.gz", hash = "sha256:050bde6a332824b05a3deef5238f2b0372f71af46f8ca2190c2cb1f66aa376cd"}, +] + +[package.dependencies] +pyparsing = ">=3.0.4,<4" + [[package]] name = "httptools" version = "0.6.4" @@ -1473,7 +1502,7 @@ files = [ {file = "hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5"}, {file = "hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08"}, ] -markers = {main = "extra == \"all\""} +markers = {main = "extra == \"tree-mem\" or extra == \"all\""} [[package]] name = "identify" @@ -2481,6 +2510,26 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4) ; platform_python_implementation != \"PyPy\""] tests = ["pytest (>=4.6)"] +[[package]] +name = "nebula3-python" +version = "3.8.3" +description = "Python client for NebulaGraph v3" +optional = true +python-versions = ">=3.6.2" +groups = ["main"] +markers = "extra == \"tree-mem\" or extra == \"all\"" +files = [ + {file = "nebula3_python-3.8.3-py3-none-any.whl", hash = "sha256:ef583d15c012751cce05bcf8b3869881dbbc3286c6e3fa8b5c65b7c5bb808c38"}, + {file = "nebula3_python-3.8.3.tar.gz", hash = "sha256:da4693171079e9d5efb01a579f7fb62ef7655f6169cc2ff2cbb7b7d902e2cf3d"}, +] + +[package.dependencies] +future = ">=0.18.0" +httplib2 = ">=0.20.0" +httpx = {version = ">=0.22.0", extras = ["http2"]} +pytz = ">=2021.1" +six = ">=1.16.0" + [[package]] name = "neo4j" version = "5.28.1" @@ -3773,17 +3822,34 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pymysql" +version = "1.1.2" +description = "Pure Python MySQL Driver" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "pymysql-1.1.2-py3-none-any.whl", hash = "sha256:e6b1d89711dd51f8f74b1631fe08f039e7d76cf67a42a323d3178f0f25762ed9"}, + {file = "pymysql-1.1.2.tar.gz", hash = "sha256:4961d3e165614ae65014e361811a724e2044ad3ea3739de9903ae7c21f539f03"}, +] + +[package.extras] +ed25519 = ["PyNaCl (>=1.4.0)"] +rsa = ["cryptography"] + [[package]] name = "pyparsing" version = "3.2.3" description = "pyparsing module - Classes and methods to define and execute parsing grammars" optional = false python-versions = ">=3.9" -groups = ["eval"] +groups = ["main", "eval"] files = [ {file = "pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf"}, {file = "pyparsing-3.2.3.tar.gz", hash = "sha256:b9c13f1ab8b3b542f72e28f634bad4de758ab3ce4546e4301970ad6fa77c38be"}, ] +markers = {main = "extra == \"tree-mem\" or extra == \"all\""} [package.extras] diagrams = ["jinja2", "railroad-diagrams"] @@ -6285,12 +6351,12 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["chonkie", "markitdown", "neo4j", "pika", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +all = ["chonkie", "markitdown", "nebula3-python", "neo4j", "pika", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] mem-reader = ["chonkie", "markitdown"] mem-scheduler = ["pika", "redis"] -tree-mem = ["neo4j", "schedule"] +tree-mem = ["nebula3-python", "neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "94a3c4f97f0deda4c6ccbfd8ceda194f18dbc7525aa49004ffcc7846a1c40f7e" +content-hash = "982c24cc8c7961f543df341e73ab8fe04bfdc47dac1784af66985d34434007ae" diff --git a/pyproject.toml b/pyproject.toml index 270fd712..b1f2f441 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "tenacity (>=9.1.2,<10.0.0)", # Error handling and retrying library "fastapi[all] (>=0.115.12,<0.116.0)", # Web framework for building APIs "sqlalchemy (>=2.0.41,<3.0.0)", # SQL toolkit + "pymysql (>=1.1.0,<2.0.0)", # MySQL Python driver "scikit-learn (>=1.7.0,<2.0.0)", # Machine learning "fastmcp (>=2.10.5,<3.0.0)", "python-dateutil (>=2.9.0.post0,<3.0.0)", @@ -67,6 +68,7 @@ memos = "memos.cli:main" # TreeTextualMemory tree-mem = [ "neo4j (>=5.28.1,<6.0.0)", # Graph database + "nebula3-python (>=3.5.0,<4.0.0)", # NebulaGraph Python client "schedule (>=1.2.2,<2.0.0)", # Task scheduling ] @@ -87,6 +89,7 @@ mem-reader = [ all = [ # Exist in the above optional groups "neo4j (>=5.28.1,<6.0.0)", + "nebula3-python (>=3.5.0,<4.0.0)", "schedule (>=1.2.2,<2.0.0)", "redis (>=6.2.0,<7.0.0)", "pika (>=1.3.2,<2.0.0)", diff --git a/src/memos/api/context/dependencies.py b/src/memos/api/context/dependencies.py index d26cadaa..d163fa0d 100644 --- a/src/memos/api/context/dependencies.py +++ b/src/memos/api/context/dependencies.py @@ -1,8 +1,6 @@ import logging -from fastapi import Depends, Header, Request - -from memos.api.context.context import RequestContext, set_request_context +from memos.context.context import RequestContext, get_current_context logger = logging.getLogger(__name__) @@ -11,56 +9,17 @@ G = RequestContext -def get_trace_id_from_header( - trace_id: str | None = Header(None, alias="trace-id"), - x_trace_id: str | None = Header(None, alias="x-trace-id"), - g_trace_id: str | None = Header(None, alias="g-trace-id"), -) -> str | None: - """ - Extract trace_id from various possible headers. - - Priority: g-trace-id > x-trace-id > trace-id - """ - return g_trace_id or x_trace_id or trace_id - - -def get_request_context( - request: Request, trace_id: str | None = Depends(get_trace_id_from_header) -) -> RequestContext: - """ - Get request context object with trace_id and request metadata. - - This function creates a RequestContext and automatically sets it - in the global context for use throughout the request lifecycle. - """ - # Create context object - ctx = RequestContext(trace_id=trace_id) - - # Set the context globally for this request - set_request_context(ctx) - - # Log request start - logger.info(f"Request started with trace_id: {ctx.trace_id}") - - # Add request metadata to context - ctx.set("method", request.method) - ctx.set("path", request.url.path) - ctx.set("client_ip", request.client.host if request.client else None) - - return ctx - - -def get_g_object(trace_id: str | None = Depends(get_trace_id_from_header)) -> G: +def get_g_object() -> G: """ Get Flask g-like object for the current request. - - This creates a RequestContext and sets it globally for access - throughout the request lifecycle. + Returns the context created by middleware. """ - g = RequestContext(trace_id=trace_id) - set_request_context(g) - logger.info(f"Request g object created with trace_id: {g.trace_id}") - return g + ctx = get_current_context() + if ctx is None: + raise RuntimeError( + "No request context available. Make sure RequestContextMiddleware is properly configured." + ) + return ctx def get_current_g() -> G | None: @@ -70,8 +29,6 @@ def get_current_g() -> G | None: Returns: The current request's g object if available, None otherwise. """ - from memos.context import get_current_context - return get_current_context() @@ -85,6 +42,9 @@ def require_g() -> G: Raises: RuntimeError: If called outside of a request context. """ - from memos.context import require_context - - return require_context() + ctx = get_current_context() + if ctx is None: + raise RuntimeError( + "No request context available. This function must be called within a request handler." + ) + return ctx diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py index 01f57a27..91b13aff 100644 --- a/src/memos/api/middleware/request_context.py +++ b/src/memos/api/middleware/request_context.py @@ -11,7 +11,7 @@ from starlette.requests import Request from starlette.responses import Response -from memos.api.context.context import RequestContext, set_request_context +from memos.context.context import RequestContext, set_request_context logger = logging.getLogger(__name__) @@ -24,18 +24,9 @@ def generate_trace_id() -> str: def extract_trace_id_from_headers(request: Request) -> str | None: """Extract trace_id from various possible headers with priority: g-trace-id > x-trace-id > trace-id.""" - trace_id = request.headers.get("g-trace-id") - if trace_id: - return trace_id - - trace_id = request.headers.get("x-trace-id") - if trace_id: - return trace_id - - trace_id = request.headers.get("trace-id") - if trace_id: - return trace_id - + for header in ["g-trace-id", "x-trace-id", "trace-id"]: + if trace_id := request.headers.get(header): + return trace_id return None @@ -51,19 +42,12 @@ class RequestContextMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: Callable) -> Response: # Extract or generate trace_id - trace_id = extract_trace_id_from_headers(request) - if not trace_id: - trace_id = generate_trace_id() + trace_id = extract_trace_id_from_headers(request) or generate_trace_id() # Create and set request context - context = RequestContext(trace_id=trace_id) + context = RequestContext(trace_id=trace_id, api_path=request.url.path) set_request_context(context) - # Add request metadata to context - context.set("method", request.method) - context.set("path", request.url.path) - context.set("client_ip", request.client.host if request.client else None) - # Log request start with parameters params_log = {} diff --git a/src/memos/api/context/context.py b/src/memos/context/context.py similarity index 87% rename from src/memos/api/context/context.py rename to src/memos/context/context.py index f7bbe380..4f54348f 100644 --- a/src/memos/api/context/context.py +++ b/src/memos/context/context.py @@ -9,7 +9,6 @@ import functools import os import threading -import uuid from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor @@ -30,8 +29,9 @@ class RequestContext: This provides a Flask g-like object for FastAPI applications. """ - def __init__(self, trace_id: str | None = None): - self.trace_id = trace_id or str(uuid.uuid4()) + def __init__(self, trace_id: str | None = None, api_path: str | None = None): + self.trace_id = trace_id or "trace-id" + self.api_path = api_path self._data: dict[str, Any] = {} def set(self, key: str, value: Any) -> None: @@ -43,7 +43,7 @@ def get(self, key: str, default: Any | None = None) -> Any: return self._data.get(key, default) def __setattr__(self, name: str, value: Any) -> None: - if name.startswith("_") or name == "trace_id": + if name.startswith("_") or name in ("trace_id", "api_path"): super().__setattr__(name, value) else: if not hasattr(self, "_data"): @@ -58,7 +58,7 @@ def __getattr__(self, name: str) -> Any: def to_dict(self) -> dict[str, Any]: """Convert context to dictionary.""" - return {"trace_id": self.trace_id, "data": self._data.copy()} + return {"trace_id": self.trace_id, "api_path": self.api_path, "data": self._data.copy()} def set_request_context(context: RequestContext) -> None: @@ -83,6 +83,16 @@ def get_current_trace_id() -> str | None: return None +def get_current_api_path() -> str | None: + """ + Get the current request's api path. + """ + context = _request_context.get() + if context: + return context.get("api_path") + return None + + def get_current_context() -> RequestContext | None: """ Get the current request context. @@ -92,7 +102,9 @@ def get_current_context() -> RequestContext | None: """ context_dict = _request_context.get() if context_dict: - ctx = RequestContext(trace_id=context_dict.get("trace_id")) + ctx = RequestContext( + trace_id=context_dict.get("trace_id"), api_path=context_dict.get("api_path") + ) ctx._data = context_dict.get("data", {}).copy() return ctx return None @@ -128,13 +140,16 @@ def __init__(self, target, args=(), kwargs=None, **thread_kwargs): self.kwargs = kwargs or {} self.main_trace_id = get_current_trace_id() + self.main_api_path = get_current_api_path() self.main_context = get_current_context() def run(self): # Create a new RequestContext with the main thread's trace_id if self.main_context: # Copy the context data - child_context = RequestContext(trace_id=self.main_trace_id) + child_context = RequestContext( + trace_id=self.main_trace_id, api_path=self.main_context.api_path + ) child_context._data = self.main_context._data.copy() # Set the context in the child thread @@ -155,13 +170,14 @@ def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any: Automatically propagates the current thread's context to the worker thread. """ main_trace_id = get_current_trace_id() + main_api_path = get_current_api_path() main_context = get_current_context() @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: if main_context: # Create and set new context in worker thread - child_context = RequestContext(trace_id=main_trace_id) + child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path) child_context._data = main_context._data.copy() set_request_context(child_context) @@ -181,13 +197,14 @@ def map( Automatically propagates the current thread's context to worker threads. """ main_trace_id = get_current_trace_id() + main_api_path = get_current_api_path() main_context = get_current_context() @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: if main_context: # Create and set new context in worker thread - child_context = RequestContext(trace_id=main_trace_id) + child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path) child_context._data = main_context._data.copy() set_request_context(child_context) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 5ca8c895..1884bf9b 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -93,9 +93,9 @@ class SessionPoolError(Exception): class SessionPool: @require_python_package( - import_name="nebulagraph_python", - install_command="pip install ... @Tianxing", - install_link=".....", + import_name="nebula3", + install_command="pip install nebula3-python", + install_link="https://pypi.org/project/nebula3-python/", ) def __init__( self, @@ -312,9 +312,9 @@ def close_all_shared_pools(cls): cls._POOL_REFCOUNT.clear() @require_python_package( - import_name="nebulagraph_python", - install_command="pip install ... @Tianxing", - install_link=".....", + import_name="nebula3", + install_command="pip install nebula3-python", + install_link="https://pypi.org/project/nebula3-python/", ) def __init__(self, config: NebulaGraphDBConfig): """ diff --git a/src/memos/log.py b/src/memos/log.py index c597dcef..61cf84d9 100644 --- a/src/memos/log.py +++ b/src/memos/log.py @@ -2,7 +2,9 @@ import logging import os import threading +import time +from concurrent.futures import ThreadPoolExecutor from logging.config import dictConfig from pathlib import Path from sys import stdout @@ -12,11 +14,7 @@ from dotenv import load_dotenv from memos import settings -from memos.api.context.context import ( - ContextThreadPoolExecutor, - generate_trace_id, - get_current_trace_id, -) +from memos.context.context import get_current_api_path, get_current_trace_id # Load environment variables @@ -42,9 +40,9 @@ class TraceIDFilter(logging.Filter): def filter(self, record): try: trace_id = get_current_trace_id() - record.trace_id = trace_id if trace_id else generate_trace_id() + record.trace_id = trace_id if trace_id else "trace-id" except Exception: - record.trace_id = generate_trace_id() + record.trace_id = "trace-id" return True @@ -68,7 +66,7 @@ def __init__(self): if not self._initialized: super().__init__() workers = int(os.getenv("CUSTOM_LOGGER_WORKERS", "2")) - self._executor = ContextThreadPoolExecutor( + self._executor = ThreadPoolExecutor( max_workers=workers, thread_name_prefix="log_sender" ) self._is_shutting_down = threading.Event() @@ -86,20 +84,27 @@ def emit(self, record): return try: - trace_id = get_current_trace_id() or "no-trace-id" - self._executor.submit(self._send_log_sync, record.getMessage(), trace_id) + trace_id = get_current_trace_id() or "trace-id" + api_path = get_current_api_path() + if api_path is not None: + self._executor.submit(self._send_log_sync, record.getMessage(), trace_id, api_path) except Exception as e: if not self._is_shutting_down.is_set(): print(f"Error sending log: {e}") - def _send_log_sync(self, message, trace_id): + def _send_log_sync(self, message, trace_id, api_path): """Send log message synchronously in a separate thread""" try: logger_url = os.getenv("CUSTOM_LOGGER_URL") token = os.getenv("CUSTOM_LOGGER_TOKEN") headers = {"Content-Type": "application/json"} - post_content = {"message": message, "trace_id": trace_id} + post_content = { + "message": message, + "trace_id": trace_id, + "action": api_path, + "current_time": round(time.time(), 3), + } # Add auth token if exists if token: @@ -146,7 +151,7 @@ def close(self): "format": "[%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" }, "simplified": { - "format": "%(asctime)s | %(trace_id)s | %(levelname)s | %(filename)s | %(message)s" + "format": "%(asctime)s | %(trace_id)s | %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s" }, }, "filters": { @@ -155,7 +160,7 @@ def close(self): }, "handlers": { "console": { - "level": selected_log_level, + "level": "INFO", "class": "logging.StreamHandler", "stream": stdout, "formatter": "simplified", diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py index 2520c8fd..2e5b3254 100644 --- a/src/memos/mem_os/main.py +++ b/src/memos/mem_os/main.py @@ -5,6 +5,7 @@ from typing import Any from memos.configs.mem_os import MOSConfig +from memos.context.context import ContextThreadPoolExecutor from memos.llms.factory import LLMFactory from memos.log import get_logger from memos.mem_os.core import MOSCore @@ -487,9 +488,7 @@ def generate_answer_for_question(question_index: int, sub_question: str) -> tupl # Generate answers in parallel while maintaining order sub_answers = [None] * len(sub_questions) - with concurrent.futures.ThreadPoolExecutor( - max_workers=min(len(sub_questions), 10) - ) as executor: + with ContextThreadPoolExecutor(max_workers=min(len(sub_questions), 10)) as executor: # Submit all answer generation tasks future_to_index = { executor.submit(generate_answer_for_question, i, question): i @@ -552,9 +551,7 @@ def search_single_question(question: str) -> list[Any]: # Search in parallel while maintaining order all_memories = [] - with concurrent.futures.ThreadPoolExecutor( - max_workers=min(len(sub_questions), 10) - ) as executor: + with ContextThreadPoolExecutor(max_workers=min(len(sub_questions), 10)) as executor: # Submit all search tasks and keep track of their order future_to_index = { executor.submit(search_single_question, question): i diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 5899b680..388af084 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -2,7 +2,6 @@ import json import os import random -import threading import time from collections.abc import Generator @@ -14,6 +13,7 @@ from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig +from memos.context.context import ContextThread from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_os.core import MOSCore @@ -642,6 +642,8 @@ def _start_post_chat_processing( def run_async_in_thread(): """Running asynchronous tasks in a new thread""" + + logger.info(f"Running asynchronous tasks in a new thread for user {user_id}") try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -694,8 +696,8 @@ def run_async_in_thread(): else None ) except RuntimeError: - # No event loop, run in a new thread - thread = threading.Thread( + # No event loop, run in a new thread with context propagation + thread = ContextThread( target=run_async_in_thread, name=f"PostChatProcessing-{user_id}", # Set as a daemon thread to avoid blocking program exit diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 2b0bbc5d..b9032f07 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -13,6 +13,7 @@ from memos.chunkers import ChunkerFactory from memos.configs.mem_reader import SimpleStructMemReaderConfig from memos.configs.parser import ParserConfigFactory +from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import EmbedderFactory from memos.llms.factory import LLMFactory from memos.mem_reader.base import BaseMemReader @@ -200,8 +201,8 @@ def get_memory( else: processing_func = self._process_doc_data - # Process Q&A pairs concurrently - with concurrent.futures.ThreadPoolExecutor() as executor: + # Process Q&A pairs concurrently with context propagation + with ContextThreadPoolExecutor() as executor: futures = [ executor.submit(processing_func, scene_data_info, info) for scene_data_info in list_scene_data_info @@ -277,7 +278,7 @@ def _process_doc_data(self, scene_data_info, info): doc_nodes = [] scene_file = scene_data_info["file"] - with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor: + with ContextThreadPoolExecutor(max_workers=50) as executor: futures = { executor.submit( _build_node, diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py index b2cb4bd3..ce6df4d5 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -2,8 +2,8 @@ from collections import defaultdict from collections.abc import Callable -from concurrent.futures import ThreadPoolExecutor +from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -33,7 +33,7 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=False): self.enable_parallel_dispatch = enable_parallel_dispatch self.thread_name_prefix = "dispatcher" if self.enable_parallel_dispatch: - self.dispatcher_executor = ThreadPoolExecutor( + self.dispatcher_executor = ContextThreadPoolExecutor( max_workers=self.max_workers, thread_name_prefix=self.thread_name_prefix ) else: diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py index 229e9c3a..d57804ea 100644 --- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -1,11 +1,11 @@ import threading import time -from concurrent.futures import ThreadPoolExecutor from datetime import datetime from time import perf_counter from memos.configs.mem_scheduler import BaseSchedulerConfig +from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher @@ -49,7 +49,7 @@ def initialize(self, dispatcher: SchedulerDispatcher): def register_pool( self, name: str, - executor: ThreadPoolExecutor, + executor: ContextThreadPoolExecutor, max_workers: int, restart_on_failure: bool = True, ) -> bool: @@ -243,7 +243,7 @@ def _restart_pool(self, name: str, pool_info: dict) -> None: self.dispatcher.shutdown() # Create new executor with same parameters - new_executor = ThreadPoolExecutor( + new_executor = ContextThreadPoolExecutor( max_workers=pool_info["max_workers"], thread_name_prefix=self.dispatcher.thread_name_prefix, # pylint: disable=protected-access ) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 6b0a6a55..e3f10b54 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -1,8 +1,9 @@ import uuid -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed from datetime import datetime +from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import OllamaEmbedder from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM @@ -55,7 +56,7 @@ def add(self, memories: list[TextualMemoryItem]) -> list[str]: """ added_ids: list[str] = [] - with ThreadPoolExecutor(max_workers=8) as executor: + with ContextThreadPoolExecutor(max_workers=8) as executor: futures = {executor.submit(self._process_memory, m): m for m in memories} for future in as_completed(futures): try: @@ -82,7 +83,7 @@ def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: Replace WorkingMemory """ working_memory_top_k = memories[: self.memory_size["WorkingMemory"]] - with ThreadPoolExecutor(max_workers=8) as executor: + with ContextThreadPoolExecutor(max_workers=8) as executor: futures = [ executor.submit(self._add_memory_to_db, memory, "WorkingMemory") for memory in working_memory_top_k diff --git a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py index 586deaab..91c5b273 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py +++ b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py @@ -4,12 +4,13 @@ import traceback from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed from queue import PriorityQueue from typing import Literal import numpy as np +from memos.context.context import ContextThreadPoolExecutor from memos.dependency import require_python_package from memos.embedders.factory import OllamaEmbedder from memos.graph_dbs.item import GraphDBEdge, GraphDBNode @@ -223,7 +224,7 @@ def optimize_structure( f"[GraphStructureReorganize] Partitioned into {len(partitioned_groups)} clusters." ) - with ThreadPoolExecutor(max_workers=4) as executor: + with ContextThreadPoolExecutor(max_workers=4) as executor: futures = [] for cluster_nodes in partitioned_groups: futures.append( @@ -282,7 +283,7 @@ def _process_cluster_and_write( nodes_to_check = cluster_nodes exclude_ids = [n.id for n in nodes_to_check] - with ThreadPoolExecutor(max_workers=4) as executor: + with ContextThreadPoolExecutor(max_workers=4) as executor: futures = [] for node in nodes_to_check: futures.append( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py index 07f2c0a5..80f93c1e 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py @@ -2,11 +2,12 @@ import json -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed from datetime import datetime import requests +from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import OllamaEmbedder from memos.log import get_logger from memos.mem_reader.base import BaseMemReader @@ -177,7 +178,7 @@ def _convert_to_mem_items( if not info: info = {"user_id": "", "session_id": ""} - with ThreadPoolExecutor(max_workers=8) as executor: + with ContextThreadPoolExecutor(max_workers=8) as executor: futures = [ executor.submit(self._process_result, r, query, parsed_goal, info) for r in search_results diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 1f6a5a41..ca0c2be1 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -1,5 +1,6 @@ import concurrent.futures +from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import OllamaEmbedder from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.memories.textual.item import TextualMemoryItem @@ -49,7 +50,7 @@ def retrieve( ) return [TextualMemoryItem.from_dict(record) for record in working_memories] - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: # Structured graph-based retrieval future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope) # Vector similarity search @@ -196,7 +197,7 @@ def search_single(vec): or [] ) - with concurrent.futures.ThreadPoolExecutor() as executor: + with ContextThreadPoolExecutor() as executor: futures = [executor.submit(search_single, vec) for vec in query_embedding[:max_num]] for future in concurrent.futures.as_completed(futures): result = future.result() diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 9ac1646e..d77190f8 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -1,8 +1,8 @@ -import concurrent.futures import json from datetime import datetime +from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import OllamaEmbedder from memos.graph_dbs.factory import Neo4jGraphDB from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM @@ -42,9 +42,7 @@ def __init__( self.internet_retriever = internet_retriever self.moscube = moscube - self._usage_executor = concurrent.futures.ThreadPoolExecutor( - max_workers=4, thread_name_prefix="usage" - ) + self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @timed def search( @@ -138,7 +136,7 @@ def _parse_task(self, query, info, mode, top_k=5): def _retrieve_paths(self, query, parsed_goal, query_embedding, info, top_k, mode, memory_type): """Run A/B/C retrieval paths in parallel""" tasks = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + with ContextThreadPoolExecutor(max_workers=3) as executor: tasks.append( executor.submit( self._retrieve_from_working_memory, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py b/src/memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py index 2fae16c1..e36777f3 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/xinyusearch.py @@ -3,11 +3,12 @@ import json import uuid -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed from datetime import datetime import requests +from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import OllamaEmbedder from memos.log import get_logger from memos.mem_reader.base import BaseMemReader @@ -150,7 +151,7 @@ def retrieve_from_internet( # Convert to TextualMemoryItem format memory_items: list[TextualMemoryItem] = [] - with ThreadPoolExecutor(max_workers=8) as executor: + with ContextThreadPoolExecutor(max_workers=8) as executor: futures = [ executor.submit(self._process_result, result, query, parsed_goal, info) for result in search_results diff --git a/src/memos/settings.py b/src/memos/settings.py index 3b3f05eb..b4553c88 100644 --- a/src/memos/settings.py +++ b/src/memos/settings.py @@ -5,6 +5,7 @@ MEMOS_DIR = Path(os.getenv("MEMOS_BASE_PATH", Path.cwd())) / ".memos" DEBUG = False +TIMED_LOG = False # "memos" or "memos.submodules" ... to filter logs from specific packages LOG_FILTER_TREE_PREFIX = "" diff --git a/src/memos/utils.py b/src/memos/utils.py index 6a1d4255..261ab931 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -1,5 +1,6 @@ import time +from memos import settings from memos.log import get_logger @@ -13,7 +14,8 @@ def wrapper(*args, **kwargs): start = time.perf_counter() result = func(*args, **kwargs) elapsed = time.perf_counter() - start - logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s") + if settings.TIMED_LOG: + logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s") return result return wrapper From 57598e300d4c95307cfdd90848414b9366dd2e75 Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Thu, 11 Sep 2025 21:36:27 +0800 Subject: [PATCH 4/9] feat: add debug logger info for nebular --- src/memos/graph_dbs/nebular.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 1884bf9b..ef88e120 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -13,6 +13,7 @@ from memos.dependency import require_python_package from memos.graph_dbs.base import BaseGraphDB from memos.log import get_logger +from memos.settings import settings from memos.utils import timed @@ -146,7 +147,9 @@ def return_client(self, client): client.execute("YIELD 1") self.pool.put(client) except Exception: - logger.info("[Pool] Client dead, replacing...") + if settings.debug: + logger.info("[Pool] Client dead, replacing...") + self.replace_client(client) @timed @@ -214,7 +217,9 @@ def replace_client(self, client): self.pool.put(new_client) - logger.info("[Pool] Replaced dead client with a new one.") + if settings.debug: + logger.info(f"[Pool] Replaced dead client with a new one. {new_client}") + return new_client From 358ab27ac3c0ec6a84662e70b7fe873449f95d99 Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Thu, 11 Sep 2025 21:38:23 +0800 Subject: [PATCH 5/9] feat: delete nebular --- poetry.lock | 66 ++++++-------------------------------------------- pyproject.toml | 2 -- 2 files changed, 8 insertions(+), 60 deletions(-) diff --git a/poetry.lock b/poetry.lock index e99d7d17..2bee595c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -1093,19 +1093,6 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard ; python_version < \"3.14\""] tqdm = ["tqdm"] -[[package]] -name = "future" -version = "1.0.0" -description = "Clean single-source support for Python 3 and 2" -optional = true -python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" -groups = ["main"] -markers = "extra == \"tree-mem\" or extra == \"all\"" -files = [ - {file = "future-1.0.0-py3-none-any.whl", hash = "sha256:929292d34f5872e70396626ef385ec22355a1fae8ad29e1a734c3e43f9fbc216"}, - {file = "future-1.0.0.tar.gz", hash = "sha256:bd2968309307861edae1458a4f8a4f3598c03be43b97521076aebf5d94c07b05"}, -] - [[package]] name = "greenlet" version = "3.2.3" @@ -1263,7 +1250,7 @@ files = [ {file = "h2-4.2.0-py3-none-any.whl", hash = "sha256:479a53ad425bb29af087f3458a61d30780bc818e4ebcf01f0b536ba916462ed0"}, {file = "h2-4.2.0.tar.gz", hash = "sha256:c8a52129695e88b1a0578d8d2cc6842bbd79128ac685463b887ee278126ad01f"}, ] -markers = {main = "extra == \"tree-mem\" or extra == \"all\""} +markers = {main = "extra == \"all\""} [package.dependencies] hpack = ">=4.1,<5" @@ -1302,7 +1289,7 @@ files = [ {file = "hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496"}, {file = "hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca"}, ] -markers = {main = "extra == \"tree-mem\" or extra == \"all\""} +markers = {main = "extra == \"all\""} [[package]] name = "httpcore" @@ -1326,22 +1313,6 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<1.0)"] -[[package]] -name = "httplib2" -version = "0.30.2" -description = "A comprehensive HTTP client library." -optional = true -python-versions = ">=3.6" -groups = ["main"] -markers = "extra == \"tree-mem\" or extra == \"all\"" -files = [ - {file = "httplib2-0.30.2-py3-none-any.whl", hash = "sha256:62a665905c1f1d1069c34f933787d2a4435c67c0bc2b323645dcfbb64661b5ec"}, - {file = "httplib2-0.30.2.tar.gz", hash = "sha256:050bde6a332824b05a3deef5238f2b0372f71af46f8ca2190c2cb1f66aa376cd"}, -] - -[package.dependencies] -pyparsing = ">=3.0.4,<4" - [[package]] name = "httptools" version = "0.6.4" @@ -1502,7 +1473,7 @@ files = [ {file = "hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5"}, {file = "hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08"}, ] -markers = {main = "extra == \"tree-mem\" or extra == \"all\""} +markers = {main = "extra == \"all\""} [[package]] name = "identify" @@ -2510,26 +2481,6 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4) ; platform_python_implementation != \"PyPy\""] tests = ["pytest (>=4.6)"] -[[package]] -name = "nebula3-python" -version = "3.8.3" -description = "Python client for NebulaGraph v3" -optional = true -python-versions = ">=3.6.2" -groups = ["main"] -markers = "extra == \"tree-mem\" or extra == \"all\"" -files = [ - {file = "nebula3_python-3.8.3-py3-none-any.whl", hash = "sha256:ef583d15c012751cce05bcf8b3869881dbbc3286c6e3fa8b5c65b7c5bb808c38"}, - {file = "nebula3_python-3.8.3.tar.gz", hash = "sha256:da4693171079e9d5efb01a579f7fb62ef7655f6169cc2ff2cbb7b7d902e2cf3d"}, -] - -[package.dependencies] -future = ">=0.18.0" -httplib2 = ">=0.20.0" -httpx = {version = ">=0.22.0", extras = ["http2"]} -pytz = ">=2021.1" -six = ">=1.16.0" - [[package]] name = "neo4j" version = "5.28.1" @@ -3844,12 +3795,11 @@ version = "3.2.3" description = "pyparsing module - Classes and methods to define and execute parsing grammars" optional = false python-versions = ">=3.9" -groups = ["main", "eval"] +groups = ["eval"] files = [ {file = "pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf"}, {file = "pyparsing-3.2.3.tar.gz", hash = "sha256:b9c13f1ab8b3b542f72e28f634bad4de758ab3ce4546e4301970ad6fa77c38be"}, ] -markers = {main = "extra == \"tree-mem\" or extra == \"all\""} [package.extras] diagrams = ["jinja2", "railroad-diagrams"] @@ -6351,12 +6301,12 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["chonkie", "markitdown", "nebula3-python", "neo4j", "pika", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +all = ["chonkie", "markitdown", "neo4j", "pika", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] mem-reader = ["chonkie", "markitdown"] mem-scheduler = ["pika", "redis"] -tree-mem = ["nebula3-python", "neo4j", "schedule"] +tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "982c24cc8c7961f543df341e73ab8fe04bfdc47dac1784af66985d34434007ae" +content-hash = "d8bc3fed94d84d40795f820fbe38ba648f04b00e9e6d54cb2ca1629d6912a7cb" diff --git a/pyproject.toml b/pyproject.toml index b1f2f441..a0087272 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,6 @@ memos = "memos.cli:main" # TreeTextualMemory tree-mem = [ "neo4j (>=5.28.1,<6.0.0)", # Graph database - "nebula3-python (>=3.5.0,<4.0.0)", # NebulaGraph Python client "schedule (>=1.2.2,<2.0.0)", # Task scheduling ] @@ -89,7 +88,6 @@ mem-reader = [ all = [ # Exist in the above optional groups "neo4j (>=5.28.1,<6.0.0)", - "nebula3-python (>=3.5.0,<4.0.0)", "schedule (>=1.2.2,<2.0.0)", "redis (>=6.2.0,<7.0.0)", "pika (>=1.3.2,<2.0.0)", From 8f4d8578244bc13a5521bd8e24a7ff17585716fb Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Thu, 11 Sep 2025 21:42:06 +0800 Subject: [PATCH 6/9] feat: debug logger info --- src/memos/graph_dbs/nebular.py | 6 +++--- src/memos/settings.py | 1 - src/memos/utils.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index ef88e120..e5292557 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -9,11 +9,11 @@ import numpy as np +from memos import settings from memos.configs.graph_db import NebulaGraphDBConfig from memos.dependency import require_python_package from memos.graph_dbs.base import BaseGraphDB from memos.log import get_logger -from memos.settings import settings from memos.utils import timed @@ -147,7 +147,7 @@ def return_client(self, client): client.execute("YIELD 1") self.pool.put(client) except Exception: - if settings.debug: + if settings.DEBUG: logger.info("[Pool] Client dead, replacing...") self.replace_client(client) @@ -217,7 +217,7 @@ def replace_client(self, client): self.pool.put(new_client) - if settings.debug: + if settings.DEBUG: logger.info(f"[Pool] Replaced dead client with a new one. {new_client}") return new_client diff --git a/src/memos/settings.py b/src/memos/settings.py index b4553c88..3b3f05eb 100644 --- a/src/memos/settings.py +++ b/src/memos/settings.py @@ -5,7 +5,6 @@ MEMOS_DIR = Path(os.getenv("MEMOS_BASE_PATH", Path.cwd())) / ".memos" DEBUG = False -TIMED_LOG = False # "memos" or "memos.submodules" ... to filter logs from specific packages LOG_FILTER_TREE_PREFIX = "" diff --git a/src/memos/utils.py b/src/memos/utils.py index 261ab931..5801bc2d 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -14,7 +14,7 @@ def wrapper(*args, **kwargs): start = time.perf_counter() result = func(*args, **kwargs) elapsed = time.perf_counter() - start - if settings.TIMED_LOG: + if settings.DEBUG: logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s") return result From b3b814a0f71dd98a144f119fb02ac56ae25fd5df Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Thu, 11 Sep 2025 21:43:15 +0800 Subject: [PATCH 7/9] revert: log.py console config --- src/memos/log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/log.py b/src/memos/log.py index 61cf84d9..7266df24 100644 --- a/src/memos/log.py +++ b/src/memos/log.py @@ -160,7 +160,7 @@ def close(self): }, "handlers": { "console": { - "level": "INFO", + "level": selected_log_level, "class": "logging.StreamHandler", "stream": stdout, "formatter": "simplified", From 8a740f5e2d264958963a7e6ade06e6fe87b2d36e Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Thu, 11 Sep 2025 21:44:20 +0800 Subject: [PATCH 8/9] feat: remove useless log --- src/memos/mem_os/product.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 388af084..58375591 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -642,8 +642,6 @@ def _start_post_chat_processing( def run_async_in_thread(): """Running asynchronous tasks in a new thread""" - - logger.info(f"Running asynchronous tasks in a new thread for user {user_id}") try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) From d57f551fec3f2b6cda93058f9ce5bad7e0a54d4b Mon Sep 17 00:00:00 2001 From: harvey_xiang Date: Mon, 15 Sep 2025 11:32:18 +0800 Subject: [PATCH 9/9] fix: unit test error --- src/memos/api/middleware/request_context.py | 2 +- tests/api/test_thread_context.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py index 5aca866d..71f069f8 100644 --- a/src/memos/api/middleware/request_context.py +++ b/src/memos/api/middleware/request_context.py @@ -10,7 +10,7 @@ from starlette.requests import Request from starlette.responses import Response -from memos.api.context.context import RequestContext, generate_trace_id, set_request_context +from memos.context.context import RequestContext, generate_trace_id, set_request_context logger = logging.getLogger(__name__) diff --git a/tests/api/test_thread_context.py b/tests/api/test_thread_context.py index 36da692f..97c395db 100644 --- a/tests/api/test_thread_context.py +++ b/tests/api/test_thread_context.py @@ -1,7 +1,12 @@ import time -from memos.api.context.context import RequestContext, get_current_context, set_request_context -from memos.api.context.context_thread import ContextThread, ContextThreadPoolExecutor +from memos.context.context import ( + ContextThread, + ContextThreadPoolExecutor, + RequestContext, + get_current_context, + set_request_context, +) from memos.log import get_logger