11import json
22import traceback
33
4+ from contextlib import suppress
45from datetime import datetime
5- from typing import Any , Literal
6+ from threading import Lock
7+ from typing import TYPE_CHECKING , Any , ClassVar , Literal
68
79import numpy as np
810
1315from memos .utils import timed
1416
1517
18+ if TYPE_CHECKING :
19+ from nebulagraph_python .client .pool import NebulaPool
20+
21+
1622logger = get_logger (__name__ )
1723
1824
@@ -85,6 +91,95 @@ class NebulaGraphDB(BaseGraphDB):
8591 NebulaGraph-based implementation of a graph memory store.
8692 """
8793
94+ # ====== shared pool cache & refcount ======
95+ # These are process-local; in a multi-process model each process will
96+ # have its own cache.
97+ _POOL_CACHE : ClassVar [dict [str , "NebulaPool" ]] = {}
98+ _POOL_REFCOUNT : ClassVar [dict [str , int ]] = {}
99+ _POOL_LOCK : ClassVar [Lock ] = Lock ()
100+
101+ @staticmethod
102+ def _make_pool_key (cfg : NebulaGraphDBConfig ) -> str :
103+ """
104+ Build a cache key that captures all connection-affecting options.
105+ Keep this key stable and include fields that change the underlying pool behavior.
106+ """
107+ # NOTE: Do not include tenant-like or query-scope-only fields here.
108+ # Only include things that affect the actual TCP/auth/session pool.
109+ return "|" .join (
110+ [
111+ "nebula" ,
112+ str (getattr (cfg , "uri" , "" )),
113+ str (getattr (cfg , "user" , "" )),
114+ str (getattr (cfg , "password" , "" )),
115+ # pool sizing / tls / timeouts if you have them in config:
116+ str (getattr (cfg , "max_client" , 1000 )),
117+ # multi-db mode can impact how we use sessions; keep it to be safe
118+ str (getattr (cfg , "use_multi_db" , False )),
119+ ]
120+ )
121+
122+ @classmethod
123+ def _get_or_create_shared_pool (cls , cfg : NebulaGraphDBConfig ):
124+ """
125+ Get a shared NebulaPool from cache or create one if missing.
126+ Thread-safe with a lock; maintains a simple refcount.
127+ """
128+ from nebulagraph_python .client .pool import NebulaPool , NebulaPoolConfig
129+
130+ key = cls ._make_pool_key (cfg )
131+
132+ with cls ._POOL_LOCK :
133+ pool = cls ._POOL_CACHE .get (key )
134+ if pool is None :
135+ # Create a new pool and put into cache
136+ pool = NebulaPool (
137+ hosts = cfg .get ("uri" ),
138+ username = cfg .get ("user" ),
139+ password = cfg .get ("password" ),
140+ pool_config = NebulaPoolConfig (max_client_size = cfg .get ("max_client" , 1000 )),
141+ )
142+ cls ._POOL_CACHE [key ] = pool
143+ cls ._POOL_REFCOUNT [key ] = 0
144+ logger .info (f"[NebulaGraphDB] Created new shared NebulaPool for key={ key } " )
145+
146+ # Increase refcount for the caller
147+ cls ._POOL_REFCOUNT [key ] = cls ._POOL_REFCOUNT .get (key , 0 ) + 1
148+ return key , pool
149+
150+ @classmethod
151+ def _release_shared_pool (cls , key : str ):
152+ """
153+ Decrease refcount for the given pool key; only close when refcount hits zero.
154+ """
155+ with cls ._POOL_LOCK :
156+ if key not in cls ._POOL_CACHE :
157+ return
158+ cls ._POOL_REFCOUNT [key ] = max (0 , cls ._POOL_REFCOUNT .get (key , 0 ) - 1 )
159+ if cls ._POOL_REFCOUNT [key ] == 0 :
160+ try :
161+ cls ._POOL_CACHE [key ].close ()
162+ except Exception as e :
163+ logger .warning (f"[NebulaGraphDB] Error closing shared pool: { e } " )
164+ finally :
165+ cls ._POOL_CACHE .pop (key , None )
166+ cls ._POOL_REFCOUNT .pop (key , None )
167+ logger .info (f"[NebulaGraphDB] Closed and removed shared pool key={ key } " )
168+
169+ @classmethod
170+ def close_all_shared_pools (cls ):
171+ """Force close all cached pools. Call this on graceful shutdown."""
172+ with cls ._POOL_LOCK :
173+ for key , pool in list (cls ._POOL_CACHE .items ()):
174+ try :
175+ pool .close ()
176+ except Exception as e :
177+ logger .warning (f"[NebulaGraphDB] Error closing pool key={ key } : { e } " )
178+ finally :
179+ logger .info (f"[NebulaGraphDB] Closed pool key={ key } " )
180+ cls ._POOL_CACHE .clear ()
181+ cls ._POOL_REFCOUNT .clear ()
182+
88183 @require_python_package (
89184 import_name = "nebulagraph_python" ,
90185 install_command = "pip install ... @Tianxing" ,
@@ -108,7 +203,6 @@ def __init__(self, config: NebulaGraphDBConfig):
108203 "space": "test"
109204 }
110205 """
111- from nebulagraph_python .client .pool import NebulaPool , NebulaPoolConfig
112206
113207 self .config = config
114208 self .db_name = config .space
@@ -135,19 +229,21 @@ def __init__(self, config: NebulaGraphDBConfig):
135229 "usage" ,
136230 "background" ,
137231 }
232+ self .base_fields = set (self .common_fields ) - {"usage" }
233+ self .heavy_fields = {"usage" }
138234 self .dim_field = (
139235 f"embedding_{ self .embedding_dimension } "
140236 if (str (self .embedding_dimension ) != str (self .default_memory_dimension ))
141237 else "embedding"
142238 )
143239 self .system_db_name = "system" if config .use_multi_db else config .space
144- self .pool = NebulaPool (
145- hosts = config .get ("uri" ),
146- username = config .get ("user" ),
147- password = config .get ("password" ),
148- pool_config = NebulaPoolConfig (max_client_size = config .get ("max_client" , 1000 )),
149- )
150240
241+ # ---- NEW: pool acquisition strategy
242+ # Get or create a shared pool from the class-level cache
243+ self ._pool_key , self .pool = self ._get_or_create_shared_pool (config )
244+ self ._owns_pool = True # We manage refcount for this instance
245+
246+ # auto-create graph type / graph / index if needed
151247 if config .auto_create :
152248 self ._ensure_database_exists ()
153249
@@ -159,7 +255,7 @@ def __init__(self, config: NebulaGraphDBConfig):
159255 logger .info ("Connected to NebulaGraph successfully." )
160256
161257 @timed
162- def execute_query (self , gql : str , timeout : float = 5 .0 , auto_set_db : bool = True ):
258+ def execute_query (self , gql : str , timeout : float = 10 .0 , auto_set_db : bool = True ):
163259 needs_use_prefix = ("SESSION SET GRAPH" not in gql ) and ("USE " not in gql )
164260 use_prefix = f"USE `{ self .db_name } ` " if auto_set_db and needs_use_prefix else ""
165261
@@ -174,7 +270,25 @@ def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True
174270
175271 @timed
176272 def close (self ):
177- self .pool .close ()
273+ """
274+ Close the connection resource if this instance owns it.
275+
276+ - If pool was injected (`shared_pool`), do nothing.
277+ - If pool was acquired via shared cache, decrement refcount and close
278+ when the last owner releases it.
279+ """
280+ if not self ._owns_pool :
281+ logger .debug ("[NebulaGraphDB] close() skipped (injected pool)." )
282+ return
283+ if self ._pool_key :
284+ self ._release_shared_pool (self ._pool_key )
285+ self ._pool_key = None
286+ self .pool = None
287+
288+ # NOTE: __del__ is best-effort; do not rely on GC order.
289+ def __del__ (self ):
290+ with suppress (Exception ):
291+ self .close ()
178292
179293 @timed
180294 def create_index (
@@ -253,12 +367,10 @@ def node_not_exist(self, scope: str) -> int:
253367 filter_clause = f'n.memory_type = "{ scope } " AND n.user_name = "{ self .config .user_name } "'
254368 else :
255369 filter_clause = f'n.memory_type = "{ scope } "'
256- return_fields = ", " .join (f"n.{ field } AS { field } " for field in self .common_fields )
257-
258370 query = f"""
259371 MATCH (n@Memory)
260372 WHERE { filter_clause }
261- RETURN { return_fields }
373+ RETURN n.id AS id
262374 LIMIT 1
263375 """
264376
@@ -455,10 +567,7 @@ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] |
455567 try :
456568 result = self .execute_query (gql )
457569 for row in result :
458- if include_embedding :
459- props = row .values ()[0 ].as_node ().get_properties ()
460- else :
461- props = {k : v .value for k , v in row .items ()}
570+ props = {k : v .value for k , v in row .items ()}
462571 node = self ._parse_node (props )
463572 return node
464573
@@ -507,10 +616,7 @@ def get_nodes(
507616 try :
508617 results = self .execute_query (query )
509618 for row in results :
510- if include_embedding :
511- props = row .values ()[0 ].as_node ().get_properties ()
512- else :
513- props = {k : v .value for k , v in row .items ()}
619+ props = {k : v .value for k , v in row .items ()}
514620 nodes .append (self ._parse_node (props ))
515621 except Exception as e :
516622 logger .error (
@@ -579,6 +685,7 @@ def get_neighbors_by_tag(
579685 exclude_ids : list [str ],
580686 top_k : int = 5 ,
581687 min_overlap : int = 1 ,
688+ include_embedding : bool = False ,
582689 ) -> list [dict [str , Any ]]:
583690 """
584691 Find top-K neighbor nodes with maximum tag overlap.
@@ -588,6 +695,7 @@ def get_neighbors_by_tag(
588695 exclude_ids: Node IDs to exclude (e.g., local cluster).
589696 top_k: Max number of neighbors to return.
590697 min_overlap: Minimum number of overlapping tags required.
698+ include_embedding: with/without embedding
591699
592700 Returns:
593701 List of dicts with node details and overlap count.
@@ -609,12 +717,13 @@ def get_neighbors_by_tag(
609717 where_clause = " AND " .join (where_clauses )
610718 tag_list_literal = "[" + ", " .join (f'"{ _escape_str (t )} "' for t in tags ) + "]"
611719
720+ return_fields = self ._build_return_fields (include_embedding )
612721 query = f"""
613722 LET tag_list = { tag_list_literal }
614723
615724 MATCH (n@Memory)
616725 WHERE { where_clause }
617- RETURN n ,
726+ RETURN { return_fields } ,
618727 size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
619728 ORDER BY overlap_count DESC
620729 LIMIT { top_k }
@@ -623,9 +732,8 @@ def get_neighbors_by_tag(
623732 result = self .execute_query (query )
624733 neighbors : list [dict [str , Any ]] = []
625734 for r in result :
626- node_props = r ["n" ].as_node ().get_properties ()
627- parsed = self ._parse_node (node_props ) # --> {id, memory, metadata}
628-
735+ props = {k : v .value for k , v in r .items () if k != "overlap_count" }
736+ parsed = self ._parse_node (props )
629737 parsed ["overlap_count" ] = r ["overlap_count" ].value
630738 neighbors .append (parsed )
631739
@@ -1112,10 +1220,7 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (
11121220 try :
11131221 results = self .execute_query (query )
11141222 for row in results :
1115- if include_embedding :
1116- props = row .values ()[0 ].as_node ().get_properties ()
1117- else :
1118- props = {k : v .value for k , v in row .items ()}
1223+ props = {k : v .value for k , v in row .items ()}
11191224 nodes .append (self ._parse_node (props ))
11201225 except Exception as e :
11211226 logger .error (f"Failed to get memories: { e } " )
@@ -1154,10 +1259,7 @@ def get_structure_optimization_candidates(
11541259 try :
11551260 results = self .execute_query (query )
11561261 for row in results :
1157- if include_embedding :
1158- props = row .values ()[0 ].as_node ().get_properties ()
1159- else :
1160- props = {k : v .value for k , v in row .items ()}
1262+ props = {k : v .value for k , v in row .items ()}
11611263 candidates .append (self ._parse_node (props ))
11621264 except Exception as e :
11631265 logger .error (f"Failed : { e } , traceback: { traceback .format_exc ()} " )
@@ -1527,6 +1629,7 @@ def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]:
15271629 return filtered_metadata
15281630
15291631 def _build_return_fields (self , include_embedding : bool = False ) -> str :
1632+ fields = set (self .base_fields )
15301633 if include_embedding :
1531- return "n"
1532- return ", " .join (f"n.{ field } AS { field } " for field in self . common_fields )
1634+ fields . add ( self . dim_field )
1635+ return ", " .join (f"n.{ f } AS { f } " for f in fields )
0 commit comments