Skip to content

Commit 0d85609

Browse files
authored
feat: enhance NebulaGraph pool management & improve Searcher usage logging (#265)
* feat: timeout for nebula query 5s->10s * feat: exclude heavy feilds when calling memories from nebula db * test: fix tree-text-mem searcher text
1 parent fe0624e commit 0d85609

File tree

2 files changed

+171
-48
lines changed

2 files changed

+171
-48
lines changed

src/memos/graph_dbs/nebular.py

Lines changed: 138 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
22
import traceback
33

4+
from contextlib import suppress
45
from datetime import datetime
5-
from typing import Any, Literal
6+
from threading import Lock
7+
from typing import TYPE_CHECKING, Any, ClassVar, Literal
68

79
import numpy as np
810

@@ -13,6 +15,10 @@
1315
from memos.utils import timed
1416

1517

18+
if TYPE_CHECKING:
19+
from nebulagraph_python.client.pool import NebulaPool
20+
21+
1622
logger = get_logger(__name__)
1723

1824

@@ -85,6 +91,95 @@ class NebulaGraphDB(BaseGraphDB):
8591
NebulaGraph-based implementation of a graph memory store.
8692
"""
8793

94+
# ====== shared pool cache & refcount ======
95+
# These are process-local; in a multi-process model each process will
96+
# have its own cache.
97+
_POOL_CACHE: ClassVar[dict[str, "NebulaPool"]] = {}
98+
_POOL_REFCOUNT: ClassVar[dict[str, int]] = {}
99+
_POOL_LOCK: ClassVar[Lock] = Lock()
100+
101+
@staticmethod
102+
def _make_pool_key(cfg: NebulaGraphDBConfig) -> str:
103+
"""
104+
Build a cache key that captures all connection-affecting options.
105+
Keep this key stable and include fields that change the underlying pool behavior.
106+
"""
107+
# NOTE: Do not include tenant-like or query-scope-only fields here.
108+
# Only include things that affect the actual TCP/auth/session pool.
109+
return "|".join(
110+
[
111+
"nebula",
112+
str(getattr(cfg, "uri", "")),
113+
str(getattr(cfg, "user", "")),
114+
str(getattr(cfg, "password", "")),
115+
# pool sizing / tls / timeouts if you have them in config:
116+
str(getattr(cfg, "max_client", 1000)),
117+
# multi-db mode can impact how we use sessions; keep it to be safe
118+
str(getattr(cfg, "use_multi_db", False)),
119+
]
120+
)
121+
122+
@classmethod
123+
def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig):
124+
"""
125+
Get a shared NebulaPool from cache or create one if missing.
126+
Thread-safe with a lock; maintains a simple refcount.
127+
"""
128+
from nebulagraph_python.client.pool import NebulaPool, NebulaPoolConfig
129+
130+
key = cls._make_pool_key(cfg)
131+
132+
with cls._POOL_LOCK:
133+
pool = cls._POOL_CACHE.get(key)
134+
if pool is None:
135+
# Create a new pool and put into cache
136+
pool = NebulaPool(
137+
hosts=cfg.get("uri"),
138+
username=cfg.get("user"),
139+
password=cfg.get("password"),
140+
pool_config=NebulaPoolConfig(max_client_size=cfg.get("max_client", 1000)),
141+
)
142+
cls._POOL_CACHE[key] = pool
143+
cls._POOL_REFCOUNT[key] = 0
144+
logger.info(f"[NebulaGraphDB] Created new shared NebulaPool for key={key}")
145+
146+
# Increase refcount for the caller
147+
cls._POOL_REFCOUNT[key] = cls._POOL_REFCOUNT.get(key, 0) + 1
148+
return key, pool
149+
150+
@classmethod
151+
def _release_shared_pool(cls, key: str):
152+
"""
153+
Decrease refcount for the given pool key; only close when refcount hits zero.
154+
"""
155+
with cls._POOL_LOCK:
156+
if key not in cls._POOL_CACHE:
157+
return
158+
cls._POOL_REFCOUNT[key] = max(0, cls._POOL_REFCOUNT.get(key, 0) - 1)
159+
if cls._POOL_REFCOUNT[key] == 0:
160+
try:
161+
cls._POOL_CACHE[key].close()
162+
except Exception as e:
163+
logger.warning(f"[NebulaGraphDB] Error closing shared pool: {e}")
164+
finally:
165+
cls._POOL_CACHE.pop(key, None)
166+
cls._POOL_REFCOUNT.pop(key, None)
167+
logger.info(f"[NebulaGraphDB] Closed and removed shared pool key={key}")
168+
169+
@classmethod
170+
def close_all_shared_pools(cls):
171+
"""Force close all cached pools. Call this on graceful shutdown."""
172+
with cls._POOL_LOCK:
173+
for key, pool in list(cls._POOL_CACHE.items()):
174+
try:
175+
pool.close()
176+
except Exception as e:
177+
logger.warning(f"[NebulaGraphDB] Error closing pool key={key}: {e}")
178+
finally:
179+
logger.info(f"[NebulaGraphDB] Closed pool key={key}")
180+
cls._POOL_CACHE.clear()
181+
cls._POOL_REFCOUNT.clear()
182+
88183
@require_python_package(
89184
import_name="nebulagraph_python",
90185
install_command="pip install ... @Tianxing",
@@ -108,7 +203,6 @@ def __init__(self, config: NebulaGraphDBConfig):
108203
"space": "test"
109204
}
110205
"""
111-
from nebulagraph_python.client.pool import NebulaPool, NebulaPoolConfig
112206

113207
self.config = config
114208
self.db_name = config.space
@@ -135,19 +229,21 @@ def __init__(self, config: NebulaGraphDBConfig):
135229
"usage",
136230
"background",
137231
}
232+
self.base_fields = set(self.common_fields) - {"usage"}
233+
self.heavy_fields = {"usage"}
138234
self.dim_field = (
139235
f"embedding_{self.embedding_dimension}"
140236
if (str(self.embedding_dimension) != str(self.default_memory_dimension))
141237
else "embedding"
142238
)
143239
self.system_db_name = "system" if config.use_multi_db else config.space
144-
self.pool = NebulaPool(
145-
hosts=config.get("uri"),
146-
username=config.get("user"),
147-
password=config.get("password"),
148-
pool_config=NebulaPoolConfig(max_client_size=config.get("max_client", 1000)),
149-
)
150240

