Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ mem_scheduler:
thread_pool_max_workers: 10
consume_interval_seconds: 1
enable_parallel_dispatch: true
enable_act_memory_update: false
max_turns_window: 20
top_k: 5
enable_textual_memory: true
Expand Down
112 changes: 18 additions & 94 deletions examples/mem_scheduler/memos_w_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,17 @@
import shutil
import sys

from datetime import datetime
from pathlib import Path
from queue import Queue
from typing import TYPE_CHECKING

from memos.configs.mem_cube import GeneralMemCubeConfig
from memos.configs.mem_os import MOSConfig
from memos.configs.mem_scheduler import AuthConfig, SchedulerConfigFactory
from memos.configs.mem_scheduler import AuthConfig
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
from memos.mem_os.main import MOS
from memos.mem_scheduler.general_scheduler import GeneralScheduler
from memos.mem_scheduler.scheduler_factory import SchedulerFactory
from memos.mem_scheduler.schemas.general_schemas import (
ANSWER_LABEL,
QUERY_LABEL,
)
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.utils.misc_utils import parse_yaml


if TYPE_CHECKING:
Expand Down Expand Up @@ -78,122 +70,56 @@ def init_task():
return conversations, questions


def run_with_automatic_scheduler_init():
def run_with_scheduler_init():
print("==== run_with_automatic_scheduler_init ====")
conversations, questions = init_task()

config = parse_yaml(
f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml"
# set configs
mos_config = MOSConfig.from_yaml_file(
f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml"
)

mos_config = MOSConfig(**config)
mos = MOS(mos_config)

user_id = "user_1"
mos.create_user(user_id)

config = GeneralMemCubeConfig.from_yaml_file(
mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml"
)
mem_cube_id = "mem_cube_5"
mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"
if Path(mem_cube_name_or_path).exists():
shutil.rmtree(mem_cube_name_or_path)
print(f"{mem_cube_name_or_path} is not empty, and has been removed.")

# default local graphdb uri
if AuthConfig.default_config_exists():
auth_config = AuthConfig.from_local_yaml()
config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri

mem_cube = GeneralMemCube(config)
mem_cube.dump(mem_cube_name_or_path)
mos.register_mem_cube(
mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
)
mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key
mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url

for item in questions:
query = item["question"]
response = mos.chat(query, user_id=user_id)
print(f"Query:\n {query}\n\nAnswer:\n {response}")
mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri

show_web_logs(mem_scheduler=mos.mem_scheduler)

mos.mem_scheduler.stop()


def run_with_manual_scheduler_init():
print("==== run_with_manual_scheduler_init ====")
conversations, questions = init_task()

config = parse_yaml(
f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_wo_scheduler.yaml"
)

mos_config = MOSConfig(**config)
# Initialization
mos = MOS(mos_config)

user_id = "user_1"
mos.create_user(user_id)

config = GeneralMemCubeConfig.from_yaml_file(
f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config.yaml"
)
mem_cube_id = "mem_cube_5"
mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"

if Path(mem_cube_name_or_path).exists():
shutil.rmtree(mem_cube_name_or_path)
print(f"{mem_cube_name_or_path} is not empty, and has been removed.")

# default local graphdb uri
if AuthConfig.default_config_exists():
auth_config = AuthConfig.from_local_yaml()
config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri

mem_cube = GeneralMemCube(config)
mem_cube = GeneralMemCube(mem_cube_config)
mem_cube.dump(mem_cube_name_or_path)
mos.register_mem_cube(
mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
)

example_scheduler_config_path = (
f"{BASE_DIR}/examples/data/config/mem_scheduler/general_scheduler_config.yaml"
)
scheduler_config = SchedulerConfigFactory.from_yaml_file(
yaml_path=example_scheduler_config_path
)
mem_scheduler = SchedulerFactory.from_config(scheduler_config)
mem_scheduler.initialize_modules(chat_llm=mos.chat_llm)

mos.mem_scheduler = mem_scheduler

mos.mem_scheduler.start()

mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)

for item in questions:
print("===== Chat Start =====")
query = item["question"]
message_item = ScheduleMessageItem(
user_id=user_id,
mem_cube_id=mem_cube_id,
label=QUERY_LABEL,
mem_cube=mos.mem_cubes[mem_cube_id],
content=query,
timestamp=datetime.now(),
)
mos.mem_scheduler.submit_messages(messages=message_item)
response = mos.chat(query, user_id=user_id)
message_item = ScheduleMessageItem(
user_id=user_id,
mem_cube_id=mem_cube_id,
label=ANSWER_LABEL,
mem_cube=mos.mem_cubes[mem_cube_id],
content=response,
timestamp=datetime.now(),
)
mos.mem_scheduler.submit_messages(messages=message_item)
print(f"Query:\n {query}\n\nAnswer:\n {response}")
print(f"Query:\n {query}\n")
response = mos.chat(query=query, user_id=user_id)
print(f"Answer:\n {response}")
print("===== Chat End =====")

show_web_logs(mem_scheduler=mos.mem_scheduler)

Expand Down Expand Up @@ -236,6 +162,4 @@ def show_web_logs(mem_scheduler: GeneralScheduler):


if __name__ == "__main__":
run_with_automatic_scheduler_init()

