Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/memos/configs/graph_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ class NebulaGraphDBConfig(BaseGraphDBConfig):
"If False: use a single shared database with logical isolation by user_name."
),
)
max_client: int = Field(
default=1000,
description=("max_client"),
)
embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding")

@model_validator(mode="after")
Expand Down
21 changes: 9 additions & 12 deletions src/memos/graph_dbs/nebular.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import numpy as np

from memos import settings
from memos.configs.graph_db import NebulaGraphDBConfig
from memos.dependency import require_python_package
from memos.graph_dbs.base import BaseGraphDB
Expand Down Expand Up @@ -143,8 +142,9 @@ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> (tuple)[str,
)

sess_conf = SessionConfig(graph=getattr(cfg, "space", None))

pool_conf = SessionPoolConfig(size=int(getattr(cfg, "max_client", 1000)))
pool_conf = SessionPoolConfig(
size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000
)

client = NebulaClient(
hosts=conn_conf.hosts,
Expand Down Expand Up @@ -257,23 +257,25 @@ def __init__(self, config: NebulaGraphDBConfig):
if getattr(config, "auto_create", False):
self._ensure_database_exists()

self.execute_query(f"SESSION SET GRAPH `{self.db_name}`")

# Create only if not exists
self.create_index(dimensions=config.embedding_dimension)
logger.info("Connected to NebulaGraph successfully.")

@timed
def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True):
try:
if auto_set_db and self.db_name:
self._client.execute(f"SESSION SET GRAPH `{self.db_name}`")
gql = f"""USE `{self.db_name}`
{gql}"""
return self._client.execute(gql, timeout=timeout)
except Exception as e:
emsg = str(e)
if "Session not found" in emsg or "Connection not established" in emsg:
logger.warning(f"[execute_query] {e!s}, retry once...")
try:
if auto_set_db and self.db_name:
gql = f"""USE `{self.db_name}`
{gql}"""
return self._client.execute(gql, timeout=timeout)
except Exception:
logger.exception("[execute_query] retry failed")
Expand Down Expand Up @@ -907,7 +909,6 @@ def search_by_embedding(
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""

gql = f"""
USE `{self.db_name}`
MATCH (n@Memory)
{where_clause}
ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
Expand Down Expand Up @@ -1262,7 +1263,6 @@ def get_structure_optimization_candidates(
return_fields = self._build_return_fields(include_embedding)

query = f"""
USE `{self.db_name}`
MATCH (n@Memory)
WHERE {where_clause}
OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
Expand Down Expand Up @@ -1430,11 +1430,8 @@ def _ensure_database_exists(self):
logger.info(f"✅ Graph Type {graph_type_name} already include {self.dim_field}")

create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
set_graph_working = f"SESSION SET GRAPH `{self.db_name}`"

try:
self.execute_query(create_graph, auto_set_db=False)
self.execute_query(set_graph_working)
logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
except Exception as e:
logger.error(f"❌ Failed to create tag: {e} trace: {traceback.format_exc()}")
Expand Down
28 changes: 15 additions & 13 deletions src/memos/mem_os/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,13 @@ def mem_reorganizer_wait(self) -> bool:
logger.info(f"close reorganizer for {mem_cube.text_mem.config.cube_id}")
mem_cube.text_mem.memory_manager.wait_reorganizer()

def _register_chat_history(self, user_id: str | None = None) -> None:
def _register_chat_history(
self, user_id: str | None = None, session_id: str | None = None
) -> None:
"""Initialize chat history with user ID."""
if user_id is None:
user_id = self.user_id
self.chat_history_manager[user_id] = ChatHistory(
user_id=user_id,
session_id=self.session_id,
user_id=user_id if user_id is not None else self.user_id,
session_id=session_id if session_id is not None else self.session_id,
created_at=datetime.utcnow(),
total_messages=0,
chat_history=[],
Expand Down Expand Up @@ -563,6 +563,7 @@ def search(
Returns:
MemoryResult: A dictionary containing the search results.
"""
target_session_id = session_id if session_id is not None else self.session_id
target_user_id = user_id if user_id is not None else self.user_id

self._validate_user_exists(target_user_id)
Expand Down Expand Up @@ -609,7 +610,7 @@ def search(
manual_close_internet=not internet_search,
info={
"user_id": target_user_id,
"session_id": session_id if session_id is not None else self.session_id,
"session_id": target_session_id,
"chat_history": chat_history.chat_history,
},
moscube=moscube,
Expand Down Expand Up @@ -652,7 +653,8 @@ def add(
assert (messages is not None) or (memory_content is not None) or (doc_path is not None), (
"messages_or_doc_path or memory_content or doc_path must be provided."
)
self.session_id = session_id
# TODO: asure that session_id is a valid string
target_session_id = session_id if session_id else self.session_id
target_user_id = user_id if user_id is not None else self.user_id
if mem_cube_id is None:
# Try to find a default cube for the user
Expand All @@ -675,7 +677,7 @@ def add(
if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text":
add_memory = []
metadata = TextualMemoryMetadata(
user_id=target_user_id, session_id=self.session_id, source="conversation"
user_id=target_user_id, session_id=target_session_id, source="conversation"
)
for message in messages:
add_memory.append(
Expand All @@ -687,7 +689,7 @@ def add(
memories = self.mem_reader.get_memory(
messages_list,
type="chat",
info={"user_id": target_user_id, "session_id": self.session_id},
info={"user_id": target_user_id, "session_id": target_session_id},
)

mem_ids = []
Expand Down Expand Up @@ -719,7 +721,7 @@ def add(
):
if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text":
metadata = TextualMemoryMetadata(
user_id=self.user_id, session_id=self.session_id, source="conversation"
user_id=target_user_id, session_id=target_session_id, source="conversation"
)
self.mem_cubes[mem_cube_id].text_mem.add(
[TextualMemoryItem(memory=memory_content, metadata=metadata)]
Expand All @@ -731,7 +733,7 @@ def add(
memories = self.mem_reader.get_memory(
messages_list,
type="chat",
info={"user_id": target_user_id, "session_id": self.session_id},
info={"user_id": target_user_id, "session_id": target_session_id},
)

mem_ids = []
Expand Down Expand Up @@ -765,7 +767,7 @@ def add(
doc_memories = self.mem_reader.get_memory(
documents,
type="doc",
info={"user_id": target_user_id, "session_id": self.session_id},
info={"user_id": target_user_id, "session_id": target_session_id},
)

mem_ids = []
Expand Down Expand Up @@ -998,7 +1000,7 @@ def load(

def get_user_info(self) -> dict[str, Any]:
"""Get current user information including accessible cubes.

TODO: maybe input user_id
Returns:
dict: User information and accessible cubes.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/memos/mem_os/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ def chat_with_references(
system_prompt = self._build_enhance_system_prompt(user_id, memories_list)
# Get chat history
if user_id not in self.chat_history_manager:
self._register_chat_history(user_id)
self._register_chat_history(user_id, session_id)

chat_history = self.chat_history_manager[user_id]
if history:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def get_current_memory_size(self) -> dict[str, int]:
"""
Return the cached memory type counts.
"""
self._refresh_memory_size()
return self.current_memory_size

def _refresh_memory_size(self) -> None:
Expand Down
35 changes: 1 addition & 34 deletions tests/mem_os/test_memos_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,41 +682,8 @@ def test_chat_without_memories(
# Verify response
assert response == "This is a test response from the assistant."

@patch("memos.mem_os.core.UserManager")
@patch("memos.mem_os.core.MemReaderFactory")
@patch("memos.mem_os.core.LLMFactory")
def test_clear_messages(
self,
mock_llm_factory,
mock_reader_factory,
mock_user_manager_class,
mock_config,
mock_llm,
mock_mem_reader,
mock_user_manager,
):
"""Test clearing chat history."""
# Setup mocks
mock_llm_factory.from_config.return_value = mock_llm
mock_reader_factory.from_config.return_value = mock_mem_reader
mock_user_manager_class.return_value = mock_user_manager

mos = MOSCore(MOSConfig(**mock_config))

# Add some chat history
mos.chat_history_manager["test_user"].chat_history.append(
{"role": "user", "content": "Hello"}
)
mos.chat_history_manager["test_user"].chat_history.append(
{"role": "assistant", "content": "Hi"}
)

assert len(mos.chat_history_manager["test_user"].chat_history) == 2

mos.clear_messages()

assert len(mos.chat_history_manager["test_user"].chat_history) == 0
assert mos.chat_history_manager["test_user"].user_id == "test_user"
# TODO: test clear message


class TestMOSSystemPrompt:
Expand Down
Loading