Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/mem_scheduler/memos_w_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ def run_with_scheduler_init():
mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url

mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
mem_cube_config.text_mem.config.graph_db.config.auto_create = (
auth_config.graph_db.auto_create
)

# Initialization
mos = MOS(mos_config)
Expand Down Expand Up @@ -118,8 +124,7 @@ def run_with_scheduler_init():
query = item["question"]
print(f"Query:\n {query}\n")
response = mos.chat(query=query, user_id=user_id)
print(f"Answer:\n {response}")
print("===== Chat End =====")
print(f"Answer:\n {response}\n")

show_web_logs(mem_scheduler=mos.mem_scheduler)

Expand Down
6 changes: 6 additions & 0 deletions examples/mem_scheduler/memos_w_scheduler_for_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ def init_task():
mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url

mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
mem_cube_config.text_mem.config.graph_db.config.auto_create = (
auth_config.graph_db.auto_create
)

# Initialization
mos = MOSForTestScheduler(mos_config)
Expand Down
6 changes: 6 additions & 0 deletions examples/mem_scheduler/try_schedule_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ def show_web_logs(mem_scheduler: GeneralScheduler):
mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url

mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
mem_cube_config.text_mem.config.graph_db.config.auto_create = (
auth_config.graph_db.auto_create
)

# Initialization
mos = MOSForTestScheduler(mos_config)
Expand Down
49 changes: 15 additions & 34 deletions src/memos/mem_scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __init__(self, config: BaseSchedulerConfig):
# other attributes
self._context_lock = threading.Lock()
self.current_user_id: UserID | str | None = None
self.current_mem_cube_id: MemCubeID | str | None = None
self.current_mem_cube: GeneralMemCube | None = None
self.auth_config_path: str | Path | None = self.config.get("auth_config_path", None)
self.auth_config = None
self.rabbitmq_config = None
Expand Down Expand Up @@ -130,8 +132,8 @@ def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None:
self.current_mem_cube_id = msg.mem_cube_id
self.current_mem_cube = msg.mem_cube

