Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/memos/mem_os/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from memos.memories.activation.item import ActivationMemoryItem
from memos.memories.parametric.item import ParametricMemoryItem
from memos.memories.textual.item import TextualMemoryItem, TextualMemoryMetadata
from memos.memos_tools.thread_safe_dict import ThreadSafeDict
from memos.templates.mos_prompts import QUERY_REWRITING_PROMPT
from memos.types import ChatHistory, MessageList, MOSSearchResult

Expand All @@ -42,10 +43,13 @@ def __init__(self, config: MOSConfig, user_manager: UserManager | None = None):
self.config = config
self.user_id = config.user_id
self.session_id = config.session_id
self.mem_cubes: dict[str, GeneralMemCube] = {}
self.chat_llm = LLMFactory.from_config(config.chat_model)
self.mem_reader = MemReaderFactory.from_config(config.mem_reader)
self.chat_history_manager: dict[str, ChatHistory] = {}
# use thread safe dict for multi-user product-server scenario
self.mem_cubes: ThreadSafeDict[str, GeneralMemCube] = (
ThreadSafeDict() if user_manager is not None else {}
)
self._register_chat_history()

# Use provided user_manager or create a new one
Expand Down Expand Up @@ -575,7 +579,13 @@ def search(
}
if install_cube_ids is None:
install_cube_ids = user_cube_ids
for mem_cube_id, mem_cube in self.mem_cubes.items():
# create exist dict in mem_cubes and avoid one search slow
tmp_mem_cubes = {}
for mem_cube_id in install_cube_ids:
if mem_cube_id in self.mem_cubes:
tmp_mem_cubes[mem_cube_id] = self.mem_cubes.get(mem_cube_id)

for mem_cube_id, mem_cube in tmp_mem_cubes.items():
if (
(mem_cube_id in install_cube_ids)
and (mem_cube.text_mem is not None)
Expand Down
120 changes: 120 additions & 0 deletions src/memos/memos_tools/lockfree_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
Lock-free dictionary implementation using copy-on-write strategy.
This provides better performance but uses more memory.
"""

import threading

from collections.abc import ItemsView, Iterator, KeysView, ValuesView
from typing import Generic, TypeVar


K = TypeVar("K")
V = TypeVar("V")


class CopyOnWriteDict(Generic[K, V]):
"""
A lock-free dictionary using copy-on-write strategy.

Reads are completely lock-free and very fast.
Writes create a new copy of the dictionary.
Uses more memory but provides excellent read performance.
"""

def __init__(self, initial_dict: dict[K, V] | None = None):
"""Initialize with optional initial dictionary."""
self._dict = initial_dict.copy() if initial_dict else {}
self._write_lock = threading.Lock() # Only for writes

def __getitem__(self, key: K) -> V:
"""Get item by key - completely lock-free."""
return self._dict[key]

def __setitem__(self, key: K, value: V) -> None:
"""Set item by key - uses copy-on-write."""
with self._write_lock:
# Create a new dictionary with the update
new_dict = self._dict.copy()
new_dict[key] = value
# Atomic replacement
self._dict = new_dict

def __delitem__(self, key: K) -> None:
"""Delete item by key - uses copy-on-write."""
with self._write_lock:
new_dict = self._dict.copy()
del new_dict[key]
self._dict = new_dict

def __contains__(self, key: K) -> bool:
"""Check if key exists - completely lock-free."""
return key in self._dict

def __len__(self) -> int:
"""Get length - completely lock-free."""
return len(self._dict)

def __bool__(self) -> bool:
"""Check if not empty - completely lock-free."""
return bool(self._dict)

def __iter__(self) -> Iterator[K]:
"""Iterate over keys - completely lock-free."""
return iter(self._dict.keys())

def get(self, key: K, default: V | None = None) -> V:
"""Get with default - completely lock-free."""
return self._dict.get(key, default)

def keys(self) -> KeysView[K]:
"""Get keys - completely lock-free."""
return self._dict.keys()

def values(self) -> ValuesView[V]:
"""Get values - completely lock-free."""
return self._dict.values()

def items(self) -> ItemsView[K, V]:
"""Get items - completely lock-free."""
return self._dict.items()

def copy(self) -> dict[K, V]:
"""Create a copy - completely lock-free."""
return self._dict.copy()

def update(self, *args, **kwargs) -> None:
"""Update dictionary - uses copy-on-write."""
with self._write_lock:
new_dict = self._dict.copy()
new_dict.update(*args, **kwargs)
self._dict = new_dict

def clear(self) -> None:
"""Clear all items."""
with self._write_lock:
self._dict = {}

def pop(self, key: K, *args) -> V:
"""Pop item by key."""
with self._write_lock:
new_dict = self._dict.copy()
result = new_dict.pop(key, *args)
self._dict = new_dict
return result

def setdefault(self, key: K, default: V | None = None) -> V:
"""Set default value for key if not exists."""
# Fast path for existing keys
if key in self._dict:
return self._dict[key]

with self._write_lock:
# Double-check after acquiring lock
if key in self._dict:
return self._dict[key]

new_dict = self._dict.copy()
result = new_dict.setdefault(key, default)
self._dict = new_dict
return result
Loading
Loading