diff --git a/examples/data/config/mem_scheduler/general_scheduler_config.yaml b/examples/data/config/mem_scheduler/general_scheduler_config.yaml index 5c82db24..39065590 100644 --- a/examples/data/config/mem_scheduler/general_scheduler_config.yaml +++ b/examples/data/config/mem_scheduler/general_scheduler_config.yaml @@ -1,9 +1,9 @@ backend: general_scheduler config: top_k: 10 - top_n: 5 + top_n: 10 act_mem_update_interval: 30 - context_window_size: 5 + context_window_size: 10 thread_pool_max_workers: 5 - consume_interval_seconds: 3 + consume_interval_seconds: 1 enable_parallel_dispatch: true diff --git a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml index a8d46ae1..e898c826 100644 --- a/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml +++ b/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml @@ -33,11 +33,11 @@ mem_scheduler: backend: "general_scheduler" config: top_k: 10 - top_n: 5 + top_n: 10 act_mem_update_interval: 30 - context_window_size: 5 + context_window_size: 10 thread_pool_max_workers: 10 - consume_interval_seconds: 3 + consume_interval_seconds: 1 enable_parallel_dispatch: true max_turns_window: 20 top_k: 5 diff --git a/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml b/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml index cdebb9af..1da4dad1 100644 --- a/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml +++ b/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml @@ -34,12 +34,12 @@ mem_reader: mem_scheduler: backend: "general_scheduler" config: - top_k: 2 - top_n: 5 + top_k: 10 + top_n: 10 act_mem_update_interval: 30 - context_window_size: 5 + context_window_size: 10 thread_pool_max_workers: 10 - consume_interval_seconds: 3 + consume_interval_seconds: 1 enable_parallel_dispatch: true max_turns_window: 20 top_k: 5 diff --git a/examples/mem_os/chat_w_scheduler.py b/examples/mem_os/chat_w_scheduler.py index a73f9132..c5d84efc 100644 --- a/examples/mem_os/chat_w_scheduler.py +++ b/examples/mem_os/chat_w_scheduler.py @@ -7,7 +7,7 @@ from memos.configs.mem_os import MOSConfig from memos.mem_cube.general import GeneralMemCube from memos.mem_os.main import MOS -from memos.mem_scheduler.utils import parse_yaml +from memos.mem_scheduler.utils.misc_utils import parse_yaml # init MOS diff --git a/examples/mem_scheduler/schedule_w_memos.py b/examples/mem_scheduler/memos_w_scheduler.py similarity index 81% rename from examples/mem_scheduler/schedule_w_memos.py rename to examples/mem_scheduler/memos_w_scheduler.py index 9032ded1..c00845d5 100644 --- a/examples/mem_scheduler/schedule_w_memos.py +++ b/examples/mem_scheduler/memos_w_scheduler.py @@ -3,6 +3,8 @@ from datetime import datetime from pathlib import Path +from queue import Queue +from typing import TYPE_CHECKING from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig @@ -10,13 +12,20 @@ from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_os.main import MOS -from memos.mem_scheduler.modules.schemas import ( +from memos.mem_scheduler.general_scheduler import GeneralScheduler +from memos.mem_scheduler.scheduler_factory import SchedulerFactory +from memos.mem_scheduler.schemas.general_schemas import ( ANSWER_LABEL, QUERY_LABEL, - ScheduleMessageItem, ) -from memos.mem_scheduler.scheduler_factory import SchedulerFactory -from memos.mem_scheduler.utils import parse_yaml +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.utils.misc_utils import parse_yaml + + +if TYPE_CHECKING: + from memos.mem_scheduler.schemas import ( + ScheduleLogForWebItem, + ) FILE_PATH = Path(__file__).absolute() @@ -109,6 +118,8 @@ def run_with_automatic_scheduler_init(): response = mos.chat(query, user_id=user_id) print(f"Query:\n {query}\n\nAnswer:\n {response}") + show_web_logs(mem_scheduler=mos.mem_scheduler) + mos.mem_scheduler.stop() @@ -184,9 +195,46 @@ def run_with_manual_scheduler_init(): mos.mem_scheduler.submit_messages(messages=message_item) print(f"Query:\n {query}\n\nAnswer:\n {response}") + show_web_logs(mem_scheduler=mos.mem_scheduler) + mos.mem_scheduler.stop() +def show_web_logs(mem_scheduler: GeneralScheduler): + """Display all web log entries from the scheduler's log queue. + + Args: + mem_scheduler: The scheduler instance containing web logs to display + """ + if mem_scheduler._web_log_message_queue.empty(): + print("Web log queue is currently empty.") + return + + print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50) + + # Create a temporary queue to preserve the original queue contents + temp_queue = Queue() + log_count = 0 + + while not mem_scheduler._web_log_message_queue.empty(): + log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get() + temp_queue.put(log_item) + log_count += 1 + + # Print log entry details + print(f"\nLog Entry #{log_count}:") + print(f'- "{log_item.label}" log: {log_item}') + + print("-" * 50) + + # Restore items back to the original queue + while not temp_queue.empty(): + mem_scheduler._web_log_message_queue.put(temp_queue.get()) + + print(f"\nTotal {log_count} web log entries displayed.") + print("=" * 110 + "\n") + + if __name__ == "__main__": run_with_automatic_scheduler_init() diff --git a/examples/mem_scheduler/schedule_chat_and_web.py b/examples/mem_scheduler/memos_w_scheduler_for_test.py similarity index 74% rename from examples/mem_scheduler/schedule_chat_and_web.py rename to examples/mem_scheduler/memos_w_scheduler_for_test.py index 666f15e0..2aae09ec 100644 --- a/examples/mem_scheduler/schedule_chat_and_web.py +++ b/examples/mem_scheduler/memos_w_scheduler_for_test.py @@ -1,25 +1,17 @@ +import json import shutil import sys from pathlib import Path -from queue import Queue -from typing import TYPE_CHECKING from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig from memos.configs.mem_scheduler import AuthConfig from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.mos_for_test_scheduler import MOSForTestScheduler -if TYPE_CHECKING: - from memos.mem_scheduler.modules.schemas import ( - ScheduleLogForWebItem, - ) - - FILE_PATH = Path(__file__).absolute() BASE_DIR = FILE_PATH.parent.parent.parent sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory @@ -90,41 +82,6 @@ def init_task(): return conversations, questions -def show_web_logs(mem_scheduler: GeneralScheduler): - """Display all web log entries from the scheduler's log queue. - - Args: - mem_scheduler: The scheduler instance containing web logs to display - """ - if mem_scheduler._web_log_message_queue.empty(): - print("Web log queue is currently empty.") - return - - print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50) - - # Create a temporary queue to preserve the original queue contents - temp_queue = Queue() - log_count = 0 - - while not mem_scheduler._web_log_message_queue.empty(): - log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get() - temp_queue.put(log_item) - log_count += 1 - - # Print log entry details - print(f"\nLog Entry #{log_count}:") - print(f'- "{log_item.label}" log: {log_item}') - - print("-" * 50) - - # Restore items back to the original queue - while not temp_queue.empty(): - mem_scheduler._web_log_message_queue.put(temp_queue.get()) - - print(f"\nTotal {log_count} web log entries displayed.") - print("=" * 110 + "\n") - - if __name__ == "__main__": # set up data conversations, questions = init_task() @@ -168,12 +125,18 @@ def show_web_logs(mem_scheduler: GeneralScheduler): mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) + # Add interfering conversations + file_path = Path(f"{BASE_DIR}/examples/data/mem_scheduler/scene_data.json") + scene_data = json.load(file_path.open("r", encoding="utf-8")) + mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id) + mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id) + for item in questions: + print("===== Chat Start =====") query = item["question"] - + print(f"Query:\n {query}\n") response = mos.chat(query=query, user_id=user_id) - print(f"Query:\n {query}\n\nAnswer:\n {response}") - - show_web_logs(mos.mem_scheduler) + print(f"Answer:\n {response}") + print("===== Chat End =====") mos.mem_scheduler.stop() diff --git a/examples/mem_scheduler/rabbitmq_example.py b/examples/mem_scheduler/rabbitmq_example.py index 39b98234..1343e0ec 100644 --- a/examples/mem_scheduler/rabbitmq_example.py +++ b/examples/mem_scheduler/rabbitmq_example.py @@ -8,7 +8,7 @@ def publish_message(rabbitmq_module, message): """Function to publish a message.""" rabbitmq_module.rabbitmq_publish_message(message) - print(f"Published message: {message}") + print(f"Published message: {message}\n") def main(): @@ -24,8 +24,7 @@ def main(): rabbitmq_module.initialize_rabbitmq(config=AuthConfig.from_local_yaml().rabbitmq) try: - # Start consumer - rabbitmq_module.rabbitmq_start_consuming() + rabbitmq_module.wait_for_connection_ready() # === Publish some test messages === # List to hold thread references @@ -38,6 +37,9 @@ def main(): thread.start() threads.append(thread) + # Start consumer + rabbitmq_module.rabbitmq_start_consuming() + # Join threads to ensure all messages are published before proceeding for thread in threads: thread.join() @@ -47,7 +49,7 @@ def main(): finally: # Give some time for cleanup - time.sleep(5) + time.sleep(3) # Close connections rabbitmq_module.rabbitmq_close() diff --git a/examples/mem_scheduler/redis_example.py b/examples/mem_scheduler/redis_example.py index 528893d3..1660d6c0 100644 --- a/examples/mem_scheduler/redis_example.py +++ b/examples/mem_scheduler/redis_example.py @@ -8,8 +8,9 @@ from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.modules.schemas import QUERY_LABEL, ScheduleMessageItem from memos.mem_scheduler.scheduler_factory import SchedulerFactory +from memos.mem_scheduler.schemas.general_schemas import QUERY_LABEL +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem if TYPE_CHECKING: diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py index e735ebfe..976e506b 100644 --- a/examples/mem_scheduler/try_schedule_modules.py +++ b/examples/mem_scheduler/try_schedule_modules.py @@ -13,12 +13,14 @@ from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.general_scheduler import GeneralScheduler -from memos.mem_scheduler.modules.schemas import NOT_APPLICABLE_TYPE from memos.mem_scheduler.mos_for_test_scheduler import MOSForTestScheduler +from memos.mem_scheduler.schemas.general_schemas import ( + NOT_APPLICABLE_TYPE, +) if TYPE_CHECKING: - from memos.mem_scheduler.modules.schemas import ( + from memos.mem_scheduler.schemas import ( ScheduleLogForWebItem, ) @@ -175,14 +177,14 @@ def show_web_logs(mem_scheduler: GeneralScheduler): query = item["question"] # test process_session_turn - mos.mem_scheduler.process_session_turn( + working_memory, new_candidates = mos.mem_scheduler.process_session_turn( queries=[query], user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, top_k=10, - query_history=None, ) + print(f"\nnew_candidates: {[one.memory for one in new_candidates]}") # test activation memory update mos.mem_scheduler.update_activation_memory_periodically( diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index 83799de7..0c99a57f 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -6,12 +6,12 @@ from pydantic import ConfigDict, Field, field_validator, model_validator from memos.configs.base import BaseConfig -from memos.mem_scheduler.modules.schemas import ( +from memos.mem_scheduler.modules.misc import DictConversionMixin +from memos.mem_scheduler.schemas.general_schemas import ( BASE_DIR, DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_CONSUME_INTERVAL_SECONDS, DEFAULT_THREAD__POOL_MAX_WORKERS, - DictConversionMixin, ) @@ -21,6 +21,7 @@ class BaseSchedulerConfig(BaseConfig): top_k: int = Field( default=10, description="Number of top candidates to consider in initial retrieval" ) + # TODO: The 'top_n' field is deprecated and will be removed in future versions. top_n: int = Field(default=5, description="Number of final results to return after processing") enable_parallel_dispatch: bool = Field( default=True, description="Whether to enable parallel message processing using thread pool" @@ -48,7 +49,7 @@ class GeneralSchedulerConfig(BaseSchedulerConfig): default=300, description="Interval in seconds for updating activation memory" ) context_window_size: int | None = Field( - default=5, description="Size of the context window for conversation history" + default=10, description="Size of the context window for conversation history" ) act_mem_dump_path: str | None = Field( default=DEFAULT_ACT_MEM_DUMP_PATH, # Replace with DEFAULT_ACT_MEM_DUMP_PATH @@ -105,7 +106,20 @@ class RabbitMQConfig( class GraphDBAuthConfig(BaseConfig): - uri: str = Field(default="localhost", description="URI for graph database access") + uri: str = Field( + default="bolt://localhost:7687", + description="URI for graph database access (e.g., bolt://host:port)", + ) + user: str = Field(default="neo4j", description="Username for graph database authentication") + password: str = Field( + default="", + description="Password for graph database authentication", + min_length=8, # 建议密码最小长度 + ) + db_name: str = Field(default="neo4j", description="Database name to connect to") + auto_create: bool = Field( + default=True, description="Whether to automatically create the database if it doesn't exist" + ) class OpenAIConfig(BaseConfig): diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 906420e2..0dcb72d4 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -13,12 +13,12 @@ from memos.mem_cube.general import GeneralMemCube from memos.mem_reader.factory import MemReaderFactory from memos.mem_scheduler.general_scheduler import GeneralScheduler -from memos.mem_scheduler.modules.schemas import ( +from memos.mem_scheduler.scheduler_factory import SchedulerFactory +from memos.mem_scheduler.schemas.general_schemas import ( ADD_LABEL, ANSWER_LABEL, - ScheduleMessageItem, ) -from memos.mem_scheduler.scheduler_factory import SchedulerFactory +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_user.user_manager import UserManager, UserRole from memos.memories.activation.item import ActivationMemoryItem from memos.memories.parametric.item import ParametricMemoryItem @@ -58,10 +58,15 @@ def __init__(self, config: MOSConfig, user_manager: UserManager | None = None): f"User '{self.user_id}' does not exist or is inactive. Please create user first." ) - # Lazy initialization marker + # Initialize mem_scheduler self._mem_scheduler_lock = Lock() self.enable_mem_scheduler = self.config.get("enable_mem_scheduler", False) - self._mem_scheduler: GeneralScheduler = None + if self.enable_mem_scheduler: + self._mem_scheduler = self._initialize_mem_scheduler() + self._mem_scheduler.mem_cubes = self.mem_cubes + else: + self._mem_scheduler: GeneralScheduler = None + logger.info(f"MOS initialized for user: {self.user_id}") @property @@ -93,14 +98,16 @@ def mem_scheduler(self, value: GeneralScheduler | None) -> None: else: logger.debug("Memory scheduler cleared") - def _initialize_mem_scheduler(self): + def _initialize_mem_scheduler(self) -> GeneralScheduler: """Initialize the memory scheduler on first access.""" if not self.config.enable_mem_scheduler: logger.debug("Memory scheduler is disabled in config") self._mem_scheduler = None + return self._mem_scheduler elif not hasattr(self.config, "mem_scheduler"): logger.error("Config of Memory scheduler is not available") self._mem_scheduler = None + return self._mem_scheduler else: logger.info("Initializing memory scheduler...") scheduler_config = self.config.mem_scheduler @@ -111,13 +118,16 @@ def _initialize_mem_scheduler(self): f"Memory reader of type {type(self.mem_reader).__name__} " "missing required 'llm' attribute" ) - self._mem_scheduler.initialize_modules(chat_llm=self.chat_llm) + self._mem_scheduler.initialize_modules( + chat_llm=self.chat_llm, process_llm=self.chat_llm + ) else: # Configure scheduler modules self._mem_scheduler.initialize_modules( chat_llm=self.chat_llm, process_llm=self.mem_reader.llm ) self._mem_scheduler.start() + return self._mem_scheduler def mem_scheduler_on(self) -> bool: if not self.config.enable_mem_scheduler or self._mem_scheduler is None: @@ -576,6 +586,7 @@ def add( user_id (str, optional): The identifier of the user to add the memories to. If None, the default user is used. """ + # user input messages assert (messages is not None) or (memory_content is not None) or (doc_path is not None), ( "messages_or_doc_path or memory_content or doc_path must be provided." ) @@ -615,23 +626,26 @@ def add( type="chat", info={"user_id": target_user_id, "session_id": str(uuid.uuid4())}, ) + + mem_ids = [] for mem in memories: - self.mem_cubes[mem_cube_id].text_mem.add(mem) + mem_id_list: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(mem) + mem_ids.extend(mem_id_list) # submit messages for scheduler - mem_cube = self.mem_cubes[mem_cube_id] if self.enable_mem_scheduler and self.mem_scheduler is not None: - text_messages = [message["content"] for message in messages] + mem_cube = self.mem_cubes[mem_cube_id] message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, label=ADD_LABEL, - content=json.dumps(text_messages), + content=json.dumps(mem_ids), timestamp=datetime.now(), ) self.mem_scheduler.submit_messages(messages=[message_item]) + # user profile if ( (memory_content is not None) and self.config.enable_textual_memory @@ -653,21 +667,56 @@ def add( type="chat", info={"user_id": target_user_id, "session_id": str(uuid.uuid4())}, ) + + mem_ids = [] for mem in memories: - self.mem_cubes[mem_cube_id].text_mem.add(mem) + mem_id_list: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(mem) + mem_ids.extend(mem_id_list) + + # submit messages for scheduler + if self.enable_mem_scheduler and self.mem_scheduler is not None: + mem_cube = self.mem_cubes[mem_cube_id] + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.now(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) + + # user doc input if ( (doc_path is not None) and self.config.enable_textual_memory and self.mem_cubes[mem_cube_id].text_mem ): documents = self._get_all_documents(doc_path) - doc_memory = self.mem_reader.get_memory( + doc_memories = self.mem_reader.get_memory( documents, type="doc", info={"user_id": target_user_id, "session_id": str(uuid.uuid4())}, ) - for mem in doc_memory: - self.mem_cubes[mem_cube_id].text_mem.add(mem) + + mem_ids = [] + for mem in doc_memories: + mem_id_list: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(mem) + mem_ids.extend(mem_id_list) + + # submit messages for scheduler + if self.enable_mem_scheduler and self.mem_scheduler is not None: + mem_cube = self.mem_cubes[mem_cube_id] + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.now(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) + logger.info(f"Add memory to {mem_cube_id} successfully") def get( diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py index edb3a333..2cead7c8 100644 --- a/src/memos/mem_os/main.py +++ b/src/memos/mem_os/main.py @@ -208,7 +208,7 @@ def _chat_with_cot_enhancement( if self.enable_mem_scheduler and self.mem_scheduler is not None: from datetime import datetime - from memos.mem_scheduler.modules.schemas import ( + from memos.mem_scheduler.schemas import ( ANSWER_LABEL, ScheduleMessageItem, ) diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index ccb0e52f..31ab55a9 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -23,7 +23,7 @@ remove_embedding_recursive, sort_children_by_memory_type, ) -from memos.mem_scheduler.modules.schemas import ANSWER_LABEL, QUERY_LABEL, ScheduleMessageItem +from memos.mem_scheduler.schemas import ANSWER_LABEL, QUERY_LABEL, ScheduleMessageItem from memos.mem_user.persistent_user_manager import PersistentUserManager from memos.mem_user.user_manager import UserRole from memos.memories.textual.item import ( diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 4f8c8c80..0335bd5e 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -15,24 +15,21 @@ from memos.mem_scheduler.modules.rabbitmq_service import RabbitMQSchedulerModule from memos.mem_scheduler.modules.redis_service import RedisSchedulerModule from memos.mem_scheduler.modules.retriever import SchedulerRetriever -from memos.mem_scheduler.modules.schemas import ( - ACTIVATION_MEMORY_TYPE, - ADD_LABEL, +from memos.mem_scheduler.modules.scheduler_logger import SchedulerLoggerModule +from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_CONSUME_INTERVAL_SECONDS, DEFAULT_THREAD__POOL_MAX_WORKERS, - LONG_TERM_MEMORY_TYPE, - NOT_INITIALIZED, - PARAMETER_MEMORY_TYPE, - QUERY_LABEL, - TEXT_MEMORY_TYPE, - USER_INPUT_TYPE, - WORKING_MEMORY_TYPE, + TreeTextMemory_SEARCH_METHOD, +) +from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, ScheduleMessageItem, - TreeTextMemory_SEARCH_METHOD, ) -from memos.mem_scheduler.utils import transform_name_to_key +from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem +from memos.mem_scheduler.utils.filter_utils import ( + transform_name_to_key, +) from memos.memories.activation.kv import KVCacheMemory from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory @@ -42,7 +39,7 @@ logger = get_logger(__name__) -class BaseScheduler(RabbitMQSchedulerModule, RedisSchedulerModule): +class BaseScheduler(RabbitMQSchedulerModule, RedisSchedulerModule, SchedulerLoggerModule): """Base class for all mem_scheduler.""" def __init__(self, config: BaseSchedulerConfig): @@ -51,12 +48,11 @@ def __init__(self, config: BaseSchedulerConfig): self.config = config # hyper-parameters - self.top_k = self.config.get("top_k", 5) + self.top_k = self.config.get("top_k", 10) self.context_window_size = self.config.get("context_window_size", 5) self.enable_act_memory_update = self.config.get("enable_act_memory_update", False) self.act_mem_dump_path = self.config.get("act_mem_dump_path", DEFAULT_ACT_MEM_DUMP_PATH) self.search_method = TreeTextMemory_SEARCH_METHOD - self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", False) self.max_workers = self.config.get( "thread_pool_max_workers", DEFAULT_THREAD__POOL_MAX_WORKERS @@ -99,7 +95,6 @@ def initialize_modules(self, chat_llm: BaseLLM, process_llm: BaseLLM | None = No self.process_llm = process_llm self.monitor = SchedulerMonitor(process_llm=self.process_llm, config=self.config) self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config) - self.retriever.log_working_memory_replacement = self.log_working_memory_replacement # initialize with auth_cofig if self.auth_config_path is not None and Path(self.auth_config_path).exists(): @@ -133,36 +128,140 @@ def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None: self._current_mem_cube_id = msg.mem_cube_id self._current_mem_cube = msg.mem_cube - def _validate_messages(self, messages: list[ScheduleMessageItem], label: str): - """Validate if all messages match the expected label. + def transform_memories_to_monitors( + self, memories: list[TextualMemoryItem] + ) -> list[MemoryMonitorItem]: + """ + Convert a list of TextualMemoryItem objects into MemoryMonitorItem objects + with importance scores based on keyword matching. Args: - messages: List of message items to validate. - label: Expected message label (e.g., QUERY_LABEL/ANSWER_LABEL). + memories: List of TextualMemoryItem objects to be transformed. Returns: - bool: True if all messages passed validation, False if any failed. + List of MemoryMonitorItem objects with computed importance scores. """ - for message in messages: - if not self._validate_message(message, label): - return False - logger.error("Message batch contains invalid labels, aborting processing") - return True + query_keywords = self.monitor.query_monitors.get_keywords_collections() + logger.debug( + f"Processing {len(memories)} memories with {len(query_keywords)} query keywords" + ) - def _validate_message(self, message: ScheduleMessageItem, label: str): - """Validate if the message matches the expected label. + result = [] + mem_length = len(memories) + for idx, mem in enumerate(memories): + text_mem = mem.memory + mem_key = transform_name_to_key(name=text_mem) + + # Calculate importance score based on keyword matches + keywords_score = 0 + if query_keywords and text_mem: + for keyword, count in query_keywords.items(): + keyword_count = text_mem.count(keyword) + if keyword_count > 0: + keywords_score += keyword_count * count + logger.debug( + f"Matched keyword '{keyword}' {keyword_count} times, added {keywords_score} to keywords_score" + ) + + # rank score + sorting_score = mem_length - idx + + mem_monitor = MemoryMonitorItem( + memory_text=text_mem, + tree_memory_item=mem, + tree_memory_item_mapping_key=mem_key, + sorting_score=sorting_score, + keywords_score=keywords_score, + recording_count=1, + ) + result.append(mem_monitor) - Args: - message: Incoming message item to validate. - label: Expected message label (e.g., QUERY_LABEL/ANSWER_LABEL). + logger.debug(f"Transformed {len(result)} memories to monitors") + return result - Returns: - bool: True if validation passed, False otherwise. - """ - if message.label != label: - logger.error(f"Handler validation failed: expected={label}, actual={message.label}") - return False - return True + def replace_working_memory( + self, + queries: list[str], + user_id: str, + mem_cube_id: str, + mem_cube: GeneralMemCube, + original_memory: list[TextualMemoryItem], + new_memory: list[TextualMemoryItem], + ) -> None | list[TextualMemoryItem]: + """Replace working memory with new memories after reranking.""" + text_mem_base = mem_cube.text_mem + if isinstance(text_mem_base, TreeTextMemory): + text_mem_base: TreeTextMemory = text_mem_base + + # process rerank memories with llm + quey_history = self.monitor.query_monitors.get_queries_with_timesort() + memories_with_new_order, rerank_success_flag = ( + self.retriever.process_and_rerank_memories( + queries=quey_history, + original_memory=original_memory, + new_memory=new_memory, + top_k=self.top_k, + ) + ) + + # update working memory monitors + new_working_memory_monitors = self.transform_memories_to_monitors( + memories=memories_with_new_order + ) + + if not rerank_success_flag: + for one in new_working_memory_monitors: + one.sorting_score = 0 + + self.monitor.update_working_memory_monitors( + new_working_memory_monitors=new_working_memory_monitors, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + + mem_monitors: list[MemoryMonitorItem] = self.monitor.working_memory_monitors[user_id][ + mem_cube_id + ].get_sorted_mem_monitors(reverse=True) + new_working_memories = [mem_monitor.tree_memory_item for mem_monitor in mem_monitors] + + text_mem_base.replace_working_memory(memories=new_working_memories) + + logger.info( + f"The working memory has been replaced with {len(memories_with_new_order)} new memories." + ) + self.log_working_memory_replacement( + original_memory=original_memory, + new_memory=new_working_memories, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + log_func_callback=self._submit_web_logs, + ) + else: + logger.error("memory_base is not supported") + memories_with_new_order = new_memory + + return memories_with_new_order + + def initialize_working_memory_monitors( + self, + user_id: str, + mem_cube_id: str, + mem_cube: GeneralMemCube, + ): + text_mem_base: TreeTextMemory = mem_cube.text_mem + working_memories = text_mem_base.get_working_memory() + + working_memory_monitors = self.transform_memories_to_monitors( + memories=working_memories, + ) + self.monitor.update_working_memory_monitors( + new_working_memory_monitors=working_memory_monitors, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) def update_activation_memory( self, @@ -195,7 +294,7 @@ def update_activation_memory( logger.error("Not Implemented.") return - text_memory = MEMORY_ASSEMBLY_TEMPLATE.format( + new_text_memory = MEMORY_ASSEMBLY_TEMPLATE.format( memory_text="".join( [ f"{i + 1}. {sentence.strip()}\n" @@ -211,9 +310,18 @@ def update_activation_memory( if len(original_cache_items) > 0: pre_cache_item: VLLMKVCacheItem = original_cache_items[-1] original_text_memories = pre_cache_item.records.text_memories + original_composed_text_memory = pre_cache_item.records.composed_text_memory + if original_composed_text_memory == new_text_memory: + logger.warning( + "Skipping memory update - new composition matches existing cache: %s", + new_text_memory[:50] + "..." + if len(new_text_memory) > 50 + else new_text_memory, + ) + return act_mem.delete_all() - cache_item = act_mem.extract(text_memory) + cache_item = act_mem.extract(new_text_memory) cache_item.records.text_memories = new_text_memories act_mem.add([cache_item]) @@ -226,6 +334,7 @@ def update_activation_memory( user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, + log_func_callback=self._submit_web_logs, ) except Exception as e: @@ -242,12 +351,22 @@ def update_activation_memory_periodically( new_activation_memories = [] if self.monitor.timed_trigger( - last_time=self.monitor._last_activation_mem_update_time, + last_time=self.monitor.last_activation_mem_update_time, interval_seconds=interval_seconds, ): logger.info(f"Updating activation memory for user {user_id} and mem_cube {mem_cube_id}") - self.monitor.update_memory_monitors( + if len(self.monitor.working_memory_monitors[user_id][mem_cube_id].memories) == 0: + logger.warning( + "No memories found in working_memory_monitors, initializing from current working_memories" + ) + self.initialize_working_memory_monitors( + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + + self.monitor.update_activation_memory_monitors( user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube ) @@ -268,15 +387,15 @@ def update_activation_memory_periodically( mem_cube=mem_cube, ) - self.monitor._last_activation_mem_update_time = datetime.now() + self.monitor.last_activation_mem_update_time = datetime.now() logger.debug( - f"Activation memory update completed at {self.monitor._last_activation_mem_update_time}" + f"Activation memory update completed at {self.monitor.last_activation_mem_update_time}" ) else: logger.info( f"Skipping update - {interval_seconds} second interval not yet reached. " - f"Last update time is {self.monitor._last_activation_mem_update_time} and now is" + f"Last update time is {self.monitor.last_activation_mem_update_time} and now is" f"{datetime.now()}" ) @@ -289,7 +408,9 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt self.memos_message_queue.put(message) logger.info(f"Submitted message: {message.label} - {message.content}") - def _submit_web_logs(self, messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem]): + def _submit_web_logs( + self, messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem] + ) -> None: """Submit log messages to the web log queue and optionally to RabbitMQ. Args: @@ -300,176 +421,14 @@ def _submit_web_logs(self, messages: ScheduleLogForWebItem | list[ScheduleLogFor for message in messages: self._web_log_message_queue.put(message) - logger.info(f"Submitted Scheduling log for web: {message.log_content}") + message_info = message.debug_info() + logger.debug(f"Submitted Scheduling log for web: {message_info}") if self.is_rabbitmq_connected(): - logger.info("Submitted Scheduling log to rabbitmq") + logger.info(f"Submitted Scheduling log to rabbitmq: {message_info}") self.rabbitmq_publish_message(message=message.to_dict()) logger.debug(f"{len(messages)} submitted. {self._web_log_message_queue.qsize()} in queue.") - def log_activation_memory_update( - self, - original_text_memories: list[str], - new_text_memories: list[str], - label: str, - user_id: str, - mem_cube_id: str, - mem_cube: GeneralMemCube, - ): - """Log changes when activation memory is updated. - - Args: - original_text_memories: List of original memory texts - new_text_memories: List of new memory texts - """ - original_set = set(original_text_memories) - new_set = set(new_text_memories) - - # Identify changes - added_memories = list(new_set - original_set) # Present in new but not original - - # recording messages - for mem in added_memories: - log_message_a = self.create_autofilled_log_item( - log_content=mem, - label=label, - from_memory_type=TEXT_MEMORY_TYPE, - to_memory_type=ACTIVATION_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - log_message_b = self.create_autofilled_log_item( - log_content=mem, - label=label, - from_memory_type=ACTIVATION_MEMORY_TYPE, - to_memory_type=PARAMETER_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - self._submit_web_logs(messages=[log_message_a, log_message_b]) - logger.info( - f"{len(added_memories)} {LONG_TERM_MEMORY_TYPE} memorie(s) " - f"transformed to {WORKING_MEMORY_TYPE} memories." - ) - - def log_working_memory_replacement( - self, - original_memory: list[TextualMemoryItem], - new_memory: list[TextualMemoryItem], - user_id: str, - mem_cube_id: str, - mem_cube: GeneralMemCube, - ): - """Log changes when working memory is replaced.""" - memory_type_map = { - transform_name_to_key(name=m.memory): m.metadata.memory_type - for m in original_memory + new_memory - } - - original_text_memories = [m.memory for m in original_memory] - new_text_memories = [m.memory for m in new_memory] - - # Convert to sets for efficient difference operations - original_set = set(original_text_memories) - new_set = set(new_text_memories) - - # Identify changes - added_memories = list(new_set - original_set) # Present in new but not original - - # recording messages - for mem in added_memories: - normalized_mem = transform_name_to_key(name=mem) - if normalized_mem not in memory_type_map: - logger.error(f"Memory text not found in type mapping: {mem[:50]}...") - # Get the memory type from the map, default to LONG_TERM_MEMORY_TYPE if not found - mem_type = memory_type_map.get(normalized_mem, LONG_TERM_MEMORY_TYPE) - - if mem_type == WORKING_MEMORY_TYPE: - logger.warning(f"Memory already in working memory: {mem[:50]}...") - continue - - log_message = self.create_autofilled_log_item( - log_content=mem, - label=QUERY_LABEL, - from_memory_type=mem_type, - to_memory_type=WORKING_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - self._submit_web_logs(messages=log_message) - logger.info( - f"{len(added_memories)} {LONG_TERM_MEMORY_TYPE} memorie(s) " - f"transformed to {WORKING_MEMORY_TYPE} memories." - ) - - def log_adding_user_inputs( - self, - user_inputs: list[str], - user_id: str, - mem_cube_id: str, - mem_cube: GeneralMemCube, - ): - """Log changes when working memory is replaced.""" - - # recording messages - for input_str in user_inputs: - log_message = self.create_autofilled_log_item( - log_content=input_str, - label=ADD_LABEL, - from_memory_type=USER_INPUT_TYPE, - to_memory_type=TEXT_MEMORY_TYPE, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - self._submit_web_logs(messages=log_message) - logger.info( - f"{len(user_inputs)} {USER_INPUT_TYPE} memorie(s) " - f"transformed to {TEXT_MEMORY_TYPE} memories." - ) - - def create_autofilled_log_item( - self, - log_content: str, - label: str, - from_memory_type: str, - to_memory_type: str, - user_id: str, - mem_cube_id: str, - mem_cube: GeneralMemCube, - ) -> ScheduleLogForWebItem: - text_mem_base: TreeTextMemory = mem_cube.text_mem - current_memory_sizes = text_mem_base.get_current_memory_size() - current_memory_sizes = { - "long_term_memory_size": current_memory_sizes["LongTermMemory"], - "user_memory_size": current_memory_sizes["UserMemory"], - "working_memory_size": current_memory_sizes["WorkingMemory"], - "transformed_act_memory_size": NOT_INITIALIZED, - "parameter_memory_size": NOT_INITIALIZED, - } - memory_capacities = { - "long_term_memory_capacity": text_mem_base.memory_manager.memory_size["LongTermMemory"], - "user_memory_capacity": text_mem_base.memory_manager.memory_size["UserMemory"], - "working_memory_capacity": text_mem_base.memory_manager.memory_size["WorkingMemory"], - "transformed_act_memory_capacity": NOT_INITIALIZED, - "parameter_memory_capacity": NOT_INITIALIZED, - } - - log_message = ScheduleLogForWebItem( - user_id=user_id, - mem_cube_id=mem_cube_id, - label=label, - from_memory_type=from_memory_type, - to_memory_type=to_memory_type, - log_content=log_content, - current_memory_sizes=current_memory_sizes, - memory_capacities=memory_capacities, - ) - return log_message - def get_web_log_messages(self) -> list[dict]: """ Retrieves all web log messages from the queue and returns them as a list of JSON-serializable dictionaries. diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 42cebc67..f4de9e8f 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -4,12 +4,14 @@ from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.base_scheduler import BaseScheduler -from memos.mem_scheduler.modules.schemas import ( +from memos.mem_scheduler.schemas.general_schemas import ( ADD_LABEL, ANSWER_LABEL, + DEFAULT_MAX_QUERY_KEY_WORDS, QUERY_LABEL, - ScheduleMessageItem, ) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory @@ -36,12 +38,12 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: Args: messages: List of query messages to process """ - logger.debug(f"Messages {messages} assigned to {QUERY_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.") # Process the query in a session turn grouped_messages = self.dispatcher.group_messages_by_user_and_cube(messages=messages) - self._validate_messages(messages=messages, label=QUERY_LABEL) + self.validate_schedule_messages(messages=messages, label=QUERY_LABEL) for user_id in grouped_messages: for mem_cube_id in grouped_messages[user_id]: @@ -49,16 +51,51 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: if len(messages) == 0: return + mem_cube = messages[0].mem_cube + # for status update self._set_current_context_from_message(msg=messages[0]) - self.process_session_turn( - queries=[msg.content for msg in messages], + # update query monitors + for msg in messages: + query = msg.content + query_keywords = self.monitor.extract_query_keywords(query=query) + logger.info(f'Extract keywords "{query_keywords}" from query "{query}"') + + item = QueryMonitorItem( + query_text=query, + keywords=query_keywords, + max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS, + ) + self.monitor.query_monitors.put(item=item) + logger.debug( + f"Queries in monitor are {self.monitor.query_monitors.get_queries_with_timesort()}." + ) + + queries = [msg.content for msg in messages] + + # recall + cur_working_memory, new_candidates = self.process_session_turn( + queries=queries, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=messages[0].mem_cube, + mem_cube=mem_cube, top_k=self.top_k, ) + logger.info( + f"Processed {queries} and get {len(new_candidates)} new candidate memories." + ) + + # rerank + new_order_working_memory = self.replace_working_memory( + queries=queries, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + original_memory=cur_working_memory, + new_memory=new_candidates, + ) + logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}") def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: """ @@ -67,11 +104,11 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: Args: messages: List of answer messages to process """ - logger.debug(f"Messages {messages} assigned to {ANSWER_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {ANSWER_LABEL} handler.") # Process the query in a session turn grouped_messages = self.dispatcher.group_messages_by_user_and_cube(messages=messages) - self._validate_messages(messages=messages, label=ANSWER_LABEL) + self.validate_schedule_messages(messages=messages, label=ANSWER_LABEL) for user_id in grouped_messages: for mem_cube_id in grouped_messages[user_id]: @@ -84,6 +121,16 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: # update acivation memories if self.enable_act_memory_update: + if ( + len(self.monitor.working_memory_monitors[user_id][mem_cube_id].memories) + == 0 + ): + self.initialize_working_memory_monitors( + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=messages[0].mem_cube, + ) + self.update_activation_memory_periodically( interval_seconds=self.monitor.act_mem_update_interval, label=ANSWER_LABEL, @@ -93,11 +140,11 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: ) def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: - logger.debug(f"Messages {messages} assigned to {ADD_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {ADD_LABEL} handler.") # Process the query in a session turn grouped_messages = self.dispatcher.group_messages_by_user_and_cube(messages=messages) - self._validate_messages(messages=messages, label=ADD_LABEL) + self.validate_schedule_messages(messages=messages, label=ADD_LABEL) for user_id in grouped_messages: for mem_cube_id in grouped_messages[user_id]: @@ -110,15 +157,23 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: # submit logs for msg in messages: - user_inputs = json.loads(msg.content) - self.log_adding_user_inputs( - user_inputs=user_inputs, - user_id=msg.user_id, - mem_cube_id=msg.mem_cube_id, - mem_cube=msg.mem_cube, - ) - - # update acivation memories + userinput_memory_ids = json.loads(msg.content) + mem_cube = msg.mem_cube + for memory_id in userinput_memory_ids: + mem_item: TextualMemoryItem = mem_cube.text_mem.get(memory_id=memory_id) + mem_type = mem_item.meta_data.memory_type + mem_content = mem_item.memory + + self.log_adding_memory( + memory=mem_content, + memory_type=mem_type, + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + mem_cube=msg.mem_cube, + log_func_callback=self._submit_web_logs, + ) + + # update activation memories if self.enable_act_memory_update: self.update_activation_memory_periodically( interval_seconds=self.monitor.act_mem_update_interval, @@ -135,52 +190,62 @@ def process_session_turn( mem_cube_id: str, mem_cube: GeneralMemCube, top_k: int = 10, - query_history: list[str] | None = None, - ) -> None: + ) -> tuple[list[TextualMemoryItem], list[TextualMemoryItem]] | None: """ Process a dialog turn: - If q_list reaches window size, trigger retrieval; - Immediately switch to the new memory if retrieval is triggered. """ - if isinstance(queries, str): - queries = [queries] - - if query_history is None: - query_history = queries - else: - query_history.extend(queries) text_mem_base = mem_cube.text_mem if not isinstance(text_mem_base, TreeTextMemory): logger.error("Not implemented!", exc_info=True) return - working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory() - text_working_memory: list[str] = [w_m.memory for w_m in working_memory] + logger.info(f"Processing {len(queries)} queries.") + + cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory() + text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory] intent_result = self.monitor.detect_intent( - q_list=query_history, text_working_memory=text_working_memory + q_list=queries, text_working_memory=text_working_memory ) - if intent_result["trigger_retrieval"]: - missing_evidences = intent_result["missing_evidences"] - num_evidence = len(missing_evidences) - k_per_evidence = max(1, top_k // max(1, num_evidence)) - new_candidates = [] - for item in missing_evidences: - logger.debug(f"missing_evidences: {item}") - results = self.retriever.search( - query=item, mem_cube=mem_cube, top_k=k_per_evidence, method=self.search_method - ) - logger.debug(f"search results for {missing_evidences}: {results}") - new_candidates.extend(results) - - new_order_working_memory = self.retriever.replace_working_memory( - queries=queries, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - original_memory=working_memory, - new_memory=new_candidates, - top_k=top_k, + time_trigger_flag = False + if self.monitor.timed_trigger( + last_time=self.monitor.last_query_consume_time, + interval_seconds=self.monitor.query_trigger_interval, + ): + time_trigger_flag = True + + if (not intent_result["trigger_retrieval"]) and (not time_trigger_flag): + logger.info(f"Query schedule not triggered. Intent_result: {intent_result}") + return + elif (not intent_result["trigger_retrieval"]) and time_trigger_flag: + logger.info("Query schedule is forced to trigger due to time ticker") + intent_result["trigger_retrieval"] = True + intent_result["missing_evidences"] = queries + else: + logger.info( + f'Query schedule triggered for user "{user_id}" and mem_cube "{mem_cube_id}".' + f" Missing evidences: {intent_result['missing_evidences']}" + ) + + missing_evidences = intent_result["missing_evidences"] + num_evidence = len(missing_evidences) + k_per_evidence = max(1, top_k // max(1, num_evidence)) + new_candidates = [] + for item in missing_evidences: + logger.info(f"missing_evidences: {item}") + results: list[TextualMemoryItem] = self.retriever.search( + query=item, mem_cube=mem_cube, top_k=k_per_evidence, method=self.search_method + ) + logger.info(f"search results for {missing_evidences}: {results}") + new_candidates.extend(results) + + if len(new_candidates) == 0: + logger.warning( + f"As new_candidates is empty, new_candidates is set same to working_memory.\n" + f"time_trigger_flag: {time_trigger_flag}; intent_result: {intent_result}" ) - logger.debug(f"size of new_order_working_memory: {len(new_order_working_memory)}") + new_candidates = cur_working_memory + return cur_working_memory, new_candidates diff --git a/src/memos/mem_scheduler/modules/base.py b/src/memos/mem_scheduler/modules/base.py index 2bc39716..58c64587 100644 --- a/src/memos/mem_scheduler/modules/base.py +++ b/src/memos/mem_scheduler/modules/base.py @@ -3,7 +3,7 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.modules.schemas import BASE_DIR +from memos.mem_scheduler.schemas.general_schemas import BASE_DIR from memos.templates.mem_scheduler_prompts import PROMPT_MAPPING diff --git a/src/memos/mem_scheduler/modules/dispatcher.py b/src/memos/mem_scheduler/modules/dispatcher.py index 945a74ac..0a3ff196 100644 --- a/src/memos/mem_scheduler/modules/dispatcher.py +++ b/src/memos/mem_scheduler/modules/dispatcher.py @@ -4,7 +4,7 @@ from memos.log import get_logger from memos.mem_scheduler.modules.base import BaseSchedulerModule -from memos.mem_scheduler.modules.schemas import ScheduleMessageItem +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem logger = get_logger(__name__) @@ -22,7 +22,7 @@ class SchedulerDispatcher(BaseSchedulerModule): - Bulk handler registration """ - def __init__(self, max_workers=3, enable_parallel_dispatch=False): + def __init__(self, max_workers=30, enable_parallel_dispatch=False): super().__init__() # Main dispatcher thread pool self.max_workers = max_workers @@ -128,16 +128,13 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]): else: handler = self.handlers[label] # dispatch to different handler - logger.debug(f"Dispatch {len(msgs)} messages to {label} handler.") + logger.debug(f"Dispatch {len(msgs)} message(s) to {label} handler.") if self.enable_parallel_dispatch and self.dispatcher_executor is not None: # Capture variables in lambda to avoid loop variable issues - # TODO check this - future = self.dispatcher_executor.submit(handler, msgs) - logger.debug(f"Dispatched {len(msgs)} messages as future task") - return future + self.dispatcher_executor.submit(handler, msgs) + logger.info(f"Dispatched {len(msgs)} message(s) as future task") else: handler(msgs) - return None def join(self, timeout: float | None = None) -> bool: """Wait for all dispatched tasks to complete. @@ -159,7 +156,7 @@ def shutdown(self) -> None: if self.dispatcher_executor is not None: self.dispatcher_executor.shutdown(wait=True) self._running = False - logger.info("Dispatcher has been shutdown") + logger.info("Dispatcher has been shutdown.") def __enter__(self): self._running = True diff --git a/src/memos/mem_scheduler/modules/misc.py b/src/memos/mem_scheduler/modules/misc.py index b0555295..41ebdfd4 100644 --- a/src/memos/mem_scheduler/modules/misc.py +++ b/src/memos/mem_scheduler/modules/misc.py @@ -1,20 +1,85 @@ -import threading +import json +from contextlib import suppress +from datetime import datetime from queue import Empty, Full, Queue -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar +from pydantic import field_serializer + + +if TYPE_CHECKING: + from pydantic import BaseModel T = TypeVar("T") +BaseModelType = TypeVar("T", bound="BaseModel") + + +class DictConversionMixin: + """ + Provides conversion functionality between Pydantic models and dictionaries, + including datetime serialization handling. + """ + + @field_serializer("timestamp", check_fields=False) + def serialize_datetime(self, dt: datetime | None, _info) -> str | None: + """ + Custom datetime serialization logic. + - Supports timezone-aware datetime objects + - Compatible with models without timestamp field (via check_fields=False) + """ + if dt is None: + return None + return dt.isoformat() + + def to_dict(self) -> dict: + """ + Convert model instance to dictionary. + - Uses model_dump to ensure field consistency + - Prioritizes custom serializer for timestamp handling + """ + dump_data = self.model_dump() + if hasattr(self, "timestamp") and self.timestamp is not None: + dump_data["timestamp"] = self.serialize_datetime(self.timestamp, None) + return dump_data + + @classmethod + def from_dict(cls: type[BaseModelType], data: dict) -> BaseModelType: + """ + Create model instance from dictionary. + - Automatically converts timestamp strings to datetime objects + """ + data_copy = data.copy() # Avoid modifying original dictionary + if "timestamp" in data_copy and isinstance(data_copy["timestamp"], str): + try: + data_copy["timestamp"] = datetime.fromisoformat(data_copy["timestamp"]) + except ValueError: + # Handle invalid time formats - adjust as needed (e.g., log warning or set to None) + data_copy["timestamp"] = None + + return cls(**data_copy) + + def __str__(self) -> str: + """ + Convert to formatted JSON string. + - Used for user-friendly display in print() or str() calls + """ + return json.dumps( + self.to_dict(), + indent=4, + ensure_ascii=False, + default=lambda o: str(o), # Handle other non-serializable objects + ) + class AutoDroppingQueue(Queue[T]): """A thread-safe queue that automatically drops the oldest item when full.""" def __init__(self, maxsize: int = 0): super().__init__(maxsize=maxsize) - self._lock = threading.Lock() # Additional lock to prevent race conditions - def put(self, item: T, block: bool = True, timeout: float | None = None) -> None: + def put(self, item: T, block: bool = False, timeout: float | None = None) -> None: """Put an item into the queue. If the queue is full, the oldest item will be automatically removed to make space. @@ -25,15 +90,15 @@ def put(self, item: T, block: bool = True, timeout: float | None = None) -> None block: Ignored (kept for compatibility with Queue interface) timeout: Ignored (kept for compatibility with Queue interface) """ - with self._lock: # Ensure atomic operation - try: - # First try non-blocking put - super().put(item, block=False) - except Full: - # If queue is full, remove the oldest item - from contextlib import suppress - - with suppress(Empty): - self.get_nowait() # Remove oldest item - # Retry putting the new item - super().put(item, block=False) + try: + # First try non-blocking put + super().put(item, block=block, timeout=timeout) + except Full: + with suppress(Empty): + self.get_nowait() # Remove oldest item + # Retry putting the new item + super().put(item, block=block, timeout=timeout) + + def get_queue_content_without_pop(self) -> list[T]: + """Return a copy of the queue's contents without modifying it.""" + return list(self.queue) diff --git a/src/memos/mem_scheduler/modules/monitor.py b/src/memos/mem_scheduler/modules/monitor.py index 366f0c20..0e7d4c39 100644 --- a/src/memos/mem_scheduler/modules/monitor.py +++ b/src/memos/mem_scheduler/modules/monitor.py @@ -6,18 +6,23 @@ from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.modules.base import BaseSchedulerModule -from memos.mem_scheduler.modules.misc import AutoDroppingQueue as Queue -from memos.mem_scheduler.modules.schemas import ( +from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT, + DEFAULT_WEIGHT_VECTOR_FOR_RANKING, DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, MONITOR_ACTIVATION_MEMORY_TYPE, MONITOR_WORKING_MEMORY_TYPE, MemCubeID, - MemoryMonitorManager, UserID, ) -from memos.mem_scheduler.utils import extract_json_dict -from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.mem_scheduler.schemas.monitor_schemas import ( + MemoryMonitorItem, + MemoryMonitorManager, + QueryMonitorItem, + QueryMonitorQueue, +) +from memos.mem_scheduler.utils.misc_utils import extract_json_dict +from memos.memories.textual.tree import TreeTextMemory logger = get_logger(__name__) @@ -31,7 +36,8 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): # hyper-parameters self.config: BaseSchedulerConfig = config - self.act_mem_update_interval = self.config.get("act_mem_update_interval", 300) + self.act_mem_update_interval = self.config.get("act_mem_update_interval", 30) + self.query_trigger_interval = self.config.get("query_trigger_interval", 10) # Partial Retention Strategy self.partial_retention_number = 2 @@ -39,16 +45,39 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): self.activation_mem_monitor_capacity = DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT # attributes - self.query_history = Queue(maxsize=self.config.context_window_size) - self.intent_history = Queue(maxsize=self.config.context_window_size) + # recording query_messages + self.query_monitors: QueryMonitorQueue[QueryMonitorItem] = QueryMonitorQueue( + maxsize=self.config.context_window_size + ) + self.working_memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]] = {} self.activation_memory_monitors: dict[UserID, dict[MemCubeID, MemoryMonitorManager]] = {} # Lifecycle monitor - self._last_activation_mem_update_time = datetime.min + self.last_activation_mem_update_time = datetime.min + self.last_query_consume_time = datetime.min self._process_llm = process_llm + def extract_query_keywords(self, query: str) -> list: + """Extracts core keywords from a user query based on specific semantic rules.""" + prompt_name = "query_keywords_extraction" + prompt = self.build_prompt( + template_name=prompt_name, + query=query, + ) + llm_response = self._process_llm.generate([{"role": "user", "content": prompt}]) + try: + # Parse JSON output from LLM response + keywords = extract_json_dict(llm_response) + assert isinstance(keywords, list) + except Exception as e: + logger.error( + f"Failed to parse keywords from LLM response: {llm_response}. Error: {e!s}" + ) + keywords = [query] + return keywords + def register_memory_manager_if_not_exists( self, user_id: str, @@ -90,13 +119,15 @@ def register_memory_manager_if_not_exists( f"mem_cube_id={mem_cube_id} in the provided memory_monitors dictionary" ) - def update_memory_monitors(self, user_id: str, mem_cube_id: str, mem_cube: GeneralMemCube): + def update_working_memory_monitors( + self, + new_working_memory_monitors: list[MemoryMonitorItem], + user_id: str, + mem_cube_id: str, + mem_cube: GeneralMemCube, + ): text_mem_base: TreeTextMemory = mem_cube.text_mem - - if not isinstance(text_mem_base, TreeTextMemory): - logger.error("Not Implemented") - return - + assert isinstance(text_mem_base, TreeTextMemory) self.working_mem_monitor_capacity = min( DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, ( @@ -105,17 +136,6 @@ def update_memory_monitors(self, user_id: str, mem_cube_id: str, mem_cube: Gener ), ) - self.update_working_memory_monitors( - user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube - ) - - self.update_activation_memory_monitors( - user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube - ) - - def update_working_memory_monitors( - self, user_id: str, mem_cube_id: str, mem_cube: GeneralMemCube - ): # register monitors self.register_memory_manager_if_not_exists( user_id=user_id, @@ -124,14 +144,8 @@ def update_working_memory_monitors( max_capacity=self.working_mem_monitor_capacity, ) - # === update working memory monitors === - # Retrieve current working memory content - text_mem_base: TreeTextMemory = mem_cube.text_mem - working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory() - text_working_memory: list[str] = [w_m.memory for w_m in working_memory] - self.working_memory_monitors[user_id][mem_cube_id].update_memories( - text_working_memories=text_working_memory, + new_memory_monitors=new_working_memory_monitors, partial_retention_number=self.partial_retention_number, ) @@ -149,16 +163,13 @@ def update_activation_memory_monitors( # Sort by importance_score in descending order and take top k top_k_memories = sorted( self.working_memory_monitors[user_id][mem_cube_id].memories, - key=lambda m: m.get_score(), + key=lambda m: m.get_importance_score(weight_vector=DEFAULT_WEIGHT_VECTOR_FOR_RANKING), reverse=True, )[: self.activation_mem_monitor_capacity] - # Extract just the text from these memories - text_top_k_memories = [m.memory_text for m in top_k_memories] - # Update the activation memory monitors with these important memories self.activation_memory_monitors[user_id][mem_cube_id].update_memories( - text_working_memories=text_top_k_memories, + new_memory_monitors=top_k_memories, partial_retention_number=self.partial_retention_number, ) @@ -206,10 +217,10 @@ def get_monitor_memories( ) return [] - manager = monitor_dict[user_id][mem_cube_id] + manager: MemoryMonitorManager = monitor_dict[user_id][mem_cube_id] # Sort memories by recording_count in descending order and return top_k items - sorted_memories = sorted(manager.memories, key=lambda m: m.recording_count, reverse=True) - sorted_text_memories = [m.memory_text for m in sorted_memories[:top_k]] + sorted_memory_monitors = manager.get_sorted_mem_monitors(reverse=True) + sorted_text_memories = [m.memory_text for m in sorted_memory_monitors[:top_k]] return sorted_text_memories def get_monitors_info(self, user_id: str, mem_cube_id: str) -> dict[str, Any]: diff --git a/src/memos/mem_scheduler/modules/rabbitmq_service.py b/src/memos/mem_scheduler/modules/rabbitmq_service.py index c60bc6c2..3106e2ec 100644 --- a/src/memos/mem_scheduler/modules/rabbitmq_service.py +++ b/src/memos/mem_scheduler/modules/rabbitmq_service.py @@ -4,13 +4,13 @@ import time from pathlib import Path -from queue import Queue from memos.configs.mem_scheduler import AuthConfig, RabbitMQConfig from memos.dependency import require_python_package from memos.log import get_logger from memos.mem_scheduler.modules.base import BaseSchedulerModule -from memos.mem_scheduler.modules.schemas import DIRECT_EXCHANGE_TYPE, FANOUT_EXCHANGE_TYPE +from memos.mem_scheduler.modules.misc import AutoDroppingQueue +from memos.mem_scheduler.schemas.general_schemas import DIRECT_EXCHANGE_TYPE, FANOUT_EXCHANGE_TYPE logger = get_logger(__name__) @@ -38,7 +38,9 @@ def __init__(self): # fixed params self.rabbitmq_message_cache_max_size = 10 # Max 10 messages - self.rabbitmq_message_cache = Queue(maxsize=self.rabbitmq_message_cache_max_size) + self.rabbitmq_message_cache = AutoDroppingQueue( + maxsize=self.rabbitmq_message_cache_max_size + ) self.rabbitmq_connection_attempts = 3 # Max retry attempts on connection failure self.rabbitmq_retry_delay = 5 # Delay (seconds) between retries self.rabbitmq_heartbeat = 60 # Heartbeat interval (seconds) for connectio @@ -214,12 +216,12 @@ def on_rabbitmq_bind_ok(self, frame): def on_rabbitmq_message(self, channel, method, properties, body): """Handle incoming messages. Only for test.""" try: - print(f"Received message: {body.decode()}") - self.rabbitmq_message_cache.put_nowait({"properties": properties, "body": body}) - print(f"message delivery_tag: {method.delivery_tag}") + print(f"Received message: {body.decode()}\n") + self.rabbitmq_message_cache.put({"properties": properties, "body": body}) + print(f"message delivery_tag: {method.delivery_tag}\n") channel.basic_ack(delivery_tag=method.delivery_tag) except Exception as e: - logger.error(f"Message handling failed: {e}") + logger.error(f"Message handling failed: {e}", exc_info=True) def wait_for_connection_ready(self): start_time = time.time() diff --git a/src/memos/mem_scheduler/modules/retriever.py b/src/memos/mem_scheduler/modules/retriever.py index 219863d4..3430288d 100644 --- a/src/memos/mem_scheduler/modules/retriever.py +++ b/src/memos/mem_scheduler/modules/retriever.py @@ -1,20 +1,19 @@ -import logging - from memos.configs.mem_scheduler import BaseSchedulerConfig -from memos.dependency import require_python_package from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.modules.base import BaseSchedulerModule -from memos.mem_scheduler.modules.schemas import ( +from memos.mem_scheduler.schemas.general_schemas import ( TreeTextMemory_SEARCH_METHOD, ) -from memos.mem_scheduler.utils import ( - extract_json_dict, - is_all_chinese, - is_all_english, +from memos.mem_scheduler.utils.filter_utils import ( + filter_similar_memories, + filter_too_short_memories, transform_name_to_key, ) +from memos.mem_scheduler.utils.misc_utils import ( + extract_json_dict, +) from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory @@ -25,19 +24,16 @@ class SchedulerRetriever(BaseSchedulerModule): def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): super().__init__() - self.config: BaseSchedulerConfig = config - self.process_llm = process_llm - # hyper-parameters self.filter_similarity_threshold = 0.75 self.filter_min_length_threshold = 6 - # log function callbacks - self.log_working_memory_replacement = None + self.config: BaseSchedulerConfig = config + self.process_llm = process_llm def search( self, query: str, mem_cube: GeneralMemCube, top_k: int, method=TreeTextMemory_SEARCH_METHOD - ): + ) -> list[TextualMemoryItem]: """Search in text memory with the given query. Args: @@ -66,203 +62,123 @@ def search( results = [] return results - @require_python_package( - import_name="sklearn", - install_command="pip install scikit-learn", - install_link="https://scikit-learn.org/stable/install.html", - ) - def filter_similar_memories( - self, text_memories: list[str], similarity_threshold: float = 0.75 - ) -> list[str]: + def rerank_memories( + self, + queries: list[str], + original_memories: list[str], + top_k: int, + ) -> (list[str], bool): """ - Filters out low-quality or duplicate memories based on text similarity. + Rerank memories based on relevance to given queries using LLM. Args: - text_memories: List of text memories to filter - similarity_threshold: Threshold for considering memories duplicates (0.0-1.0) - Higher values mean stricter filtering + queries: List of query strings to determine relevance + original_memories: List of memory strings to be reranked + top_k: Number of top memories to return after reranking Returns: - List of filtered memories with duplicates removed - """ - from sklearn.feature_extraction.text import TfidfVectorizer - from sklearn.metrics.pairwise import cosine_similarity - - if not text_memories: - logging.warning("Received empty memories list - nothing to filter") - return [] - - for idx in range(len(text_memories)): - if not isinstance(text_memories[idx], str): - logger.error( - f"{text_memories[idx]} in memories is not a string," - f" and now has been transformed to be a string." - ) - text_memories[idx] = str(text_memories[idx]) - - try: - # Step 1: Vectorize texts using TF-IDF - vectorizer = TfidfVectorizer() - tfidf_matrix = vectorizer.fit_transform(text_memories) - - # Step 2: Calculate pairwise similarity matrix - similarity_matrix = cosine_similarity(tfidf_matrix) - - # Step 3: Identify duplicates - to_keep = [] - removal_reasons = {} - - for current_idx in range(len(text_memories)): - is_duplicate = False - - # Compare with already kept memories - for kept_idx in to_keep: - similarity_score = similarity_matrix[current_idx, kept_idx] - - if similarity_score > similarity_threshold: - is_duplicate = True - # Generate removal reason with sample text - removal_reasons[current_idx] = ( - f"Memory too similar (score: {similarity_score:.2f}) to kept memory #{kept_idx}. " - f"Kept: '{text_memories[kept_idx][:100]}...' | " - f"Removed: '{text_memories[current_idx][:100]}...'" - ) - logger.info(removal_reasons) - break - - if not is_duplicate: - to_keep.append(current_idx) - - # Return filtered memories - return [text_memories[i] for i in sorted(to_keep)] - - except Exception as e: - logging.error(f"Error filtering memories: {e!s}") - return text_memories # Return original list if error occurs - - def filter_too_short_memories( - self, text_memories: list[str], min_length_threshold: int = 20 - ) -> list[str]: - """ - Filters out text memories that fall below the minimum length requirement. - Handles both English (word count) and Chinese (character count) differently. + List of reranked memory strings (length <= top_k) - Args: - text_memories: List of text memories to be filtered - min_length_threshold: Minimum length required to keep a memory. - For English: word count, for Chinese: character count. - - Returns: - List of filtered memories meeting the length requirement + Note: + If LLM reranking fails, falls back to original order (truncated to top_k) """ - if not text_memories: - logging.debug("Empty memories list received in short memory filter") - return [] - - filtered_memories = [] - removed_count = 0 - - for memory in text_memories: - stripped_memory = memory.strip() - if not stripped_memory: # Skip empty/whitespace memories - removed_count += 1 - continue + success_flag = False + try: + logger.info(f"Starting memory reranking for {len(original_memories)} memories") - # Determine measurement method based on language - if is_all_english(stripped_memory): - length = len(stripped_memory.split()) # Word count for English - elif is_all_chinese(stripped_memory): - length = len(stripped_memory) # Character count for Chinese - else: - logger.debug( - f"Mixed-language memory, using character count: {stripped_memory[:50]}..." - ) - length = len(stripped_memory) # Default to character count + # Build LLM prompt for memory reranking + prompt = self.build_prompt( + "memory_reranking", + queries=[f"[0] {queries[0]}"], + current_order=[f"[{i}] {mem}" for i, mem in enumerate(original_memories)], + ) + logger.debug(f"Generated reranking prompt: {prompt[:200]}...") # Log first 200 chars - if length >= min_length_threshold: - filtered_memories.append(memory) - else: - removed_count += 1 + # Get LLM response + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + logger.debug(f"Received LLM response: {response[:200]}...") # Log first 200 chars - if removed_count > 0: + # Parse JSON response + response = extract_json_dict(response) + new_order = response["new_order"][:top_k] + text_memories_with_new_order = [original_memories[idx] for idx in new_order] logger.info( - f"Filtered out {removed_count} short memories " - f"(below {min_length_threshold} units). " - f"Total remaining: {len(filtered_memories)}" + f"Successfully reranked memories. Returning top {len(text_memories_with_new_order)} items;" + f"Ranking reasoning: {response['reasoning']}" ) + success_flag = True + except Exception as e: + logger.error( + f"Failed to rerank memories with LLM;\nException: {e}. ", + exc_info=True, + ) + text_memories_with_new_order = original_memories[:top_k] + success_flag = False + return text_memories_with_new_order, success_flag - return filtered_memories - - def replace_working_memory( + def process_and_rerank_memories( self, queries: list[str], - user_id: str, - mem_cube_id: str, - mem_cube: GeneralMemCube, original_memory: list[TextualMemoryItem], new_memory: list[TextualMemoryItem], top_k: int = 10, - ) -> None | list[TextualMemoryItem]: - """Replace working memory with new memories after reranking.""" - memories_with_new_order = None - text_mem_base = mem_cube.text_mem - if isinstance(text_mem_base, TreeTextMemory): - text_mem_base: TreeTextMemory = text_mem_base - combined_memory = original_memory + new_memory - memory_map = { - transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory - } - combined_text_memory = [transform_name_to_key(name=m.memory) for m in combined_memory] - - # apply filters - filtered_combined_text_memory = self.filter_similar_memories( - text_memories=combined_text_memory, - similarity_threshold=self.filter_similarity_threshold, - ) - - filtered_combined_text_memory = self.filter_too_short_memories( - text_memories=filtered_combined_text_memory, - min_length_threshold=self.filter_min_length_threshold, - ) + ) -> list[TextualMemoryItem] | None: + """ + Process and rerank memory items by combining original and new memories, + applying filters, and then reranking based on relevance to queries. - unique_memory = list(dict.fromkeys(filtered_combined_text_memory)) + Args: + queries: List of query strings to rerank memories against + original_memory: List of original TextualMemoryItem objects + new_memory: List of new TextualMemoryItem objects to merge + top_k: Maximum number of memories to return after reranking - try: - prompt = self.build_prompt( - "memory_reranking", - queries=queries, - current_order=unique_memory, - staging_buffer=[], + Returns: + List of reranked TextualMemoryItem objects, or None if processing fails + """ + # Combine original and new memories into a single list + combined_memory = original_memory + new_memory + + # Create a mapping from normalized text to memory objects + memory_map = { + transform_name_to_key(name=mem_obj.memory): mem_obj for mem_obj in combined_memory + } + + # Extract normalized text representations from all memory items + combined_text_memory = [m.memory for m in combined_memory] + + # Apply similarity filter to remove overly similar memories + filtered_combined_text_memory = filter_similar_memories( + text_memories=combined_text_memory, + similarity_threshold=self.filter_similarity_threshold, + ) + + # Apply length filter to remove memories that are too short + filtered_combined_text_memory = filter_too_short_memories( + text_memories=filtered_combined_text_memory, + min_length_threshold=self.filter_min_length_threshold, + ) + + # Ensure uniqueness of memory texts using dictionary keys (preserves order) + unique_memory = list(dict.fromkeys(filtered_combined_text_memory)) + + # Rerank the filtered memories based on relevance to the queries + text_memories_with_new_order, success_flag = self.rerank_memories( + queries=queries, + original_memories=unique_memory, + top_k=top_k, + ) + + # Map reranked text entries back to their original memory objects + memories_with_new_order = [] + for text in text_memories_with_new_order: + normalized_text = transform_name_to_key(name=text) + if normalized_text in memory_map: # Ensure correct key matching + memories_with_new_order.append(memory_map[normalized_text]) + else: + logger.warning( + f"Memory text not found in memory map. text: {text};\n" + f"Keys of memory_map: {memory_map.keys()}" ) - response = self.process_llm.generate([{"role": "user", "content": prompt}]) - response = extract_json_dict(response) - text_memories_with_new_order = response.get("new_order", [])[:top_k] - except Exception as e: - logger.error(f"Fail to rerank with LLM, Exeption: {e}.", exc_info=True) - text_memories_with_new_order = unique_memory[:top_k] - - memories_with_new_order = [] - for text in text_memories_with_new_order: - normalized_text = transform_name_to_key(name=text) - if text in memory_map: - memories_with_new_order.append(memory_map[normalized_text]) - else: - logger.warning( - f"Memory text not found in memory map. text: {text}; keys of memory_map: {memory_map.keys()}" - ) - - text_mem_base.replace_working_memory(memories_with_new_order) - logger.info( - f"The working memory has been replaced with {len(memories_with_new_order)} new memories." - ) - self.log_working_memory_replacement( - original_memory=original_memory, - new_memory=memories_with_new_order, - user_id=user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - ) - else: - logger.error("memory_base is not supported") - return memories_with_new_order + return memories_with_new_order, success_flag diff --git a/src/memos/mem_scheduler/modules/scheduler_logger.py b/src/memos/mem_scheduler/modules/scheduler_logger.py new file mode 100644 index 00000000..d0eecac9 --- /dev/null +++ b/src/memos/mem_scheduler/modules/scheduler_logger.py @@ -0,0 +1,242 @@ +from collections.abc import Callable + +from memos.log import get_logger +from memos.mem_cube.general import GeneralMemCube +from memos.mem_scheduler.modules.base import BaseSchedulerModule +from memos.mem_scheduler.schemas.general_schemas import ( + ACTIVATION_MEMORY_TYPE, + ADD_LABEL, + LONG_TERM_MEMORY_TYPE, + NOT_INITIALIZED, + PARAMETER_MEMORY_TYPE, + QUERY_LABEL, + TEXT_MEMORY_TYPE, + USER_INPUT_TYPE, + WORKING_MEMORY_TYPE, +) +from memos.mem_scheduler.schemas.message_schemas import ( + ScheduleLogForWebItem, + ScheduleMessageItem, +) +from memos.mem_scheduler.utils.filter_utils import ( + transform_name_to_key, +) +from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory + + +logger = get_logger(__name__) + + +class SchedulerLoggerModule(BaseSchedulerModule): + def __init__(self): + """ + Initialize RabbitMQ connection settings. + """ + super().__init__() + + def create_autofilled_log_item( + self, + log_content: str, + label: str, + from_memory_type: str, + to_memory_type: str, + user_id: str, + mem_cube_id: str, + mem_cube: GeneralMemCube, + ) -> ScheduleLogForWebItem: + text_mem_base: TreeTextMemory = mem_cube.text_mem + current_memory_sizes = text_mem_base.get_current_memory_size() + current_memory_sizes = { + "long_term_memory_size": current_memory_sizes["LongTermMemory"], + "user_memory_size": current_memory_sizes["UserMemory"], + "working_memory_size": current_memory_sizes["WorkingMemory"], + "transformed_act_memory_size": NOT_INITIALIZED, + "parameter_memory_size": NOT_INITIALIZED, + } + memory_capacities = { + "long_term_memory_capacity": text_mem_base.memory_manager.memory_size["LongTermMemory"], + "user_memory_capacity": text_mem_base.memory_manager.memory_size["UserMemory"], + "working_memory_capacity": text_mem_base.memory_manager.memory_size["WorkingMemory"], + "transformed_act_memory_capacity": NOT_INITIALIZED, + "parameter_memory_capacity": NOT_INITIALIZED, + } + + if hasattr(self, "monitor"): + if ( + user_id in self.monitor.activation_memory_monitors + and mem_cube_id in self.monitor.activation_memory_monitors[user_id] + ): + activation_monitor = self.monitor.activation_memory_monitors[user_id][mem_cube_id] + transformed_act_memory_size = len(activation_monitor.memories) + else: + transformed_act_memory_size = 0 + current_memory_sizes["transformed_act_memory_size"] = transformed_act_memory_size + current_memory_sizes["parameter_memory_size"] = 1 + + memory_capacities["transformed_act_memory_capacity"] = ( + self.monitor.activation_mem_monitor_capacity + ) + memory_capacities["parameter_memory_capacity"] = 1 + + log_message = ScheduleLogForWebItem( + user_id=user_id, + mem_cube_id=mem_cube_id, + label=label, + from_memory_type=from_memory_type, + to_memory_type=to_memory_type, + log_content=log_content, + current_memory_sizes=current_memory_sizes, + memory_capacities=memory_capacities, + ) + return log_message + + def log_working_memory_replacement( + self, + original_memory: list[TextualMemoryItem], + new_memory: list[TextualMemoryItem], + user_id: str, + mem_cube_id: str, + mem_cube: GeneralMemCube, + log_func_callback: Callable[[list[ScheduleLogForWebItem]], None], + ): + """Log changes when working memory is replaced.""" + memory_type_map = { + transform_name_to_key(name=m.memory): m.metadata.memory_type + for m in original_memory + new_memory + } + + original_text_memories = [m.memory for m in original_memory] + new_text_memories = [m.memory for m in new_memory] + + # Convert to sets for efficient difference operations + original_set = set(original_text_memories) + new_set = set(new_text_memories) + + # Identify changes + added_memories = list(new_set - original_set) # Present in new but not original + + # recording messages + for memory in added_memories: + normalized_mem = transform_name_to_key(name=memory) + if normalized_mem not in memory_type_map: + logger.error(f"Memory text not found in type mapping: {memory[:50]}...") + # Get the memory type from the map, default to LONG_TERM_MEMORY_TYPE if not found + mem_type = memory_type_map.get(normalized_mem, LONG_TERM_MEMORY_TYPE) + + if mem_type == WORKING_MEMORY_TYPE: + logger.warning(f"Memory already in working memory: {memory[:50]}...") + continue + + log_message = self.create_autofilled_log_item( + log_content=memory, + label=QUERY_LABEL, + from_memory_type=mem_type, + to_memory_type=WORKING_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + log_func_callback([log_message]) + logger.info( + f"{len(added_memories)} {LONG_TERM_MEMORY_TYPE} memorie(s) " + f"transformed to {WORKING_MEMORY_TYPE} memories." + ) + + def log_activation_memory_update( + self, + original_text_memories: list[str], + new_text_memories: list[str], + label: str, + user_id: str, + mem_cube_id: str, + mem_cube: GeneralMemCube, + log_func_callback: Callable[[list[ScheduleLogForWebItem]], None], + ): + """Log changes when activation memory is updated.""" + original_set = set(original_text_memories) + new_set = set(new_text_memories) + + # Identify changes + added_memories = list(new_set - original_set) # Present in new but not original + + # recording messages + for mem in added_memories: + log_message_a = self.create_autofilled_log_item( + log_content=mem, + label=label, + from_memory_type=TEXT_MEMORY_TYPE, + to_memory_type=ACTIVATION_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + log_message_b = self.create_autofilled_log_item( + log_content=mem, + label=label, + from_memory_type=ACTIVATION_MEMORY_TYPE, + to_memory_type=PARAMETER_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + log_func_callback([log_message_a, log_message_b]) + logger.info( + f"{len(added_memories)} {LONG_TERM_MEMORY_TYPE} memorie(s) " + f"transformed to {WORKING_MEMORY_TYPE} memories." + ) + + def log_adding_memory( + self, + memory: str, + memory_type: str, + user_id: str, + mem_cube_id: str, + mem_cube: GeneralMemCube, + log_func_callback: Callable[[list[ScheduleLogForWebItem]], None], + ): + """Log changes when working memory is replaced.""" + log_message = self.create_autofilled_log_item( + log_content=memory, + label=ADD_LABEL, + from_memory_type=USER_INPUT_TYPE, + to_memory_type=memory_type, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + ) + log_func_callback([log_message]) + logger.info( + f"{USER_INPUT_TYPE} memory for user {user_id} " + f"converted to {memory_type} memory in mem_cube {mem_cube_id}: {memory}" + ) + + def validate_schedule_message(self, message: ScheduleMessageItem, label: str): + """Validate if the message matches the expected label. + + Args: + message: Incoming message item to validate. + label: Expected message label (e.g., QUERY_LABEL/ANSWER_LABEL). + + Returns: + bool: True if validation passed, False otherwise. + """ + if message.label != label: + logger.error(f"Handler validation failed: expected={label}, actual={message.label}") + return False + return True + + def validate_schedule_messages(self, messages: list[ScheduleMessageItem], label: str): + """Validate if all messages match the expected label. + + Args: + messages: List of message items to validate. + label: Expected message label (e.g., QUERY_LABEL/ANSWER_LABEL). + + Returns: + bool: True if all messages passed validation, False if any failed. + """ + for message in messages: + if not self.validate_schedule_message(message, label): + logger.error("Message batch contains invalid labels, aborting processing") + return False + return True diff --git a/src/memos/mem_scheduler/modules/schemas.py b/src/memos/mem_scheduler/modules/schemas.py deleted file mode 100644 index 51e13d01..00000000 --- a/src/memos/mem_scheduler/modules/schemas.py +++ /dev/null @@ -1,328 +0,0 @@ -import json - -from datetime import datetime -from pathlib import Path -from typing import ClassVar, NewType, TypeVar -from uuid import uuid4 - -from pydantic import BaseModel, Field, computed_field -from typing_extensions import TypedDict - -from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube - - -logger = get_logger(__name__) - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent - -QUERY_LABEL = "query" -ANSWER_LABEL = "answer" -ADD_LABEL = "add" - -TreeTextMemory_SEARCH_METHOD = "tree_text_memory_search" -TextMemory_SEARCH_METHOD = "text_memory_search" -DIRECT_EXCHANGE_TYPE = "direct" -FANOUT_EXCHANGE_TYPE = "fanout" -DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT = 20 -DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT = 5 -DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache" -DEFAULT_THREAD__POOL_MAX_WORKERS = 5 -DEFAULT_CONSUME_INTERVAL_SECONDS = 3 -NOT_INITIALIZED = -1 -BaseModelType = TypeVar("T", bound="BaseModel") - -# web log -LONG_TERM_MEMORY_TYPE = "LongTermMemory" -USER_MEMORY_TYPE = "UserMemory" -WORKING_MEMORY_TYPE = "WorkingMemory" -TEXT_MEMORY_TYPE = "TextMemory" -ACTIVATION_MEMORY_TYPE = "ActivationMemory" -PARAMETER_MEMORY_TYPE = "ParameterMemory" -USER_INPUT_TYPE = "UserInput" -NOT_APPLICABLE_TYPE = "NotApplicable" - -# monitors -MONITOR_WORKING_MEMORY_TYPE = "MonitorWorkingMemoryType" -MONITOR_ACTIVATION_MEMORY_TYPE = "MonitorActivationMemoryType" - - -# new types -UserID = NewType("UserID", str) -MemCubeID = NewType("CubeID", str) - - -# ************************* Public ************************* -class DictConversionMixin: - def to_dict(self) -> dict: - """Convert the instance to a dictionary.""" - return { - **self.model_dump(), # 替换 self.dict() - "timestamp": self.timestamp.isoformat() if hasattr(self, "timestamp") else None, - } - - @classmethod - def from_dict(cls: type[BaseModelType], data: dict) -> BaseModelType: - """Create an instance from a dictionary.""" - if "timestamp" in data: - data["timestamp"] = datetime.fromisoformat(data["timestamp"]) - return cls(**data) - - def __str__(self) -> str: - """Convert the instance to a JSON string with indentation of 4 spaces. - This will be used when str() or print() is called on the instance. - - Returns: - str: A JSON string representation of the instance with 4-space indentation. - """ - return json.dumps( - self.to_dict(), - indent=4, - ensure_ascii=False, - default=str, # 处理无法序列化的对象 - ) - - class Config: - json_encoders: ClassVar[dict[type, object]] = {datetime: lambda v: v.isoformat()} - - -# ************************* Messages ************************* -class ScheduleMessageItem(BaseModel, DictConversionMixin): - item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4())) - user_id: str = Field(..., description="user id") - mem_cube_id: str = Field(..., description="memcube id") - label: str = Field(..., description="Label of the schedule message") - mem_cube: GeneralMemCube | str = Field(..., description="memcube for schedule") - content: str = Field(..., description="Content of the schedule message") - timestamp: datetime = Field( - default_factory=datetime.now, description="submit time for schedule_messages" - ) - - class Config: - arbitrary_types_allowed = True - json_encoders: ClassVar[dict[type, object]] = { - datetime: lambda v: v.isoformat(), - GeneralMemCube: lambda v: f"", - } - - def to_dict(self) -> dict: - """Convert model to dictionary suitable for Redis Stream""" - return { - "item_id": self.item_id, - "user_id": self.user_id, - "cube_id": self.mem_cube_id, - "label": self.label, - "cube": "Not Applicable", # Custom cube serialization - "content": self.content, - "timestamp": self.timestamp.isoformat(), - } - - @classmethod - def from_dict(cls, data: dict) -> "ScheduleMessageItem": - """Create model from Redis Stream dictionary""" - return cls( - item_id=data.get("item_id", str(uuid4())), - user_id=data["user_id"], - cube_id=data["cube_id"], - label=data["label"], - cube="Not Applicable", # Custom cube deserialization - content=data["content"], - timestamp=datetime.fromisoformat(data["timestamp"]), - ) - - -class MemorySizes(TypedDict): - long_term_memory_size: int - user_memory_size: int - working_memory_size: int - transformed_act_memory_size: int - - -class MemoryCapacities(TypedDict): - long_term_memory_capacity: int - user_memory_capacity: int - working_memory_capacity: int - transformed_act_memory_capacity: int - - -DEFAULT_MEMORY_SIZES = { - "long_term_memory_size": NOT_INITIALIZED, - "user_memory_size": NOT_INITIALIZED, - "working_memory_size": NOT_INITIALIZED, - "transformed_act_memory_size": NOT_INITIALIZED, - "parameter_memory_size": NOT_INITIALIZED, -} - -DEFAULT_MEMORY_CAPACITIES = { - "long_term_memory_capacity": 10000, - "user_memory_capacity": 10000, - "working_memory_capacity": 20, - "transformed_act_memory_capacity": NOT_INITIALIZED, - "parameter_memory_capacity": NOT_INITIALIZED, -} - - -class ScheduleLogForWebItem(BaseModel, DictConversionMixin): - item_id: str = Field( - description="Unique identifier for the log entry", default_factory=lambda: str(uuid4()) - ) - user_id: str = Field(..., description="Identifier for the user associated with the log") - mem_cube_id: str = Field( - ..., description="Identifier for the memcube associated with this log entry" - ) - label: str = Field(..., description="Label categorizing the type of log") - from_memory_type: str = Field(..., description="Source memory type") - to_memory_type: str = Field(..., description="Destination memory type") - log_content: str = Field(..., description="Detailed content of the log entry") - current_memory_sizes: MemorySizes = Field( - default_factory=lambda: dict(DEFAULT_MEMORY_SIZES), - description="Current utilization of memory partitions", - ) - memory_capacities: MemoryCapacities = Field( - default_factory=lambda: dict(DEFAULT_MEMORY_CAPACITIES), - description="Maximum capacities of memory partitions", - ) - timestamp: datetime = Field( - default_factory=datetime.now, - description="Timestamp indicating when the log entry was created", - ) - - -# ************************* Monitor ************************* -class MemoryMonitorItem(BaseModel, DictConversionMixin): - item_id: str = Field( - description="Unique identifier for the memory item", default_factory=lambda: str(uuid4()) - ) - memory_text: str = Field( - ..., - description="The actual content of the memory", - min_length=1, - max_length=10000, # Prevent excessively large memory texts - ) - importance_score: float = Field( - default=NOT_INITIALIZED, - description="Numerical score representing the memory's importance", - ge=NOT_INITIALIZED, # Minimum value of 0 - ) - recording_count: int = Field( - default=1, - description="How many times this memory has been recorded", - ge=1, # Greater than or equal to 1 - ) - - def get_score(self) -> float: - """ - Calculate the effective score for the memory item. - - Returns: - float: The importance_score if it has been initialized (>=0), - otherwise the recording_count converted to float. - - Note: - This method provides a unified way to retrieve a comparable score - for memory items, regardless of whether their importance has been explicitly set. - """ - if self.importance_score == NOT_INITIALIZED: - # Return recording_count as float when importance_score is not initialized - return float(self.recording_count) - else: - # Return the initialized importance_score - return self.importance_score - - -class MemoryMonitorManager(BaseModel, DictConversionMixin): - user_id: str = Field(..., description="Required user identifier", min_length=1) - mem_cube_id: str = Field(..., description="Required memory cube identifier", min_length=1) - memories: list[MemoryMonitorItem] = Field( - default_factory=list, description="Collection of memory items" - ) - max_capacity: int | None = Field( - default=None, description="Maximum number of memories allowed (None for unlimited)", ge=1 - ) - - @computed_field - @property - def memory_size(self) -> int: - """Automatically calculated count of memory items.""" - return len(self.memories) - - def update_memories( - self, text_working_memories: list[str], partial_retention_number: int - ) -> MemoryMonitorItem: - """ - Update memories based on text_working_memories. - - Args: - text_working_memories: List of memory texts to update - partial_retention_number: Number of top memories to keep by recording count - - Returns: - List of added or updated MemoryMonitorItem instances - """ - - # Validate partial_retention_number - if partial_retention_number < 0: - raise ValueError("partial_retention_number must be non-negative") - - # Create text lookup set - working_memory_set = set(text_working_memories) - - # Step 1: Update existing memories or add new ones - added_or_updated = [] - memory_text_map = {item.memory_text: item for item in self.memories} - - for text in text_working_memories: - if text in memory_text_map: - # Update existing memory - memory = memory_text_map[text] - memory.recording_count += 1 - added_or_updated.append(memory) - else: - # Add new memory - new_memory = MemoryMonitorItem(memory_text=text, recording_count=1) - self.memories.append(new_memory) - added_or_updated.append(new_memory) - - # Step 2: Identify memories to remove - # Sort memories by recording_count in descending order - sorted_memories = sorted(self.memories, key=lambda item: item.recording_count, reverse=True) - - # Keep the top N memories by recording_count - records_to_keep = { - memory.memory_text for memory in sorted_memories[:partial_retention_number] - } - - # Collect memories to remove: not in current working memory and not in top N - memories_to_remove = [ - memory - for memory in self.memories - if memory.memory_text not in working_memory_set - and memory.memory_text not in records_to_keep - ] - - # Step 3: Remove identified memories - for memory in memories_to_remove: - self.memories.remove(memory) - - # Step 4: Enforce max_capacity if set - if self.max_capacity is not None and len(self.memories) > self.max_capacity: - # Sort by importance and then recording count - sorted_memories = sorted( - self.memories, - key=lambda item: (item.importance_score, item.recording_count), - reverse=True, - ) - # Keep only the top max_capacity memories - self.memories = sorted_memories[: self.max_capacity] - - # Log the update result - logger.info( - f"Updated monitor manager for user {self.user_id}, mem_cube {self.mem_cube_id}: " - f"Total memories: {len(self.memories)}, " - f"Added/Updated: {len(added_or_updated)}, " - f"Removed: {len(memories_to_remove)} (excluding top {partial_retention_number} by recording_count)" - ) - - return added_or_updated diff --git a/src/memos/mem_scheduler/mos_for_test_scheduler.py b/src/memos/mem_scheduler/mos_for_test_scheduler.py index 600d6cf5..d796f824 100644 --- a/src/memos/mem_scheduler/mos_for_test_scheduler.py +++ b/src/memos/mem_scheduler/mos_for_test_scheduler.py @@ -3,12 +3,12 @@ from memos.configs.mem_os import MOSConfig from memos.log import get_logger from memos.mem_os.main import MOS -from memos.mem_scheduler.modules.schemas import ( +from memos.mem_scheduler.schemas.general_schemas import ( ANSWER_LABEL, MONITOR_WORKING_MEMORY_TYPE, QUERY_LABEL, - ScheduleMessageItem, ) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem logger = get_logger(__name__) @@ -54,24 +54,20 @@ def chat(self, query: str, user_id: str | None = None) -> str: if not mem_cube.text_mem: continue - # submit message to scheduler - if self.enable_mem_scheduler and self.mem_scheduler is not None: - message_item = ScheduleMessageItem( - user_id=target_user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - label=QUERY_LABEL, - content=query, - timestamp=datetime.now(), - ) - self.mem_scheduler.submit_messages(messages=[message_item]) - - self.mem_scheduler.monitor.register_memory_manager_if_not_exists( - user_id=user_id, + message_item = ScheduleMessageItem( + user_id=target_user_id, mem_cube_id=mem_cube_id, - memory_monitors=self.mem_scheduler.monitor.working_memory_monitors, - max_capacity=self.mem_scheduler.monitor.working_mem_monitor_capacity, + mem_cube=mem_cube, + label=QUERY_LABEL, + content=query, + timestamp=datetime.now(), ) + cur_working_memories = [m.memory for m in mem_cube.text_mem.get_working_memory()] + print(f"Working memories before schedule: {cur_working_memories}") + + # --- force to run mem_scheduler --- + self.mem_scheduler.monitor.query_trigger_interval = 0 + self.mem_scheduler._query_message_consumer(messages=[message_item]) # from scheduler scheduler_memories = self.mem_scheduler.monitor.get_monitor_memories( @@ -80,6 +76,7 @@ def chat(self, query: str, user_id: str | None = None) -> str: memory_type=MONITOR_WORKING_MEMORY_TYPE, top_k=topk_for_scheduler, ) + print(f"Working memories after schedule: {scheduler_memories}") memories_all.extend(scheduler_memories) # from mem_cube @@ -87,6 +84,7 @@ def chat(self, query: str, user_id: str | None = None) -> str: query, top_k=self.config.top_k - topk_for_scheduler ) text_memories = [m.memory for m in memories] + print(f"Search results with new working memories: {text_memories}") memories_all.extend(text_memories) memories_all = list(set(memories_all)) @@ -139,5 +137,4 @@ def chat(self, query: str, user_id: str | None = None) -> str: timestamp=datetime.now(), ) self.mem_scheduler.submit_messages(messages=[message_item]) - return response diff --git a/src/memos/mem_scheduler/schemas/__init__.py b/src/memos/mem_scheduler/schemas/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py new file mode 100644 index 00000000..a5dcc34b --- /dev/null +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -0,0 +1,43 @@ +from pathlib import Path +from typing import NewType + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent + +QUERY_LABEL = "query" +ANSWER_LABEL = "answer" +ADD_LABEL = "add" + +TreeTextMemory_SEARCH_METHOD = "tree_text_memory_search" +TextMemory_SEARCH_METHOD = "text_memory_search" +DIRECT_EXCHANGE_TYPE = "direct" +FANOUT_EXCHANGE_TYPE = "fanout" +DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT = 20 +DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT = 5 +DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache" +DEFAULT_THREAD__POOL_MAX_WORKERS = 5 +DEFAULT_CONSUME_INTERVAL_SECONDS = 3 +NOT_INITIALIZED = -1 + + +# web log +LONG_TERM_MEMORY_TYPE = "LongTermMemory" +USER_MEMORY_TYPE = "UserMemory" +WORKING_MEMORY_TYPE = "WorkingMemory" +TEXT_MEMORY_TYPE = "TextMemory" +ACTIVATION_MEMORY_TYPE = "ActivationMemory" +PARAMETER_MEMORY_TYPE = "ParameterMemory" +USER_INPUT_TYPE = "UserInput" +NOT_APPLICABLE_TYPE = "NotApplicable" + +# monitors +MONITOR_WORKING_MEMORY_TYPE = "MonitorWorkingMemoryType" +MONITOR_ACTIVATION_MEMORY_TYPE = "MonitorActivationMemoryType" +DEFAULT_MAX_QUERY_KEY_WORDS = 1000 +DEFAULT_WEIGHT_VECTOR_FOR_RANKING = [0.9, 0.05, 0.05] + + +# new types +UserID = NewType("UserID", str) +MemCubeID = NewType("CubeID", str) diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py new file mode 100644 index 00000000..2c34b078 --- /dev/null +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -0,0 +1,148 @@ +from datetime import datetime +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field, field_serializer +from typing_extensions import TypedDict + +from memos.log import get_logger +from memos.mem_cube.general import GeneralMemCube +from memos.mem_scheduler.modules.misc import DictConversionMixin + +from .general_schemas import NOT_INITIALIZED + + +logger = get_logger(__name__) + +DEFAULT_MEMORY_SIZES = { + "long_term_memory_size": NOT_INITIALIZED, + "user_memory_size": NOT_INITIALIZED, + "working_memory_size": NOT_INITIALIZED, + "transformed_act_memory_size": NOT_INITIALIZED, + "parameter_memory_size": NOT_INITIALIZED, +} + +DEFAULT_MEMORY_CAPACITIES = { + "long_term_memory_capacity": 10000, + "user_memory_capacity": 10000, + "working_memory_capacity": 20, + "transformed_act_memory_capacity": NOT_INITIALIZED, + "parameter_memory_capacity": NOT_INITIALIZED, +} + + +class ScheduleMessageItem(BaseModel, DictConversionMixin): + item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4())) + user_id: str = Field(..., description="user id") + mem_cube_id: str = Field(..., description="memcube id") + label: str = Field(..., description="Label of the schedule message") + mem_cube: GeneralMemCube | str = Field(..., description="memcube for schedule") + content: str = Field(..., description="Content of the schedule message") + timestamp: datetime = Field( + default_factory=lambda: datetime.utcnow(), description="submit time for schedule_messages" + ) + + # Pydantic V2 model configuration + model_config = ConfigDict( + # Allows arbitrary Python types as model fields without validation + # Required when using custom types like GeneralMemCube that aren't Pydantic models + arbitrary_types_allowed=True, + # Additional metadata for JSON Schema generation + json_schema_extra={ + # Example payload demonstrating the expected structure and sample values + # Used for API documentation, testing, and developer reference + "example": { + "item_id": "123e4567-e89b-12d3-a456-426614174000", # Sample UUID + "user_id": "user123", # Example user identifier + "mem_cube_id": "cube456", # Sample memory cube ID + "label": "sample_label", # Demonstration label value + "mem_cube": "obj of GeneralMemCube", # Added mem_cube example + "content": "sample content", # Example message content + "timestamp": "2024-07-22T12:00:00Z", # Added timestamp example + } + }, + ) + + @field_serializer("mem_cube") + def serialize_mem_cube(self, cube: GeneralMemCube | str, _info) -> str: + """Custom serializer for GeneralMemCube objects to string representation""" + if isinstance(cube, str): + return cube + return f"" + + def to_dict(self) -> dict: + """Convert model to dictionary suitable for Redis Stream""" + return { + "item_id": self.item_id, + "user_id": self.user_id, + "cube_id": self.mem_cube_id, + "label": self.label, + "cube": "Not Applicable", # Custom cube serialization + "content": self.content, + "timestamp": self.timestamp.isoformat(), + } + + @classmethod + def from_dict(cls, data: dict) -> "ScheduleMessageItem": + """Create model from Redis Stream dictionary""" + return cls( + item_id=data.get("item_id", str(uuid4())), + user_id=data["user_id"], + cube_id=data["cube_id"], + label=data["label"], + cube="Not Applicable", # Custom cube deserialization + content=data["content"], + timestamp=datetime.fromisoformat(data["timestamp"]), + ) + + +class MemorySizes(TypedDict): + long_term_memory_size: int + user_memory_size: int + working_memory_size: int + transformed_act_memory_size: int + + +class MemoryCapacities(TypedDict): + long_term_memory_capacity: int + user_memory_capacity: int + working_memory_capacity: int + transformed_act_memory_capacity: int + + +class ScheduleLogForWebItem(BaseModel, DictConversionMixin): + item_id: str = Field( + description="Unique identifier for the log entry", default_factory=lambda: str(uuid4()) + ) + user_id: str = Field(..., description="Identifier for the user associated with the log") + mem_cube_id: str = Field( + ..., description="Identifier for the memcube associated with this log entry" + ) + label: str = Field(..., description="Label categorizing the type of log") + from_memory_type: str = Field(..., description="Source memory type") + to_memory_type: str = Field(..., description="Destination memory type") + log_content: str = Field(..., description="Detailed content of the log entry") + current_memory_sizes: MemorySizes = Field( + default_factory=lambda: dict(DEFAULT_MEMORY_SIZES), + description="Current utilization of memory partitions", + ) + memory_capacities: MemoryCapacities = Field( + default_factory=lambda: dict(DEFAULT_MEMORY_CAPACITIES), + description="Maximum capacities of memory partitions", + ) + timestamp: datetime = Field( + default_factory=lambda: datetime.utcnow(), + description="Timestamp indicating when the log entry was created", + ) + + def debug_info(self) -> dict[str, Any]: + """Return structured debug information for logging purposes.""" + return { + "log_id": self.item_id, + "user_id": self.user_id, + "mem_cube_id": self.mem_cube_id, + "operation": f"{self.from_memory_type} → {self.to_memory_type}", + "label": self.label, + "content_length": len(self.log_content), + "timestamp": self.timestamp.isoformat(), + } diff --git a/src/memos/mem_scheduler/schemas/monitor_schemas.py b/src/memos/mem_scheduler/schemas/monitor_schemas.py new file mode 100644 index 00000000..6cba48fa --- /dev/null +++ b/src/memos/mem_scheduler/schemas/monitor_schemas.py @@ -0,0 +1,329 @@ +from collections import Counter +from datetime import datetime +from pathlib import Path +from typing import ClassVar +from uuid import uuid4 + +from pydantic import BaseModel, Field, computed_field, field_validator + +from memos.log import get_logger +from memos.mem_scheduler.modules.misc import AutoDroppingQueue, DictConversionMixin +from memos.mem_scheduler.schemas.general_schemas import ( + DEFAULT_MAX_QUERY_KEY_WORDS, + DEFAULT_WEIGHT_VECTOR_FOR_RANKING, + NOT_INITIALIZED, +) +from memos.mem_scheduler.utils.filter_utils import transform_name_to_key +from memos.memories.textual.tree import TextualMemoryItem + + +logger = get_logger(__name__) + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent + + +# ============== Queries ============== +class QueryMonitorItem(BaseModel, DictConversionMixin): + item_id: str = Field( + description="Unique identifier for the query item", default_factory=lambda: str(uuid4()) + ) + query_text: str = Field( + ..., + description="The actual user query text content", + min_length=1, + ) + keywords: list[str] | None = Field( + default=None, + min_length=1, # If provided, shouldn't be empty + description="Semantic keywords extracted from the query text", + ) + max_keywords: ClassVar[int] = DEFAULT_MAX_QUERY_KEY_WORDS + + timestamp: datetime = Field( + default_factory=datetime.now, description="Timestamp indicating when query was submitted" + ) + + @field_validator("keywords", mode="before") + @classmethod + def validate_keywords(cls, v, values): + if v is None: + return None + + if not isinstance(v, list): + raise ValueError("Keywords must be a list") + + if len(v) > cls.max_keywords: + logger.warning( + f"Keywords list truncated from {len(v)} to {cls.max_keywords} items. " + f"Configure max_keywords class attribute to adjust this limit." + ) + return v[: cls.max_keywords] + return v + + @classmethod + def with_max_keywords(cls, limit: int): + """Create a new class with custom keywords limit.""" + if not isinstance(limit, int) or limit <= 0: + raise ValueError("Max keywords limit must be positive integer") + + return type(f"{cls.__name__}_MaxKeywords{limit}", (cls,), {"max_keywords": limit}) + + +class QueryMonitorQueue(AutoDroppingQueue[QueryMonitorItem]): + """ + A thread-safe queue for monitoring queries with timestamp and keyword tracking. + Each item is expected to be a dictionary containing: + """ + + def put(self, item: QueryMonitorItem, block: bool = True, timeout: float | None = None) -> None: + """ + Add a query item to the queue. Ensures the item is of correct type. + + Args: + item: A QueryMonitorItem instance + """ + if not isinstance(item, QueryMonitorItem): + raise ValueError("Item must be an instance of QueryMonitorItem") + super().put(item, block, timeout) + + def get_queries_by_timestamp( + self, start_time: datetime, end_time: datetime + ) -> list[QueryMonitorItem]: + """ + Retrieve queries added between the specified time range. + """ + with self.mutex: + return [item for item in self.queue if start_time <= item.timestamp <= end_time] + + def get_keywords_collections(self) -> Counter: + """ + Generate a Counter containing keyword frequencies across all queries. + + Returns: + Counter object with keyword counts + """ + with self.mutex: + all_keywords = [kw for item in self.queue for kw in item.keywords] + return Counter(all_keywords) + + def get_queries_with_timesort(self, reverse: bool = True) -> list[dict]: + """ + Retrieve all queries sorted by timestamp. + + Args: + reverse: If True, sort in descending order (newest first), + otherwise sort in ascending order (oldest first) + + Returns: + List of query items sorted by timestamp + """ + with self.mutex: + return [ + monitor.query_text + for monitor in sorted(self.queue, key=lambda x: x.timestamp, reverse=reverse) + ] + + +# ============== Memories ============== +class MemoryMonitorItem(BaseModel, DictConversionMixin): + item_id: str = Field( + description="Unique identifier for the memory item", default_factory=lambda: str(uuid4()) + ) + memory_text: str = Field( + ..., + description="The actual content of the memory", + min_length=1, + ) + tree_memory_item: TextualMemoryItem | None = Field( + default=None, description="Optional textual memory item" + ) + tree_memory_item_mapping_key: str = Field( + description="Key generated from memory_text using transform_name_to_key", + ) + keywords_score: float = Field( + default=NOT_INITIALIZED, + description="The score generate by counting keywords in queries", + ge=NOT_INITIALIZED, # Minimum value of 0 + ) + sorting_score: float = Field( + default=NOT_INITIALIZED, + description="The score generate from rerank process", + ge=NOT_INITIALIZED, # Minimum value of 0 + ) + importance_score: float = Field( + default=NOT_INITIALIZED, + description="Numerical score representing the memory's importance", + ge=NOT_INITIALIZED, # Minimum value of 0 + ) + recording_count: int = Field( + default=1, + description="How many times this memory has been recorded", + ge=1, # Greater than or equal to 1 + ) + + @field_validator("tree_memory_item_mapping_key", mode="before") + def generate_mapping_key(cls, v, values): # noqa: N805 + if v is None and "memory_text" in values: + return transform_name_to_key(values["memory_text"]) + return v + + def get_importance_score(self, weight_vector: list[float] | None = None) -> float: + """ + Calculate the effective score for the memory item. + + Returns: + float: The importance_score if it has been initialized (>=0), + otherwise the recording_count converted to float. + + Note: + This method provides a unified way to retrieve a comparable score + for memory items, regardless of whether their importance has been explicitly set. + """ + if weight_vector is None: + logger.warning("weight_vector of get_importance_score is None.") + weight_vector = DEFAULT_WEIGHT_VECTOR_FOR_RANKING + assert sum(weight_vector) == 1 + normalized_keywords_score = min(self.keywords_score * weight_vector[1], 5) + normalized_recording_count_score = min(self.recording_count * weight_vector[2], 2) + self.importance_score = ( + self.sorting_score * weight_vector[0] + + normalized_keywords_score + + normalized_recording_count_score + ) + return self.importance_score + + +class MemoryMonitorManager(BaseModel, DictConversionMixin): + user_id: str = Field(..., description="Required user identifier", min_length=1) + mem_cube_id: str = Field(..., description="Required memory cube identifier", min_length=1) + memories: list[MemoryMonitorItem] = Field( + default_factory=list, description="Collection of memory items" + ) + max_capacity: int | None = Field( + default=None, description="Maximum number of memories allowed (None for unlimited)", ge=1 + ) + + @computed_field + @property + def memory_size(self) -> int: + """Automatically calculated count of memory items.""" + return len(self.memories) + + @property + def memories_mapping_dict(self) -> dict[str, MemoryMonitorItem]: + """ + Generate a mapping dictionary for the memories in MemoryMonitorManager, + using tree_memory_item_mapping_key as the key and MemoryMonitorItem as the value. + + Returns: + Dict[str, MemoryMonitorItem]: A dictionary where keys are + tree_memory_item_mapping_key values from MemoryMonitorItem, + and values are the corresponding MemoryMonitorItem objects. + """ + mapping_dict = { + mem_item.tree_memory_item_mapping_key: mem_item for mem_item in self.memories + } + + logger.debug( + f"Generated memories mapping dict for user_id={self.user_id}, " + f"mem_cube_id={self.mem_cube_id}, " + f"total_items={len(mapping_dict)}, " + f"source_memory_count={len(self.memories)}" + ) + return mapping_dict + + def get_sorted_mem_monitors(self, reverse=True) -> list[MemoryMonitorItem]: + """ + Retrieve memory monitors sorted by their ranking score in descending order. + + Returns: + list[MemoryMonitorItem]: Sorted list of memory monitor items. + """ + return sorted( + self.memories, + key=lambda item: item.get_importance_score( + weight_vector=DEFAULT_WEIGHT_VECTOR_FOR_RANKING + ), + reverse=reverse, + ) + + def update_memories( + self, new_memory_monitors: list[MemoryMonitorItem], partial_retention_number: int + ) -> MemoryMonitorItem: + """ + Update memories based on monitor_working_memories. + """ + + # Validate partial_retention_number + if partial_retention_number < 0: + raise ValueError("partial_retention_number must be non-negative") + + # Step 1: Update existing memories or add new ones + added_count = 0 + memories_mapping_dict = self.memories_mapping_dict + new_mem_set = set() + for memory_monitor in new_memory_monitors: + if memory_monitor.tree_memory_item_mapping_key in memories_mapping_dict: + # Update existing memory + item: MemoryMonitorItem = memories_mapping_dict[ + memory_monitor.tree_memory_item_mapping_key + ] + item.recording_count += 1 + item.keywords_score = memory_monitor.keywords_score + item.sorting_score = memory_monitor.sorting_score + else: + # Add new memory + self.memories.append(memory_monitor) + added_count += 1 + + new_mem_set.add(memory_monitor.tree_memory_item_mapping_key) + + # Step 2: Identify memories to remove + old_mem_monitor_list = [] + for mem_monitor in self.memories: + if mem_monitor.tree_memory_item_mapping_key not in new_mem_set: + old_mem_monitor_list.append(mem_monitor) + + # Sort memories by recording_count in descending order + sorted_old_mem_monitors = sorted( + old_mem_monitor_list, + key=lambda item: item.get_importance_score( + weight_vector=DEFAULT_WEIGHT_VECTOR_FOR_RANKING + ), + reverse=True, + ) + + # Keep the top N old memories + memories_to_remove = sorted_old_mem_monitors[partial_retention_number:] + memories_to_change_score = sorted_old_mem_monitors[:partial_retention_number] + + # Step 3: Remove identified memories and change the scores of left old memories + for memory in memories_to_remove: + self.memories.remove(memory) + + for memory in memories_to_change_score: + memory.sorting_score = 0 + memory.recording_count = 0 + memory.keywords_score = 0 + + # Step 4: Enforce max_capacity if set + sorted_memories = sorted( + self.memories, + key=lambda item: item.get_importance_score( + weight_vector=DEFAULT_WEIGHT_VECTOR_FOR_RANKING + ), + reverse=True, + ) + # Keep only the top max_capacity memories + self.memories = sorted_memories[: self.max_capacity] + + # Log the update result + logger.info( + f"Updated monitor manager for user {self.user_id}, mem_cube {self.mem_cube_id}: " + f"Total memories: {len(self.memories)}, " + f"Added/Updated: {added_count}, " + f"Removed: {len(memories_to_remove)} (excluding top {partial_retention_number} by recording_count)" + ) + + return self.memories diff --git a/src/memos/mem_scheduler/utils.py b/src/memos/mem_scheduler/utils.py deleted file mode 100644 index 3675b855..00000000 --- a/src/memos/mem_scheduler/utils.py +++ /dev/null @@ -1,75 +0,0 @@ -import json -import re - -from pathlib import Path - -import yaml - - -def extract_json_dict(text: str): - text = text.strip() - patterns_to_remove = ["json```", "```json", "latex```", "```latex", "```"] - for pattern in patterns_to_remove: - text = text.replace(pattern, "") - res = json.loads(text.strip()) - return res - - -def transform_name_to_key(name): - """ - Normalize text by removing all punctuation marks, keeping only letters, numbers, and word characters. - - Args: - name (str): Input text to be processed - - Returns: - str: Processed text with all punctuation removed - """ - # Match all characters that are NOT: - # \w - word characters (letters, digits, underscore) - # \u4e00-\u9fff - Chinese/Japanese/Korean characters - # \s - whitespace - pattern = r"[^\w\u4e00-\u9fff\s]" - - # Substitute all matched punctuation marks with empty string - # re.UNICODE flag ensures proper handling of Unicode characters - normalized = re.sub(pattern, "", name, flags=re.UNICODE) - - # Optional: Collapse multiple whitespaces into single space - normalized = "_".join(normalized.split()) - - normalized = normalized.lower() - - return normalized - - -def parse_yaml(yaml_file): - yaml_path = Path(yaml_file) - yaml_path = Path(yaml_file) - if not yaml_path.is_file(): - raise FileNotFoundError(f"No such file: {yaml_file}") - - with yaml_path.open("r", encoding="utf-8") as fr: - data = yaml.safe_load(fr) - - return data - - -def is_all_english(input_string: str) -> bool: - """Determine if the string consists entirely of English characters (including spaces)""" - return all(char.isascii() or char.isspace() for char in input_string) - - -def is_all_chinese(input_string: str) -> bool: - """Determine if the string consists entirely of Chinese characters (including Chinese punctuation and spaces)""" - return all( - ("\u4e00" <= char <= "\u9fff") # Basic Chinese characters - or ("\u3400" <= char <= "\u4dbf") # Extension A - or ("\u20000" <= char <= "\u2a6df") # Extension B - or ("\u2a700" <= char <= "\u2b73f") # Extension C - or ("\u2b740" <= char <= "\u2b81f") # Extension D - or ("\u2b820" <= char <= "\u2ceaf") # Extension E - or ("\u2f800" <= char <= "\u2fa1f") # Extension F - or char.isspace() # Spaces - for char in input_string - ) diff --git a/src/memos/mem_scheduler/utils/__init__.py b/src/memos/mem_scheduler/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/memos/mem_scheduler/utils/filter_utils.py b/src/memos/mem_scheduler/utils/filter_utils.py new file mode 100644 index 00000000..6055fe41 --- /dev/null +++ b/src/memos/mem_scheduler/utils/filter_utils.py @@ -0,0 +1,176 @@ +import re + +from memos.dependency import require_python_package +from memos.log import get_logger + + +logger = get_logger(__name__) + + +def transform_name_to_key(name): + """ + Normalize text by removing all punctuation marks, keeping only letters, numbers, and word characters. + + Args: + name (str): Input text to be processed + + Returns: + str: Processed text with all punctuation removed + """ + # Match all characters that are NOT: + # \w - word characters (letters, digits, underscore) + # \u4e00-\u9fff - Chinese/Japanese/Korean characters + # \s - whitespace + pattern = r"[^\w\u4e00-\u9fff\s]" + + # Substitute all matched punctuation marks with empty string + # re.UNICODE flag ensures proper handling of Unicode characters + normalized = re.sub(pattern, "", name, flags=re.UNICODE) + + # Optional: Collapse multiple whitespaces into single space + normalized = "_".join(normalized.split()) + + normalized = normalized.lower() + + return normalized + + +def is_all_english(input_string: str) -> bool: + """Determine if the string consists entirely of English characters (including spaces)""" + return all(char.isascii() or char.isspace() for char in input_string) + + +def is_all_chinese(input_string: str) -> bool: + """Determine if the string consists entirely of Chinese characters (including Chinese punctuation and spaces)""" + return all( + ("\u4e00" <= char <= "\u9fff") # Basic Chinese characters + or ("\u3400" <= char <= "\u4dbf") # Extension A + or ("\u20000" <= char <= "\u2a6df") # Extension B + or ("\u2a700" <= char <= "\u2b73f") # Extension C + or ("\u2b740" <= char <= "\u2b81f") # Extension D + or ("\u2b820" <= char <= "\u2ceaf") # Extension E + or ("\u2f800" <= char <= "\u2fa1f") # Extension F + or char.isspace() # Spaces + for char in input_string + ) + + +@require_python_package( + import_name="sklearn", + install_command="pip install scikit-learn", + install_link="https://scikit-learn.org/stable/install.html", +) +def filter_similar_memories( + text_memories: list[str], similarity_threshold: float = 0.75 +) -> list[str]: + """ + Filters out low-quality or duplicate memories based on text similarity. + + Args: + text_memories: List of text memories to filter + similarity_threshold: Threshold for considering memories duplicates (0.0-1.0) + Higher values mean stricter filtering + + Returns: + List of filtered memories with duplicates removed + """ + from sklearn.feature_extraction.text import TfidfVectorizer + from sklearn.metrics.pairwise import cosine_similarity + + if not text_memories: + logger.warning("Received empty memories list - nothing to filter") + return [] + + for idx in range(len(text_memories)): + if not isinstance(text_memories[idx], str): + logger.error( + f"{text_memories[idx]} in memories is not a string," + f" and now has been transformed to be a string." + ) + text_memories[idx] = str(text_memories[idx]) + + try: + # Step 1: Vectorize texts using TF-IDF + vectorizer = TfidfVectorizer() + tfidf_matrix = vectorizer.fit_transform(text_memories) + + # Step 2: Calculate pairwise similarity matrix + similarity_matrix = cosine_similarity(tfidf_matrix) + + # Step 3: Identify duplicates + to_keep = set(range(len(text_memories))) # Start with all indices + for i in range(len(similarity_matrix)): + if i not in to_keep: + continue # Already marked for removal + + # Find all similar items to this one (excluding self and already removed) + similar_indices = [ + j + for j in range(i + 1, len(similarity_matrix)) + if similarity_matrix[i][j] >= similarity_threshold and j in to_keep + ] + similar_indices = set(similar_indices) + + # Remove all similar items (keeping the first one - i) + to_keep -= similar_indices + + # Return filtered memories + filtered_memories = [text_memories[i] for i in sorted(to_keep)] + logger.debug(f"filtered_memories: {filtered_memories}") + return filtered_memories + + except Exception as e: + logger.error(f"Error filtering memories: {e!s}") + return text_memories # Return original list if error occurs + + +def filter_too_short_memories( + text_memories: list[str], min_length_threshold: int = 20 +) -> list[str]: + """ + Filters out text memories that fall below the minimum length requirement. + Handles both English (word count) and Chinese (character count) differently. + + Args: + text_memories: List of text memories to be filtered + min_length_threshold: Minimum length required to keep a memory. + For English: word count, for Chinese: character count. + + Returns: + List of filtered memories meeting the length requirement + """ + if not text_memories: + logger.debug("Empty memories list received in short memory filter") + return [] + + filtered_memories = [] + removed_count = 0 + + for memory in text_memories: + stripped_memory = memory.strip() + if not stripped_memory: # Skip empty/whitespace memories + removed_count += 1 + continue + + # Determine measurement method based on language + if is_all_english(stripped_memory): + length = len(stripped_memory.split()) # Word count for English + elif is_all_chinese(stripped_memory): + length = len(stripped_memory) # Character count for Chinese + else: + logger.debug(f"Mixed-language memory, using character count: {stripped_memory[:50]}...") + length = len(stripped_memory) # Default to character count + + if length >= min_length_threshold: + filtered_memories.append(memory) + else: + removed_count += 1 + + if removed_count > 0: + logger.info( + f"Filtered out {removed_count} short memories " + f"(below {min_length_threshold} units). " + f"Total remaining: {len(filtered_memories)}" + ) + + return filtered_memories diff --git a/src/memos/mem_scheduler/utils/misc_utils.py b/src/memos/mem_scheduler/utils/misc_utils.py new file mode 100644 index 00000000..92b56944 --- /dev/null +++ b/src/memos/mem_scheduler/utils/misc_utils.py @@ -0,0 +1,26 @@ +import json + +from pathlib import Path + +import yaml + + +def extract_json_dict(text: str): + text = text.strip() + patterns_to_remove = ["json```", "```python", "```json", "latex```", "```latex", "```"] + for pattern in patterns_to_remove: + text = text.replace(pattern, "") + res = json.loads(text.strip()) + return res + + +def parse_yaml(yaml_file): + yaml_path = Path(yaml_file) + yaml_path = Path(yaml_file) + if not yaml_path.is_file(): + raise FileNotFoundError(f"No such file: {yaml_file}") + + with yaml_path.open("r", encoding="utf-8") as fr: + data = yaml.safe_load(fr) + + return data diff --git a/src/memos/memories/activation/item.py b/src/memos/memories/activation/item.py index fd8c1ce6..8babfeef 100644 --- a/src/memos/memories/activation/item.py +++ b/src/memos/memories/activation/item.py @@ -18,6 +18,10 @@ class KVCacheRecords(BaseModel): default=[], description="The list of text memories transformed to the activation memory.", ) + composed_text_memory: str = Field( + default="", + description="Single string combining all text_memories using assembly template", + ) timestamp: datetime = Field( default_factory=datetime.now, description="submit time for schedule_messages" ) diff --git a/src/memos/memories/textual/base.py b/src/memos/memories/textual/base.py index f3597780..31e82c3c 100644 --- a/src/memos/memories/textual/base.py +++ b/src/memos/memories/textual/base.py @@ -24,7 +24,7 @@ def extract(self, messages: MessageList) -> list[TextualMemoryItem]: """ @abstractmethod - def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> None: + def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: """Add memories. Args: diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py index 83b984bf..c2688879 100644 --- a/src/memos/templates/mem_scheduler_prompts.py +++ b/src/memos/templates/mem_scheduler_prompts.py @@ -72,13 +72,18 @@ 3. Sorting evidence in descending order of relevance 4. Maintaining all original items (no additions or deletions) +## Temporal Priority Rules +- Query recency matters: Index 0 is the MOST RECENT query +- Evidence matching recent queries gets higher priority +- For equal relevance scores: Favor items matching newer queries + ## Input Format - Queries: Recent user questions/requests (list) -- Current Order: Existing memory sequence (list) +- Current Order: Existing memory sequence (list of strings with indices) ## Output Requirements Return a JSON object with: -- "new_order": The reordered list (maintaining all original items) +- "new_order": The reordered indices (array of integers) - "reasoning": Brief explanation of your ranking logic (1-2 sentences) ## Processing Guidelines @@ -89,26 +94,55 @@ - Shows temporal relevance (newer > older) 2. For ambiguous cases, maintain original relative ordering +## Scoring Priorities (Descending Order) +1. Direct matches to newer queries +2. Exact keyword matches in recent queries +3. Contextual support for recent topics +4. General relevance to older queries + ## Example -Input queries: ["python threading best practices"] -Input order: ["basic python syntax", "thread safety patterns", "data structures"] +Input queries: ["[0] python threading", "[1] data visualization"] +Input order: ["[0] syntax", "[1] matplotlib", "[2] threading"] Output: {{ - "new_order": ["thread safety patterns", "data structures", "basic python syntax"], - "reasoning": "Prioritized threading-related content while maintaining general python references" + "new_order": [2, 1, 0], + "reasoning": "Threading (2) prioritized for matching newest query, followed by matplotlib (1) for older visualization query", }} ## Current Task -Queries: {queries} +Queries: {queries} (recency-ordered) Current order: {current_order} Please provide your reorganization: """ +QUERY_KEYWORDS_EXTRACTION_PROMPT = """ +## Role +You are an intelligent keyword extraction system. Your task is to identify and extract the most important words or short phrases from user queries. + +## Instructions +- They have to be single words or short phrases that make sense. +- Only nouns (naming words) or verbs (action words) are allowed. +- Don't include stop words (like "the", "is") or adverbs (words that describe verbs, like "quickly"). +- Keep them as the smallest possible units that still have meaning. + +## Example +- Input Query: "What breed is Max?" +- Output Keywords (list of string): ["breed", "Max"] + +## Current Task +- Query: {query} +- Output Format: A Json list of keywords. + +Answer: +""" + + PROMPT_MAPPING = { "intent_recognizing": INTENT_RECOGNIZING_PROMPT, "memory_reranking": MEMORY_RERANKING_PROMPT, + "query_keywords_extraction": QUERY_KEYWORDS_EXTRACTION_PROMPT, } MEMORY_ASSEMBLY_TEMPLATE = """The retrieved memories are listed as follows:\n\n {memory_text}""" diff --git a/tests/mem_scheduler/test_retriever.py b/tests/mem_scheduler/test_retriever.py index 6d2e79ed..ff40f339 100644 --- a/tests/mem_scheduler/test_retriever.py +++ b/tests/mem_scheduler/test_retriever.py @@ -8,6 +8,10 @@ from memos.llms.base import BaseLLM from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.scheduler_factory import SchedulerFactory +from memos.mem_scheduler.utils.filter_utils import ( + filter_similar_memories, + filter_too_short_memories, +) from memos.memories.textual.tree import TreeTextMemory @@ -53,11 +57,8 @@ def tearDown(self): def test_filter_similar_memories_empty_input(self): """Test filter_similar_memories with empty input list.""" - result = self.retriever.filter_similar_memories([]) + result = filter_similar_memories([]) self.assertEqual(result, []) - self.mock_logging_warning.assert_called_with( - "Received empty memories list - nothing to filter" - ) def test_filter_similar_memories_no_duplicates(self): """Test filter_similar_memories with no duplicate memories.""" @@ -67,7 +68,7 @@ def test_filter_similar_memories_no_duplicates(self): "And this third one has nothing in common with the others", ] - result = self.retriever.filter_similar_memories(memories) + result = filter_similar_memories(memories) self.assertEqual(len(result), 3) self.assertEqual(set(result), set(memories)) @@ -78,22 +79,19 @@ def test_filter_similar_memories_with_duplicates(self): "The user is planning to move to Chicago next month, which reflects a significant change in their living situation.", "The user is planning to move to Chicago in the upcoming month, indicating a significant change in their living situation.", ] - result = self.retriever.filter_similar_memories(memories, similarity_threshold=0.75) + result = filter_similar_memories(memories, similarity_threshold=0.75) self.assertLess(len(result), len(memories)) - # Verify logging was called for removed items - self.assertGreater(self.mock_logger_info.call_count, 0) - def test_filter_similar_memories_error_handling(self): """Test filter_similar_memories error handling.""" # Test with non-string input (should return original list due to error) memories = ["valid text", 12345, "another valid text"] - result = self.retriever.filter_similar_memories(memories) + result = filter_similar_memories(memories) self.assertEqual(result, memories) def test_filter_too_short_memories_empty_input(self): """Test filter_too_short_memories with empty input list.""" - result = self.retriever.filter_too_short_memories([]) + result = filter_too_short_memories([]) self.assertEqual(result, []) def test_filter_too_short_memories_all_valid(self): @@ -104,7 +102,7 @@ def test_filter_too_short_memories_all_valid(self): "And this third memory meets the minimum length requirements too", ] - result = self.retriever.filter_too_short_memories(memories, min_length_threshold=5) + result = filter_too_short_memories(memories, min_length_threshold=5) self.assertEqual(len(result), 3) self.assertEqual(result, memories) @@ -119,7 +117,7 @@ def test_filter_too_short_memories_with_short_ones(self): ] # Test with word count threshold of 3 - result = self.retriever.filter_too_short_memories(memories, min_length_threshold=3) + result = filter_too_short_memories(memories, min_length_threshold=3) self.assertEqual(len(result), 3) self.assertNotIn("Too short", result) self.assertNotIn("Nope", result) @@ -130,7 +128,7 @@ def test_filter_too_short_memories_edge_case(self): # Test with threshold exactly matching some memories # The implementation uses word count, not character count - result = self.retriever.filter_too_short_memories(memories, min_length_threshold=3) + result = filter_too_short_memories(memories, min_length_threshold=3) self.assertEqual( len(result), 3 ) # "Exactly three words here", "Two words only", "Four words right here" diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py index cf750a75..30ccc934 100644 --- a/tests/mem_scheduler/test_scheduler.py +++ b/tests/mem_scheduler/test_scheduler.py @@ -3,22 +3,23 @@ from datetime import datetime from pathlib import Path -from unittest.mock import ANY, MagicMock, call, patch +from unittest.mock import MagicMock from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.llms.base import BaseLLM from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.modules.monitor import SchedulerMonitor from memos.mem_scheduler.modules.retriever import SchedulerRetriever -from memos.mem_scheduler.modules.schemas import ( +from memos.mem_scheduler.scheduler_factory import SchedulerFactory +from memos.mem_scheduler.schemas.general_schemas import ( ANSWER_LABEL, QUERY_LABEL, - ScheduleLogForWebItem, - ScheduleMessageItem, TreeTextMemory_SEARCH_METHOD, ) -from memos.mem_scheduler.scheduler_factory import SchedulerFactory -from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.mem_scheduler.schemas.message_schemas import ( + ScheduleLogForWebItem, +) +from memos.memories.textual.tree import TreeTextMemory FILE_PATH = Path(__file__).absolute() @@ -63,101 +64,6 @@ def test_initialize_modules(self): self.assertIsInstance(self.scheduler.monitor, SchedulerMonitor) self.assertIsInstance(self.scheduler.retriever, SchedulerRetriever) - def test_query_message_consumer(self): - # Create test message with all required fields - message = ScheduleMessageItem( - user_id="test_user", - mem_cube_id="test_cube", - mem_cube=self.mem_cube, # or could be str like "test_cube" - label=QUERY_LABEL, - content="Test query", - ) - - # Mock the detect_intent method to return a valid result - mock_intent_result = {"trigger_retrieval": False, "missing_evidences": []} - - # Mock the process_session_turn method - with ( - patch.object(self.scheduler, "process_session_turn") as mock_process_session_turn, - patch.object(self.scheduler.monitor, "detect_intent") as mock_detect_intent, - ): - mock_detect_intent.return_value = mock_intent_result - - # Test message handling - self.scheduler._query_message_consumer([message]) - - # Verify method call - updated to match new signature - mock_process_session_turn.assert_called_once_with( - queries=["Test query"], # or ["Test query"] depending on implementation - user_id="test_user", - mem_cube_id="test_cube", - mem_cube=self.mem_cube, - top_k=10, - ) - - def test_process_session_turn(self): - """Test session turn processing with retrieval trigger.""" - # Setup mock working memory - working_memory = [ - TextualMemoryItem(memory="Memory 1"), - TextualMemoryItem(memory="Memory 2"), - ] - self.tree_text_memory.get_working_memory.return_value = working_memory - - # Setup mock memory cube - mem_cube = MagicMock() - mem_cube.text_mem = self.tree_text_memory - - # Setup intent detection result - intent_result = { - "trigger_retrieval": True, - "missing_evidences": ["Evidence 1", "Evidence 2"], - } - - # Create test results that we'll return and expect - result1 = TextualMemoryItem(memory="Result 1") - result2 = TextualMemoryItem(memory="Result 2") - expected_new_memory = [result1, result2] - - # Mock methods - with ( - patch.object(self.scheduler.monitor, "detect_intent") as mock_detect, - patch.object(self.scheduler.retriever, "search") as mock_search, - patch.object(self.scheduler.retriever, "replace_working_memory") as mock_replace, - ): - mock_detect.return_value = intent_result - mock_search.side_effect = [ - [result1], - [result2], - ] - mock_replace.return_value = expected_new_memory - - # Test session turn processing - self.scheduler.process_session_turn( - queries=["Test query"], - user_id="test_user", - mem_cube_id="test_cube", - mem_cube=mem_cube, - top_k=10, - ) - - # Verify method calls - mock_detect.assert_called_once_with( - q_list=["Test query"], text_working_memory=["Memory 1", "Memory 2"] - ) - - # Verify search calls - using ANY for the method since we can't predict the exact value - mock_search.assert_has_calls( - [ - call(query="Evidence 1", mem_cube=mem_cube, top_k=5, method=ANY), - call(query="Evidence 2", mem_cube=mem_cube, top_k=5, method=ANY), - ], - any_order=True, - ) - - # Verify replace call - we'll check the structure but not the exact memory items - self.assertEqual(mock_replace.call_count, 1) - def test_submit_web_logs(self): """Test submission of web logs with updated data structure.""" # Create log message with all required fields