def transform_memories_to_monitors(
self, memories: list[TextualMemoryItem]
def transform_working_memories_to_monitors(
self, query_keywords, memories: list[TextualMemoryItem]
) -> list[MemoryMonitorItem]:
"""
Convert a list of TextualMemoryItem objects into MemoryMonitorItem objects
Expand All @@ -143,10 +145,6 @@ def transform_memories_to_monitors(
Returns:
List of MemoryMonitorItem objects with computed importance scores.
"""
query_keywords = self.monitor.query_monitors.get_keywords_collections()
logger.debug(
f"Processing {len(memories)} memories with {len(query_keywords)} query keywords"
)

result = []
mem_length = len(memories)
Expand Down Expand Up @@ -195,7 +193,8 @@ def replace_working_memory(
text_mem_base: TreeTextMemory = text_mem_base

# process rerank memories with llm
query_history = self.monitor.query_monitors.get_queries_with_timesort()
query_monitor = self.monitor.query_monitors[user_id][mem_cube_id]
query_history = query_monitor.get_queries_with_timesort()
memories_with_new_order, rerank_success_flag = (
self.retriever.process_and_rerank_memories(
queries=query_history,
Expand All @@ -206,8 +205,13 @@ def replace_working_memory(
)

# update working memory monitors
new_working_memory_monitors = self.transform_memories_to_monitors(
memories=memories_with_new_order
query_keywords = query_monitor.get_keywords_collections()
logger.debug(
f"Processing {len(memories_with_new_order)} memories with {len(query_keywords)} query keywords"
)
new_working_memory_monitors = self.transform_working_memories_to_monitors(
query_keywords=query_keywords,
memories=memories_with_new_order,
)

if not rerank_success_flag:
Expand Down Expand Up @@ -245,25 +249,6 @@ def replace_working_memory(

return memories_with_new_order

def initialize_working_memory_monitors(
self,
user_id: UserID | str,
mem_cube_id: MemCubeID | str,
mem_cube: GeneralMemCube,
):
text_mem_base: TreeTextMemory = mem_cube.text_mem
working_memories = text_mem_base.get_working_memory()

working_memory_monitors = self.transform_memories_to_monitors(
memories=working_memories,
)
self.monitor.update_working_memory_monitors(
new_working_memory_monitors=working_memory_monitors,
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
)

def update_activation_memory(
self,
new_memories: list[str | TextualMemoryItem],
Expand Down Expand Up @@ -367,13 +352,9 @@ def update_activation_memory_periodically(
or len(self.monitor.working_memory_monitors[user_id][mem_cube_id].memories) == 0
):
logger.warning(
"No memories found in working_memory_monitors, initializing from current working_memories"
)
self.initialize_working_memory_monitors(
user_id=user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
"No memories found in working_memory_monitors, activation memory update is skipped"
)
return

self.monitor.update_activation_memory_monitors(
user_id=user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube
Expand Down
115 changes: 77 additions & 38 deletions src/memos/mem_scheduler/general_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ANSWER_LABEL,
DEFAULT_MAX_QUERY_KEY_WORDS,
QUERY_LABEL,
WORKING_MEMORY_TYPE,
MemCubeID,
UserID,
)
Expand All @@ -35,11 +36,12 @@ def __init__(self, config: GeneralSchedulerConfig):

# for evaluation
def search_for_eval(
self,
query: str,
user_id: UserID | str,
top_k: int,
) -> list[str]:
self, query: str, user_id: UserID | str, top_k: int, scheduler_flag: bool = True
) -> (list[str], bool):
self.monitor.register_query_monitor_if_not_exists(
user_id=user_id, mem_cube_id=self.current_mem_cube_id
)

query_keywords = self.monitor.extract_query_keywords(query=query)
logger.info(f'Extract keywords "{query_keywords}" from query "{query}"')

Expand All @@ -48,35 +50,61 @@ def search_for_eval(
keywords=query_keywords,
max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS,
)
self.monitor.query_monitors.put(item=item)
logger.debug(
f"Queries in monitor are {self.monitor.query_monitors.get_queries_with_timesort()}."
)
query_monitor = self.monitor.query_monitors[user_id][self.current_mem_cube_id]
query_monitor.put(item=item)
logger.debug(f"Queries in monitor are {query_monitor.get_queries_with_timesort()}.")

queries = [query]

# recall
cur_working_memory, new_candidates = self.process_session_turn(
queries=queries,
user_id=user_id,
mem_cube_id=self.current_mem_cube_id,
mem_cube=self.current_mem_cube,
top_k=self.top_k,
)
logger.info(f"Processed {queries} and get {len(new_candidates)} new candidate memories.")

# rerank
new_order_working_memory = self.replace_working_memory(
user_id=user_id,
mem_cube_id=self.current_mem_cube_id,
mem_cube=self.current_mem_cube,
original_memory=cur_working_memory,
new_memory=new_candidates,
mem_cube = self.current_mem_cube
text_mem_base = mem_cube.text_mem

cur_working_memory: list[TextualMemoryItem] = text_mem_base.get_working_memory()
text_working_memory: list[str] = [w_m.memory for w_m in cur_working_memory]
intent_result = self.monitor.detect_intent(
q_list=queries, text_working_memory=text_working_memory
)
new_order_working_memory = new_order_working_memory[:top_k]
logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}")

return [m.memory for m in new_order_working_memory]
if not scheduler_flag:
return text_working_memory, intent_result["trigger_retrieval"]
else:
if intent_result["trigger_retrieval"]:
missing_evidences = intent_result["missing_evidences"]
num_evidence = len(missing_evidences)
k_per_evidence = max(1, top_k // max(1, num_evidence))
new_candidates = []
for item in missing_evidences:
logger.info(f"missing_evidences: {item}")
results: list[TextualMemoryItem] = self.retriever.search(
query=item,
mem_cube=mem_cube,
top_k=k_per_evidence,
method=self.search_method,
)
logger.info(
f"search results for {missing_evidences}: {[one.memory for one in results]}"
)
new_candidates.extend(results)
print(
f"missing_evidences: {missing_evidences} and get {len(new_candidates)} new candidate memories."
)
else:
new_candidates = []
print(f"intent_result: {intent_result}. not triggered")

# rerank
new_order_working_memory = self.replace_working_memory(
user_id=user_id,
mem_cube_id=self.current_mem_cube_id,
mem_cube=self.current_mem_cube,
original_memory=cur_working_memory,
new_memory=new_candidates,
)
new_order_working_memory = new_order_working_memory[:top_k]
logger.info(f"size of new_order_working_memory: {len(new_order_working_memory)}")

return [m.memory for m in new_order_working_memory], intent_result["trigger_retrieval"]

def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
"""
Expand Down Expand Up @@ -105,6 +133,10 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:

# update query monitors
for msg in messages:
self.monitor.register_query_monitor_if_not_exists(
user_id=user_id, mem_cube_id=mem_cube_id
)

query = msg.content
query_keywords = self.monitor.extract_query_keywords(query=query)
logger.info(f'Extract keywords "{query_keywords}" from query "{query}"')
Expand All @@ -114,9 +146,11 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
keywords=query_keywords,
max_keywords=DEFAULT_MAX_QUERY_KEY_WORDS,
)
self.monitor.query_monitors.put(item=item)

self.monitor.query_monitors[user_id][mem_cube_id].put(item=item)
logger.debug(
f"Queries in monitor are {self.monitor.query_monitors.get_queries_with_timesort()}."
f"Queries in monitor are "
f"{self.monitor.query_monitors[user_id][mem_cube_id].get_queries_with_timesort()}."
)

queries = [msg.content for msg in messages]
Expand Down Expand Up @@ -215,6 +249,9 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
mem_type = mem_item.metadata.memory_type
mem_content = mem_item.memory

if mem_type == WORKING_MEMORY_TYPE:
continue

self.log_adding_memory(
memory=mem_content,
memory_type=mem_type,
Expand Down Expand Up @@ -289,18 +326,20 @@ def process_session_turn(
new_candidates = []
for item in missing_evidences:
logger.info(f"missing_evidences: {item}")
info = {
"user_id": user_id,
"session_id": "",
}

results: list[TextualMemoryItem] = self.retriever.search(
query=item, mem_cube=mem_cube, top_k=k_per_evidence, method=self.search_method
query=item,
mem_cube=mem_cube,
top_k=k_per_evidence,
method=self.search_method,
info=info,
)
logger.info(
f"search results for {missing_evidences}: {[one.memory for one in results]}"
)
new_candidates.extend(results)

if len(new_candidates) == 0:
logger.warning(
f"As new_candidates is empty, new_candidates is set same to working_memory.\n"
f"time_trigger_flag: {time_trigger_flag}; intent_result: {intent_result}"
)
new_candidates = cur_working_memory
return cur_working_memory, new_candidates
3 changes: 1 addition & 2 deletions src/memos/mem_scheduler/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def __init__(self):

self._chat_llm = None
self._process_llm = None
self.current_mem_cube_id: str | None = None
self.current_mem_cube: GeneralMemCube | None = None

self.mem_cubes: dict[str, GeneralMemCube] = {}

def load_template(self, template_name: str) -> str:
Expand Down
Loading