Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/mem_scheduler/memos_w_scheduler_for_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def init_task():
query = item["question"]
print(f"Query:\n {query}\n")
response = mos.chat(query=query, user_id=user_id)
print(f"Answer:\n {response}")
print("===== Chat End =====")
print(f"Answer:\n {response}\n")

mos.mem_scheduler.stop()
2 changes: 1 addition & 1 deletion examples/mem_scheduler/rabbitmq_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time

from memos.configs.mem_scheduler import AuthConfig
from memos.mem_scheduler.modules.rabbitmq_service import RabbitMQSchedulerModule
from memos.mem_scheduler.general_modules.rabbitmq_service import RabbitMQSchedulerModule


def publish_message(rabbitmq_module, message):
Expand Down
8 changes: 4 additions & 4 deletions scripts/check_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def extract_top_level_modules(tree: ast.Module) -> set[str]:
"""
Extract all top-level imported modules (excluding relative imports).
Extract all top-level imported general_modules (excluding relative imports).
"""
modules = set()
for node in tree.body:
Expand All @@ -27,12 +27,12 @@ def extract_top_level_modules(tree: ast.Module) -> set[str]:
def check_importable(modules: set[str], filename: str) -> list[str]:
"""
Attempt to import each module in the current environment.
Return a list of modules that fail to import.
Return a list of general_modules that fail to import.
"""
failed = []
for mod in sorted(modules):
if mod in EXCLUDE_MODULES:
# Skip excluded modules such as your own package
# Skip excluded general_modules such as your own package
continue
try:
importlib.import_module(mod)
Expand Down Expand Up @@ -70,7 +70,7 @@ def main():

if has_error:
print(
"\n💥 Top-level imports failed. These modules may not be main dependencies."
"\n💥 Top-level imports failed. These general_modules may not be main dependencies."
" Try moving the imports to a function or class scope, and decorate it with @require_python_package."
)
sys.exit(1)
Expand Down
2 changes: 1 addition & 1 deletion src/memos/api/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def set_trace_id_getter(getter: TraceIdGetter) -> None:
Set a custom trace_id getter function.