241+
# ---- NEW: pool acquisition strategy
242+
# Get or create a shared pool from the class-level cache
243+
self._pool_key, self.pool = self._get_or_create_shared_pool(config)
244+
self._owns_pool = True # We manage refcount for this instance
245+
246+
# auto-create graph type / graph / index if needed
151247
if config.auto_create:
152248
self._ensure_database_exists()
153249

@@ -159,7 +255,7 @@ def __init__(self, config: NebulaGraphDBConfig):
159255
logger.info("Connected to NebulaGraph successfully.")
160256

161257
@timed
162-
def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True):
258+
def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
163259
needs_use_prefix = ("SESSION SET GRAPH" not in gql) and ("USE " not in gql)
164260
use_prefix = f"USE `{self.db_name}` " if auto_set_db and needs_use_prefix else ""
165261

@@ -174,7 +270,25 @@ def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True
174270

175271
@timed
176272
def close(self):
177-
self.pool.close()
273+
"""
274+
Close the connection resource if this instance owns it.
275+
276+
- If pool was injected (`shared_pool`), do nothing.
277+
- If pool was acquired via shared cache, decrement refcount and close
278+
when the last owner releases it.
279+
"""
280+
if not self._owns_pool:
281+
logger.debug("[NebulaGraphDB] close() skipped (injected pool).")
282+
return
283+
if self._pool_key:
284+
self._release_shared_pool(self._pool_key)
285+
self._pool_key = None
286+
self.pool = None
287+
288+
# NOTE: __del__ is best-effort; do not rely on GC order.
289+
def __del__(self):
290+
with suppress(Exception):
291+
self.close()
178292

179293
@timed
180294
def create_index(
@@ -253,12 +367,10 @@ def node_not_exist(self, scope: str) -> int:
253367
filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"'
254368
else:
255369
filter_clause = f'n.memory_type = "{scope}"'
256-
return_fields = ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
257-
258370
query = f"""
259371
MATCH (n@Memory)
260372
WHERE {filter_clause}
261-
RETURN {return_fields}
373+
RETURN n.id AS id
262374
LIMIT 1
263375
"""
264376

@@ -455,10 +567,7 @@ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] |
455567
try:
456568
result = self.execute_query(gql)
457569
for row in result:
458-
if include_embedding:
459-
props = row.values()[0].as_node().get_properties()
460-
else:
461-
props = {k: v.value for k, v in row.items()}
570+
props = {k: v.value for k, v in row.items()}
462571
node = self._parse_node(props)
463572
return node
464573

@@ -507,10 +616,7 @@ def get_nodes(
507616
try:
508617
results = self.execute_query(query)
509618
for row in results:
510-
if include_embedding:
511-
props = row.values()[0].as_node().get_properties()
512-
else:
513-
props = {k: v.value for k, v in row.items()}
619+
props = {k: v.value for k, v in row.items()}
514620
nodes.append(self._parse_node(props))
515621
except Exception as e:
516622
logger.error(
@@ -579,6 +685,7 @@ def get_neighbors_by_tag(
579685
exclude_ids: list[str],
580686
top_k: int = 5,
581687
min_overlap: int = 1,
688+
include_embedding: bool = False,
582689
) -> list[dict[str, Any]]:
583690
"""
584691
Find top-K neighbor nodes with maximum tag overlap.
@@ -588,6 +695,7 @@ def get_neighbors_by_tag(
588695
exclude_ids: Node IDs to exclude (e.g., local cluster).
589696
top_k: Max number of neighbors to return.
590697
min_overlap: Minimum number of overlapping tags required.
698+
include_embedding: with/without embedding
591699
592700
Returns:
593701
List of dicts with node details and overlap count.
@@ -609,12 +717,13 @@ def get_neighbors_by_tag(
609717
where_clause = " AND ".join(where_clauses)
610718
tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
611719

720+
return_fields = self._build_return_fields(include_embedding)
612721
query = f"""
613722
LET tag_list = {tag_list_literal}
614723
615724
MATCH (n@Memory)
616725
WHERE {where_clause}
617-
RETURN n,
726+
RETURN {return_fields},
618727
size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
619728
ORDER BY overlap_count DESC
620729
LIMIT {top_k}
@@ -623,9 +732,8 @@ def get_neighbors_by_tag(
623732
result = self.execute_query(query)
624733
neighbors: list[dict[str, Any]] = []
625734
for r in result:
626-
node_props = r["n"].as_node().get_properties()
627-
parsed = self._parse_node(node_props) # --> {id, memory, metadata}
628-
735+
props = {k: v.value for k, v in r.items() if k != "overlap_count"}
736+
parsed = self._parse_node(props)
629737
parsed["overlap_count"] = r["overlap_count"].value
630738
neighbors.append(parsed)
631739

@@ -1112,10 +1220,7 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (
11121220
try:
11131221
results = self.execute_query(query)
11141222
for row in results:
1115-
if include_embedding:
1116-
props = row.values()[0].as_node().get_properties()
1117-
else:
1118-
props = {k: v.value for k, v in row.items()}
1223+
props = {k: v.value for k, v in row.items()}
11191224
nodes.append(self._parse_node(props))
11201225
except Exception as e:
11211226
logger.error(f"Failed to get memories: {e}")
@@ -1154,10 +1259,7 @@ def get_structure_optimization_candidates(
11541259
try:
11551260
results = self.execute_query(query)
11561261
for row in results:
1157-
if include_embedding:
1158-
props = row.values()[0].as_node().get_properties()
1159-
else:
1160-
props = {k: v.value for k, v in row.items()}
1262+
props = {k: v.value for k, v in row.items()}
11611263
candidates.append(self._parse_node(props))
11621264
except Exception as e:
11631265
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
@@ -1527,6 +1629,7 @@ def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]:
15271629
return filtered_metadata
15281630

15291631
def _build_return_fields(self, include_embedding: bool = False) -> str:
1632+
fields = set(self.base_fields)
15301633
if include_embedding:
1531-
return "n"
1532-
return ", ".join(f"n.{field} AS {field}" for field in self.common_fields)
1634+
fields.add(self.dim_field)
1635+
return ", ".join(f"n.{f} AS {f}" for f in fields)

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def __init__(
4141
self.internet_retriever = internet_retriever
4242
self.moscube = moscube
4343

44+
self._usage_executor = concurrent.futures.ThreadPoolExecutor(
45+
max_workers=4, thread_name_prefix="usage"
46+
)
47+
4448
@timed
4549
def search(
4650
self, query: str, top_k: int, info=None, mode="fast", memory_type="All"
@@ -225,7 +229,7 @@ def _retrieve_from_long_term_and_user(
225229
query=query,
226230
query_embedding=query_embedding[0],
227231
graph_results=results,
228-
top_k=top_k * 2,
232+
top_k=top_k,
229233
parsed_goal=parsed_goal,
230234
)
231235

@@ -244,7 +248,7 @@ def _retrieve_from_memcubes(
244248
query=query,
245249
query_embedding=query_embedding[0],
246250
graph_results=results,
247-
top_k=top_k * 2,
251+
top_k=top_k,
248252
parsed_goal=parsed_goal,
249253
)
250254

@@ -303,14 +307,30 @@ def _sort_and_trim(self, results, top_k):
303307
def _update_usage_history(self, items, info):
304308
"""Update usage history in graph DB"""
305309
now_time = datetime.now().isoformat()
306-
info.pop("chat_history", None)
307-
# `info` should be a serializable dict or string
308-
usage_record = json.dumps({"time": now_time, "info": info})
309-
for item in items:
310-
if (
311-
hasattr(item, "id")
312-
and hasattr(item, "metadata")
313-
and hasattr(item.metadata, "usage")
314-
):
315-
item.metadata.usage.append(usage_record)
316-
self.graph_store.update_node(item.id, {"usage": item.metadata.usage})
310+
info_copy = dict(info or {})
311+
info_copy.pop("chat_history", None)
312+
usage_record = json.dumps({"time": now_time, "info": info_copy})
313+
payload = []
314+
for it in items:
315+
try:
316+
item_id = getattr(it, "id", None)
317+
md = getattr(it, "metadata", None)
318+
if md is None:
319+
continue
320+
if not hasattr(md, "usage") or md.usage is None:
321+
md.usage = []
322+
md.usage.append(usage_record)
323+
if item_id:
324+
payload.append((item_id, list(md.usage)))
325+
except Exception:
326+
logger.exception("[USAGE] snapshot item failed")
327+
328+
if payload:
329+
self._usage_executor.submit(self._update_usage_history_worker, payload, usage_record)
330+
331+
def _update_usage_history_worker(self, payload, usage_record: str):
332+
try:
333+
for item_id, usage_list in payload:
334+
self.graph_store.update_node(item_id, {"usage": usage_list})
335+
except Exception:
336+
logger.exception("[USAGE] update usage failed")

0 commit comments

Comments
 (0)