diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py index d67b6715..e18aa9d9 100644 --- a/examples/mem_scheduler/memos_w_scheduler.py +++ b/examples/mem_scheduler/memos_w_scheduler.py @@ -91,6 +91,12 @@ def run_with_scheduler_init(): mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri + mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user + mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password + mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name + mem_cube_config.text_mem.config.graph_db.config.auto_create = ( + auth_config.graph_db.auto_create + ) # Initialization mos = MOS(mos_config) @@ -118,8 +124,7 @@ def run_with_scheduler_init(): query = item["question"] print(f"Query:\n {query}\n") response = mos.chat(query=query, user_id=user_id) - print(f"Answer:\n {response}") - print("===== Chat End =====") + print(f"Answer:\n {response}\n") show_web_logs(mem_scheduler=mos.mem_scheduler) diff --git a/examples/mem_scheduler/memos_w_scheduler_for_test.py b/examples/mem_scheduler/memos_w_scheduler_for_test.py index 2aae09ec..87646013 100644 --- a/examples/mem_scheduler/memos_w_scheduler_for_test.py +++ b/examples/mem_scheduler/memos_w_scheduler_for_test.py @@ -103,6 +103,12 @@ def init_task(): mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri + mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user + mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password + mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name + mem_cube_config.text_mem.config.graph_db.config.auto_create = ( + auth_config.graph_db.auto_create + ) # Initialization mos = MOSForTestScheduler(mos_config) diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py index 976e506b..29b166bf 100644 --- a/examples/mem_scheduler/try_schedule_modules.py +++ b/examples/mem_scheduler/try_schedule_modules.py @@ -151,6 +151,12 @@ def show_web_logs(mem_scheduler: GeneralScheduler): mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri + mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user + mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password + mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name + mem_cube_config.text_mem.config.graph_db.config.auto_create = ( + auth_config.graph_db.auto_create + ) # Initialization mos = MOSForTestScheduler(mos_config) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 295090f9..be74d617 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -84,6 +84,8 @@ def __init__(self, config: BaseSchedulerConfig): # other attributes self._context_lock = threading.Lock() self.current_user_id: UserID | str | None = None + self.current_mem_cube_id: MemCubeID | str | None = None + self.current_mem_cube: GeneralMemCube | None = None self.auth_config_path: str | Path | None = self.config.get("auth_config_path", None) self.auth_config = None self.rabbitmq_config = None @@ -130,8 +132,8 @@ def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None: self.current_mem_cube_id = msg.mem_cube_id self.current_mem_cube = msg.mem_cube - def transform_memories_to_monitors( - self, memories: list[TextualMemoryItem] + def transform_working_memories_to_monitors( + self, query_keywords, memories: list[TextualMemoryItem] ) -> list[MemoryMonitorItem]: """ Convert a list of TextualMemoryItem objects into MemoryMonitorItem objects @@ -143,10 +145,6 @@ def transform_memories_to_monitors( Returns: List of MemoryMonitorItem objects with computed importance scores. """ - query_keywords = self.monitor.query_monitors.get_keywords_collections() - logger.debug( - f"Processing {len(memories)} memories with {len(query_keywords)} query keywords" - ) result = [] mem_length = len(memories) @@ -195,7 +193,8 @@ def replace_working_memory( text_mem_base: TreeTextMemory = text_mem_base # process rerank memories with llm - query_history = self.monitor.query_monitors.get_queries_with_timesort() + query_monitor = self.monitor.query_monitors[user_id][mem_cube_id] + query_history = query_monitor.get_queries_with_timesort() memories_with_new_order, rerank_success_flag = ( self.retriever.process_and_rerank_memories( queries=query_history, @@ -206,8 +205,13 @@ def replace_working_memory( ) # update working memory monitors - new_working_memory_monitors = self.transform_memories_to_monitors( - memories=memories_with_new_order + query_keywords = query_monitor.get_keywords_collections() + logger.debug( + f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords" + ) + new_working_memory_monitors = self.transform_working_memories_to_monitors( + query_keywords=query_keywords, + memories=memories_with_new_order, ) if not rerank_success_flag: @@ -245,25 +249,6 @@ def replace_working_memory( return memories_with_new_order - def initialize_working_memory_monitors( - self, - user_id: UserID | str, - mem_cube_id: MemCubeID | str, - mem_cube: GeneralMemCube, - ): - text_mem_base: TreeTextMemory = mem_cube.text_mem - working_memories = text_mem_base.get_working_memory() - - working_memory_monitors = self.transform_memories_to_monitors( - memories=working_memories, - ) - self.monitor.update_working_memory_monitors( - new_working_memory_monitors=working_memory_monitors, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - def update_activation_memory( self, new_memories: list[str | TextualMemoryItem], @@ -367,13 +352,9 @@ def update_activation_memory_periodically( or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].memories) == 0 ): logger.warning( - "No memories found in working_memory_monitors, initializing from current working_memories" - ) - self.initialize_working_memory_monitors( - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, + "No memories found in working_memory_monitors, activation memory update is skipped" ) + return self.monitor.update_activation_memory_monitors( user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index d0b2d61e..390a1445 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -9,6 +9,7 @@ ANSWER_LABEL, DEFAULT_MAX_QUERY_KEY_WORDS, QUERY_LABEL, + WORKING_MEMORY_TYPE, MemCubeID, UserID, ) @@ -35,11 +36,12 @@ def __init__(self, config: GeneralSchedulerConfig): # for evaluation def search_for_eval( - self, - query: str, - user_id: UserID | str, - top_k: int, - ) -> list[str]: + self, query: str, user_id: UserID | str, top_k: int, scheduler_flag: bool = True + ) -> (list[str], bool): + self.monitor.register_query_monitor_if_not_exists( + user_id=user_id, mem_cube_id=self.current_mem_cube_id + ) + query_keywords = self.monitor.extract_query_keywords(query=query) logger.info(f'Extract keywords "{query_keywords}" from query "{query}"') @@ -48,35 +50,61 @@ def search_for_eval( keywords=query_keywords, max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS, ) - self.monitor.query_monitors.put(item=item) - logger.debug( - f"Queries in monitor are {self.monitor.query_monitors.get_queries_with_timesort()}." - ) + query_monitor = self.monitor.query_monitors[user_id][self.current_mem_cube_id] + query_monitor.put(item=item) + logger.debug(f"Queries in monitor are {query_monitor.get_queries_with_timesort()}.") queries = [query] # recall - cur_working_memory, new_candidates = self.process_session_turn( - queries=queries, - user_id=user_id, - mem_cube_id=self.current_mem_cube_id, - mem_cube=self.current_mem_cube, - top_k=self.top_k, - ) - logger.info(f"Processed {queries} and get {len(new_candidates)} new candidate memories.") - - # rerank - new_order_working_memory = self.replace_working_memory( - user_id=user_id, - mem_cube_id=self.current_mem_cube_id, - mem_cube=self.current_mem_cube, - original_memory=cur_working_memory, - new_memory=new_candidates, + mem_cube = self.current_mem_cube + text_mem_base = mem_cube.text_mem + + cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory() + text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] + intent_result = self.monitor.detect_intent( + q_list=queries, text_working_memory=text_working_memory ) - new_order_working_memory = new_order_working_memory[:top_k] - logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}") - return [m.memory for m in new_order_working_memory] + if not scheduler_flag: + return text_working_memory, intent_result["trigger_retrieval"] + else: + if intent_result["trigger_retrieval"]: + missing_evidences = intent_result["missing_evidences"] + num_evidence = len(missing_evidences) + k_per_evidence = max(1, top_k // max(1, num_evidence)) + new_candidates = [] + for item in missing_evidences: + logger.info(f"missing_evidences: {item}") + results: list[TextualMemoryItem] = self.retriever.search( + query=item, + mem_cube=mem_cube, + top_k=k_per_evidence, + method=self.search_method, + ) + logger.info( + f"search results for {missing_evidences}: {[one.memory for one in results]}" + ) + new_candidates.extend(results) + print( + f"missing_evidences: {missing_evidences} and get {len(new_candidates)} new candidate memories." + ) + else: + new_candidates = [] + print(f"intent_result: {intent_result}. not triggered") + + # rerank + new_order_working_memory = self.replace_working_memory( + user_id=user_id, + mem_cube_id=self.current_mem_cube_id, + mem_cube=self.current_mem_cube, + original_memory=cur_working_memory, + new_memory=new_candidates, + ) + new_order_working_memory = new_order_working_memory[:top_k] + logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}") + + return [m.memory for m in new_order_working_memory], intent_result["trigger_retrieval"] def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: """ @@ -105,6 +133,10 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: # update query monitors for msg in messages: + self.monitor.register_query_monitor_if_not_exists( + user_id=user_id, mem_cube_id=mem_cube_id + ) + query = msg.content query_keywords = self.monitor.extract_query_keywords(query=query) logger.info(f'Extract keywords "{query_keywords}" from query "{query}"') @@ -114,9 +146,11 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: keywords=query_keywords, max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS, ) - self.monitor.query_monitors.put(item=item) + + self.monitor.query_monitors[user_id][mem_cube_id].put(item=item) logger.debug( - f"Queries in monitor are {self.monitor.query_monitors.get_queries_with_timesort()}." + f"Queries in monitor are " + f"{self.monitor.query_monitors[user_id][mem_cube_id].get_queries_with_timesort()}." ) queries = [msg.content for msg in messages] @@ -215,6 +249,9 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: mem_type = mem_item.metadata.memory_type mem_content = mem_item.memory + if mem_type == WORKING_MEMORY_TYPE: + continue + self.log_adding_memory( memory=mem_content, memory_type=mem_type, @@ -289,18 +326,20 @@ def process_session_turn( new_candidates = [] for item in missing_evidences: logger.info(f"missing_evidences: {item}") + info = { + "user_id": user_id, + "session_id": "", + } + results: list[TextualMemoryItem] = self.retriever.search( - query=item, mem_cube=mem_cube, top_k=k_per_evidence, method=self.search_method + query=item, + mem_cube=mem_cube, + top_k=k_per_evidence, + method=self.search_method, + info=info, ) logger.info( f"search results for {missing_evidences}: {[one.memory for one in results]}" ) new_candidates.extend(results) - - if len(new_candidates) == 0: - logger.warning( - f"As new_candidates is empty, new_candidates is set same to working_memory.\n" - f"time_trigger_flag: {time_trigger_flag}; intent_result: {intent_result}" - ) - new_candidates = cur_working_memory return cur_working_memory, new_candidates diff --git a/src/memos/mem_scheduler/modules/base.py b/src/memos/mem_scheduler/modules/base.py index 37538b0f..392f2bde 100644 --- a/src/memos/mem_scheduler/modules/base.py +++ b/src/memos/mem_scheduler/modules/base.py @@ -17,8 +17,7 @@ def __init__(self): self._chat_llm = None self._process_llm = None - self.current_mem_cube_id: str | None = None - self.current_mem_cube: GeneralMemCube | None = None + self.mem_cubes: dict[str, GeneralMemCube] = {} def load_template(self, template_name: str) -> str: diff --git a/src/memos/mem_scheduler/modules/monitor.py b/src/memos/mem_scheduler/modules/monitor.py index 0e7d4c39..04b8620b 100644 --- a/src/memos/mem_scheduler/modules/monitor.py +++ b/src/memos/mem_scheduler/modules/monitor.py @@ -1,4 +1,5 @@ from datetime import datetime +from threading import Lock from typing import Any from memos.configs.mem_scheduler import BaseSchedulerConfig @@ -46,9 +47,7 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): # attributes # recording query_messages - self.query_monitors: QueryMonitorQueue[QueryMonitorItem] = QueryMonitorQueue( - maxsize=self.config.context_window_size - ) + self.query_monitors: dict[UserID, dict[MemCubeID, QueryMonitorQueue[QueryMonitorItem]]] = {} self.working_memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]] = {} self.activation_memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]] = {} @@ -57,6 +56,7 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): self.last_activation_mem_update_time = datetime.min self.last_query_consume_time = datetime.min + self._register_lock = Lock() self._process_llm = process_llm def extract_query_keywords(self, query: str) -> list: @@ -78,15 +78,34 @@ def extract_query_keywords(self, query: str) -> list: keywords = [query] return keywords + def register_query_monitor_if_not_exists( + self, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, + ) -> None: + # First check (lock-free, fast path) + if user_id in self.query_monitors and mem_cube_id in self.query_monitors[user_id]: + return + + # Second check (with lock, ensures uniqueness) + with self._register_lock: + if user_id not in self.query_monitors: + self.query_monitors[user_id] = {} + if mem_cube_id not in self.query_monitors[user_id]: + self.query_monitors[user_id][mem_cube_id] = QueryMonitorQueue( + maxsize=self.config.context_window_size + ) + def register_memory_manager_if_not_exists( self, - user_id: str, - mem_cube_id: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]], max_capacity: int, ) -> None: """ Register a new MemoryMonitorManager for the given user and memory cube if it doesn't exist. + Thread-safe implementation using double-checked locking pattern. Checks if a MemoryMonitorManager already exists for the specified user_id and mem_cube_id. If not, creates a new MemoryMonitorManager with appropriate capacity settings and registers it. @@ -94,14 +113,34 @@ def register_memory_manager_if_not_exists( Args: user_id: The ID of the user to associate with the memory manager mem_cube_id: The ID of the memory cube to monitor + memory_monitors: Dictionary storing existing memory monitor managers + max_capacity: Maximum capacity for the new memory monitor manager + lock: Threading lock to ensure safe concurrent access Note: This function will update the loose_max_working_memory_capacity based on the current WorkingMemory size plus partial retention number before creating a new manager. """ - # Check if a MemoryMonitorManager already exists for the current user_id and mem_cube_id - # If doesn't exist, create and register a new one - if (user_id not in memory_monitors) or (mem_cube_id not in memory_monitors[user_id]): + # First check (lock-free, fast path) + # Quickly verify existence without lock overhead + if user_id in memory_monitors and mem_cube_id in memory_monitors[user_id]: + logger.info( + f"MemoryMonitorManager already exists for user_id={user_id}, " + f"mem_cube_id={mem_cube_id} in the provided memory_monitors dictionary" + ) + return + + # Second check (with lock, ensures uniqueness) + # Acquire lock before modification and verify again to prevent race conditions + with self._register_lock: + # Re-check after acquiring lock, as another thread might have created it + if user_id in memory_monitors and mem_cube_id in memory_monitors[user_id]: + logger.info( + f"MemoryMonitorManager already exists for user_id={user_id}, " + f"mem_cube_id={mem_cube_id} in the provided memory_monitors dictionary" + ) + return + # Initialize MemoryMonitorManager with user ID, memory cube ID, and max capacity monitor_manager = MemoryMonitorManager( user_id=user_id, mem_cube_id=mem_cube_id, max_capacity=max_capacity @@ -113,11 +152,6 @@ def register_memory_manager_if_not_exists( f"Registered new MemoryMonitorManager for user_id={user_id}," f" mem_cube_id={mem_cube_id} with max_capacity={max_capacity}" ) - else: - logger.info( - f"MemoryMonitorManager already exists for user_id={user_id}, " - f"mem_cube_id={mem_cube_id} in the provided memory_monitors dictionary" - ) def update_working_memory_monitors( self, diff --git a/src/memos/mem_scheduler/modules/retriever.py b/src/memos/mem_scheduler/modules/retriever.py index 2cf3bd04..595e8e85 100644 --- a/src/memos/mem_scheduler/modules/retriever.py +++ b/src/memos/mem_scheduler/modules/retriever.py @@ -4,6 +4,7 @@ from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.modules.base import BaseSchedulerModule from memos.mem_scheduler.schemas.general_schemas import ( + TreeTextMemory_FINE_SEARCH_METHOD, TreeTextMemory_SEARCH_METHOD, ) from memos.mem_scheduler.utils.filter_utils import ( @@ -32,7 +33,12 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): self.process_llm = process_llm def search( - self, query: str, mem_cube: GeneralMemCube, top_k: int, method=TreeTextMemory_SEARCH_METHOD + self, + query: str, + mem_cube: GeneralMemCube, + top_k: int, + method: str = TreeTextMemory_SEARCH_METHOD, + info: dict | None = None, ) -> list[TextualMemoryItem]: """Search in text memory with the given query. @@ -46,13 +52,21 @@ def search( """ text_mem_base = mem_cube.text_mem try: - if method == TreeTextMemory_SEARCH_METHOD: + if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]: assert isinstance(text_mem_base, TreeTextMemory) + if info is None: + logger.warning( + "Please input 'info' when use tree.search so that " + "the database would store the consume history." + ) + info = {"user_id": "", "session_id": ""} + + mode = "fast" if method == TreeTextMemory_SEARCH_METHOD else "fine" results_long_term = text_mem_base.search( - query=query, top_k=top_k, memory_type="LongTermMemory" + query=query, top_k=top_k, memory_type="LongTermMemory", mode=mode, info=info ) results_user = text_mem_base.search( - query=query, top_k=top_k, memory_type="UserMemory" + query=query, top_k=top_k, memory_type="UserMemory", mode=mode, info=info ) results = results_long_term + results_user else: diff --git a/src/memos/mem_scheduler/mos_for_test_scheduler.py b/src/memos/mem_scheduler/mos_for_test_scheduler.py index d796f824..f275da2b 100644 --- a/src/memos/mem_scheduler/mos_for_test_scheduler.py +++ b/src/memos/mem_scheduler/mos_for_test_scheduler.py @@ -81,7 +81,13 @@ def chat(self, query: str, user_id: str | None = None) -> str: # from mem_cube memories = mem_cube.text_mem.search( - query, top_k=self.config.top_k - topk_for_scheduler + query, + top_k=self.config.top_k - topk_for_scheduler, + info={ + "user_id": target_user_id, + "session_id": self.session_id, + "chat_history": chat_history.chat_history, + }, ) text_memories = [m.memory for m in memories] print(f"Search results with new working memories: {text_memories}") diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index a5dcc34b..070026f0 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -10,6 +10,7 @@ ADD_LABEL = "add" TreeTextMemory_SEARCH_METHOD = "tree_text_memory_search" +TreeTextMemory_FINE_SEARCH_METHOD = "tree_text_memory_fine_search" TextMemory_SEARCH_METHOD = "text_memory_search" DIRECT_EXCHANGE_TYPE = "direct" FANOUT_EXCHANGE_TYPE = "fanout" diff --git a/src/memos/mem_scheduler/schemas/monitor_schemas.py b/src/memos/mem_scheduler/schemas/monitor_schemas.py index 68d53f55..a25015cc 100644 --- a/src/memos/mem_scheduler/schemas/monitor_schemas.py +++ b/src/memos/mem_scheduler/schemas/monitor_schemas.py @@ -1,3 +1,5 @@ +import threading + from collections import Counter from datetime import datetime from pathlib import Path @@ -76,7 +78,7 @@ class QueryMonitorQueue(AutoDroppingQueue[QueryMonitorItem]): Each item is expected to be a dictionary containing: """ - def put(self, item: QueryMonitorItem, block: bool = True, timeout: float | None = None) -> None: + def put(self, item: QueryMonitorItem, block: bool = True, timeout: float | None = 5.0) -> None: """ Add a query item to the queue. Ensures the item is of correct type. @@ -85,6 +87,9 @@ def put(self, item: QueryMonitorItem, block: bool = True, timeout: float | None """ if not isinstance(item, QueryMonitorItem): raise ValueError("Item must be an instance of QueryMonitorItem") + logger.debug( + f"Thread {threading.get_ident()} acquired mutex. Timeout is set to {timeout} seconds" + ) super().put(item, block, timeout) def get_queries_by_timestamp( @@ -94,6 +99,7 @@ def get_queries_by_timestamp( Retrieve queries added between the specified time range. """ with self.mutex: + logger.debug(f"Thread {threading.get_ident()} acquired mutex.") return [item for item in self.queue if start_time <= item.timestamp <= end_time] def get_keywords_collections(self) -> Counter: @@ -104,6 +110,7 @@ def get_keywords_collections(self) -> Counter: Counter object with keyword counts """ with self.mutex: + logger.debug(f"Thread {threading.get_ident()} acquired mutex.") all_keywords = [kw for item in self.queue for kw in item.keywords] return Counter(all_keywords) @@ -119,6 +126,7 @@ def get_queries_with_timesort(self, reverse: bool = True) -> list[str]: List of query items sorted by timestamp """ with self.mutex: + logger.debug(f"Thread {threading.get_ident()} acquired mutex.") return [ monitor.query_text for monitor in sorted(self.queue, key=lambda x: x.timestamp, reverse=reverse) diff --git a/src/memos/memories/textual/base.py b/src/memos/memories/textual/base.py index 31e82c3c..8171fadc 100644 --- a/src/memos/memories/textual/base.py +++ b/src/memos/memories/textual/base.py @@ -36,7 +36,7 @@ def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) """Update a memory by memory_id.""" @abstractmethod - def search(self, query: str, top_k: int, info=None) -> list[TextualMemoryItem]: + def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: """Search for memories based on a query. Args: query (str): The query to search for. diff --git a/src/memos/memories/textual/general.py b/src/memos/memories/textual/general.py index 4a1d90cb..9793224b 100644 --- a/src/memos/memories/textual/general.py +++ b/src/memos/memories/textual/general.py @@ -114,7 +114,7 @@ def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) self.vector_db.update(memory_id, vec_db_item) - def search(self, query: str, top_k: int) -> list[TextualMemoryItem]: + def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: """Search for memories based on a query. Args: query (str): The query to search for. diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py index c2688879..a15021d9 100644 --- a/src/memos/templates/mem_scheduler_prompts.py +++ b/src/memos/templates/mem_scheduler_prompts.py @@ -17,24 +17,29 @@ 4. Personalization (tailored to user's context) ## Decision Framework -1. Mark as satisfied ONLY if: +1. We have enough information (satisfied) ONLY when: - All question aspects are addressed - Supporting evidence exists in working memory - - No apparent gaps in information + - There's no obvious information missing -2. Mark as unsatisfied if: +2. We need more information (unsatisfied) if: - Any question aspect remains unanswered - Evidence is generic/non-specific - Personal context is missing ## Output Specification Return JSON with: -- "trigger_retrieval": Boolean (true if more evidence needed) -- "missing_evidences": List of specific evidence types required +- "trigger_retrieval": true/false (true if we need more information) +- "evidences": List of information from our working memory that helps answer the questions +- "missing_evidences": List of specific types of information we need to answer the questions ## Response Format {{ "trigger_retrieval": , + "evidences": [ + "", + "" + ], "missing_evidences": [ "", "" @@ -107,7 +112,7 @@ Output: {{ "new_order": [2, 1, 0], - "reasoning": "Threading (2) prioritized for matching newest query, followed by matplotlib (1) for older visualization query", + "reasoning": "Threading (2) prioritized for matching newest query, followed by matplotlib (1) for older visualization query" }} ## Current Task diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py index b35b7b17..88f3a8f0 100644 --- a/tests/mem_scheduler/test_scheduler.py +++ b/tests/mem_scheduler/test_scheduler.py @@ -14,7 +14,6 @@ from memos.mem_scheduler.schemas.general_schemas import ( ANSWER_LABEL, QUERY_LABEL, - TreeTextMemory_SEARCH_METHOD, ) from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, @@ -126,32 +125,3 @@ def test_submit_web_logs(self): self.assertTrue(isinstance(actual_message.item_id, str)) self.assertTrue(hasattr(actual_message, "timestamp")) self.assertTrue(isinstance(actual_message.timestamp, datetime)) - - def test_search_with_empty_results(self): - """Test search method with empty results.""" - # Setup mock memory cube and text memory - mock_mem_cube = MagicMock() - mock_mem_cube.text_mem = self.tree_text_memory - - # Setup mock search results for both memory types - self.tree_text_memory.search.side_effect = [ - [], # results_long_term - [], # results_user - ] - - # Test search - results = self.scheduler.retriever.search( - query="Test query", mem_cube=mock_mem_cube, top_k=5, method=TreeTextMemory_SEARCH_METHOD - ) - - # Verify results - self.assertEqual(results, []) - - # Verify search was called twice (for LongTermMemory and UserMemory) - self.assertEqual(self.tree_text_memory.search.call_count, 2) - self.tree_text_memory.search.assert_any_call( - query="Test query", top_k=5, memory_type="LongTermMemory" - ) - self.tree_text_memory.search.assert_any_call( - query="Test query", top_k=5, memory_type="UserMemory" - )