Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions src/memos/memories/textual/tree_text_memory/retrieve/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,24 +214,39 @@ def search_single(vec, filt=None):
or []
)

all_hits = []
# Path A: without filter
with ContextThreadPoolExecutor() as executor:
futures = [
executor.submit(search_single, vec, None) for vec in query_embedding[:max_num]
]
for f in concurrent.futures.as_completed(futures):
all_hits.extend(f.result() or [])

# Path B: with filter
if search_filter:
def search_path_a():
"""Path A: search without filter"""
path_a_hits = []
with ContextThreadPoolExecutor() as executor:
futures = [
executor.submit(search_single, vec, None) for vec in query_embedding[:max_num]
]
for f in concurrent.futures.as_completed(futures):
path_a_hits.extend(f.result() or [])
return path_a_hits

def search_path_b():
"""Path B: search with filter"""
if not search_filter:
return []
path_b_hits = []
with ContextThreadPoolExecutor() as executor:
futures = [
executor.submit(search_single, vec, search_filter)
for vec in query_embedding[:max_num]
]
for f in concurrent.futures.as_completed(futures):
all_hits.extend(f.result() or [])
path_b_hits.extend(f.result() or [])
return path_b_hits

# Execute both paths concurrently
all_hits = []
with ContextThreadPoolExecutor(max_workers=2) as executor:
path_a_future = executor.submit(search_path_a)
path_b_future = executor.submit(search_path_b)

all_hits.extend(path_a_future.result())
all_hits.extend(path_b_future.result())

if not all_hits:
return []
Expand Down
50 changes: 32 additions & 18 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,24 +269,38 @@ def _retrieve_from_long_term_and_user(
):
"""Retrieve and rerank from LongTermMemory and UserMemory"""
results = []
if memory_type in ["All", "LongTermMemory"]:
results += self.graph_retriever.retrieve(
query=query,
parsed_goal=parsed_goal,
query_embedding=query_embedding,
top_k=top_k * 2,
memory_scope="LongTermMemory",
search_filter=search_filter,
)
if memory_type in ["All", "UserMemory"]:
results += self.graph_retriever.retrieve(
query=query,
parsed_goal=parsed_goal,
query_embedding=query_embedding,
top_k=top_k * 2,
memory_scope="UserMemory",
search_filter=search_filter,
)
tasks = []

with ContextThreadPoolExecutor(max_workers=2) as executor:
if memory_type in ["All", "LongTermMemory"]:
tasks.append(
executor.submit(
self.graph_retriever.retrieve,
query=query,
parsed_goal=parsed_goal,
query_embedding=query_embedding,
top_k=top_k * 2,
memory_scope="LongTermMemory",
search_filter=search_filter,
)
)
if memory_type in ["All", "UserMemory"]:
tasks.append(
executor.submit(
self.graph_retriever.retrieve,
query=query,
parsed_goal=parsed_goal,
query_embedding=query_embedding,
top_k=top_k * 2,
memory_scope="UserMemory",
search_filter=search_filter,
)
)

# Collect results from all tasks
for task in tasks:
results.extend(task.result())

return self.reranker.rerank(
query=query,
query_embedding=query_embedding[0],
Expand Down
Loading