diff --git a/examples/mem_scheduler/memos_w_scheduler_for_test.py b/examples/mem_scheduler/memos_w_scheduler_for_test.py index 87646013..28f3c3b3 100644 --- a/examples/mem_scheduler/memos_w_scheduler_for_test.py +++ b/examples/mem_scheduler/memos_w_scheduler_for_test.py @@ -142,7 +142,6 @@ def init_task(): query = item["question"] print(f"Query:\n {query}\n") response = mos.chat(query=query, user_id=user_id) - print(f"Answer:\n {response}") - print("===== Chat End =====") + print(f"Answer:\n {response}\n") mos.mem_scheduler.stop() diff --git a/examples/mem_scheduler/rabbitmq_example.py b/examples/mem_scheduler/rabbitmq_example.py index 1343e0ec..6b04111a 100644 --- a/examples/mem_scheduler/rabbitmq_example.py +++ b/examples/mem_scheduler/rabbitmq_example.py @@ -2,7 +2,7 @@ import time from memos.configs.mem_scheduler import AuthConfig -from memos.mem_scheduler.modules.rabbitmq_service import RabbitMQSchedulerModule +from memos.mem_scheduler.general_modules.rabbitmq_service import RabbitMQSchedulerModule def publish_message(rabbitmq_module, message): diff --git a/scripts/check_dependencies.py b/scripts/check_dependencies.py index f25d4255..9cb70402 100644 --- a/scripts/check_dependencies.py +++ b/scripts/check_dependencies.py @@ -11,7 +11,7 @@ def extract_top_level_modules(tree: ast.Module) -> set[str]: """ - Extract all top-level imported modules (excluding relative imports). + Extract all top-level imported general_modules (excluding relative imports). """ modules = set() for node in tree.body: @@ -27,12 +27,12 @@ def extract_top_level_modules(tree: ast.Module) -> set[str]: def check_importable(modules: set[str], filename: str) -> list[str]: """ Attempt to import each module in the current environment. - Return a list of modules that fail to import. + Return a list of general_modules that fail to import. """ failed = [] for mod in sorted(modules): if mod in EXCLUDE_MODULES: - # Skip excluded modules such as your own package + # Skip excluded general_modules such as your own package continue try: importlib.import_module(mod) @@ -70,7 +70,7 @@ def main(): if has_error: print( - "\n💥 Top-level imports failed. These modules may not be main dependencies." + "\n💥 Top-level imports failed. These general_modules may not be main dependencies." " Try moving the imports to a function or class scope, and decorate it with @require_python_package." ) sys.exit(1) diff --git a/src/memos/api/context/context.py b/src/memos/api/context/context.py index 557ec84d..8aee2cfe 100644 --- a/src/memos/api/context/context.py +++ b/src/memos/api/context/context.py @@ -122,7 +122,7 @@ def set_trace_id_getter(getter: TraceIdGetter) -> None: Set a custom trace_id getter function. This allows the logging system to retrieve trace_id without importing - API-specific modules. + API-specific general_modules. """ global _trace_id_getter _trace_id_getter = getter diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index 0c99a57f..c1166a03 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -6,7 +6,7 @@ from pydantic import ConfigDict, Field, field_validator, model_validator from memos.configs.base import BaseConfig -from memos.mem_scheduler.modules.misc import DictConversionMixin +from memos.mem_scheduler.general_modules.misc import DictConversionMixin from memos.mem_scheduler.schemas.general_schemas import ( BASE_DIR, DEFAULT_ACT_MEM_DUMP_PATH, diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 601d2035..f7d14f3a 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -124,7 +124,7 @@ def _initialize_mem_scheduler(self) -> GeneralScheduler: chat_llm=self.chat_llm, process_llm=self.chat_llm ) else: - # Configure scheduler modules + # Configure scheduler general_modules self._mem_scheduler.initialize_modules( chat_llm=self.chat_llm, process_llm=self.mem_reader.llm ) @@ -185,7 +185,7 @@ def _register_chat_history(self, user_id: str | None = None) -> None: self.chat_history_manager[user_id] = ChatHistory( user_id=user_id, session_id=self.session_id, - created_at=datetime.now(), + created_at=datetime.utcnow(), total_messages=0, chat_history=[], ) @@ -279,7 +279,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = mem_cube=mem_cube, label=QUERY_LABEL, content=query, - timestamp=datetime.now(), + timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) @@ -338,7 +338,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = mem_cube=mem_cube, label=ANSWER_LABEL, content=response, - timestamp=datetime.now(), + timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) @@ -681,7 +681,7 @@ def add( mem_cube=mem_cube, label=ADD_LABEL, content=json.dumps(mem_ids), - timestamp=datetime.now(), + timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) @@ -725,7 +725,7 @@ def add( mem_cube=mem_cube, label=ADD_LABEL, content=json.dumps(mem_ids), - timestamp=datetime.now(), + timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) @@ -756,7 +756,7 @@ def add( mem_cube=mem_cube, label=ADD_LABEL, content=json.dumps(mem_ids), - timestamp=datetime.now(), + timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index be74d617..1daf87b4 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -9,13 +9,14 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.modules.dispatcher import SchedulerDispatcher -from memos.mem_scheduler.modules.misc import AutoDroppingQueue as Queue -from memos.mem_scheduler.modules.monitor import SchedulerMonitor -from memos.mem_scheduler.modules.rabbitmq_service import RabbitMQSchedulerModule -from memos.mem_scheduler.modules.redis_service import RedisSchedulerModule -from memos.mem_scheduler.modules.retriever import SchedulerRetriever -from memos.mem_scheduler.modules.scheduler_logger import SchedulerLoggerModule +from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher +from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue +from memos.mem_scheduler.general_modules.rabbitmq_service import RabbitMQSchedulerModule +from memos.mem_scheduler.general_modules.redis_service import RedisSchedulerModule +from memos.mem_scheduler.general_modules.retriever import SchedulerRetriever +from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule +from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor +from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_CONSUME_INTERVAL_SECONDS, @@ -56,15 +57,16 @@ def __init__(self, config: BaseSchedulerConfig): self.act_mem_dump_path = self.config.get("act_mem_dump_path", DEFAULT_ACT_MEM_DUMP_PATH) self.search_method = TreeTextMemory_SEARCH_METHOD self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", False) - self.max_workers = self.config.get( + self.thread_pool_max_workers = self.config.get( "thread_pool_max_workers", DEFAULT_THREAD__POOL_MAX_WORKERS ) self.retriever: SchedulerRetriever | None = None - self.monitor: SchedulerMonitor | None = None - + self.monitor: SchedulerGeneralMonitor | None = None + self.thread_pool_monitor: SchedulerDispatcherMonitor | None = None self.dispatcher = SchedulerDispatcher( - max_workers=self.max_workers, enable_parallel_dispatch=self.enable_parallel_dispatch + max_workers=self.thread_pool_max_workers, + enable_parallel_dispatch=self.enable_parallel_dispatch, ) # internal message queue @@ -97,9 +99,14 @@ def initialize_modules(self, chat_llm: BaseLLM, process_llm: BaseLLM | None = No # initialize submodules self.chat_llm = chat_llm self.process_llm = process_llm - self.monitor = SchedulerMonitor(process_llm=self.process_llm, config=self.config) + self.monitor = SchedulerGeneralMonitor(process_llm=self.process_llm, config=self.config) + self.thread_pool_monitor = SchedulerDispatcherMonitor(config=self.config) self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config) + if self.enable_parallel_dispatch: + self.thread_pool_monitor.initialize(dispatcher=self.dispatcher) + self.thread_pool_monitor.start() + # initialize with auth_cofig if self.auth_config_path is not None and Path(self.auth_config_path).exists(): self.auth_config = AuthConfig.from_local_yaml(config_path=self.auth_config_path) @@ -377,7 +384,7 @@ def update_activation_memory_periodically( mem_cube=mem_cube, ) - self.monitor.last_activation_mem_update_time = datetime.now() + self.monitor.last_activation_mem_update_time = datetime.utcnow() logger.debug( f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}" @@ -386,7 +393,7 @@ def update_activation_memory_periodically( logger.info( f"Skipping update - {interval_seconds} second interval not yet reached. " f"Last update time is {self.monitor.last_activation_mem_update_time} and now is" - f"{datetime.now()}" + f"{datetime.utcnow()}" ) except Exception as e: logger.error(f"Error: {e}", exc_info=True) @@ -487,7 +494,9 @@ def start(self) -> None: # Initialize dispatcher resources if self.enable_parallel_dispatch: - logger.info(f"Initializing dispatcher thread pool with {self.max_workers} workers") + logger.info( + f"Initializing dispatcher thread pool with {self.thread_pool_max_workers} workers" + ) # Start consumer thread self._running = True diff --git a/src/memos/mem_scheduler/modules/__init__.py b/src/memos/mem_scheduler/general_modules/__init__.py similarity index 100% rename from src/memos/mem_scheduler/modules/__init__.py rename to src/memos/mem_scheduler/general_modules/__init__.py diff --git a/src/memos/mem_scheduler/modules/base.py b/src/memos/mem_scheduler/general_modules/base.py similarity index 100% rename from src/memos/mem_scheduler/modules/base.py rename to src/memos/mem_scheduler/general_modules/base.py diff --git a/src/memos/mem_scheduler/modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py similarity index 74% rename from src/memos/mem_scheduler/modules/dispatcher.py rename to src/memos/mem_scheduler/general_modules/dispatcher.py index 0a3ff196..b2cb4bd3 100644 --- a/src/memos/mem_scheduler/modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -1,9 +1,11 @@ +import concurrent + from collections import defaultdict from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from memos.log import get_logger -from memos.mem_scheduler.modules.base import BaseSchedulerModule +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -26,20 +28,27 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=False): super().__init__() # Main dispatcher thread pool self.max_workers = max_workers + # Only initialize thread pool if in parallel mode self.enable_parallel_dispatch = enable_parallel_dispatch + self.thread_name_prefix = "dispatcher" if self.enable_parallel_dispatch: self.dispatcher_executor = ThreadPoolExecutor( - max_workers=self.max_workers, thread_name_prefix="dispatcher" + max_workers=self.max_workers, thread_name_prefix=self.thread_name_prefix ) else: self.dispatcher_executor = None logger.info(f"enable_parallel_dispatch is set to {self.enable_parallel_dispatch}") + # Registered message handlers self.handlers: dict[str, Callable] = {} + # Dispatcher running state self._running = False + # Set to track active futures for monitoring purposes + self._futures = set() + def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]): """ Register a handler function for a specific message label. @@ -105,6 +114,13 @@ def group_messages_by_user_and_cube( # Convert defaultdict to regular dict for cleaner output return {user_id: dict(cube_groups) for user_id, cube_groups in grouped_dict.items()} + def _handle_future_result(self, future): + self._futures.remove(future) + try: + future.result() # this will throw exception + except Exception as e: + logger.error(f"Handler execution failed: {e!s}", exc_info=True) + def dispatch(self, msg_list: list[ScheduleMessageItem]): """ Dispatch a list of messages to their respective handlers. @@ -112,26 +128,26 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]): Args: msg_list: List of ScheduleMessageItem objects to process """ + if not msg_list: + logger.debug("Received empty message list, skipping dispatch") + return - # Group messages by their labels + # Group messages by their labels, and organize messages by label label_groups = defaultdict(list) - - # Organize messages by label for message in msg_list: label_groups[message.label].append(message) # Process each label group for label, msgs in label_groups.items(): - if label not in self.handlers: - logger.error(f"No handler registered for label: {label}") - handler = self._default_message_handler - else: - handler = self.handlers[label] + handler = self.handlers.get(label, self._default_message_handler) + # dispatch to different handler logger.debug(f"Dispatch {len(msgs)} message(s) to {label} handler.") if self.enable_parallel_dispatch and self.dispatcher_executor is not None: # Capture variables in lambda to avoid loop variable issues - self.dispatcher_executor.submit(handler, msgs) + future = self.dispatcher_executor.submit(handler, msgs) + self._futures.add(future) + future.add_done_callback(self._handle_future_result) logger.info(f"Dispatched {len(msgs)} message(s) as future task") else: handler(msgs) @@ -148,15 +164,38 @@ def join(self, timeout: float | None = None) -> bool: if not self.enable_parallel_dispatch or self.dispatcher_executor is None: return True # 串行模式无需等待 - self.dispatcher_executor.shutdown(wait=True, timeout=timeout) - return True + done, not_done = concurrent.futures.wait( + self._futures, timeout=timeout, return_when=concurrent.futures.ALL_COMPLETED + ) + + # Check for exceptions in completed tasks + for future in done: + try: + future.result() + except Exception: + logger.error("Handler failed during shutdown", exc_info=True) + + return len(not_done) == 0 def shutdown(self) -> None: """Gracefully shutdown the dispatcher.""" + self._running = False + if self.dispatcher_executor is not None: + # Cancel pending tasks + cancelled = 0 + for future in self._futures: + if future.cancel(): + cancelled += 1 + logger.info(f"Cancelled {cancelled}/{len(self._futures)} pending tasks") + + # Shutdown executor + try: self.dispatcher_executor.shutdown(wait=True) - self._running = False - logger.info("Dispatcher has been shutdown.") + except Exception as e: + logger.error(f"Executor shutdown error: {e}", exc_info=True) + finally: + self._futures.clear() def __enter__(self): self._running = True diff --git a/src/memos/mem_scheduler/modules/misc.py b/src/memos/mem_scheduler/general_modules/misc.py similarity index 100% rename from src/memos/mem_scheduler/modules/misc.py rename to src/memos/mem_scheduler/general_modules/misc.py diff --git a/src/memos/mem_scheduler/modules/rabbitmq_service.py b/src/memos/mem_scheduler/general_modules/rabbitmq_service.py similarity index 98% rename from src/memos/mem_scheduler/modules/rabbitmq_service.py rename to src/memos/mem_scheduler/general_modules/rabbitmq_service.py index 3106e2ec..c782c309 100644 --- a/src/memos/mem_scheduler/modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/general_modules/rabbitmq_service.py @@ -8,8 +8,8 @@ from memos.configs.mem_scheduler import AuthConfig, RabbitMQConfig from memos.dependency import require_python_package from memos.log import get_logger -from memos.mem_scheduler.modules.base import BaseSchedulerModule -from memos.mem_scheduler.modules.misc import AutoDroppingQueue +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule +from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue from memos.mem_scheduler.schemas.general_schemas import DIRECT_EXCHANGE_TYPE, FANOUT_EXCHANGE_TYPE diff --git a/src/memos/mem_scheduler/modules/redis_service.py b/src/memos/mem_scheduler/general_modules/redis_service.py similarity index 98% rename from src/memos/mem_scheduler/modules/redis_service.py rename to src/memos/mem_scheduler/general_modules/redis_service.py index 4a6ad35b..5b04ec28 100644 --- a/src/memos/mem_scheduler/modules/redis_service.py +++ b/src/memos/mem_scheduler/general_modules/redis_service.py @@ -6,7 +6,7 @@ from memos.dependency import require_python_package from memos.log import get_logger -from memos.mem_scheduler.modules.base import BaseSchedulerModule +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule logger = get_logger(__name__) diff --git a/src/memos/mem_scheduler/modules/retriever.py b/src/memos/mem_scheduler/general_modules/retriever.py similarity index 99% rename from src/memos/mem_scheduler/modules/retriever.py rename to src/memos/mem_scheduler/general_modules/retriever.py index 595e8e85..3732078d 100644 --- a/src/memos/mem_scheduler/modules/retriever.py +++ b/src/memos/mem_scheduler/general_modules/retriever.py @@ -2,7 +2,7 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.modules.base import BaseSchedulerModule +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.schemas.general_schemas import ( TreeTextMemory_FINE_SEARCH_METHOD, TreeTextMemory_SEARCH_METHOD, diff --git a/src/memos/mem_scheduler/modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py similarity index 99% rename from src/memos/mem_scheduler/modules/scheduler_logger.py rename to src/memos/mem_scheduler/general_modules/scheduler_logger.py index e41b4822..cb5a122a 100644 --- a/src/memos/mem_scheduler/modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -2,7 +2,7 @@ from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.modules.base import BaseSchedulerModule +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.schemas.general_schemas import ( ACTIVATION_MEMORY_TYPE, ADD_LABEL, diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 390a1445..a4768625 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -201,16 +201,6 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: # update activation memories if self.enable_act_memory_update: - if ( - len(self.monitor.working_memory_monitors[user_id][mem_cube_id].memories) - == 0 - ): - self.initialize_working_memory_monitors( - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=messages[0].mem_cube, - ) - self.update_activation_memory_periodically( interval_seconds=self.monitor.act_mem_update_interval, label=ANSWER_LABEL, diff --git a/src/memos/mem_scheduler/monitors/__init__.py b/src/memos/mem_scheduler/monitors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py new file mode 100644 index 00000000..6e820980 --- /dev/null +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -0,0 +1,286 @@ +import threading +import time + +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from time import perf_counter + +from memos.configs.mem_scheduler import BaseSchedulerConfig +from memos.log import get_logger +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule +from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher + + +logger = get_logger(__name__) + + +class SchedulerDispatcherMonitor(BaseSchedulerModule): + """Monitors and manages scheduling operations with LLM integration.""" + + def __init__(self, config: BaseSchedulerConfig): + super().__init__() + self.config: BaseSchedulerConfig = config + + self.check_interval = self.config.get("thread_pool_monitor_check_interval", 60) + self.max_failures = self.config.get("thread_pool_monitor_max_failures", 2) + + # Registry of monitored thread pools + self._pools: dict[str, dict] = {} + self._pool_lock = threading.Lock() + + # thread pool monitor + self._monitor_thread: threading.Thread | None = None + self._running = False + self._restart_in_progress = False + + # modules with thread pool + self.dispatcher: SchedulerDispatcher | None = None + self.dispatcher_pool_name = "dispatcher" + + def initialize(self, dispatcher: SchedulerDispatcher): + self.dispatcher = dispatcher + self.register_pool( + name=self.dispatcher_pool_name, + executor=self.dispatcher.dispatcher_executor, + max_workers=self.dispatcher.max_workers, + restart_on_failure=True, + ) + + def register_pool( + self, + name: str, + executor: ThreadPoolExecutor, + max_workers: int, + restart_on_failure: bool = True, + ) -> bool: + """ + Register a thread pool for monitoring. + + Args: + name: Unique identifier for the pool + executor: ThreadPoolExecutor instance to monitor + max_workers: Expected maximum worker count + restart_on_failure: Whether to restart if pool fails + + Returns: + bool: True if registration succeeded, False if pool already registered + """ + with self._pool_lock: + if name in self._pools: + logger.warning(f"Thread pool '{name}' is already registered") + return False + + self._pools[name] = { + "executor": executor, + "max_workers": max_workers, + "restart": restart_on_failure, + "failure_count": 0, + "last_active": datetime.utcnow(), + "healthy": True, + } + logger.info(f"Registered thread pool '{name}' for monitoring") + return True + + def unregister_pool(self, name: str) -> bool: + """ + Remove a thread pool from monitoring. + + Args: + name: Identifier of the pool to remove + + Returns: + bool: True if removal succeeded, False if pool not found + """ + with self._pool_lock: + if name not in self._pools: + logger.warning(f"Thread pool '{name}' not found in registry") + return False + + del self._pools[name] + logger.info(f"Unregistered thread pool '{name}'") + return True + + def _monitor_loop(self) -> None: + """Main monitoring loop that periodically checks all registered pools.""" + logger.info(f"Starting monitor loop with {self.check_interval} second interval") + + while self._running: + time.sleep(self.check_interval) + try: + self._check_pools_health() + except Exception as e: + logger.error(f"Error during health check: {e!s}", exc_info=True) + + logger.debug("Monitor loop exiting") + + def start(self) -> bool: + """ + Start the monitoring thread. + + Returns: + bool: True if monitor started successfully, False if already running + """ + if self._running: + logger.warning("Dispatcher Monitor is already running") + return False + + self._running = True + self._monitor_thread = threading.Thread( + target=self._monitor_loop, name="threadpool_monitor", daemon=True + ) + self._monitor_thread.start() + logger.info("Dispatcher Monitor monitor started") + return True + + def stop(self) -> None: + """Stop the monitoring thread gracefully.""" + if not self._running: + return + + self._running = False + if self._monitor_thread and self._monitor_thread.is_alive(): + self._monitor_thread.join(timeout=5) + logger.info("Thread pool monitor stopped") + + def _check_pools_health(self) -> None: + """Check health of all registered thread pools.""" + for name, pool_info in list(self._pools.items()): + is_healthy, reason = self._check_pool_health( + pool_info=pool_info, + stuck_max_interval=4, + ) + logger.info(f"Pool '{name}'. is_healthy: {is_healthy}. pool_info: {pool_info}") + with self._pool_lock: + if is_healthy: + pool_info["failure_count"] = 0 + pool_info["healthy"] = True + return + else: + pool_info["failure_count"] += 1 + pool_info["healthy"] = False + logger.warning( + f"Pool '{name}' unhealthy ({pool_info['failure_count']}/{self.max_failures}): {reason}" + ) + + if ( + pool_info["failure_count"] >= self.max_failures + and pool_info["restart"] + and not self._restart_in_progress + ): + self._restart_pool(name, pool_info) + + def _check_pool_health(self, pool_info: dict, stuck_max_interval=4) -> tuple[bool, str]: + """ + Check health of a single thread pool. + + Args: + pool_info: Dictionary containing pool configuration + + Returns: + Tuple: (is_healthy, reason) where reason explains failure if not healthy + """ + executor = pool_info["executor"] + + # Check if executor is shutdown + if executor._shutdown: # pylint: disable=protected-access + return False, "Executor is shutdown" + + # Check thread activity + active_threads = sum( + 1 + for t in threading.enumerate() + if t.name.startswith(executor._thread_name_prefix) # pylint: disable=protected-access + ) + + # Check if no threads are active but should be + if active_threads == 0 and pool_info["max_workers"] > 0: + return False, "No active worker threads" + + # Check if threads are stuck (no activity for 2 intervals) + time_delta = (datetime.utcnow() - pool_info["last_active"]).total_seconds() + if time_delta >= self.check_interval * stuck_max_interval: + return False, "No recent activity" + + # If we got here, pool appears healthy + pool_info["last_active"] = datetime.utcnow() + return True, "" + + def _restart_pool(self, name: str, pool_info: dict) -> None: + """ + Attempt to restart a failed thread pool. + + Args: + name: Name of the pool to restart + pool_info: Dictionary containing pool configuration + """ + if self._restart_in_progress: + return + + self._restart_in_progress = True + logger.warning(f"Attempting to restart thread pool '{name}'") + + try: + old_executor = pool_info["executor"] + self.dispatcher.shutdown() + + # Create new executor with same parameters + new_executor = ThreadPoolExecutor( + max_workers=pool_info["max_workers"], + thread_name_prefix=self.dispatcher.thread_name_prefix, # pylint: disable=protected-access + ) + self.unregister_pool(name=self.dispatcher_pool_name) + self.dispatcher.dispatcher_executor = new_executor + self.register_pool( + name=self.dispatcher_pool_name, + executor=self.dispatcher.dispatcher_executor, + max_workers=self.dispatcher.max_workers, + restart_on_failure=True, + ) + + # Replace in registry + start_time = perf_counter() + with self._pool_lock: + pool_info["executor"] = new_executor + pool_info["failure_count"] = 0 + pool_info["healthy"] = True + pool_info["last_active"] = datetime.utcnow() + + elapsed_time = perf_counter() - start_time + if elapsed_time > 1: + logger.warning(f"Long lock wait: {elapsed_time:.3f}s") + + # Shutdown old executor + try: + old_executor.shutdown(wait=False) + except Exception as e: + logger.error(f"Error shutting down old executor: {e!s}", exc_info=True) + + logger.info(f"Successfully restarted thread pool '{name}'") + except Exception as e: + logger.error(f"Failed to restart pool '{name}': {e!s}", exc_info=True) + finally: + self._restart_in_progress = False + + def get_status(self, name: str | None = None) -> dict: + """ + Get status of monitored pools. + + Args: + name: Optional specific pool name to check + + Returns: + Dictionary of status information + """ + with self._pool_lock: + if name: + return {name: self._pools.get(name, {}).copy()} + return {k: v.copy() for k, v in self._pools.items()} + + def __enter__(self): + """Context manager entry point.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit point.""" + self.stop() diff --git a/src/memos/mem_scheduler/modules/monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py similarity index 98% rename from src/memos/mem_scheduler/modules/monitor.py rename to src/memos/mem_scheduler/monitors/general_monitor.py index 04b8620b..0c4174fa 100644 --- a/src/memos/mem_scheduler/modules/monitor.py +++ b/src/memos/mem_scheduler/monitors/general_monitor.py @@ -6,7 +6,7 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.modules.base import BaseSchedulerModule +from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT, DEFAULT_WEIGHT_VECTOR_FOR_RANKING, @@ -29,7 +29,7 @@ logger = get_logger(__name__) -class SchedulerMonitor(BaseSchedulerModule): +class SchedulerGeneralMonitor(BaseSchedulerModule): """Monitors and manages scheduling operations with LLM integration.""" def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): @@ -208,11 +208,11 @@ def update_activation_memory_monitors( ) def timed_trigger(self, last_time: datetime, interval_seconds: float) -> bool: - now = datetime.now() + now = datetime.utcnow() elapsed = (now - last_time).total_seconds() if elapsed >= interval_seconds: return True - logger.debug(f"Time trigger not ready, {elapsed:.1f}s elapsed (needs {interval_seconds}s)") + logger.info(f"Time trigger not ready, {elapsed:.1f}s elapsed (needs {interval_seconds}s)") return False def get_monitor_memories( diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 2c34b078..0a540574 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -7,7 +7,7 @@ from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.modules.misc import DictConversionMixin +from memos.mem_scheduler.general_modules.misc import DictConversionMixin from .general_schemas import NOT_INITIALIZED diff --git a/src/memos/mem_scheduler/schemas/monitor_schemas.py b/src/memos/mem_scheduler/schemas/monitor_schemas.py index a25015cc..65238d72 100644 --- a/src/memos/mem_scheduler/schemas/monitor_schemas.py +++ b/src/memos/mem_scheduler/schemas/monitor_schemas.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, Field, computed_field, field_validator from memos.log import get_logger -from memos.mem_scheduler.modules.misc import AutoDroppingQueue, DictConversionMixin +from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue, DictConversionMixin from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_MAX_QUERY_KEY_WORDS, DEFAULT_WEIGHT_VECTOR_FOR_RANKING, diff --git a/tests/mem_reader/test_memory.py b/tests/mem_reader/test_memory.py index 03b50e4b..a0091ade 100644 --- a/tests/mem_reader/test_memory.py +++ b/tests/mem_reader/test_memory.py @@ -7,7 +7,7 @@ def test_memory_initialization(): """Test initialization of Memory class.""" user_id = "user123" session_id = "session456" - created_at = datetime.now() + created_at = datetime.utcnow() memory = Memory(user_id=user_id, session_id=session_id, created_at=created_at) diff --git a/tests/mem_scheduler/test_retriever.py b/tests/mem_scheduler/test_retriever.py index ff40f339..0ef6eb8e 100644 --- a/tests/mem_scheduler/test_retriever.py +++ b/tests/mem_scheduler/test_retriever.py @@ -37,7 +37,7 @@ def setUp(self): self.mem_cube.text_mem = self.tree_text_memory self.mem_cube.act_mem = MagicMock() - # Initialize modules with mock LLM + # Initialize general_modules with mock LLM self.scheduler.initialize_modules(chat_llm=self.llm, process_llm=self.llm) self.scheduler.mem_cube = self.mem_cube @@ -47,7 +47,7 @@ def setUp(self): self.logging_warning_patch = patch("logging.warning") self.mock_logging_warning = self.logging_warning_patch.start() - self.logger_info_patch = patch("memos.mem_scheduler.modules.retriever.logger.info") + self.logger_info_patch = patch("memos.mem_scheduler.general_modules.retriever.logger.info") self.mock_logger_info = self.logger_info_patch.start() def tearDown(self): diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py index 88f3a8f0..97377738 100644 --- a/tests/mem_scheduler/test_scheduler.py +++ b/tests/mem_scheduler/test_scheduler.py @@ -8,8 +8,8 @@ from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.llms.base import BaseLLM from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.modules.monitor import SchedulerMonitor -from memos.mem_scheduler.modules.retriever import SchedulerRetriever +from memos.mem_scheduler.general_modules.retriever import SchedulerRetriever +from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor from memos.mem_scheduler.scheduler_factory import SchedulerFactory from memos.mem_scheduler.schemas.general_schemas import ( ANSWER_LABEL, @@ -43,7 +43,7 @@ def setUp(self): self.mem_cube.text_mem = self.tree_text_memory self.mem_cube.act_mem = MagicMock() - # Initialize modules with mock LLM + # Initialize general_modules with mock LLM self.scheduler.initialize_modules(chat_llm=self.llm, process_llm=self.llm) self.scheduler.mem_cube = self.mem_cube @@ -60,7 +60,7 @@ def test_initialization(self): def test_initialize_modules(self): """Test module initialization with proper component assignments.""" self.assertEqual(self.scheduler.chat_llm, self.llm) - self.assertIsInstance(self.scheduler.monitor, SchedulerMonitor) + self.assertIsInstance(self.scheduler.monitor, SchedulerGeneralMonitor) self.assertIsInstance(self.scheduler.retriever, SchedulerRetriever) def test_submit_web_logs(self):