11import json
22import traceback
33
4- from contextlib import suppress
54from datetime import datetime
6- from queue import Empty , Queue
7- from threading import Lock
85from typing import Any , Literal
96
107import numpy as np
@@ -83,137 +80,6 @@ def _normalize_datetime(val):
8380 return str (val )
8481
8582
86- class SessionPoolError (Exception ):
87- pass
88-
89-
90- class SessionPool :
91- @require_python_package (
92- import_name = "nebulagraph_python" ,
93- install_command = "pip install ... @Tianxing" ,
94- install_link = "....." ,
95- )
96- def __init__ (
97- self ,
98- hosts : list [str ],
99- user : str ,
100- password : str ,
101- minsize : int = 1 ,
102- maxsize : int = 10000 ,
103- ):
104- self .hosts = hosts
105- self .user = user
106- self .password = password
107- self .minsize = minsize
108- self .maxsize = maxsize
109- self .pool = Queue (maxsize )
110- self .lock = Lock ()
111-
112- self .clients = []
113-
114- for _ in range (minsize ):
115- self ._create_and_add_client ()
116-
117- @timed
118- def _create_and_add_client (self ):
119- from nebulagraph_python import NebulaClient
120-
121- client = NebulaClient (self .hosts , self .user , self .password )
122- self .pool .put (client )
123- self .clients .append (client )
124-
125- @timed
126- def get_client (self , timeout : float = 5.0 ):
127- try :
128- return self .pool .get (timeout = timeout )
129- except Empty :
130- with self .lock :
131- if len (self .clients ) < self .maxsize :
132- from nebulagraph_python import NebulaClient
133-
134- client = NebulaClient (self .hosts , self .user , self .password )
135- self .clients .append (client )
136- return client
137- raise RuntimeError ("NebulaClientPool exhausted" ) from None
138-
139- @timed
140- def return_client (self , client ):
141- try :
142- client .execute ("YIELD 1" )
143- self .pool .put (client )
144- except Exception :
145- logger .info ("[Pool] Client dead, replacing..." )
146- self .replace_client (client )
147-
148- @timed
149- def close (self ):
150- for client in self .clients :
151- with suppress (Exception ):
152- client .close ()
153- self .clients .clear ()
154-
155- @timed
156- def get (self ):
157- """
158- Context manager: with pool.get() as client:
159- """
160-
161- class _ClientContext :
162- def __init__ (self , outer ):
163- self .outer = outer
164- self .client = None
165-
166- def __enter__ (self ):
167- self .client = self .outer .get_client ()
168- return self .client
169-
170- def __exit__ (self , exc_type , exc_val , exc_tb ):
171- if self .client :
172- self .outer .return_client (self .client )
173-
174- return _ClientContext (self )
175-
176- @timed
177- def reset_pool (self ):
178- """⚠️ Emergency reset: Close all clients and clear the pool."""
179- logger .warning ("[Pool] Resetting all clients. Existing sessions will be lost." )
180- with self .lock :
181- for client in self .clients :
182- try :
183- client .close ()
184- except Exception :
185- logger .error ("Fail to close!!!" )
186- self .clients .clear ()
187- while not self .pool .empty ():
188- try :
189- self .pool .get_nowait ()
190- except Empty :
191- break
192- for _ in range (self .minsize ):
193- self ._create_and_add_client ()
194- logger .info ("[Pool] Pool has been reset successfully." )
195-
196- @timed
197- def replace_client (self , client ):
198- try :
199- client .close ()
200- except Exception :
201- logger .error ("Fail to close client" )
202-
203- if client in self .clients :
204- self .clients .remove (client )
205-
206- from nebulagraph_python import NebulaClient
207-
208- new_client = NebulaClient (self .hosts , self .user , self .password )
209- self .clients .append (new_client )
210-
211- self .pool .put (new_client )
212-
213- logger .info ("[Pool] Replaced dead client with a new one." )
214- return new_client
215-
216-
21783class NebulaGraphDB (BaseGraphDB ):
21884 """
21985 NebulaGraph-based implementation of a graph memory store.
@@ -242,6 +108,7 @@ def __init__(self, config: NebulaGraphDBConfig):
242108 "space": "test"
243109 }
244110 """
111+ from nebulagraph_python .client .pool import NebulaPool , NebulaPoolConfig
245112
246113 self .config = config
247114 self .db_name = config .space
@@ -274,12 +141,11 @@ def __init__(self, config: NebulaGraphDBConfig):
274141 else "embedding"
275142 )
276143 self .system_db_name = "system" if config .use_multi_db else config .space
277- self .pool = SessionPool (
144+ self .pool = NebulaPool (
278145 hosts = config .get ("uri" ),
279- user = config .get ("user" ),
146+ username = config .get ("user" ),
280147 password = config .get ("password" ),
281- minsize = 1 ,
282- maxsize = config .get ("max_client" , 1000 ),
148+ pool_config = NebulaPoolConfig (max_client_size = config .get ("max_client" , 1000 )),
283149 )
284150
285151 if config .auto_create :
@@ -294,18 +160,17 @@ def __init__(self, config: NebulaGraphDBConfig):
294160
295161 @timed
296162 def execute_query (self , gql : str , timeout : float = 5.0 , auto_set_db : bool = True ):
297- with self .pool .get () as client :
298- try :
299- if auto_set_db and self .db_name :
300- client .execute (f"SESSION SET GRAPH `{ self .db_name } `" )
301- return client .execute (gql , timeout = timeout )
163+ needs_use_prefix = ("SESSION SET GRAPH" not in gql ) and ("USE " not in gql )
164+ use_prefix = f"USE `{ self .db_name } ` " if auto_set_db and needs_use_prefix else ""
302165
303- except Exception as e :
304- if "Session not found" in str (e ) or "Connection not established" in str (e ):
305- logger .warning (f"[execute_query] { e !s} , replacing client..." )
306- self .pool .replace_client (client )
307- return self .execute_query (gql , timeout , auto_set_db )
308- raise
166+ ngql = use_prefix + gql
167+
168+ try :
169+ with self .pool .borrow () as client :
170+ return client .execute (ngql , timeout = timeout )
171+ except Exception as e :
172+ logger .error (f"[execute_query] Failed: { e } " )
173+ raise
309174
310175 @timed
311176 def close (self ):
0 commit comments