Skip to content
172 changes: 148 additions & 24 deletions src/memos/graph_dbs/nebular.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from contextlib import suppress
from datetime import datetime
from queue import Empty, Queue
from threading import Lock
from typing import TYPE_CHECKING, Any, ClassVar, Literal

Expand Down Expand Up @@ -86,6 +87,137 @@ def _normalize_datetime(val):
return str(val)


class SessionPoolError(Exception):
pass


class SessionPool:
@require_python_package(
import_name="nebulagraph_python",
install_command="pip install ... @Tianxing",
install_link=".....",
)
def __init__(
self,
hosts: list[str],
user: str,
password: str,
minsize: int = 1,
maxsize: int = 10000,
):
self.hosts = hosts
self.user = user
self.password = password
self.minsize = minsize
self.maxsize = maxsize
self.pool = Queue(maxsize)
self.lock = Lock()

self.clients = []

for _ in range(minsize):
self._create_and_add_client()

@timed
def _create_and_add_client(self):
from nebulagraph_python import NebulaClient

client = NebulaClient(self.hosts, self.user, self.password)
self.pool.put(client)
self.clients.append(client)

@timed
def get_client(self, timeout: float = 5.0):
try:
return self.pool.get(timeout=timeout)
except Empty:
with self.lock:
if len(self.clients) < self.maxsize:
from nebulagraph_python import NebulaClient

client = NebulaClient(self.hosts, self.user, self.password)
self.clients.append(client)
return client
raise RuntimeError("NebulaClientPool exhausted") from None

@timed
def return_client(self, client):
try:
client.execute("YIELD 1")
self.pool.put(client)
except Exception:
logger.info("[Pool] Client dead, replacing...")
self.replace_client(client)

@timed
def close(self):
for client in self.clients:
with suppress(Exception):
client.close()
self.clients.clear()

@timed
def get(self):
"""
Context manager: with pool.get() as client:
"""

class _ClientContext:
def __init__(self, outer):
self.outer = outer
self.client = None

def __enter__(self):
self.client = self.outer.get_client()
return self.client

def __exit__(self, exc_type, exc_val, exc_tb):
if self.client:
self.outer.return_client(self.client)

return _ClientContext(self)

@timed
def reset_pool(self):
"""⚠️ Emergency reset: Close all clients and clear the pool."""
logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
with self.lock:
for client in self.clients:
try:
client.close()
except Exception:
logger.error("Fail to close!!!")
self.clients.clear()
while not self.pool.empty():
try:
self.pool.get_nowait()
except Empty:
break
for _ in range(self.minsize):
self._create_and_add_client()
logger.info("[Pool] Pool has been reset successfully.")

@timed
def replace_client(self, client):
try:
client.close()
except Exception:
logger.error("Fail to close client")

if client in self.clients:
self.clients.remove(client)

from nebulagraph_python import NebulaClient

new_client = NebulaClient(self.hosts, self.user, self.password)
self.clients.append(new_client)

self.pool.put(new_client)

logger.info("[Pool] Replaced dead client with a new one.")
return new_client


class NebulaGraphDB(BaseGraphDB):
"""
NebulaGraph-based implementation of a graph memory store.
Expand Down Expand Up @@ -125,19 +257,18 @@ def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig):
Get a shared NebulaPool from cache or create one if missing.
Thread-safe with a lock; maintains a simple refcount.
"""
from nebulagraph_python.client.pool import NebulaPool, NebulaPoolConfig

key = cls._make_pool_key(cfg)

with cls._POOL_LOCK:
pool = cls._POOL_CACHE.get(key)
if pool is None:
# Create a new pool and put into cache
pool = NebulaPool(
pool = SessionPool(
hosts=cfg.get("uri"),
username=cfg.get("user"),
user=cfg.get("user"),
password=cfg.get("password"),
pool_config=NebulaPoolConfig(max_client_size=cfg.get("max_client", 1000)),
minsize=1,
maxsize=cfg.get("max_client", 1000),
)
cls._POOL_CACHE[key] = pool
cls._POOL_REFCOUNT[key] = 0
Expand Down Expand Up @@ -256,17 +387,18 @@ def __init__(self, config: NebulaGraphDBConfig):

@timed
def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
needs_use_prefix = ("SESSION SET GRAPH" not in gql) and ("USE " not in gql)
use_prefix = f"USE `{self.db_name}` " if auto_set_db and needs_use_prefix else ""

ngql = use_prefix + gql
with self.pool.get() as client:
try:
if auto_set_db and self.db_name:
client.execute(f"SESSION SET GRAPH `{self.db_name}`")
return client.execute(gql, timeout=timeout)

try:
with self.pool.borrow() as client:
return client.execute(ngql, timeout=timeout)
except Exception as e:
logger.error(f"[execute_query] Failed: {e}")
raise
except Exception as e:
if "Session not found" in str(e) or "Connection not established" in str(e):
logger.warning(f"[execute_query] {e!s}, replacing client...")
self.pool.replace_client(client)
return self.execute_query(gql, timeout, auto_set_db)
raise

@timed
def close(self):
Expand Down Expand Up @@ -940,20 +1072,12 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
"""
where_clauses = []

def _escape_value(value):
if isinstance(value, str):
return f'"{value}"'
elif isinstance(value, list):
return "[" + ", ".join(_escape_value(v) for v in value) + "]"
else:
return str(value)

for _i, f in enumerate(filters):
field = f["field"]
op = f.get("op", "=")
value = f["value"]

escaped_value = _escape_value(value)
escaped_value = self._format_value(value)

# Build WHERE clause
if op == "=":
Expand Down
Loading
Loading