Skip to content

Commit d58e548

Browse files
authored
feat: modify nebula session pool (#259)
* fix: mem-reader bug * fix: test mem reader * feat: modify nebula session pool
1 parent c14868b commit d58e548

File tree

1 file changed

+14
-149
lines changed

1 file changed

+14
-149
lines changed

src/memos/graph_dbs/nebular.py

Lines changed: 14 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import json
22
import traceback
33

4-
from contextlib import suppress
54
from datetime import datetime
6-
from queue import Empty, Queue
7-
from threading import Lock
85
from typing import Any, Literal
96

107
import numpy as np
@@ -83,137 +80,6 @@ def _normalize_datetime(val):
8380
return str(val)
8481

8582

86-
class SessionPoolError(Exception):
87-
pass
88-
89-
90-
class SessionPool:
91-
@require_python_package(
92-
import_name="nebulagraph_python",
93-
install_command="pip install ... @Tianxing",
94-
install_link=".....",
95-
)
96-
def __init__(
97-
self,
98-
hosts: list[str],
99-
user: str,
100-
password: str,
101-
minsize: int = 1,
102-
maxsize: int = 10000,
103-
):
104-
self.hosts = hosts
105-
self.user = user
106-
self.password = password
107-
self.minsize = minsize
108-
self.maxsize = maxsize
109-
self.pool = Queue(maxsize)
110-
self.lock = Lock()
111-
112-
self.clients = []
113-
114-
for _ in range(minsize):
115-
self._create_and_add_client()
116-
117-
@timed
118-
def _create_and_add_client(self):
119-
from nebulagraph_python import NebulaClient
120-
121-
client = NebulaClient(self.hosts, self.user, self.password)
122-
self.pool.put(client)
123-
self.clients.append(client)
124-
125-
@timed
126-
def get_client(self, timeout: float = 5.0):
127-
try:
128-
return self.pool.get(timeout=timeout)
129-
except Empty:
130-
with self.lock:
131-
if len(self.clients) < self.maxsize:
132-
from nebulagraph_python import NebulaClient
133-
134-
client = NebulaClient(self.hosts, self.user, self.password)
135-
self.clients.append(client)
136-
return client
137-
raise RuntimeError("NebulaClientPool exhausted") from None
138-
139-
@timed
140-
def return_client(self, client):
141-
try:
142-
client.execute("YIELD 1")
143-
self.pool.put(client)
144-
except Exception:
145-
logger.info("[Pool] Client dead, replacing...")
146-
self.replace_client(client)
147-
148-
@timed
149-
def close(self):
150-
for client in self.clients:
151-
with suppress(Exception):
152-
client.close()
153-
self.clients.clear()
154-
155-
@timed
156-
def get(self):
157-
"""
158-
Context manager: with pool.get() as client:
159-
"""
160-
161-
class _ClientContext:
162-
def __init__(self, outer):
163-
self.outer = outer
164-
self.client = None
165-
166-
def __enter__(self):
167-
self.client = self.outer.get_client()
168-
return self.client
169-
170-
def __exit__(self, exc_type, exc_val, exc_tb):
171-
if self.client:
172-
self.outer.return_client(self.client)
173-
174-
return _ClientContext(self)
175-
176-
@timed
177-
def reset_pool(self):
178-
"""⚠️ Emergency reset: Close all clients and clear the pool."""
179-
logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
180-
with self.lock:
181-
for client in self.clients:
182-
try:
183-
client.close()
184-
except Exception:
185-
logger.error("Fail to close!!!")
186-
self.clients.clear()
187-
while not self.pool.empty():
188-
try:
189-
self.pool.get_nowait()
190-
except Empty:
191-
break
192-
for _ in range(self.minsize):
193-
self._create_and_add_client()
194-
logger.info("[Pool] Pool has been reset successfully.")
195-
196-
@timed
197-
def replace_client(self, client):
198-
try:
199-
client.close()
200-
except Exception:
201-
logger.error("Fail to close client")
202-
203-
if client in self.clients:
204-
self.clients.remove(client)
205-
206-
from nebulagraph_python import NebulaClient
207-
208-
new_client = NebulaClient(self.hosts, self.user, self.password)
209-
self.clients.append(new_client)
210-
211-
self.pool.put(new_client)
212-
213-
logger.info("[Pool] Replaced dead client with a new one.")
214-
return new_client
215-
216-
21783
class NebulaGraphDB(BaseGraphDB):
21884
"""
21985
NebulaGraph-based implementation of a graph memory store.
@@ -242,6 +108,7 @@ def __init__(self, config: NebulaGraphDBConfig):
242108
"space": "test"
243109
}
244110
"""
111+
from nebulagraph_python.client.pool import NebulaPool, NebulaPoolConfig
245112

246113
self.config = config
247114
self.db_name = config.space
@@ -274,12 +141,11 @@ def __init__(self, config: NebulaGraphDBConfig):
274141
else "embedding"
275142
)
276143
self.system_db_name = "system" if config.use_multi_db else config.space
277-
self.pool = SessionPool(
144+
self.pool = NebulaPool(
278145
hosts=config.get("uri"),
279-
user=config.get("user"),
146+
username=config.get("user"),
280147
password=config.get("password"),
281-
minsize=1,
282-
maxsize=config.get("max_client", 1000),
148+
pool_config=NebulaPoolConfig(max_client_size=config.get("max_client", 1000)),
283149
)
284150

285151
if config.auto_create:
@@ -294,18 +160,17 @@ def __init__(self, config: NebulaGraphDBConfig):
294160

295161
@timed
296162
def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True):
297-
with self.pool.get() as client:
298-
try:
299-
if auto_set_db and self.db_name:
300-
client.execute(f"SESSION SET GRAPH `{self.db_name}`")
301-
return client.execute(gql, timeout=timeout)
163+
needs_use_prefix = ("SESSION SET GRAPH" not in gql) and ("USE " not in gql)
164+
use_prefix = f"USE `{self.db_name}` " if auto_set_db and needs_use_prefix else ""
302165

303-
except Exception as e:
304-
if "Session not found" in str(e) or "Connection not established" in str(e):
305-
logger.warning(f"[execute_query] {e!s}, replacing client...")
306-
self.pool.replace_client(client)
307-
return self.execute_query(gql, timeout, auto_set_db)
308-
raise
166+
ngql = use_prefix + gql
167+
168+
try:
169+
with self.pool.borrow() as client:
170+
return client.execute(ngql, timeout=timeout)
171+
except Exception as e:
172+
logger.error(f"[execute_query] Failed: {e}")
173+
raise
309174

310175
@timed
311176
def close(self):

0 commit comments

Comments
 (0)