Skip to content
82 changes: 23 additions & 59 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ def get_start_default_config() -> dict[str, Any]:
def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, GeneralMemCube]:
"""Create configuration for a specific user."""
openai_config = APIConfig.get_openai_config()

qwen_config = APIConfig.qwen_config()
vllm_config = APIConfig.vllm_config()
backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai")
Expand Down Expand Up @@ -351,8 +350,15 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General

default_config = MOSConfig(**config_dict)

if os.getenv("NEO4J_BACKEND", "neo4j_community").lower() == "neo4j_community":
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id)
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id)
neo4j_config = APIConfig.get_neo4j_config(user_id)

graph_db_backend_map = {
"neo4j-community": neo4j_community_config,
"neo4j": neo4j_config,
}
graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower()
if graph_db_backend in graph_db_backend_map:
# Create MemCube config
default_cube_config = GeneralMemCubeConfig.model_validate(
{
Expand All @@ -364,8 +370,8 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
"extractor_llm": {"backend": "openai", "config": openai_config},
"dispatcher_llm": {"backend": "openai", "config": openai_config},
"graph_db": {
"backend": "neo4j-community",
"config": neo4j_community_config,
"backend": graph_db_backend,
"config": graph_db_backend_map[graph_db_backend],
},
"embedder": APIConfig.get_embedder_config(),
},
Expand All @@ -377,30 +383,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
}
)
else:
neo4j_config = APIConfig.get_neo4j_config(user_id)
# Create MemCube config
default_cube_config = GeneralMemCubeConfig.model_validate(
{
"user_id": user_id,
"cube_id": f"{user_name}_default_cube",
"text_mem": {
"backend": "tree_text",
"config": {
"extractor_llm": {"backend": "openai", "config": openai_config},
"dispatcher_llm": {"backend": "openai", "config": openai_config},
"graph_db": {
"backend": "neo4j",
"config": neo4j_config,
},
"embedder": APIConfig.get_embedder_config(),
},
},
"act_mem": {}
if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false"
else APIConfig.get_activation_vllm_config(),
"para_mem": {},
}
)
raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}")

default_mem_cube = GeneralMemCube(default_cube_config)
return default_config, default_mem_cube
Expand All @@ -416,9 +399,14 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
return None

openai_config = APIConfig.get_openai_config()

if os.getenv("NEO4J_BACKEND", "neo4j_community").lower() == "neo4j_community":
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default")
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default")
neo4j_config = APIConfig.get_neo4j_config(user_id="default")
graph_db_backend_map = {
"neo4j-community": neo4j_community_config,
"neo4j": neo4j_config,
}
graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower()
if graph_db_backend in graph_db_backend_map:
return GeneralMemCubeConfig.model_validate(
{
"user_id": "default",
Expand All @@ -429,8 +417,8 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
"extractor_llm": {"backend": "openai", "config": openai_config},
"dispatcher_llm": {"backend": "openai", "config": openai_config},
"graph_db": {
"backend": "neo4j-community",
"config": neo4j_community_config,
"backend": graph_db_backend,
"config": graph_db_backend_map[graph_db_backend],
},
"embedder": APIConfig.get_embedder_config(),
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower()
Expand All @@ -444,28 +432,4 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
}
)
else:
neo4j_config = APIConfig.get_neo4j_config(user_id="default")
return GeneralMemCubeConfig.model_validate(
{
"user_id": "default",
"cube_id": "default_cube",
"text_mem": {
"backend": "tree_text",
"config": {
"extractor_llm": {"backend": "openai", "config": openai_config},
"dispatcher_llm": {"backend": "openai", "config": openai_config},
"graph_db": {
"backend": "neo4j",
"config": neo4j_config,
},
"embedder": APIConfig.get_embedder_config(),
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower()
== "true",
},
},
"act_mem": {}
if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false"
else APIConfig.get_activation_vllm_config(),
"para_mem": {},
}
)
raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}")
10 changes: 2 additions & 8 deletions src/memos/mem_os/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,14 +646,8 @@ def add(
)
else:
messages_list = [
[
{"role": "user", "content": memory_content},
{
"role": "assistant",
"content": "",
}, # add by str to keep the format,assistant role is empty
]
]
[{"role": "user", "content": memory_content}]
] # for only user-str input and convert message
memories = self.mem_reader.get_memory(
messages_list,
type="chat",
Expand Down
23 changes: 14 additions & 9 deletions src/memos/mem_os/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from memos.mem_cube.general import GeneralMemCube
from memos.mem_os.core import MOSCore
from memos.mem_os.utils.format_utils import (
clean_json_response,
convert_graph_to_tree_forworkmem,
ensure_unique_tree_ids,
filter_nodes_by_tree_ids,
remove_embedding_recursive,
sort_children_by_memory_type,
Expand Down Expand Up @@ -656,15 +658,15 @@ def get_suggestion_query(self, user_id: str, language: str = "zh") -> list[str]:
you should generate some suggestion query, the query should be user what to query,
user recently memories is:
{memories}
please generate 3 suggestion query in English,
if the user recently memories is empty, please generate 3 suggestion query in English,
output should be a json format, the key is "query", the value is a list of suggestion query.

example:
{{
"query": ["query1", "query2", "query3"]
}}
"""
text_mem_result = super().search("my recently memories", user_id=user_id, top_k=10)[
text_mem_result = super().search("my recently memories", user_id=user_id, top_k=3)[
"text_mem"
]
if text_mem_result:
Expand All @@ -673,8 +675,8 @@ def get_suggestion_query(self, user_id: str, language: str = "zh") -> list[str]:
memories = ""
message_list = [{"role": "system", "content": suggestion_prompt.format(memories=memories)}]
response = self.chat_llm.generate(message_list)
response_json = json.loads(response)

clean_response = clean_json_response(response)
response_json = json.loads(clean_response)
return response_json["query"]

def chat(
Expand Down Expand Up @@ -762,11 +764,10 @@ def chat_with_references(
system_prompt = self._build_system_prompt(user_id, memories_list)

# Get chat history
target_user_id = user_id if user_id is not None else self.user_id
if target_user_id not in self.chat_history_manager:
self._register_chat_history(target_user_id)
if user_id not in self.chat_history_manager:
self._register_chat_history(user_id)

chat_history = self.chat_history_manager[target_user_id]
chat_history = self.chat_history_manager[user_id]
current_messages = [
{"role": "system", "content": system_prompt},
*chat_history.chat_history,
Expand Down Expand Up @@ -918,8 +919,10 @@ def get_all(
"UserMemory": 0.40,
}
tree_result, node_type_count = convert_graph_to_tree_forworkmem(
memories, target_node_count=150, type_ratios=custom_type_ratios
memories, target_node_count=200, type_ratios=custom_type_ratios
)
# Ensure all node IDs are unique in the tree structure
tree_result = ensure_unique_tree_ids(tree_result)
memories_filtered = filter_nodes_by_tree_ids(tree_result, memories)
children = tree_result["children"]
children_sort = sort_children_by_memory_type(children)
Expand Down Expand Up @@ -1009,6 +1012,8 @@ def get_subgraph(
tree_result, node_type_count = convert_graph_to_tree_forworkmem(
memories, target_node_count=150, type_ratios=custom_type_ratios
)
# Ensure all node IDs are unique in the tree structure
tree_result = ensure_unique_tree_ids(tree_result)
memories_filtered = filter_nodes_by_tree_ids(tree_result, memories)
children = tree_result["children"]
children_sort = sort_children_by_memory_type(children)
Expand Down
Loading