From fa5b9efccf688dcca08c6b59eaf9294ef1ea1c1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Tue, 16 Sep 2025 20:50:00 +0800 Subject: [PATCH 1/4] feat: update nebula to nebula 5.1.1 --- src/memos/graph_dbs/nebular.py | 311 ++++++++++----------------------- 1 file changed, 95 insertions(+), 216 deletions(-) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 5ca8c895..2ee3593a 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -3,7 +3,6 @@ from contextlib import suppress from datetime import datetime -from queue import Empty, Queue from threading import Lock from typing import TYPE_CHECKING, Any, ClassVar, Literal @@ -17,7 +16,9 @@ if TYPE_CHECKING: - from nebulagraph_python.client.pool import NebulaPool + from nebulagraph_python import ( + NebulaClient, + ) logger = get_logger(__name__) @@ -87,137 +88,6 @@ def _normalize_datetime(val): return str(val) -class SessionPoolError(Exception): - pass - - -class SessionPool: - @require_python_package( - import_name="nebulagraph_python", - install_command="pip install ... @Tianxing", - install_link=".....", - ) - def __init__( - self, - hosts: list[str], - user: str, - password: str, - minsize: int = 1, - maxsize: int = 10000, - ): - self.hosts = hosts - self.user = user - self.password = password - self.minsize = minsize - self.maxsize = maxsize - self.pool = Queue(maxsize) - self.lock = Lock() - - self.clients = [] - - for _ in range(minsize): - self._create_and_add_client() - - @timed - def _create_and_add_client(self): - from nebulagraph_python import NebulaClient - - client = NebulaClient(self.hosts, self.user, self.password) - self.pool.put(client) - self.clients.append(client) - - @timed - def get_client(self, timeout: float = 5.0): - try: - return self.pool.get(timeout=timeout) - except Empty: - with self.lock: - if len(self.clients) < self.maxsize: - from nebulagraph_python import NebulaClient - - client = NebulaClient(self.hosts, self.user, self.password) - self.clients.append(client) - return client - raise RuntimeError("NebulaClientPool exhausted") from None - - @timed - def return_client(self, client): - try: - client.execute("YIELD 1") - self.pool.put(client) - except Exception: - logger.info("[Pool] Client dead, replacing...") - self.replace_client(client) - - @timed - def close(self): - for client in self.clients: - with suppress(Exception): - client.close() - self.clients.clear() - - @timed - def get(self): - """ - Context manager: with pool.get() as client: - """ - - class _ClientContext: - def __init__(self, outer): - self.outer = outer - self.client = None - - def __enter__(self): - self.client = self.outer.get_client() - return self.client - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.client: - self.outer.return_client(self.client) - - return _ClientContext(self) - - @timed - def reset_pool(self): - """⚠️ Emergency reset: Close all clients and clear the pool.""" - logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.") - with self.lock: - for client in self.clients: - try: - client.close() - except Exception: - logger.error("Fail to close!!!") - self.clients.clear() - while not self.pool.empty(): - try: - self.pool.get_nowait() - except Empty: - break - for _ in range(self.minsize): - self._create_and_add_client() - logger.info("[Pool] Pool has been reset successfully.") - - @timed - def replace_client(self, client): - try: - client.close() - except Exception: - logger.error("Fail to close client") - - if client in self.clients: - self.clients.remove(client) - - from nebulagraph_python import NebulaClient - - new_client = NebulaClient(self.hosts, self.user, self.password) - self.clients.append(new_client) - - self.pool.put(new_client) - - logger.info("[Pool] Replaced dead client with a new one.") - return new_client - - class NebulaGraphDB(BaseGraphDB): """ NebulaGraph-based implementation of a graph memory store. @@ -226,94 +96,102 @@ class NebulaGraphDB(BaseGraphDB): # ====== shared pool cache & refcount ====== # These are process-local; in a multi-process model each process will # have its own cache. - _POOL_CACHE: ClassVar[dict[str, "NebulaPool"]] = {} - _POOL_REFCOUNT: ClassVar[dict[str, int]] = {} - _POOL_LOCK: ClassVar[Lock] = Lock() + _CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {} + _CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {} + _CLIENT_LOCK: ClassVar[Lock] = Lock() @staticmethod - def _make_pool_key(cfg: NebulaGraphDBConfig) -> str: - """ - Build a cache key that captures all connection-affecting options. - Keep this key stable and include fields that change the underlying pool behavior. - """ - # NOTE: Do not include tenant-like or query-scope-only fields here. - # Only include things that affect the actual TCP/auth/session pool. + def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]: + hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None) + if isinstance(hosts, str): + return [hosts] + return list(hosts or []) + + @staticmethod + def _make_client_key(cfg: NebulaGraphDBConfig) -> str: + hosts = NebulaGraphDB._get_hosts_from_cfg(cfg) return "|".join( [ - "nebula", - str(getattr(cfg, "uri", "")), + "nebula-sync", + ",".join(hosts), str(getattr(cfg, "user", "")), str(getattr(cfg, "password", "")), - # pool sizing / tls / timeouts if you have them in config: - str(getattr(cfg, "max_client", 1000)), - # multi-db mode can impact how we use sessions; keep it to be safe str(getattr(cfg, "use_multi_db", False)), ] ) @classmethod - def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig): - """ - Get a shared NebulaPool from cache or create one if missing. - Thread-safe with a lock; maintains a simple refcount. - """ - key = cls._make_pool_key(cfg) - - with cls._POOL_LOCK: - pool = cls._POOL_CACHE.get(key) - if pool is None: - # Create a new pool and put into cache - pool = SessionPool( - hosts=cfg.get("uri"), - user=cfg.get("user"), - password=cfg.get("password"), - minsize=1, - maxsize=cfg.get("max_client", 1000), + def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> (tuple)[str, "NebulaClient"]: + from nebulagraph_python import ( + ConnectionConfig, + NebulaClient, + SessionConfig, + SessionPoolConfig, + ) + + key = cls._make_client_key(cfg) + with cls._CLIENT_LOCK: + client = cls._CLIENT_CACHE.get(key) + if client is None: + # Connection setting + conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None) + if conn_conf is None: + conn_conf = ConnectionConfig.from_defults( + cls._get_hosts_from_cfg(cfg), + getattr(cfg, "ssl_param", None), + ) + + sess_conf = SessionConfig(graph=getattr(cfg, "space", None)) + + pool_conf = SessionPoolConfig(size=int(getattr(cfg, "max_client", 1000))) + + client = NebulaClient( + hosts=conn_conf.hosts, + username=cfg.user, + password=cfg.password, + conn_config=conn_conf, + session_config=sess_conf, + session_pool_config=pool_conf, ) - cls._POOL_CACHE[key] = pool - cls._POOL_REFCOUNT[key] = 0 - logger.info(f"[NebulaGraphDB] Created new shared NebulaPool for key={key}") + cls._CLIENT_CACHE[key] = client + cls._CLIENT_REFCOUNT[key] = 0 + logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}") - # Increase refcount for the caller - cls._POOL_REFCOUNT[key] = cls._POOL_REFCOUNT.get(key, 0) + 1 - return key, pool + cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1 + return key, client @classmethod - def _release_shared_pool(cls, key: str): - """ - Decrease refcount for the given pool key; only close when refcount hits zero. - """ - with cls._POOL_LOCK: - if key not in cls._POOL_CACHE: + def _release_shared_client(cls, key: str): + with cls._CLIENT_LOCK: + if key not in cls._CLIENT_CACHE: return - cls._POOL_REFCOUNT[key] = max(0, cls._POOL_REFCOUNT.get(key, 0) - 1) - if cls._POOL_REFCOUNT[key] == 0: + cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1) + if cls._CLIENT_REFCOUNT[key] == 0: try: - cls._POOL_CACHE[key].close() + cls._CLIENT_CACHE[key].close() except Exception as e: - logger.warning(f"[NebulaGraphDB] Error closing shared pool: {e}") + logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}") finally: - cls._POOL_CACHE.pop(key, None) - cls._POOL_REFCOUNT.pop(key, None) - logger.info(f"[NebulaGraphDB] Closed and removed shared pool key={key}") + cls._CLIENT_CACHE.pop(key, None) + cls._CLIENT_REFCOUNT.pop(key, None) + logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}") @classmethod - def close_all_shared_pools(cls): - """Force close all cached pools. Call this on graceful shutdown.""" - with cls._POOL_LOCK: - for key, pool in list(cls._POOL_CACHE.items()): + def close_all_shared_clients(cls): + with cls._CLIENT_LOCK: + for key, client in list(cls._CLIENT_CACHE.items()): try: - pool.close() + client.close() except Exception as e: - logger.warning(f"[NebulaGraphDB] Error closing pool key={key}: {e}") + logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}") finally: - logger.info(f"[NebulaGraphDB] Closed pool key={key}") - cls._POOL_CACHE.clear() - cls._POOL_REFCOUNT.clear() + logger.info(f"[NebulaGraphDBSync] Closed client key={key}") + cls._CLIENT_CACHE.clear() + cls._CLIENT_REFCOUNT.clear() @require_python_package( import_name="nebulagraph_python", - install_command="pip install ... @Tianxing", + install_command="pip install nebulagraph-python>=5.1.1", install_link=".....", ) def __init__(self, config: NebulaGraphDBConfig): @@ -371,34 +249,35 @@ def __init__(self, config: NebulaGraphDBConfig): # ---- NEW: pool acquisition strategy # Get or create a shared pool from the class-level cache - self._pool_key, self.pool = self._get_or_create_shared_pool(config) - self._owns_pool = True # We manage refcount for this instance + self._client_key, self._client = self._get_or_create_shared_client(config) + self._owns_client = True # auto-create graph type / graph / index if needed - if config.auto_create: + if getattr(config, "auto_create", False): self._ensure_database_exists() self.execute_query(f"SESSION SET GRAPH `{self.db_name}`") # Create only if not exists self.create_index(dimensions=config.embedding_dimension) - logger.info("Connected to NebulaGraph successfully.") @timed def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True): - with self.pool.get() as client: - try: - if auto_set_db and self.db_name: - client.execute(f"SESSION SET GRAPH `{self.db_name}`") - return client.execute(gql, timeout=timeout) - - except Exception as e: - if "Session not found" in str(e) or "Connection not established" in str(e): - logger.warning(f"[execute_query] {e!s}, replacing client...") - self.pool.replace_client(client) - return self.execute_query(gql, timeout, auto_set_db) - raise + try: + if auto_set_db and self.db_name: + self._client.execute(f"SESSION SET GRAPH `{self.db_name}`") + return self._client.execute(gql, timeout=timeout) + except Exception as e: + emsg = str(e) + if "Session not found" in emsg or "Connection not established" in emsg: + logger.warning(f"[execute_query] {e!s}, retry once...") + try: + return self._client.execute(gql, timeout=timeout) + except Exception: + logger.exception("[execute_query] retry failed") + raise + raise @timed def close(self): @@ -409,13 +288,13 @@ def close(self): - If pool was acquired via shared cache, decrement refcount and close when the last owner releases it. """ - if not self._owns_pool: - logger.debug("[NebulaGraphDB] close() skipped (injected pool).") + if not self._owns_client: + logger.debug("[NebulaGraphDBSync] close() skipped (injected client).") return - if self._pool_key: - self._release_shared_pool(self._pool_key) - self._pool_key = None - self.pool = None + if self._client_key: + self._release_shared_client(self._client_key) + self._client_key = None + self._client = None # NOTE: __del__ is best-effort; do not rely on GC order. def __del__(self): From be349379e26ca1390e1a923a282f4092dc4dd498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 17 Sep 2025 17:21:10 +0800 Subject: [PATCH 2/4] fix: bug in nebula and manager --- src/memos/graph_dbs/nebular.py | 19 +++++++++---------- .../tree_text_memory/organize/manager.py | 1 + 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 2ee3593a..dfbb70be 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -142,8 +142,9 @@ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> (tuple)[str, ) sess_conf = SessionConfig(graph=getattr(cfg, "space", None)) - - pool_conf = SessionPoolConfig(size=int(getattr(cfg, "max_client", 1000))) + pool_conf = SessionPoolConfig( + size=int(getattr(cfg, "max_client", 1000)), wait_timeout=500 + ) client = NebulaClient( hosts=conn_conf.hosts, @@ -256,8 +257,6 @@ def __init__(self, config: NebulaGraphDBConfig): if getattr(config, "auto_create", False): self._ensure_database_exists() - self.execute_query(f"SESSION SET GRAPH `{self.db_name}`") - # Create only if not exists self.create_index(dimensions=config.embedding_dimension) logger.info("Connected to NebulaGraph successfully.") @@ -266,13 +265,18 @@ def __init__(self, config: NebulaGraphDBConfig): def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True): try: if auto_set_db and self.db_name: - self._client.execute(f"SESSION SET GRAPH `{self.db_name}`") + gql = f"""USE `{self.db_name}` + {gql}""" return self._client.execute(gql, timeout=timeout) + except Exception as e: emsg = str(e) if "Session not found" in emsg or "Connection not established" in emsg: logger.warning(f"[execute_query] {e!s}, retry once...") try: + if auto_set_db and self.db_name: + gql = f"""USE `{self.db_name}` + {gql}""" return self._client.execute(gql, timeout=timeout) except Exception: logger.exception("[execute_query] retry failed") @@ -894,7 +898,6 @@ def search_by_embedding( where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" gql = f""" - USE `{self.db_name}` MATCH (n@Memory) {where_clause} ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC @@ -1249,7 +1252,6 @@ def get_structure_optimization_candidates( return_fields = self._build_return_fields(include_embedding) query = f""" - USE `{self.db_name}` MATCH (n@Memory) WHERE {where_clause} OPTIONAL MATCH (n)-[@PARENT]->(c@Memory) @@ -1417,11 +1419,8 @@ def _ensure_database_exists(self): logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}") create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}" - set_graph_working = f"SESSION SET GRAPH `{self.db_name}`" - try: self.execute_query(create_graph, auto_set_db=False) - self.execute_query(set_graph_working) logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.") except Exception as e: logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}") diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 6b0a6a55..bbd11c53 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -102,6 +102,7 @@ def get_current_memory_size(self) -> dict[str, int]: """ Return the cached memory type counts. """ + self._refresh_memory_size() return self.current_memory_size def _refresh_memory_size(self) -> None: From 498d62c0e9a6cbcadbbf05685685852088f12997 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 17 Sep 2025 21:27:33 +0800 Subject: [PATCH 3/4] feat: update product --- src/memos/configs/graph_db.py | 4 ++++ src/memos/graph_dbs/nebular.py | 4 ++-- src/memos/mem_os/core.py | 28 +++++++++++++++------------- src/memos/mem_os/product.py | 2 +- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 01246b06..2df91716 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -140,6 +140,10 @@ class NebulaGraphDBConfig(BaseGraphDBConfig): "If False: use a single shared database with logical isolation by user_name." ), ) + max_client: int = Field( + default=1000, + description=("max_client"), + ) embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding") @model_validator(mode="after") diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 57789f10..3a1136ba 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -143,7 +143,7 @@ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> (tuple)[str, sess_conf = SessionConfig(graph=getattr(cfg, "space", None)) pool_conf = SessionPoolConfig( - size=int(getattr(cfg, "max_client", 1000)), wait_timeout=500 + size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000 ) client = NebulaClient( @@ -262,7 +262,7 @@ def __init__(self, config: NebulaGraphDBConfig): logger.info("Connected to NebulaGraph successfully.") @timed - def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True): + def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True): try: if auto_set_db and self.db_name: gql = f"""USE `{self.db_name}` diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 852b5192..082ae5ce 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -182,13 +182,13 @@ def mem_reorganizer_wait(self) -> bool: logger.info(f"close reorganizer for {mem_cube.text_mem.config.cube_id}") mem_cube.text_mem.memory_manager.wait_reorganizer() - def _register_chat_history(self, user_id: str | None = None) -> None: + def _register_chat_history( + self, user_id: str | None = None, session_id: str | None = None + ) -> None: """Initialize chat history with user ID.""" - if user_id is None: - user_id = self.user_id self.chat_history_manager[user_id] = ChatHistory( - user_id=user_id, - session_id=self.session_id, + user_id=user_id if user_id is not None else self.user_id, + session_id=session_id if session_id is not None else self.session_id, created_at=datetime.utcnow(), total_messages=0, chat_history=[], @@ -563,6 +563,7 @@ def search( Returns: MemoryResult: A dictionary containing the search results. """ + target_session_id = session_id if session_id is not None else self.session_id target_user_id = user_id if user_id is not None else self.user_id self._validate_user_exists(target_user_id) @@ -609,7 +610,7 @@ def search( manual_close_internet=not internet_search, info={ "user_id": target_user_id, - "session_id": session_id if session_id is not None else self.session_id, + "session_id": target_session_id, "chat_history": chat_history.chat_history, }, moscube=moscube, @@ -652,7 +653,8 @@ def add( assert (messages is not None) or (memory_content is not None) or (doc_path is not None), ( "messages_or_doc_path or memory_content or doc_path must be provided." ) - self.session_id = session_id + # TODO: asure that session_id is a valid string + target_session_id = session_id if session_id else self.session_id target_user_id = user_id if user_id is not None else self.user_id if mem_cube_id is None: # Try to find a default cube for the user @@ -675,7 +677,7 @@ def add( if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text": add_memory = [] metadata = TextualMemoryMetadata( - user_id=target_user_id, session_id=self.session_id, source="conversation" + user_id=target_user_id, session_id=target_session_id, source="conversation" ) for message in messages: add_memory.append( @@ -687,7 +689,7 @@ def add( memories = self.mem_reader.get_memory( messages_list, type="chat", - info={"user_id": target_user_id, "session_id": self.session_id}, + info={"user_id": target_user_id, "session_id": target_session_id}, ) mem_ids = [] @@ -719,7 +721,7 @@ def add( ): if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text": metadata = TextualMemoryMetadata( - user_id=self.user_id, session_id=self.session_id, source="conversation" + user_id=target_user_id, session_id=target_session_id, source="conversation" ) self.mem_cubes[mem_cube_id].text_mem.add( [TextualMemoryItem(memory=memory_content, metadata=metadata)] @@ -731,7 +733,7 @@ def add( memories = self.mem_reader.get_memory( messages_list, type="chat", - info={"user_id": target_user_id, "session_id": self.session_id}, + info={"user_id": target_user_id, "session_id": target_session_id}, ) mem_ids = [] @@ -765,7 +767,7 @@ def add( doc_memories = self.mem_reader.get_memory( documents, type="doc", - info={"user_id": target_user_id, "session_id": self.session_id}, + info={"user_id": target_user_id, "session_id": target_session_id}, ) mem_ids = [] @@ -998,7 +1000,7 @@ def load( def get_user_info(self) -> dict[str, Any]: """Get current user information including accessible cubes. - + TODO: maybe input user_id Returns: dict: User information and accessible cubes. """ diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index cba31d5f..6032add6 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -1036,7 +1036,7 @@ def chat_with_references( system_prompt = self._build_enhance_system_prompt(user_id, memories_list) # Get chat history if user_id not in self.chat_history_manager: - self._register_chat_history(user_id) + self._register_chat_history(user_id, session_id) chat_history = self.chat_history_manager[user_id] if history: From d629dd640216edeb955460bc0b4b80ab8e993ab2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 17 Sep 2025 21:31:21 +0800 Subject: [PATCH 4/4] test: update --- tests/mem_os/test_memos_core.py | 35 +-------------------------------- 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/tests/mem_os/test_memos_core.py b/tests/mem_os/test_memos_core.py index 2c873e5a..6d2408d0 100644 --- a/tests/mem_os/test_memos_core.py +++ b/tests/mem_os/test_memos_core.py @@ -682,41 +682,8 @@ def test_chat_without_memories( # Verify response assert response == "This is a test response from the assistant." - @patch("memos.mem_os.core.UserManager") - @patch("memos.mem_os.core.MemReaderFactory") - @patch("memos.mem_os.core.LLMFactory") - def test_clear_messages( - self, - mock_llm_factory, - mock_reader_factory, - mock_user_manager_class, - mock_config, - mock_llm, - mock_mem_reader, - mock_user_manager, - ): - """Test clearing chat history.""" - # Setup mocks - mock_llm_factory.from_config.return_value = mock_llm - mock_reader_factory.from_config.return_value = mock_mem_reader - mock_user_manager_class.return_value = mock_user_manager - - mos = MOSCore(MOSConfig(**mock_config)) - - # Add some chat history - mos.chat_history_manager["test_user"].chat_history.append( - {"role": "user", "content": "Hello"} - ) - mos.chat_history_manager["test_user"].chat_history.append( - {"role": "assistant", "content": "Hi"} - ) - - assert len(mos.chat_history_manager["test_user"].chat_history) == 2 - - mos.clear_messages() - assert len(mos.chat_history_manager["test_user"].chat_history) == 0 - assert mos.chat_history_manager["test_user"].user_id == "test_user" +# TODO: test clear message class TestMOSSystemPrompt: