Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 138 additions & 35 deletions src/memos/graph_dbs/nebular.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import traceback

from contextlib import suppress
from datetime import datetime
from typing import Any, Literal
from threading import Lock
from typing import TYPE_CHECKING, Any, ClassVar, Literal

import numpy as np

Expand All @@ -13,6 +15,10 @@
from memos.utils import timed


if TYPE_CHECKING:
from nebulagraph_python.client.pool import NebulaPool


logger = get_logger(__name__)


Expand Down Expand Up @@ -85,6 +91,95 @@ class NebulaGraphDB(BaseGraphDB):
NebulaGraph-based implementation of a graph memory store.
"""

# ====== shared pool cache & refcount ======
# These are process-local; in a multi-process model each process will
# have its own cache.
_POOL_CACHE: ClassVar[dict[str, "NebulaPool"]] = {}
_POOL_REFCOUNT: ClassVar[dict[str, int]] = {}
_POOL_LOCK: ClassVar[Lock] = Lock()

@staticmethod
def _make_pool_key(cfg: NebulaGraphDBConfig) -> str:
"""
Build a cache key that captures all connection-affecting options.
Keep this key stable and include fields that change the underlying pool behavior.
"""
# NOTE: Do not include tenant-like or query-scope-only fields here.
# Only include things that affect the actual TCP/auth/session pool.
return "|".join(
[
"nebula",
str(getattr(cfg, "uri", "")),
str(getattr(cfg, "user", "")),
str(getattr(cfg, "password", "")),
# pool sizing / tls / timeouts if you have them in config:
str(getattr(cfg, "max_client", 1000)),
# multi-db mode can impact how we use sessions; keep it to be safe
str(getattr(cfg, "use_multi_db", False)),
]
)

@classmethod
def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig):
"""
Get a shared NebulaPool from cache or create one if missing.
Thread-safe with a lock; maintains a simple refcount.
"""
from nebulagraph_python.client.pool import NebulaPool, NebulaPoolConfig

key = cls._make_pool_key(cfg)

with cls._POOL_LOCK:
pool = cls._POOL_CACHE.get(key)
if pool is None:
# Create a new pool and put into cache
pool = NebulaPool(
hosts=cfg.get("uri"),
username=cfg.get("user"),
password=cfg.get("password"),
pool_config=NebulaPoolConfig(max_client_size=cfg.get("max_client", 1000)),
)
cls._POOL_CACHE[key] = pool
cls._POOL_REFCOUNT[key] = 0
logger.info(f"[NebulaGraphDB] Created new shared NebulaPool for key={key}")

# Increase refcount for the caller
cls._POOL_REFCOUNT[key] = cls._POOL_REFCOUNT.get(key, 0) + 1
return key, pool

@classmethod
def _release_shared_pool(cls, key: str):
"""
Decrease refcount for the given pool key; only close when refcount hits zero.
"""
with cls._POOL_LOCK:
if key not in cls._POOL_CACHE:
return
cls._POOL_REFCOUNT[key] = max(0, cls._POOL_REFCOUNT.get(key, 0) - 1)
if cls._POOL_REFCOUNT[key] == 0:
try:
cls._POOL_CACHE[key].close()
except Exception as e:
logger.warning(f"[NebulaGraphDB] Error closing shared pool: {e}")
finally:
cls._POOL_CACHE.pop(key, None)
cls._POOL_REFCOUNT.pop(key, None)
logger.info(f"[NebulaGraphDB] Closed and removed shared pool key={key}")

@classmethod
def close_all_shared_pools(cls):
"""Force close all cached pools. Call this on graceful shutdown."""
with cls._POOL_LOCK:
for key, pool in list(cls._POOL_CACHE.items()):
try:
pool.close()
except Exception as e:
logger.warning(f"[NebulaGraphDB] Error closing pool key={key}: {e}")
finally:
logger.info(f"[NebulaGraphDB] Closed pool key={key}")
cls._POOL_CACHE.clear()
cls._POOL_REFCOUNT.clear()

@require_python_package(
import_name="nebulagraph_python",
install_command="pip install ... @Tianxing",
Expand All @@ -108,7 +203,6 @@ def __init__(self, config: NebulaGraphDBConfig):
"space": "test"
}
"""
from nebulagraph_python.client.pool import NebulaPool, NebulaPoolConfig

self.config = config
self.db_name = config.space
Expand All @@ -135,19 +229,21 @@ def __init__(self, config: NebulaGraphDBConfig):
"usage",
"background",
}
self.base_fields = set(self.common_fields) - {"usage"}
self.heavy_fields = {"usage"}
self.dim_field = (
f"embedding_{self.embedding_dimension}"
if (str(self.embedding_dimension) != str(self.default_memory_dimension))
else "embedding"
)
self.system_db_name = "system" if config.use_multi_db else config.space
self.pool = NebulaPool(
hosts=config.get("uri"),
username=config.get("user"),
password=config.get("password"),
pool_config=NebulaPoolConfig(max_client_size=config.get("max_client", 1000)),
)

# ---- NEW: pool acquisition strategy
# Get or create a shared pool from the class-level cache
self._pool_key, self.pool = self._get_or_create_shared_pool(config)
self._owns_pool = True # We manage refcount for this instance

# auto-create graph type / graph / index if needed
if config.auto_create:
self._ensure_database_exists()

Expand All @@ -159,7 +255,7 @@ def __init__(self, config: NebulaGraphDBConfig):
logger.info("Connected to NebulaGraph successfully.")

@timed
def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True):
def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
needs_use_prefix = ("SESSION SET GRAPH" not in gql) and ("USE " not in gql)
use_prefix = f"USE `{self.db_name}` " if auto_set_db and needs_use_prefix else ""

Expand All @@ -174,7 +270,25 @@ def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True

@timed
def close(self):
self.pool.close()
"""
Close the connection resource if this instance owns it.

- If pool was injected (`shared_pool`), do nothing.
- If pool was acquired via shared cache, decrement refcount and close
when the last owner releases it.
"""
if not self._owns_pool:
logger.debug("[NebulaGraphDB] close() skipped (injected pool).")
return
if self._pool_key:
self._release_shared_pool(self._pool_key)
self._pool_key = None
self.pool = None

# NOTE: __del__ is best-effort; do not rely on GC order.
def __del__(self):
with suppress(Exception):
self.close()

@timed
def create_index(
Expand Down Expand Up @@ -253,12 +367,10 @@ def node_not_exist(self, scope: str) -> int:
filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"'
else:
filter_clause = f'n.memory_type = "{scope}"'
return_fields = ", ".join(f"n.{field} AS {field}" for field in self.common_fields)

query = f"""
MATCH (n@Memory)
WHERE {filter_clause}
RETURN {return_fields}
RETURN n.id AS id
LIMIT 1
"""

Expand Down Expand Up @@ -455,10 +567,7 @@ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] |
try:
result = self.execute_query(gql)
for row in result:
if include_embedding:
props = row.values()[0].as_node().get_properties()
else:
props = {k: v.value for k, v in row.items()}
props = {k: v.value for k, v in row.items()}
node = self._parse_node(props)
return node