run_with_manual_scheduler_init()
run_with_scheduler_init()
3 changes: 2 additions & 1 deletion src/memos/mem_os/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from memos.mem_scheduler.schemas.general_schemas import (
ADD_LABEL,
ANSWER_LABEL,
QUERY_LABEL,
)
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_user.user_manager import UserManager, UserRole
Expand Down Expand Up @@ -267,7 +268,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
user_id=target_user_id,
mem_cube_id=mem_cube_id,
mem_cube=mem_cube,
label=ADD_LABEL,
label=QUERY_LABEL,
content=query,
timestamp=datetime.now(),
)
Expand Down
7 changes: 3 additions & 4 deletions src/memos/mem_os/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,9 @@ def chat_with_references(
"""

self._load_user_cubes(user_id, self.default_cube_config)

self._send_message_to_scheduler(
user_id=user_id, mem_cube_id=cube_id, query=query, label=QUERY_LABEL
)
time_start = time.time()
memories_list = []
memories_result = super().search(
Expand Down Expand Up @@ -808,9 +810,6 @@ def chat_with_references(
yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
total_time = round(float(time_end - time_start), 1)
yield f"data: {json.dumps({'type': 'time', 'data': {'total_time': total_time, 'speed_improvement': '23%'}})}\n\n"
self._send_message_to_scheduler(
user_id=user_id, mem_cube_id=cube_id, query=query, label=QUERY_LABEL
)
self._send_message_to_scheduler(
user_id=user_id, mem_cube_id=cube_id, query=full_response, label=ANSWER_LABEL
)
Expand Down
42 changes: 22 additions & 20 deletions src/memos/mem_scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
DEFAULT_ACT_MEM_DUMP_PATH,
DEFAULT_CONSUME_INTERVAL_SECONDS,
DEFAULT_THREAD__POOL_MAX_WORKERS,
MemCubeID,
TreeTextMemory_SEARCH_METHOD,
UserID,
)
from memos.mem_scheduler.schemas.message_schemas import (
ScheduleLogForWebItem,
Expand Down Expand Up @@ -81,7 +83,7 @@ def __init__(self, config: BaseSchedulerConfig):

# other attributes
self._context_lock = threading.Lock()
self._current_user_id: str | None = None
self.current_user_id: UserID | str | None = None
self.auth_config_path: str | Path | None = self.config.get("auth_config_path", None)
self.auth_config = None
self.rabbitmq_config = None
Expand Down Expand Up @@ -113,20 +115,20 @@ def initialize_modules(self, chat_llm: BaseLLM, process_llm: BaseLLM | None = No
@property
def mem_cube(self) -> GeneralMemCube:
"""The memory cube associated with this MemChat."""
return self._current_mem_cube
return self.current_mem_cube

@mem_cube.setter
def mem_cube(self, value: GeneralMemCube) -> None:
"""The memory cube associated with this MemChat."""
self._current_mem_cube = value
self.current_mem_cube = value
self.retriever.mem_cube = value

def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None:
"""Update current user/cube context from the incoming message (thread-safe)."""
with self._context_lock:
self._current_user_id = msg.user_id
self._current_mem_cube_id = msg.mem_cube_id
self._current_mem_cube = msg.mem_cube
self.current_user_id = msg.user_id
self.current_mem_cube_id = msg.mem_cube_id
self.current_mem_cube = msg.mem_cube

def transform_memories_to_monitors(
self, memories: list[TextualMemoryItem]
Expand Down Expand Up @@ -181,9 +183,8 @@ def transform_memories_to_monitors(

def replace_working_memory(
self,
queries: list[str],
user_id: str,
mem_cube_id: str,
user_id: UserID | str,
mem_cube_id: MemCubeID | str,
mem_cube: GeneralMemCube,
original_memory: list[TextualMemoryItem],
new_memory: list[TextualMemoryItem],
Expand Down Expand Up @@ -246,8 +247,8 @@ def replace_working_memory(

def initialize_working_memory_monitors(
self,
user_id: str,
mem_cube_id: str,
user_id: UserID | str,
mem_cube_id: MemCubeID | str,
mem_cube: GeneralMemCube,
):
text_mem_base: TreeTextMemory = mem_cube.text_mem
Expand All @@ -267,8 +268,8 @@ def update_activation_memory(
self,
new_memories: list[str | TextualMemoryItem],
label: str,
user_id: str,
mem_cube_id: str,
user_id: UserID | str,
mem_cube_id: MemCubeID | str,
mem_cube: GeneralMemCube,
) -> None:
"""
Expand Down Expand Up @@ -344,16 +345,17 @@ def update_activation_memory_periodically(
self,
interval_seconds: int,
label: str,
user_id: str,
mem_cube_id: str,
user_id: UserID | str,
mem_cube_id: MemCubeID | str,
mem_cube: GeneralMemCube,
):
new_activation_memories = []

try:
if self.monitor.timed_trigger(
last_time=self.monitor.last_activation_mem_update_time,
interval_seconds=interval_seconds,
if (
self.monitor.last_activation_mem_update_time == datetime.min
or self.monitor.timed_trigger(
last_time=self.monitor.last_activation_mem_update_time,
interval_seconds=interval_seconds,
)
):
logger.info(
f"Updating activation memory for user {user_id} and mem_cube {mem_cube_id}"
Expand Down
Loading