diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 55bf673b..84cc8ecb 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -214,24 +214,39 @@ def search_single(vec, filt=None): or [] ) - all_hits = [] - # Path A: without filter - with ContextThreadPoolExecutor() as executor: - futures = [ - executor.submit(search_single, vec, None) for vec in query_embedding[:max_num] - ] - for f in concurrent.futures.as_completed(futures): - all_hits.extend(f.result() or []) - - # Path B: with filter - if search_filter: + def search_path_a(): + """Path A: search without filter""" + path_a_hits = [] + with ContextThreadPoolExecutor() as executor: + futures = [ + executor.submit(search_single, vec, None) for vec in query_embedding[:max_num] + ] + for f in concurrent.futures.as_completed(futures): + path_a_hits.extend(f.result() or []) + return path_a_hits + + def search_path_b(): + """Path B: search with filter""" + if not search_filter: + return [] + path_b_hits = [] with ContextThreadPoolExecutor() as executor: futures = [ executor.submit(search_single, vec, search_filter) for vec in query_embedding[:max_num] ] for f in concurrent.futures.as_completed(futures): - all_hits.extend(f.result() or []) + path_b_hits.extend(f.result() or []) + return path_b_hits + + # Execute both paths concurrently + all_hits = [] + with ContextThreadPoolExecutor(max_workers=2) as executor: + path_a_future = executor.submit(search_path_a) + path_b_future = executor.submit(search_path_b) + + all_hits.extend(path_a_future.result()) + all_hits.extend(path_b_future.result()) if not all_hits: return [] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index c15792c0..df154f23 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -269,24 +269,38 @@ def _retrieve_from_long_term_and_user( ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] - if memory_type in ["All", "LongTermMemory"]: - results += self.graph_retriever.retrieve( - query=query, - parsed_goal=parsed_goal, - query_embedding=query_embedding, - top_k=top_k * 2, - memory_scope="LongTermMemory", - search_filter=search_filter, - ) - if memory_type in ["All", "UserMemory"]: - results += self.graph_retriever.retrieve( - query=query, - parsed_goal=parsed_goal, - query_embedding=query_embedding, - top_k=top_k * 2, - memory_scope="UserMemory", - search_filter=search_filter, - ) + tasks = [] + + with ContextThreadPoolExecutor(max_workers=2) as executor: + if memory_type in ["All", "LongTermMemory"]: + tasks.append( + executor.submit( + self.graph_retriever.retrieve, + query=query, + parsed_goal=parsed_goal, + query_embedding=query_embedding, + top_k=top_k * 2, + memory_scope="LongTermMemory", + search_filter=search_filter, + ) + ) + if memory_type in ["All", "UserMemory"]: + tasks.append( + executor.submit( + self.graph_retriever.retrieve, + query=query, + parsed_goal=parsed_goal, + query_embedding=query_embedding, + top_k=top_k * 2, + memory_scope="UserMemory", + search_filter=search_filter, + ) + ) + + # Collect results from all tasks + for task in tasks: + results.extend(task.result()) + return self.reranker.rerank( query=query, query_embedding=query_embedding[0],