Expand Down Expand Up @@ -507,10 +616,7 @@ def get_nodes(
try:
results = self.execute_query(query)
for row in results:
if include_embedding:
props = row.values()[0].as_node().get_properties()
else:
props = {k: v.value for k, v in row.items()}
props = {k: v.value for k, v in row.items()}
nodes.append(self._parse_node(props))
except Exception as e:
logger.error(
Expand Down Expand Up @@ -579,6 +685,7 @@ def get_neighbors_by_tag(
exclude_ids: list[str],
top_k: int = 5,
min_overlap: int = 1,
include_embedding: bool = False,
) -> list[dict[str, Any]]:
"""
Find top-K neighbor nodes with maximum tag overlap.
Expand All @@ -588,6 +695,7 @@ def get_neighbors_by_tag(
exclude_ids: Node IDs to exclude (e.g., local cluster).
top_k: Max number of neighbors to return.
min_overlap: Minimum number of overlapping tags required.
include_embedding: with/without embedding

Returns:
List of dicts with node details and overlap count.
Expand All @@ -609,12 +717,13 @@ def get_neighbors_by_tag(
where_clause = " AND ".join(where_clauses)
tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"

return_fields = self._build_return_fields(include_embedding)
query = f"""
LET tag_list = {tag_list_literal}

MATCH (n@Memory)
WHERE {where_clause}
RETURN n,
RETURN {return_fields},
size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
ORDER BY overlap_count DESC
LIMIT {top_k}
Expand All @@ -623,9 +732,8 @@ def get_neighbors_by_tag(
result = self.execute_query(query)
neighbors: list[dict[str, Any]] = []
for r in result:
node_props = r["n"].as_node().get_properties()
parsed = self._parse_node(node_props) # --> {id, memory, metadata}

props = {k: v.value for k, v in r.items() if k != "overlap_count"}
parsed = self._parse_node(props)
parsed["overlap_count"] = r["overlap_count"].value
neighbors.append(parsed)

Expand Down Expand Up @@ -1112,10 +1220,7 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (
try:
results = self.execute_query(query)
for row in results:
if include_embedding:
props = row.values()[0].as_node().get_properties()
else:
props = {k: v.value for k, v in row.items()}
props = {k: v.value for k, v in row.items()}
nodes.append(self._parse_node(props))
except Exception as e:
logger.error(f"Failed to get memories: {e}")
Expand Down Expand Up @@ -1154,10 +1259,7 @@ def get_structure_optimization_candidates(
try:
results = self.execute_query(query)
for row in results:
if include_embedding:
props = row.values()[0].as_node().get_properties()
else:
props = {k: v.value for k, v in row.items()}
props = {k: v.value for k, v in row.items()}
candidates.append(self._parse_node(props))
except Exception as e:
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
Expand Down Expand Up @@ -1527,6 +1629,7 @@ def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]:
return filtered_metadata

def _build_return_fields(self, include_embedding: bool = False) -> str:
fields = set(self.base_fields)
if include_embedding:
return "n"
return ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
fields.add(self.dim_field)
return ", ".join(f"n.{f} AS {f}" for f in fields)
46 changes: 33 additions & 13 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(
self.internet_retriever = internet_retriever
self.moscube = moscube

self._usage_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=4, thread_name_prefix="usage"
)

@timed
def search(
self, query: str, top_k: int, info=None, mode="fast", memory_type="All"
Expand Down Expand Up @@ -225,7 +229,7 @@ def _retrieve_from_long_term_and_user(
query=query,
query_embedding=query_embedding[0],
graph_results=results,
top_k=top_k * 2,
top_k=top_k,
parsed_goal=parsed_goal,
)

Expand All @@ -244,7 +248,7 @@ def _retrieve_from_memcubes(
query=query,
query_embedding=query_embedding[0],
graph_results=results,
top_k=top_k * 2,
top_k=top_k,
parsed_goal=parsed_goal,
)

Expand Down Expand Up @@ -303,14 +307,30 @@ def _sort_and_trim(self, results, top_k):
def _update_usage_history(self, items, info):
"""Update usage history in graph DB"""
now_time = datetime.now().isoformat()
info.pop("chat_history", None)
# `info` should be a serializable dict or string
usage_record = json.dumps({"time": now_time, "info": info})
for item in items:
if (
hasattr(item, "id")
and hasattr(item, "metadata")
and hasattr(item.metadata, "usage")
):
item.metadata.usage.append(usage_record)
self.graph_store.update_node(item.id, {"usage": item.metadata.usage})
info_copy = dict(info or {})
info_copy.pop("chat_history", None)
usage_record = json.dumps({"time": now_time, "info": info_copy})
payload = []
for it in items:
try:
item_id = getattr(it, "id", None)
md = getattr(it, "metadata", None)
if md is None:
continue
if not hasattr(md, "usage") or md.usage is None:
md.usage = []
md.usage.append(usage_record)
if item_id:
payload.append((item_id, list(md.usage)))
except Exception:
logger.exception("[USAGE] snapshot item failed")

if payload:
self._usage_executor.submit(self._update_usage_history_worker, payload, usage_record)

def _update_usage_history_worker(self, payload, usage_record: str):
try:
for item_id, usage_list in payload:
self.graph_store.update_node(item_id, {"usage": usage_list})
except Exception:
logger.exception("[USAGE] update usage failed")
Loading