This allows the logging system to retrieve trace_id without importing
API-specific modules.
API-specific general_modules.
"""
global _trace_id_getter
_trace_id_getter = getter
Expand Down
2 changes: 1 addition & 1 deletion src/memos/configs/mem_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import ConfigDict, Field, field_validator, model_validator

from memos.configs.base import BaseConfig
from memos.mem_scheduler.modules.misc import DictConversionMixin
from memos.mem_scheduler.general_modules.misc import DictConversionMixin
from memos.mem_scheduler.schemas.general_schemas import (
BASE_DIR,
DEFAULT_ACT_MEM_DUMP_PATH,
Expand Down
14 changes: 7 additions & 7 deletions src/memos/mem_os/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _initialize_mem_scheduler(self) -> GeneralScheduler:
chat_llm=self.chat_llm, process_llm=self.chat_llm
)
else:
# Configure scheduler modules
# Configure scheduler general_modules
self._mem_scheduler.initialize_modules(
chat_llm=self.chat_llm, process_llm=self.mem_reader.llm
)
Expand Down Expand Up @@ -185,7 +185,7 @@ def _register_chat_history(self, user_id: str | None = None) -> None:
self.chat_history_manager[user_id] = ChatHistory(
user_id=user_id,
session_id=self.session_id,
created_at=datetime.now(),
created_at=datetime.utcnow(),
total_messages=0,
chat_history=[],
)
Expand Down Expand Up @@ -279,7 +279,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
mem_cube=mem_cube,
label=QUERY_LABEL,
content=query,
timestamp=datetime.now(),
timestamp=datetime.utcnow(),
)
self.mem_scheduler.submit_messages(messages=[message_item])

Expand Down Expand Up @@ -338,7 +338,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
mem_cube=mem_cube,
label=ANSWER_LABEL,
content=response,
timestamp=datetime.now(),
timestamp=datetime.utcnow(),
)
self.mem_scheduler.submit_messages(messages=[message_item])

Expand Down Expand Up @@ -681,7 +681,7 @@ def add(
mem_cube=mem_cube,
label=ADD_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.now(),
timestamp=datetime.utcnow(),
)
self.mem_scheduler.submit_messages(messages=[message_item])

Expand Down Expand Up @@ -725,7 +725,7 @@ def add(
mem_cube=mem_cube,
label=ADD_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.now(),
timestamp=datetime.utcnow(),
)
self.mem_scheduler.submit_messages(messages=[message_item])

Expand Down Expand Up @@ -756,7 +756,7 @@ def add(
mem_cube=mem_cube,
label=ADD_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.now(),
timestamp=datetime.utcnow(),
)
self.mem_scheduler.submit_messages(messages=[message_item])

Expand Down
39 changes: 24 additions & 15 deletions src/memos/mem_scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.modules.dispatcher import SchedulerDispatcher
from memos.mem_scheduler.modules.misc import AutoDroppingQueue as Queue
from memos.mem_scheduler.modules.monitor import SchedulerMonitor
from memos.mem_scheduler.modules.rabbitmq_service import RabbitMQSchedulerModule
from memos.mem_scheduler.modules.redis_service import RedisSchedulerModule
from memos.mem_scheduler.modules.retriever import SchedulerRetriever
from memos.mem_scheduler.modules.scheduler_logger import SchedulerLoggerModule
from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
from memos.mem_scheduler.general_modules.rabbitmq_service import RabbitMQSchedulerModule
from memos.mem_scheduler.general_modules.redis_service import RedisSchedulerModule
from memos.mem_scheduler.general_modules.retriever import SchedulerRetriever
from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule
from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor
from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
from memos.mem_scheduler.schemas.general_schemas import (
DEFAULT_ACT_MEM_DUMP_PATH,
DEFAULT_CONSUME_INTERVAL_SECONDS,
Expand Down Expand Up @@ -56,15 +57,16 @@ def __init__(self, config: BaseSchedulerConfig):
self.act_mem_dump_path = self.config.get("act_mem_dump_path", DEFAULT_ACT_MEM_DUMP_PATH)
self.search_method = TreeTextMemory_SEARCH_METHOD
self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", False)
self.max_workers = self.config.get(
self.thread_pool_max_workers = self.config.get(
"thread_pool_max_workers", DEFAULT_THREAD__POOL_MAX_WORKERS
)

self.retriever: SchedulerRetriever | None = None
self.monitor: SchedulerMonitor | None = None

self.monitor: SchedulerGeneralMonitor | None = None
self.thread_pool_monitor: SchedulerDispatcherMonitor | None = None
self.dispatcher = SchedulerDispatcher(
max_workers=self.max_workers, enable_parallel_dispatch=self.enable_parallel_dispatch
max_workers=self.thread_pool_max_workers,
enable_parallel_dispatch=self.enable_parallel_dispatch,
)

# internal message queue
Expand Down Expand Up @@ -97,9 +99,14 @@ def initialize_modules(self, chat_llm: BaseLLM, process_llm: BaseLLM | None = No
# initialize submodules
self.chat_llm = chat_llm
self.process_llm = process_llm
self.monitor = SchedulerMonitor(process_llm=self.process_llm, config=self.config)
self.monitor = SchedulerGeneralMonitor(process_llm=self.process_llm, config=self.config)
self.thread_pool_monitor = SchedulerDispatcherMonitor(config=self.config)
self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config)

if self.enable_parallel_dispatch:
self.thread_pool_monitor.initialize(dispatcher=self.dispatcher)
self.thread_pool_monitor.start()

# initialize with auth_cofig
if self.auth_config_path is not None and Path(self.auth_config_path).exists():
self.auth_config = AuthConfig.from_local_yaml(config_path=self.auth_config_path)
Expand Down Expand Up @@ -377,7 +384,7 @@ def update_activation_memory_periodically(
mem_cube=mem_cube,
)

self.monitor.last_activation_mem_update_time = datetime.now()
self.monitor.last_activation_mem_update_time = datetime.utcnow()

logger.debug(
f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}"
Expand All @@ -386,7 +393,7 @@ def update_activation_memory_periodically(
logger.info(
f"Skipping update - {interval_seconds} second interval not yet reached. "
f"Last update time is {self.monitor.last_activation_mem_update_time} and now is"
f"{datetime.now()}"
f"{datetime.utcnow()}"
)
except Exception as e:
logger.error(f"Error: {e}", exc_info=True)
Expand Down Expand Up @@ -487,7 +494,9 @@ def start(self) -> None:

# Initialize dispatcher resources
if self.enable_parallel_dispatch:
logger.info(f"Initializing dispatcher thread pool with {self.max_workers} workers")
logger.info(
f"Initializing dispatcher thread pool with {self.thread_pool_max_workers} workers"
)

# Start consumer thread
self._running = True
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import concurrent

from collections import defaultdict
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor

from memos.log import get_logger
from memos.mem_scheduler.modules.base import BaseSchedulerModule
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem


Expand All @@ -26,20 +28,27 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=False):
super().__init__()
# Main dispatcher thread pool
self.max_workers = max_workers

# Only initialize thread pool if in parallel mode
self.enable_parallel_dispatch = enable_parallel_dispatch
self.thread_name_prefix = "dispatcher"
if self.enable_parallel_dispatch:
self.dispatcher_executor = ThreadPoolExecutor(
max_workers=self.max_workers, thread_name_prefix="dispatcher"
max_workers=self.max_workers, thread_name_prefix=self.thread_name_prefix
)
else:
self.dispatcher_executor = None
logger.info(f"enable_parallel_dispatch is set to {self.enable_parallel_dispatch}")

# Registered message handlers
self.handlers: dict[str, Callable] = {}

# Dispatcher running state
self._running = False

# Set to track active futures for monitoring purposes
self._futures = set()

def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]):
"""
Register a handler function for a specific message label.
Expand Down Expand Up @@ -105,33 +114,40 @@ def group_messages_by_user_and_cube(
# Convert defaultdict to regular dict for cleaner output
return {user_id: dict(cube_groups) for user_id, cube_groups in grouped_dict.items()}

def _handle_future_result(self, future):
self._futures.remove(future)
try:
future.result() # this will throw exception
except Exception as e:
logger.error(f"Handler execution failed: {e!s}", exc_info=True)

def dispatch(self, msg_list: list[ScheduleMessageItem]):
"""
Dispatch a list of messages to their respective handlers.

Args:
msg_list: List of ScheduleMessageItem objects to process
"""
if not msg_list:
logger.debug("Received empty message list, skipping dispatch")
return

# Group messages by their labels
# Group messages by their labels, and organize messages by label
label_groups = defaultdict(list)

# Organize messages by label
for message in msg_list:
label_groups[message.label].append(message)

# Process each label group
for label, msgs in label_groups.items():
if label not in self.handlers:
logger.error(f"No handler registered for label: {label}")
handler = self._default_message_handler
else:
handler = self.handlers[label]
handler = self.handlers.get(label, self._default_message_handler)

# dispatch to different handler
logger.debug(f"Dispatch {len(msgs)} message(s) to {label} handler.")
if self.enable_parallel_dispatch and self.dispatcher_executor is not None:
# Capture variables in lambda to avoid loop variable issues
self.dispatcher_executor.submit(handler, msgs)
future = self.dispatcher_executor.submit(handler, msgs)
self._futures.add(future)
future.add_done_callback(self._handle_future_result)
logger.info(f"Dispatched {len(msgs)} message(s) as future task")
else:
handler(msgs)
Expand All @@ -148,15 +164,38 @@ def join(self, timeout: float | None = None) -> bool:
if not self.enable_parallel_dispatch or self.dispatcher_executor is None:
return True # 串行模式无需等待

self.dispatcher_executor.shutdown(wait=True, timeout=timeout)
return True
done, not_done = concurrent.futures.wait(
self._futures, timeout=timeout, return_when=concurrent.futures.ALL_COMPLETED
)

# Check for exceptions in completed tasks
for future in done:
try:
future.result()
except Exception:
logger.error("Handler failed during shutdown", exc_info=True)

return len(not_done) == 0

def shutdown(self) -> None:
"""Gracefully shutdown the dispatcher."""
self._running = False

if self.dispatcher_executor is not None:
# Cancel pending tasks
cancelled = 0
for future in self._futures:
if future.cancel():
cancelled += 1
logger.info(f"Cancelled {cancelled}/{len(self._futures)} pending tasks")

# Shutdown executor
try:
self.dispatcher_executor.shutdown(wait=True)
self._running = False
logger.info("Dispatcher has been shutdown.")
except Exception as e:
logger.error(f"Executor shutdown error: {e}", exc_info=True)
finally:
self._futures.clear()

def __enter__(self):
self._running = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from memos.configs.mem_scheduler import AuthConfig, RabbitMQConfig
from memos.dependency import require_python_package
from memos.log import get_logger
from memos.mem_scheduler.modules.base import BaseSchedulerModule
from memos.mem_scheduler.modules.misc import AutoDroppingQueue
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue
from memos.mem_scheduler.schemas.general_schemas import DIRECT_EXCHANGE_TYPE, FANOUT_EXCHANGE_TYPE


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from memos.dependency import require_python_package
from memos.log import get_logger
from memos.mem_scheduler.modules.base import BaseSchedulerModule
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule


logger = get_logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.modules.base import BaseSchedulerModule
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.schemas.general_schemas import (
TreeTextMemory_FINE_SEARCH_METHOD,
TreeTextMemory_SEARCH_METHOD,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.modules.base import BaseSchedulerModule
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.schemas.general_schemas import (
ACTIVATION_MEMORY_TYPE,
ADD_LABEL,
Expand Down
Loading