diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 75bec556..61e73f9a 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -300,7 +300,6 @@ def get_start_default_config() -> dict[str, Any]: def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, GeneralMemCube]: """Create configuration for a specific user.""" openai_config = APIConfig.get_openai_config() - qwen_config = APIConfig.qwen_config() vllm_config = APIConfig.vllm_config() backend = os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai") @@ -351,8 +350,15 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General default_config = MOSConfig(**config_dict) - if os.getenv("NEO4J_BACKEND", "neo4j_community").lower() == "neo4j_community": - neo4j_community_config = APIConfig.get_neo4j_community_config(user_id) + neo4j_community_config = APIConfig.get_neo4j_community_config(user_id) + neo4j_config = APIConfig.get_neo4j_config(user_id) + + graph_db_backend_map = { + "neo4j-community": neo4j_community_config, + "neo4j": neo4j_config, + } + graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower() + if graph_db_backend in graph_db_backend_map: # Create MemCube config default_cube_config = GeneralMemCubeConfig.model_validate( { @@ -364,8 +370,8 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "extractor_llm": {"backend": "openai", "config": openai_config}, "dispatcher_llm": {"backend": "openai", "config": openai_config}, "graph_db": { - "backend": "neo4j-community", - "config": neo4j_community_config, + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], }, "embedder": APIConfig.get_embedder_config(), }, @@ -377,30 +383,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General } ) else: - neo4j_config = APIConfig.get_neo4j_config(user_id) - # Create MemCube config - default_cube_config = GeneralMemCubeConfig.model_validate( - { - "user_id": user_id, - "cube_id": f"{user_name}_default_cube", - "text_mem": { - "backend": "tree_text", - "config": { - "extractor_llm": {"backend": "openai", "config": openai_config}, - "dispatcher_llm": {"backend": "openai", "config": openai_config}, - "graph_db": { - "backend": "neo4j", - "config": neo4j_config, - }, - "embedder": APIConfig.get_embedder_config(), - }, - }, - "act_mem": {} - if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false" - else APIConfig.get_activation_vllm_config(), - "para_mem": {}, - } - ) + raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}") default_mem_cube = GeneralMemCube(default_cube_config) return default_config, default_mem_cube @@ -416,9 +399,14 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: return None openai_config = APIConfig.get_openai_config() - - if os.getenv("NEO4J_BACKEND", "neo4j_community").lower() == "neo4j_community": - neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default") + neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default") + neo4j_config = APIConfig.get_neo4j_config(user_id="default") + graph_db_backend_map = { + "neo4j-community": neo4j_community_config, + "neo4j": neo4j_config, + } + graph_db_backend = os.getenv("NEO4J_BACKEND", "neo4j-community").lower() + if graph_db_backend in graph_db_backend_map: return GeneralMemCubeConfig.model_validate( { "user_id": "default", @@ -429,8 +417,8 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "extractor_llm": {"backend": "openai", "config": openai_config}, "dispatcher_llm": {"backend": "openai", "config": openai_config}, "graph_db": { - "backend": "neo4j-community", - "config": neo4j_community_config, + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], }, "embedder": APIConfig.get_embedder_config(), "reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower() @@ -444,28 +432,4 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: } ) else: - neo4j_config = APIConfig.get_neo4j_config(user_id="default") - return GeneralMemCubeConfig.model_validate( - { - "user_id": "default", - "cube_id": "default_cube", - "text_mem": { - "backend": "tree_text", - "config": { - "extractor_llm": {"backend": "openai", "config": openai_config}, - "dispatcher_llm": {"backend": "openai", "config": openai_config}, - "graph_db": { - "backend": "neo4j", - "config": neo4j_config, - }, - "embedder": APIConfig.get_embedder_config(), - "reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower() - == "true", - }, - }, - "act_mem": {} - if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false" - else APIConfig.get_activation_vllm_config(), - "para_mem": {}, - } - ) + raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}") diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index d0938033..906420e2 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -646,14 +646,8 @@ def add( ) else: messages_list = [ - [ - {"role": "user", "content": memory_content}, - { - "role": "assistant", - "content": "", - }, # add by str to keep the format,assistant role is empty - ] - ] + [{"role": "user", "content": memory_content}] + ] # for only user-str input and convert message memories = self.mem_reader.get_memory( messages_list, type="chat", diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 071002f0..ccb0e52f 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -16,7 +16,9 @@ from memos.mem_cube.general import GeneralMemCube from memos.mem_os.core import MOSCore from memos.mem_os.utils.format_utils import ( + clean_json_response, convert_graph_to_tree_forworkmem, + ensure_unique_tree_ids, filter_nodes_by_tree_ids, remove_embedding_recursive, sort_children_by_memory_type, @@ -656,7 +658,7 @@ def get_suggestion_query(self, user_id: str, language: str = "zh") -> list[str]: you should generate some suggestion query, the query should be user what to query, user recently memories is: {memories} - please generate 3 suggestion query in English, + if the user recently memories is empty, please generate 3 suggestion query in English, output should be a json format, the key is "query", the value is a list of suggestion query. example: @@ -664,7 +666,7 @@ def get_suggestion_query(self, user_id: str, language: str = "zh") -> list[str]: "query": ["query1", "query2", "query3"] }} """ - text_mem_result = super().search("my recently memories", user_id=user_id, top_k=10)[ + text_mem_result = super().search("my recently memories", user_id=user_id, top_k=3)[ "text_mem" ] if text_mem_result: @@ -673,8 +675,8 @@ def get_suggestion_query(self, user_id: str, language: str = "zh") -> list[str]: memories = "" message_list = [{"role": "system", "content": suggestion_prompt.format(memories=memories)}] response = self.chat_llm.generate(message_list) - response_json = json.loads(response) - + clean_response = clean_json_response(response) + response_json = json.loads(clean_response) return response_json["query"] def chat( @@ -762,11 +764,10 @@ def chat_with_references( system_prompt = self._build_system_prompt(user_id, memories_list) # Get chat history - target_user_id = user_id if user_id is not None else self.user_id - if target_user_id not in self.chat_history_manager: - self._register_chat_history(target_user_id) + if user_id not in self.chat_history_manager: + self._register_chat_history(user_id) - chat_history = self.chat_history_manager[target_user_id] + chat_history = self.chat_history_manager[user_id] current_messages = [ {"role": "system", "content": system_prompt}, *chat_history.chat_history, @@ -918,8 +919,10 @@ def get_all( "UserMemory": 0.40, } tree_result, node_type_count = convert_graph_to_tree_forworkmem( - memories, target_node_count=150, type_ratios=custom_type_ratios + memories, target_node_count=200, type_ratios=custom_type_ratios ) + # Ensure all node IDs are unique in the tree structure + tree_result = ensure_unique_tree_ids(tree_result) memories_filtered = filter_nodes_by_tree_ids(tree_result, memories) children = tree_result["children"] children_sort = sort_children_by_memory_type(children) @@ -1009,6 +1012,8 @@ def get_subgraph( tree_result, node_type_count = convert_graph_to_tree_forworkmem( memories, target_node_count=150, type_ratios=custom_type_ratios ) + # Ensure all node IDs are unique in the tree structure + tree_result = ensure_unique_tree_ids(tree_result) memories_filtered = filter_nodes_by_tree_ids(tree_result, memories) children = tree_result["children"] children_sort = sort_children_by_memory_type(children) diff --git a/src/memos/mem_os/utils/format_utils.py b/src/memos/mem_os/utils/format_utils.py index 9722779f..8465b44f 100644 --- a/src/memos/mem_os/utils/format_utils.py +++ b/src/memos/mem_os/utils/format_utils.py @@ -239,10 +239,10 @@ def sample_nodes_with_type_balance( "MetaMemory": 0.05, # 5% } - print( + logger.info( f"Starting type-balanced sampling, original nodes: {len(nodes)}, target nodes: {target_count}" ) - print(f"Target type ratios: {type_ratios}") + logger.info(f"Target type ratios: {type_ratios}") # Analyze current node type distribution current_type_counts = {} @@ -255,7 +255,7 @@ def sample_nodes_with_type_balance( nodes_by_type[memory_type] = [] nodes_by_type[memory_type].append(node) - print(f"Current type distribution: {current_type_counts}") + logger.info(f"Current type distribution: {current_type_counts}") # Calculate target node count for each type type_targets = {} @@ -290,7 +290,7 @@ def sample_nodes_with_type_balance( ) type_targets[memory_type] = type_targets.get(memory_type, 0) + additional - print(f"Target node count for each type: {type_targets}") + logger.info(f"Target node count for each type: {type_targets}") # Perform subtree quality sampling for each type selected_nodes = [] @@ -300,16 +300,18 @@ def sample_nodes_with_type_balance( continue type_nodes = nodes_by_type[memory_type] - print(f"\n--- Processing {memory_type} type: {len(type_nodes)} -> {target_for_type} ---") + logger.info( + f"\n--- Processing {memory_type} type: {len(type_nodes)} -> {target_for_type} ---" + ) if len(type_nodes) <= target_for_type: selected_nodes.extend(type_nodes) - print(f" Select all: {len(type_nodes)} nodes") + logger.info(f" Select all: {len(type_nodes)} nodes") else: # Use enhanced subtree quality sampling type_selected = sample_by_enhanced_subtree_quality(type_nodes, edges, target_for_type) selected_nodes.extend(type_selected) - print(f" Sampled selection: {len(type_selected)} nodes") + logger.info(f" Sampled selection: {len(type_selected)} nodes") # Filter edges selected_node_ids = {node["id"] for node in selected_nodes} @@ -319,8 +321,8 @@ def sample_nodes_with_type_balance( if edge["source"] in selected_node_ids and edge["target"] in selected_node_ids ] - print(f"\nFinal selected nodes: {len(selected_nodes)}") - print(f"Final edges: {len(filtered_edges)}") + logger.info(f"\nFinal selected nodes: {len(selected_nodes)}") + logger.info(f"Final edges: {len(filtered_edges)}") # Verify final type distribution final_type_counts = {} @@ -328,11 +330,11 @@ def sample_nodes_with_type_balance( memory_type = node.get("metadata", {}).get("memory_type", "Unknown") final_type_counts[memory_type] = final_type_counts.get(memory_type, 0) + 1 - print(f"Final type distribution: {final_type_counts}") + logger.info(f"Final type distribution: {final_type_counts}") for memory_type, count in final_type_counts.items(): percentage = count / len(selected_nodes) * 100 target_percentage = type_ratios.get(memory_type, 0) * 100 - print( + logger.info( f" {memory_type}: {count} nodes ({percentage:.1f}%, target: {target_percentage:.1f}%)" ) @@ -358,9 +360,9 @@ def sample_by_enhanced_subtree_quality( subtree_analysis.items(), key=lambda x: x[1]["quality_score"], reverse=True ) - print(" Subtree quality ranking:") + logger.info(" Subtree quality ranking:") for i, (root_id, analysis) in enumerate(sorted_subtrees[:5]): - print( + logger.info( f" #{i + 1} Root node {root_id}: Quality={analysis['quality_score']:.2f}, " f"Depth={analysis['max_depth']}, Branches={analysis['branch_nodes']}, " f"Leaves={analysis['leaf_count']}, Max Width={analysis['max_width']}" @@ -386,7 +388,7 @@ def sample_by_enhanced_subtree_quality( if node: selected_nodes.append(node) selected_node_ids.add(node_id) - print(f" Select entire subtree {root_id}: +{len(new_nodes)} nodes") + logger.info(f" Select entire subtree {root_id}: +{len(new_nodes)} nodes") else: # Subtree too large, need partial selection if analysis["quality_score"] > 5: # Only partial selection for high-quality subtrees @@ -398,7 +400,7 @@ def sample_by_enhanced_subtree_quality( selected_nodes.extend(partial_selection) for node in partial_selection: selected_node_ids.add(node["id"]) - print( + logger.info( f" Partial selection of subtree {root_id}: +{len(partial_selection)} nodes" ) @@ -411,7 +413,7 @@ def sample_by_enhanced_subtree_quality( remaining_count = target_count - len(selected_nodes) additional = sample_nodes_by_importance(remaining_nodes, edges, remaining_count) selected_nodes.extend(additional) - print(f" Supplementary selection: +{len(additional)} nodes") + logger.info(f" Supplementary selection: +{len(additional)} nodes") return selected_nodes @@ -493,7 +495,7 @@ def sample_nodes_by_importance( # Modified main function to use new sampling strategy def convert_graph_to_tree_forworkmem( json_data: dict[str, Any], - target_node_count: int = 150, + target_node_count: int = 200, type_ratios: dict[str, float] | None = None, ) -> dict[str, Any]: """ @@ -502,8 +504,8 @@ def convert_graph_to_tree_forworkmem( original_nodes = json_data.get("nodes", []) original_edges = json_data.get("edges", []) - print(f"Original node count: {len(original_nodes)}") - print(f"Target node count: {target_node_count}") + logger.info(f"Original node count: {len(original_nodes)}") + logger.info(f"Target node count: {target_node_count}") filter_original_edges = [] for original_edge in original_edges: if original_edge["type"] == "PARENT": @@ -633,7 +635,7 @@ def build_tree(node_id: str) -> dict[str, Any]: def print_tree_structure(node: dict[str, Any], level: int = 0, max_level: int = 5): - """Print the first few layers of tree structure for easy viewing""" + """logger.info the first few layers of tree structure for easy viewing""" if level > max_level: return @@ -647,21 +649,21 @@ def print_tree_structure(node: dict[str, Any], level: int = 0, max_level: int = children = node.get("children", []) if children: # Intermediate node, display name, type and child count - print(f"{indent}- {node_name} [{memory_type}] ({len(children)} children)") - print(f"{indent} ID: {node_id}") + logger.info(f"{indent}- {node_name} [{memory_type}] ({len(children)} children)") + logger.info(f"{indent} ID: {node_id}") display_value = node_value[:80] + "..." if len(node_value) > 80 else node_value - print(f"{indent} Value: {display_value}") + logger.info(f"{indent} Value: {display_value}") if level < max_level: for child in children: print_tree_structure(child, level + 1, max_level) elif level == max_level: - print(f"{indent} ... (expansion limited)") + logger.info(f"{indent} ... (expansion limited)") else: # Leaf node, display name, type and value display_value = node_value[:80] + "..." if len(node_value) > 80 else node_value - print(f"{indent}- {node_name} [{memory_type}]: {display_value}") - print(f"{indent} ID: {node_id}") + logger.info(f"{indent}- {node_name} [{memory_type}]: {display_value}") + logger.info(f"{indent} ID: {node_id}") def analyze_final_tree_quality(tree_data: dict[str, Any]) -> dict: @@ -856,107 +858,107 @@ def count_subtree(subnode, subdepth): def print_tree_analysis(tree_data: dict[str, Any]): - """Print enhanced tree analysis results""" + """logger.info enhanced tree analysis results""" stats = analyze_final_tree_quality(tree_data) - print("\n" + "=" * 60) - print("🌳 Enhanced Tree Structure Quality Analysis Report") - print("=" * 60) + logger.info("\n" + "=" * 60) + logger.info("🌳 Enhanced Tree Structure Quality Analysis Report") + logger.info("=" * 60) # Basic statistics - print("\nšŸ“Š Basic Statistics:") - print(f" Total nodes: {stats['total_nodes']}") - print(f" Max depth: {stats['max_depth']}") - print( + logger.info("\nšŸ“Š Basic Statistics:") + logger.info(f" Total nodes: {stats['total_nodes']}") + logger.info(f" Max depth: {stats['max_depth']}") + logger.info( f" Leaf nodes: {stats['total_leaves']} ({stats['total_leaves'] / stats['total_nodes'] * 100:.1f}%)" ) - print( + logger.info( f" Branch nodes: {stats['total_branches']} ({stats['total_branches'] / stats['total_nodes'] * 100:.1f}%)" ) # Structure quality assessment structure = stats.get("structure_quality", {}) if structure: - print("\nšŸ—ļø Structure Quality Assessment:") - print( + logger.info("\nšŸ—ļø Structure Quality Assessment:") + logger.info( f" Branch density: {structure['branch_density']:.3f} ({'āœ… Good' if 0.2 <= structure['branch_density'] <= 0.6 else 'āš ļø Needs improvement'})" ) - print( + logger.info( f" Leaf ratio: {structure['leaf_ratio']:.3f} ({'āœ… Good' if 0.3 <= structure['leaf_ratio'] <= 0.7 else 'āš ļø Needs improvement'})" ) - print(f" Max width: {structure['max_width']}") - print( + logger.info(f" Max width: {structure['max_width']}") + logger.info( f" Depth-width ratio: {structure['depth_width_ratio']:.2f} ({'āœ… Good' if structure['depth_width_ratio'] <= 3 else 'āš ļø Too thin'})" ) - print( + logger.info( f" Overall balance: {'āœ… Good' if structure['is_well_balanced'] else 'āš ļø Needs improvement'}" ) # Single chain analysis chain_analysis = stats.get("chain_analysis", {}) if chain_analysis: - print("\nšŸ”— Single Chain Structure Analysis:") - print(f" Longest chain: {chain_analysis.get('max_chain_length', 0)} layers") - print(f" Single chain subtrees: {chain_analysis.get('single_chain_subtrees', 0)}") - print( + logger.info("\nšŸ”— Single Chain Structure Analysis:") + logger.info(f" Longest chain: {chain_analysis.get('max_chain_length', 0)} layers") + logger.info(f" Single chain subtrees: {chain_analysis.get('single_chain_subtrees', 0)}") + logger.info( f" Single chain subtree ratio: {chain_analysis.get('chain_subtree_ratio', 0) * 100:.1f}%" ) if chain_analysis.get("max_chain_length", 0) > 5: - print(" āš ļø Warning: Overly long single chain structure may affect display") + logger.info(" āš ļø Warning: Overly long single chain structure may affect display") elif chain_analysis.get("chain_subtree_ratio", 0) > 0.3: - print( + logger.info( " āš ļø Warning: Too many single chain subtrees, suggest increasing branch structure" ) else: - print(" āœ… Single chain structure well controlled") + logger.info(" āœ… Single chain structure well controlled") # Type diversity type_div = stats.get("type_diversity", {}) if type_div: - print("\nšŸŽØ Type Diversity Analysis:") - print(f" Total types: {type_div['total_types']}") - print(f" Diversity index: {type_div['shannon_diversity']:.3f}") - print(f" Normalized diversity: {type_div['normalized_diversity']:.3f}") - print(f" Distribution balance: {type_div['distribution_balance']:.3f}") + logger.info("\nšŸŽØ Type Diversity Analysis:") + logger.info(f" Total types: {type_div['total_types']}") + logger.info(f" Diversity index: {type_div['shannon_diversity']:.3f}") + logger.info(f" Normalized diversity: {type_div['normalized_diversity']:.3f}") + logger.info(f" Distribution balance: {type_div['distribution_balance']:.3f}") # Type distribution - print("\nšŸ“‹ Type Distribution Details:") + logger.info("\nšŸ“‹ Type Distribution Details:") for mem_type, count in sorted(stats["by_type"].items(), key=lambda x: x[1], reverse=True): percentage = count / stats["total_nodes"] * 100 - print(f" {mem_type}: {count} nodes ({percentage:.1f}%)") + logger.info(f" {mem_type}: {count} nodes ({percentage:.1f}%)") # Depth distribution - print("\nšŸ“ Depth Distribution:") + logger.info("\nšŸ“ Depth Distribution:") for depth in sorted(stats["by_depth"].keys()): count = stats["by_depth"][depth] - print(f" Depth {depth}: {count} nodes") + logger.info(f" Depth {depth}: {count} nodes") # Major subtree analysis if stats["subtrees"]: - print("\n🌲 Major Subtree Analysis (sorted by quality):") + logger.info("\n🌲 Major Subtree Analysis (sorted by quality):") sorted_subtrees = sorted( stats["subtrees"], key=lambda x: x.get("quality_score", 0), reverse=True ) for i, subtree in enumerate(sorted_subtrees[:8]): # Show first 8 quality = subtree.get("quality_score", 0) - print(f" #{i + 1} {subtree['root']} [{subtree['type']}]:") - print(f" Quality score: {quality:.2f}") - print( + logger.info(f" #{i + 1} {subtree['root']} [{subtree['type']}]:") + logger.info(f" Quality score: {quality:.2f}") + logger.info( f" Structure: Depth={subtree['depth']}, Branches={subtree['branches']}, Leaves={subtree['leaves']}" ) - print( + logger.info( f" Density: Branch density={subtree.get('branch_density', 0):.3f}, Leaf ratio={subtree.get('leaf_ratio', 0):.3f}" ) if quality > 15: - print(" āœ… High quality subtree") + logger.info(" āœ… High quality subtree") elif quality > 8: - print(" 🟔 Medium quality subtree") + logger.info(" 🟔 Medium quality subtree") else: - print(" šŸ”“ Low quality subtree") + logger.info(" šŸ”“ Low quality subtree") - print("\n" + "=" * 60) + logger.info("\n" + "=" * 60) def remove_embedding_recursive(memory_info: dict) -> Any: @@ -1152,3 +1154,204 @@ def convert_activation_memory_summary(act_mem_items: list[KVCacheItem]) -> dict[ "total_parameters": total_parameters, "summary": f"Activation memory contains {total_items} items with {total_layers} layers and approximately {total_parameters:,} parameters", } + + +def detect_and_remove_duplicate_ids(tree_node: dict[str, Any]) -> dict[str, Any]: + """ + Detect and remove duplicate IDs in tree structure by skipping duplicate nodes. + First occurrence of each ID is kept, subsequent duplicates are removed. + + Args: + tree_node: Tree node (dictionary format) + + Returns: + dict: Fixed tree node with duplicate nodes removed + """ + used_ids = set() + removed_count = 0 + + def remove_duplicates_recursive( + node: dict[str, Any], parent_path: str = "" + ) -> dict[str, Any] | None: + """Recursively remove duplicate IDs by skipping duplicate nodes""" + nonlocal removed_count + + if not isinstance(node, dict): + return node + + # Create node copy + fixed_node = node.copy() + + # Handle current node ID + current_id = fixed_node.get("id", "") + if current_id in used_ids and current_id not in ["root", "WorkingMemory"]: + # Skip this duplicate node + logger.info(f"Skipping duplicate node: {current_id} (path: {parent_path})") + removed_count += 1 + return None # Return None to indicate this node should be removed + else: + used_ids.add(current_id) + + # Recursively process child nodes + if "children" in fixed_node and isinstance(fixed_node["children"], list): + fixed_children = [] + for i, child in enumerate(fixed_node["children"]): + child_path = f"{parent_path}/{fixed_node.get('node_name', 'unknown')}[{i}]" + fixed_child = remove_duplicates_recursive(child, child_path) + if fixed_child is not None: # Only add non-None children + fixed_children.append(fixed_child) + fixed_node["children"] = fixed_children + + return fixed_node + + result = remove_duplicates_recursive(tree_node) + if result is not None: + logger.info(f"Removed {removed_count} duplicate nodes") + return result + else: + # If root node itself was removed (shouldn't happen), return empty root + return { + "id": "root", + "node_name": "root", + "value": "root", + "memory_type": "Root", + "children": [], + } + + +def validate_tree_structure(tree_node: dict[str, Any]) -> dict[str, Any]: + """ + Validate tree structure integrity, including ID uniqueness check + + Args: + tree_node: Tree node (dictionary format) + + Returns: + dict: Validation result containing error messages and fix suggestions + """ + validation_result = { + "is_valid": True, + "errors": [], + "warnings": [], + "total_nodes": 0, + "unique_ids": set(), + "duplicate_ids": set(), + "missing_ids": set(), + "invalid_structure": [], + } + + def validate_recursive(node: dict[str, Any], path: str = "", depth: int = 0): + """Recursively validate tree structure""" + if not isinstance(node, dict): + validation_result["errors"].append(f"Node is not a dictionary: {path}") + validation_result["is_valid"] = False + return + + validation_result["total_nodes"] += 1 + + # Check required fields + if "id" not in node: + validation_result["errors"].append(f"Node missing ID field: {path}") + validation_result["missing_ids"].add(path) + validation_result["is_valid"] = False + else: + node_id = node["id"] + if node_id in validation_result["unique_ids"]: + validation_result["errors"].append(f"Duplicate node ID: {node_id} (path: {path})") + validation_result["duplicate_ids"].add(node_id) + validation_result["is_valid"] = False + else: + validation_result["unique_ids"].add(node_id) + + # Check other required fields + required_fields = ["node_name", "value", "memory_type"] + for field in required_fields: + if field not in node: + validation_result["warnings"].append(f"Node missing field '{field}': {path}") + + # Recursively validate child nodes + if "children" in node: + if not isinstance(node["children"], list): + validation_result["errors"].append(f"Children field is not a list: {path}") + validation_result["is_valid"] = False + else: + for i, child in enumerate(node["children"]): + child_path = f"{path}/children[{i}]" + validate_recursive(child, child_path, depth + 1) + + # Check depth limit + if depth > 20: + validation_result["warnings"].append(f"Tree depth too deep ({depth}): {path}") + + validate_recursive(tree_node) + + # Generate fix suggestions + if validation_result["duplicate_ids"]: + validation_result["fix_suggestion"] = ( + "Use detect_and_fix_duplicate_ids() function to fix duplicate IDs" + ) + + return validation_result + + +def ensure_unique_tree_ids(tree_result: dict[str, Any]) -> dict[str, Any]: + """ + Ensure all node IDs in tree structure are unique by removing duplicate nodes, + this is a post-processing function for convert_graph_to_tree_forworkmem + + Args: + tree_result: Tree structure returned by convert_graph_to_tree_forworkmem + + Returns: + dict: Fixed tree structure with duplicate nodes removed + """ + logger.info("šŸ” Starting duplicate ID check in tree structure...") + + # First validate tree structure + validation = validate_tree_structure(tree_result) + + if validation["is_valid"]: + logger.info("Tree structure validation passed, no duplicate IDs found") + return tree_result + + # Report issues + logger.info(f"Found {len(validation['errors'])} errors:") + for error in validation["errors"][:5]: # Only show first 5 errors + logger.info(f" - {error}") + + if len(validation["errors"]) > 5: + logger.info(f" ... and {len(validation['errors']) - 5} more errors") + + logger.info("Statistics:") + logger.info(f" - Total nodes: {validation['total_nodes']}") + logger.info(f" - Unique IDs: {len(validation['unique_ids'])}") + logger.info(f" - Duplicate IDs: {len(validation['duplicate_ids'])}") + + # Remove duplicate nodes + logger.info(" Starting duplicate node removal...") + fixed_tree = detect_and_remove_duplicate_ids(tree_result) + + # Validate again + post_validation = validate_tree_structure(fixed_tree) + if post_validation["is_valid"]: + logger.info("Removal completed, tree structure is now valid") + logger.info(f"Final node count: {post_validation['total_nodes']}") + else: + logger.info("Issues remain after removal, please check code logic") + for error in post_validation["errors"][:3]: + logger.info(f" - {error}") + + return fixed_tree + + +def clean_json_response(response: str) -> str: + """ + Remove markdown JSON code block formatting from LLM response. + + Args: + response: Raw response string that may contain ```json and ``` + + Returns: + str: Clean JSON string without markdown formatting + """ + return response.replace("```json", "").replace("```", "").strip()