diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 01246b06..2df91716 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -140,6 +140,10 @@ class NebulaGraphDBConfig(BaseGraphDBConfig): "If False: use a single shared database with logical isolation by user_name." ), ) + max_client: int = Field( + default=1000, + description=("max_client"), + ) embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding") @model_validator(mode="after") diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index cfe69b39..3a1136ba 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -8,7 +8,6 @@ import numpy as np -from memos import settings from memos.configs.graph_db import NebulaGraphDBConfig from memos.dependency import require_python_package from memos.graph_dbs.base import BaseGraphDB @@ -143,8 +142,9 @@ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> (tuple)[str, ) sess_conf = SessionConfig(graph=getattr(cfg, "space", None)) - - pool_conf = SessionPoolConfig(size=int(getattr(cfg, "max_client", 1000))) + pool_conf = SessionPoolConfig( + size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000 + ) client = NebulaClient( hosts=conn_conf.hosts, @@ -257,23 +257,25 @@ def __init__(self, config: NebulaGraphDBConfig): if getattr(config, "auto_create", False): self._ensure_database_exists() - self.execute_query(f"SESSION SET GRAPH `{self.db_name}`") - # Create only if not exists self.create_index(dimensions=config.embedding_dimension) logger.info("Connected to NebulaGraph successfully.") @timed - def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True): + def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True): try: if auto_set_db and self.db_name: - self._client.execute(f"SESSION SET GRAPH `{self.db_name}`") + gql = f"""USE `{self.db_name}` + {gql}""" return self._client.execute(gql, timeout=timeout) except Exception as e: emsg = str(e) if "Session not found" in emsg or "Connection not established" in emsg: logger.warning(f"[execute_query] {e!s}, retry once...") try: + if auto_set_db and self.db_name: + gql = f"""USE `{self.db_name}` + {gql}""" return self._client.execute(gql, timeout=timeout) except Exception: logger.exception("[execute_query] retry failed") @@ -907,7 +909,6 @@ def search_by_embedding( where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" gql = f""" - USE `{self.db_name}` MATCH (n@Memory) {where_clause} ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC @@ -1262,7 +1263,6 @@ def get_structure_optimization_candidates( return_fields = self._build_return_fields(include_embedding) query = f""" - USE `{self.db_name}` MATCH (n@Memory) WHERE {where_clause} OPTIONAL MATCH (n)-[@PARENT]->(c@Memory) @@ -1430,11 +1430,8 @@ def _ensure_database_exists(self): logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}") create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}" - set_graph_working = f"SESSION SET GRAPH `{self.db_name}`" - try: self.execute_query(create_graph, auto_set_db=False) - self.execute_query(set_graph_working) logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.") except Exception as e: logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}") diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 852b5192..082ae5ce 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -182,13 +182,13 @@ def mem_reorganizer_wait(self) -> bool: logger.info(f"close reorganizer for {mem_cube.text_mem.config.cube_id}") mem_cube.text_mem.memory_manager.wait_reorganizer() - def _register_chat_history(self, user_id: str | None = None) -> None: + def _register_chat_history( + self, user_id: str | None = None, session_id: str | None = None + ) -> None: """Initialize chat history with user ID.""" - if user_id is None: - user_id = self.user_id self.chat_history_manager[user_id] = ChatHistory( - user_id=user_id, - session_id=self.session_id, + user_id=user_id if user_id is not None else self.user_id, + session_id=session_id if session_id is not None else self.session_id, created_at=datetime.utcnow(), total_messages=0, chat_history=[], @@ -563,6 +563,7 @@ def search( Returns: MemoryResult: A dictionary containing the search results. """ + target_session_id = session_id if session_id is not None else self.session_id target_user_id = user_id if user_id is not None else self.user_id self._validate_user_exists(target_user_id) @@ -609,7 +610,7 @@ def search( manual_close_internet=not internet_search, info={ "user_id": target_user_id, - "session_id": session_id if session_id is not None else self.session_id, + "session_id": target_session_id, "chat_history": chat_history.chat_history, }, moscube=moscube, @@ -652,7 +653,8 @@ def add( assert (messages is not None) or (memory_content is not None) or (doc_path is not None), ( "messages_or_doc_path or memory_content or doc_path must be provided." ) - self.session_id = session_id + # TODO: asure that session_id is a valid string + target_session_id = session_id if session_id else self.session_id target_user_id = user_id if user_id is not None else self.user_id if mem_cube_id is None: # Try to find a default cube for the user @@ -675,7 +677,7 @@ def add( if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text": add_memory = [] metadata = TextualMemoryMetadata( - user_id=target_user_id, session_id=self.session_id, source="conversation" + user_id=target_user_id, session_id=target_session_id, source="conversation" ) for message in messages: add_memory.append( @@ -687,7 +689,7 @@ def add( memories = self.mem_reader.get_memory( messages_list, type="chat", - info={"user_id": target_user_id, "session_id": self.session_id}, + info={"user_id": target_user_id, "session_id": target_session_id}, ) mem_ids = [] @@ -719,7 +721,7 @@ def add( ): if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text": metadata = TextualMemoryMetadata( - user_id=self.user_id, session_id=self.session_id, source="conversation" + user_id=target_user_id, session_id=target_session_id, source="conversation" ) self.mem_cubes[mem_cube_id].text_mem.add( [TextualMemoryItem(memory=memory_content, metadata=metadata)] @@ -731,7 +733,7 @@ def add( memories = self.mem_reader.get_memory( messages_list, type="chat", - info={"user_id": target_user_id, "session_id": self.session_id}, + info={"user_id": target_user_id, "session_id": target_session_id}, ) mem_ids = [] @@ -765,7 +767,7 @@ def add( doc_memories = self.mem_reader.get_memory( documents, type="doc", - info={"user_id": target_user_id, "session_id": self.session_id}, + info={"user_id": target_user_id, "session_id": target_session_id}, ) mem_ids = [] @@ -998,7 +1000,7 @@ def load( def get_user_info(self) -> dict[str, Any]: """Get current user information including accessible cubes. - + TODO: maybe input user_id Returns: dict: User information and accessible cubes. """ diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index cba31d5f..6032add6 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -1036,7 +1036,7 @@ def chat_with_references( system_prompt = self._build_enhance_system_prompt(user_id, memories_list) # Get chat history if user_id not in self.chat_history_manager: - self._register_chat_history(user_id) + self._register_chat_history(user_id, session_id) chat_history = self.chat_history_manager[user_id] if history: diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index e3f10b54..4ca30d6f 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -103,6 +103,7 @@ def get_current_memory_size(self) -> dict[str, int]: """ Return the cached memory type counts. """ + self._refresh_memory_size() return self.current_memory_size def _refresh_memory_size(self) -> None: diff --git a/tests/mem_os/test_memos_core.py b/tests/mem_os/test_memos_core.py index 2c873e5a..6d2408d0 100644 --- a/tests/mem_os/test_memos_core.py +++ b/tests/mem_os/test_memos_core.py @@ -682,41 +682,8 @@ def test_chat_without_memories( # Verify response assert response == "This is a test response from the assistant." - @patch("memos.mem_os.core.UserManager") - @patch("memos.mem_os.core.MemReaderFactory") - @patch("memos.mem_os.core.LLMFactory") - def test_clear_messages( - self, - mock_llm_factory, - mock_reader_factory, - mock_user_manager_class, - mock_config, - mock_llm, - mock_mem_reader, - mock_user_manager, - ): - """Test clearing chat history.""" - # Setup mocks - mock_llm_factory.from_config.return_value = mock_llm - mock_reader_factory.from_config.return_value = mock_mem_reader - mock_user_manager_class.return_value = mock_user_manager - - mos = MOSCore(MOSConfig(**mock_config)) - - # Add some chat history - mos.chat_history_manager["test_user"].chat_history.append( - {"role": "user", "content": "Hello"} - ) - mos.chat_history_manager["test_user"].chat_history.append( - {"role": "assistant", "content": "Hi"} - ) - - assert len(mos.chat_history_manager["test_user"].chat_history) == 2 - - mos.clear_messages() - assert len(mos.chat_history_manager["test_user"].chat_history) == 0 - assert mos.chat_history_manager["test_user"].user_id == "test_user" +# TODO: test clear message class TestMOSSystemPrompt: