From 1e4ab96c43323bb129c0c135f6c89ee9e6e8dd26 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 23 Jul 2025 19:51:40 +0800 Subject: [PATCH 1/2] fix bugs: fix bugs in the example of memos_w_scheduler and fix bugs of log submittion in MOS --- .../memos_config_w_scheduler_and_openai.yaml | 1 + examples/mem_scheduler/memos_w_scheduler.py | 112 +++--------------- src/memos/mem_os/core.py | 3 +- src/memos/mem_scheduler/base_scheduler.py | 42 +++---- src/memos/mem_scheduler/general_scheduler.py | 65 +++++++++- src/memos/mem_scheduler/modules/base.py | 8 +- src/memos/mem_scheduler/modules/retriever.py | 27 +++-- .../mem_scheduler/modules/scheduler_logger.py | 6 + .../mem_scheduler/schemas/monitor_schemas.py | 2 +- tests/mem_scheduler/test_scheduler.py | 4 +- 10 files changed, 129 insertions(+), 141 deletions(-) diff --git a/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml b/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml index 1da4dad1..b329e0fc 100644 --- a/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml +++ b/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml @@ -41,6 +41,7 @@ mem_scheduler: thread_pool_max_workers: 10 consume_interval_seconds: 1 enable_parallel_dispatch: true + enable_act_memory_update: false max_turns_window: 20 top_k: 5 enable_textual_memory: true diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py index c00845d5..d67b6715 100644 --- a/examples/mem_scheduler/memos_w_scheduler.py +++ b/examples/mem_scheduler/memos_w_scheduler.py @@ -1,25 +1,17 @@ import shutil import sys -from datetime import datetime from pathlib import Path from queue import Queue from typing import TYPE_CHECKING from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig -from memos.configs.mem_scheduler import AuthConfig, SchedulerConfigFactory +from memos.configs.mem_scheduler import AuthConfig from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_os.main import MOS from memos.mem_scheduler.general_scheduler import GeneralScheduler -from memos.mem_scheduler.scheduler_factory import SchedulerFactory -from memos.mem_scheduler.schemas.general_schemas import ( - ANSWER_LABEL, - QUERY_LABEL, -) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.utils.misc_utils import parse_yaml if TYPE_CHECKING: @@ -78,122 +70,56 @@ def init_task(): return conversations, questions -def run_with_automatic_scheduler_init(): +def run_with_scheduler_init(): print("==== run_with_automatic_scheduler_init ====") conversations, questions = init_task() - config = parse_yaml( - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" + # set configs + mos_config = MOSConfig.from_yaml_file( + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml" ) - mos_config = MOSConfig(**config) - mos = MOS(mos_config) - - user_id = "user_1" - mos.create_user(user_id) - - config = GeneralMemCubeConfig.from_yaml_file( + mem_cube_config = GeneralMemCubeConfig.from_yaml_file( f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" ) - mem_cube_id = "mem_cube_5" - mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" - if Path(mem_cube_name_or_path).exists(): - shutil.rmtree(mem_cube_name_or_path) - print(f"{mem_cube_name_or_path} is not empty, and has been removed.") # default local graphdb uri if AuthConfig.default_config_exists(): auth_config = AuthConfig.from_local_yaml() - config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri - mem_cube = GeneralMemCube(config) - mem_cube.dump(mem_cube_name_or_path) - mos.register_mem_cube( - mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id - ) - mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) + mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key + mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url - for item in questions: - query = item["question"] - response = mos.chat(query, user_id=user_id) - print(f"Query:\n {query}\n\nAnswer:\n {response}") + mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri - show_web_logs(mem_scheduler=mos.mem_scheduler) - - mos.mem_scheduler.stop() - - -def run_with_manual_scheduler_init(): - print("==== run_with_manual_scheduler_init ====") - conversations, questions = init_task() - - config = parse_yaml( - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_wo_scheduler.yaml" - ) - - mos_config = MOSConfig(**config) + # Initialization mos = MOS(mos_config) user_id = "user_1" mos.create_user(user_id) - config = GeneralMemCubeConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml" - ) mem_cube_id = "mem_cube_5" mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" + if Path(mem_cube_name_or_path).exists(): shutil.rmtree(mem_cube_name_or_path) print(f"{mem_cube_name_or_path} is not empty, and has been removed.") - # default local graphdb uri - if AuthConfig.default_config_exists(): - auth_config = AuthConfig.from_local_yaml() - config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri - - mem_cube = GeneralMemCube(config) + mem_cube = GeneralMemCube(mem_cube_config) mem_cube.dump(mem_cube_name_or_path) mos.register_mem_cube( mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id ) - example_scheduler_config_path = ( - f"{BASE_DIR}/examples/data/config/mem_scheduler/general_scheduler_config.yaml" - ) - scheduler_config = SchedulerConfigFactory.from_yaml_file( - yaml_path=example_scheduler_config_path - ) - mem_scheduler = SchedulerFactory.from_config(scheduler_config) - mem_scheduler.initialize_modules(chat_llm=mos.chat_llm) - - mos.mem_scheduler = mem_scheduler - - mos.mem_scheduler.start() - mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) for item in questions: + print("===== Chat Start =====") query = item["question"] - message_item = ScheduleMessageItem( - user_id=user_id, - mem_cube_id=mem_cube_id, - label=QUERY_LABEL, - mem_cube=mos.mem_cubes[mem_cube_id], - content=query, - timestamp=datetime.now(), - ) - mos.mem_scheduler.submit_messages(messages=message_item) - response = mos.chat(query, user_id=user_id) - message_item = ScheduleMessageItem( - user_id=user_id, - mem_cube_id=mem_cube_id, - label=ANSWER_LABEL, - mem_cube=mos.mem_cubes[mem_cube_id], - content=response, - timestamp=datetime.now(), - ) - mos.mem_scheduler.submit_messages(messages=message_item) - print(f"Query:\n {query}\n\nAnswer:\n {response}") + print(f"Query:\n {query}\n") + response = mos.chat(query=query, user_id=user_id) + print(f"Answer:\n {response}") + print("===== Chat End =====") show_web_logs(mem_scheduler=mos.mem_scheduler) @@ -236,6 +162,4 @@ def show_web_logs(mem_scheduler: GeneralScheduler): if __name__ == "__main__": - run_with_automatic_scheduler_init() - - run_with_manual_scheduler_init() + run_with_scheduler_init() diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index d419ee70..14b85d5e 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -17,6 +17,7 @@ from memos.mem_scheduler.schemas.general_schemas import ( ADD_LABEL, ANSWER_LABEL, + QUERY_LABEL, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_user.user_manager import UserManager, UserRole @@ -267,7 +268,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = user_id=target_user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, - label=ADD_LABEL, + label=QUERY_LABEL, content=query, timestamp=datetime.now(), ) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index c37b9155..295090f9 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -20,7 +20,9 @@ DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_CONSUME_INTERVAL_SECONDS, DEFAULT_THREAD__POOL_MAX_WORKERS, + MemCubeID, TreeTextMemory_SEARCH_METHOD, + UserID, ) from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, @@ -81,7 +83,7 @@ def __init__(self, config: BaseSchedulerConfig): # other attributes self._context_lock = threading.Lock() - self._current_user_id: str | None = None + self.current_user_id: UserID | str | None = None self.auth_config_path: str | Path | None = self.config.get("auth_config_path", None) self.auth_config = None self.rabbitmq_config = None @@ -113,20 +115,20 @@ def initialize_modules(self, chat_llm: BaseLLM, process_llm: BaseLLM | None = No @property def mem_cube(self) -> GeneralMemCube: """The memory cube associated with this MemChat.""" - return self._current_mem_cube + return self.current_mem_cube @mem_cube.setter def mem_cube(self, value: GeneralMemCube) -> None: """The memory cube associated with this MemChat.""" - self._current_mem_cube = value + self.current_mem_cube = value self.retriever.mem_cube = value def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None: """Update current user/cube context from the incoming message (thread-safe).""" with self._context_lock: - self._current_user_id = msg.user_id - self._current_mem_cube_id = msg.mem_cube_id - self._current_mem_cube = msg.mem_cube + self.current_user_id = msg.user_id + self.current_mem_cube_id = msg.mem_cube_id + self.current_mem_cube = msg.mem_cube def transform_memories_to_monitors( self, memories: list[TextualMemoryItem] @@ -181,9 +183,8 @@ def transform_memories_to_monitors( def replace_working_memory( self, - queries: list[str], - user_id: str, - mem_cube_id: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, mem_cube: GeneralMemCube, original_memory: list[TextualMemoryItem], new_memory: list[TextualMemoryItem], @@ -246,8 +247,8 @@ def replace_working_memory( def initialize_working_memory_monitors( self, - user_id: str, - mem_cube_id: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, mem_cube: GeneralMemCube, ): text_mem_base: TreeTextMemory = mem_cube.text_mem @@ -267,8 +268,8 @@ def update_activation_memory( self, new_memories: list[str | TextualMemoryItem], label: str, - user_id: str, - mem_cube_id: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, mem_cube: GeneralMemCube, ) -> None: """ @@ -344,16 +345,17 @@ def update_activation_memory_periodically( self, interval_seconds: int, label: str, - user_id: str, - mem_cube_id: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, mem_cube: GeneralMemCube, ): - new_activation_memories = [] - try: - if self.monitor.timed_trigger( - last_time=self.monitor.last_activation_mem_update_time, - interval_seconds=interval_seconds, + if ( + self.monitor.last_activation_mem_update_time == datetime.min + or self.monitor.timed_trigger( + last_time=self.monitor.last_activation_mem_update_time, + interval_seconds=interval_seconds, + ) ): logger.info( f"Updating activation memory for user {user_id} and mem_cube {mem_cube_id}" diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index d0b06733..d0b2d61e 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -9,6 +9,8 @@ ANSWER_LABEL, DEFAULT_MAX_QUERY_KEY_WORDS, QUERY_LABEL, + MemCubeID, + UserID, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem @@ -31,6 +33,51 @@ def __init__(self, config: GeneralSchedulerConfig): } self.dispatcher.register_handlers(handlers) + # for evaluation + def search_for_eval( + self, + query: str, + user_id: UserID | str, + top_k: int, + ) -> list[str]: + query_keywords = self.monitor.extract_query_keywords(query=query) + logger.info(f'Extract keywords "{query_keywords}" from query "{query}"') + + item = QueryMonitorItem( + query_text=query, + keywords=query_keywords, + max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS, + ) + self.monitor.query_monitors.put(item=item) + logger.debug( + f"Queries in monitor are {self.monitor.query_monitors.get_queries_with_timesort()}." + ) + + queries = [query] + + # recall + cur_working_memory, new_candidates = self.process_session_turn( + queries=queries, + user_id=user_id, + mem_cube_id=self.current_mem_cube_id, + mem_cube=self.current_mem_cube, + top_k=self.top_k, + ) + logger.info(f"Processed {queries} and get {len(new_candidates)} new candidate memories.") + + # rerank + new_order_working_memory = self.replace_working_memory( + user_id=user_id, + mem_cube_id=self.current_mem_cube_id, + mem_cube=self.current_mem_cube, + original_memory=cur_working_memory, + new_memory=new_candidates, + ) + new_order_working_memory = new_order_working_memory[:top_k] + logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}") + + return [m.memory for m in new_order_working_memory] + def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: """ Process and handle query trigger messages from the queue. @@ -88,7 +135,6 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: # rerank new_order_working_memory = self.replace_working_memory( - queries=queries, user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, @@ -119,7 +165,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: # for status update self._set_current_context_from_message(msg=messages[0]) - # update acivation memories + # update activation memories if self.enable_act_memory_update: if ( len(self.monitor.working_memory_monitors[user_id][mem_cube_id].memories) @@ -157,7 +203,12 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: # submit logs for msg in messages: - userinput_memory_ids = json.loads(msg.content) + try: + userinput_memory_ids = json.loads(msg.content) + except Exception as e: + logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True) + userinput_memory_ids = [] + mem_cube = msg.mem_cube for memory_id in userinput_memory_ids: mem_item: TextualMemoryItem = mem_cube.text_mem.get(memory_id=memory_id) @@ -188,8 +239,8 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: def process_session_turn( self, queries: str | list[str], - user_id: str, - mem_cube_id: str, + user_id: UserID | str, + mem_cube_id: MemCubeID | str, mem_cube: GeneralMemCube, top_k: int = 10, ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]] | None: @@ -241,7 +292,9 @@ def process_session_turn( results: list[TextualMemoryItem] = self.retriever.search( query=item, mem_cube=mem_cube, top_k=k_per_evidence, method=self.search_method ) - logger.info(f"search results for {missing_evidences}: {results}") + logger.info( + f"search results for {missing_evidences}: {[one.memory for one in results]}" + ) new_candidates.extend(results) if len(new_candidates) == 0: diff --git a/src/memos/mem_scheduler/modules/base.py b/src/memos/mem_scheduler/modules/base.py index 58c64587..37538b0f 100644 --- a/src/memos/mem_scheduler/modules/base.py +++ b/src/memos/mem_scheduler/modules/base.py @@ -17,8 +17,8 @@ def __init__(self): self._chat_llm = None self._process_llm = None - self._current_mem_cube_id: str | None = None - self._current_mem_cube: GeneralMemCube | None = None + self.current_mem_cube_id: str | None = None + self.current_mem_cube: GeneralMemCube | None = None self.mem_cubes: dict[str, GeneralMemCube] = {} def load_template(self, template_name: str) -> str: @@ -75,9 +75,9 @@ def process_llm(self, value: BaseLLM) -> None: @property def mem_cube(self) -> GeneralMemCube: """The memory cube associated with this MemChat.""" - return self._current_mem_cube + return self.current_mem_cube @mem_cube.setter def mem_cube(self, value: GeneralMemCube) -> None: """The memory cube associated with this MemChat.""" - self._current_mem_cube = value + self.current_mem_cube = value diff --git a/src/memos/mem_scheduler/modules/retriever.py b/src/memos/mem_scheduler/modules/retriever.py index 3430288d..2cf3bd04 100644 --- a/src/memos/mem_scheduler/modules/retriever.py +++ b/src/memos/mem_scheduler/modules/retriever.py @@ -83,21 +83,22 @@ def rerank_memories( If LLM reranking fails, falls back to original order (truncated to top_k) """ success_flag = False - try: - logger.info(f"Starting memory reranking for {len(original_memories)} memories") - # Build LLM prompt for memory reranking - prompt = self.build_prompt( - "memory_reranking", - queries=[f"[0] {queries[0]}"], - current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)], - ) - logger.debug(f"Generated reranking prompt: {prompt[:200]}...") # Log first 200 chars + logger.info(f"Starting memory reranking for {len(original_memories)} memories") - # Get LLM response - response = self.process_llm.generate([{"role": "user", "content": prompt}]) - logger.debug(f"Received LLM response: {response[:200]}...") # Log first 200 chars + # Build LLM prompt for memory reranking + prompt = self.build_prompt( + "memory_reranking", + queries=[f"[0] {queries[0]}"], + current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)], + ) + logger.debug(f"Generated reranking prompt: {prompt[:200]}...") # Log first 200 chars + # Get LLM response + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + logger.debug(f"Received LLM response: {response[:200]}...") # Log first 200 chars + + try: # Parse JSON response response = extract_json_dict(response) new_order = response["new_order"][:top_k] @@ -109,7 +110,7 @@ def rerank_memories( success_flag = True except Exception as e: logger.error( - f"Failed to rerank memories with LLM;\nException: {e}. ", + f"Failed to rerank memories with LLM. Exception: {e}. Raw response: {response} ", exc_info=True, ) text_memories_with_new_order = original_memories[:top_k] diff --git a/src/memos/mem_scheduler/modules/scheduler_logger.py b/src/memos/mem_scheduler/modules/scheduler_logger.py index d0eecac9..3b2986e9 100644 --- a/src/memos/mem_scheduler/modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/modules/scheduler_logger.py @@ -68,8 +68,14 @@ def create_autofilled_log_item( ): activation_monitor = self.monitor.activation_memory_monitors[user_id][mem_cube_id] transformed_act_memory_size = len(activation_monitor.memories) + logger.info( + f'activation_memory_monitors currently has "{transformed_act_memory_size}" transformed memory size' + ) else: transformed_act_memory_size = 0 + logger.info( + f'activation_memory_monitors is not initialized for user "{user_id}" and mem_cube "{mem_cube_id}' + ) current_memory_sizes["transformed_act_memory_size"] = transformed_act_memory_size current_memory_sizes["parameter_memory_size"] = 1 diff --git a/src/memos/mem_scheduler/schemas/monitor_schemas.py b/src/memos/mem_scheduler/schemas/monitor_schemas.py index 6cba48fa..68d53f55 100644 --- a/src/memos/mem_scheduler/schemas/monitor_schemas.py +++ b/src/memos/mem_scheduler/schemas/monitor_schemas.py @@ -107,7 +107,7 @@ def get_keywords_collections(self) -> Counter: all_keywords = [kw for item in self.queue for kw in item.keywords] return Counter(all_keywords) - def get_queries_with_timesort(self, reverse: bool = True) -> list[dict]: + def get_queries_with_timesort(self, reverse: bool = True) -> list[str]: """ Retrieve all queries sorted by timestamp. diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py index 30ccc934..b35b7b17 100644 --- a/tests/mem_scheduler/test_scheduler.py +++ b/tests/mem_scheduler/test_scheduler.py @@ -49,8 +49,8 @@ def setUp(self): self.scheduler.mem_cube = self.mem_cube # Set current user and memory cube ID for testing - self.scheduler._current_user_id = "test_user" - self.scheduler._current_mem_cube_id = "test_cube" + self.scheduler.current_user_id = "test_user" + self.scheduler.current_mem_cube_id = "test_cube" def test_initialization(self): """Test that scheduler initializes with correct default values and handlers.""" From 0b74801bdfc3db4682d625f943756dc46d7742c8 Mon Sep 17 00:00:00 2001 From: chentang Date: Wed, 23 Jul 2025 20:23:10 +0800 Subject: [PATCH 2/2] fix bugs: modify mos product, and add more exception catch code --- src/memos/mem_os/product.py | 7 ++-- .../mem_scheduler/modules/scheduler_logger.py | 13 +++++-- src/memos/mem_scheduler/utils/misc_utils.py | 39 ++++++++++++++++++- 3 files changed, 50 insertions(+), 9 deletions(-) diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 6a45e1ee..07e765b7 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -704,7 +704,9 @@ def chat_with_references( """ self._load_user_cubes(user_id, self.default_cube_config) - + self._send_message_to_scheduler( + user_id=user_id, mem_cube_id=cube_id, query=query, label=QUERY_LABEL + ) time_start = time.time() memories_list = [] memories_result = super().search( @@ -808,9 +810,6 @@ def chat_with_references( yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" total_time = round(float(time_end - time_start), 1) yield f"data: {json.dumps({'type': 'time', 'data': {'total_time': total_time, 'speed_improvement': '23%'}})}\n\n" - self._send_message_to_scheduler( - user_id=user_id, mem_cube_id=cube_id, query=query, label=QUERY_LABEL - ) self._send_message_to_scheduler( user_id=user_id, mem_cube_id=cube_id, query=full_response, label=ANSWER_LABEL ) diff --git a/src/memos/mem_scheduler/modules/scheduler_logger.py b/src/memos/mem_scheduler/modules/scheduler_logger.py index 3b2986e9..e41b4822 100644 --- a/src/memos/mem_scheduler/modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/modules/scheduler_logger.py @@ -21,6 +21,7 @@ from memos.mem_scheduler.utils.filter_utils import ( transform_name_to_key, ) +from memos.mem_scheduler.utils.misc_utils import log_exceptions from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory @@ -34,6 +35,7 @@ def __init__(self): """ super().__init__() + @log_exceptions(logger=logger) def create_autofilled_log_item( self, log_content: str, @@ -47,9 +49,9 @@ def create_autofilled_log_item( text_mem_base: TreeTextMemory = mem_cube.text_mem current_memory_sizes = text_mem_base.get_current_memory_size() current_memory_sizes = { - "long_term_memory_size": current_memory_sizes["LongTermMemory"], - "user_memory_size": current_memory_sizes["UserMemory"], - "working_memory_size": current_memory_sizes["WorkingMemory"], + "long_term_memory_size": current_memory_sizes.get("LongTermMemory", 0), + "user_memory_size": current_memory_sizes.get("UserMemory", 0), + "working_memory_size": current_memory_sizes.get("WorkingMemory", 0), "transformed_act_memory_size": NOT_INITIALIZED, "parameter_memory_size": NOT_INITIALIZED, } @@ -96,6 +98,7 @@ def create_autofilled_log_item( ) return log_message + @log_exceptions(logger=logger) def log_working_memory_replacement( self, original_memory: list[TextualMemoryItem], @@ -148,6 +151,7 @@ def log_working_memory_replacement( f"transformed to {WORKING_MEMORY_TYPE} memories." ) + @log_exceptions(logger=logger) def log_activation_memory_update( self, original_text_memories: list[str], @@ -191,6 +195,7 @@ def log_activation_memory_update( f"transformed to {WORKING_MEMORY_TYPE} memories." ) + @log_exceptions(logger=logger) def log_adding_memory( self, memory: str, @@ -216,6 +221,7 @@ def log_adding_memory( f"converted to {memory_type} memory in mem_cube {mem_cube_id}: {memory}" ) + @log_exceptions(logger=logger) def validate_schedule_message(self, message: ScheduleMessageItem, label: str): """Validate if the message matches the expected label. @@ -231,6 +237,7 @@ def validate_schedule_message(self, message: ScheduleMessageItem, label: str): return False return True + @log_exceptions(logger=logger) def validate_schedule_messages(self, messages: list[ScheduleMessageItem], label: str): """Validate if all messages match the expected label. diff --git a/src/memos/mem_scheduler/utils/misc_utils.py b/src/memos/mem_scheduler/utils/misc_utils.py index 92b56944..4a9c0246 100644 --- a/src/memos/mem_scheduler/utils/misc_utils.py +++ b/src/memos/mem_scheduler/utils/misc_utils.py @@ -1,9 +1,15 @@ import json +from functools import wraps from pathlib import Path import yaml +from memos.log import get_logger + + +logger = get_logger(__name__) + def extract_json_dict(text: str): text = text.strip() @@ -14,8 +20,7 @@ def extract_json_dict(text: str): return res -def parse_yaml(yaml_file): - yaml_path = Path(yaml_file) +def parse_yaml(yaml_file: str | Path): yaml_path = Path(yaml_file) if not yaml_path.is_file(): raise FileNotFoundError(f"No such file: {yaml_file}") @@ -24,3 +29,33 @@ def parse_yaml(yaml_file): data = yaml.safe_load(fr) return data + + +def log_exceptions(logger=logger): + """ + Exception-catching decorator that automatically logs errors (including stack traces) + + Args: + logger: Optional logger object (default: module-level logger) + + Example: + @log_exceptions() + def risky_function(): + raise ValueError("Oops!") + + @log_exceptions(logger=custom_logger) + def another_risky_function(): + might_fail() + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logger.error(f"Error in {func.__name__}: {e}", exc_info=True) + + return wrapper + + return decorator