Skip to content

Commit 79ad733

Browse files
authored
feat: add reranker (#277)
* feat: parallel doc process * add memory size config for tree * feat: add reranker Facktory * feat: pass reranker from tree config * feat: add reranker config in mos product * style: modify annotation * feat: slightly adjust similarity threshold for returned memories * test: fix researcher test script
1 parent 296bc92 commit 79ad733

File tree

16 files changed

+581
-157
lines changed

16 files changed

+581
-157
lines changed

examples/basic_modules/reranker.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import os
2+
import uuid
3+
4+
from dotenv import load_dotenv
5+
6+
from memos import log
7+
from memos.configs.embedder import EmbedderConfigFactory
8+
from memos.configs.reranker import RerankerConfigFactory
9+
from memos.embedders.factory import EmbedderFactory
10+
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
11+
from memos.reranker.factory import RerankerFactory
12+
13+
14+
load_dotenv()
15+
logger = log.get_logger(__name__)
16+
17+
18+
def make_item(text: str) -> TextualMemoryItem:
19+
"""Build a minimal TextualMemoryItem; embedding will be populated later."""
20+
return TextualMemoryItem(
21+
id=str(uuid.uuid4()),
22+
memory=text,
23+
metadata=TreeNodeTextualMemoryMetadata(
24+
user_id=None,
25+
session_id=None,
26+
status="activated",
27+
type="fact",
28+
memory_time="2024-01-01",
29+
source="conversation",
30+
confidence=100.0,
31+
tags=[],
32+
visibility="public",
33+
updated_at="2025-01-01T00:00:00",
34+
memory_type="LongTermMemory",
35+
key="demo_key",
36+
sources=["demo://example"],
37+
embedding=[],
38+
background="demo background...",
39+
),
40+
)
41+
42+
43+
def show_ranked(title: str, ranked: list[tuple[TextualMemoryItem, float]], top_n: int = 5) -> None:
44+
print(f"\n=== {title} ===")
45+
for i, (item, score) in enumerate(ranked[:top_n], start=1):
46+
preview = (item.memory[:80] + "...") if len(item.memory) > 80 else item.memory
47+
print(f"[#{i}] score={score:.6f} | {preview}")
48+
49+
50+
def main():
51+
# -------------------------------
52+
# 1) Build the embedder (real vectors)
53+
# -------------------------------
54+
embedder_cfg = EmbedderConfigFactory.model_validate(
55+
{
56+
"backend": "universal_api",
57+
"config": {
58+
"provider": "openai", # or "azure"
59+
"api_key": os.getenv("OPENAI_API_KEY"),
60+
"model_name_or_path": "text-embedding-3-large",
61+
"base_url": os.getenv("OPENAI_API_BASE"), # optional
62+
},
63+
}
64+
)
65+
embedder = EmbedderFactory.from_config(embedder_cfg)
66+
67+
# -------------------------------
68+
# 2) Prepare query + documents
69+
# -------------------------------
70+
query = "What is the capital of France?"
71+
items = [
72+
make_item("Paris is the capital of France."),
73+
make_item("Berlin is the capital of Germany."),
74+
make_item("The capital of Brazil is Brasilia."),
75+
make_item("Apples and bananas are common fruits."),
76+
make_item("The Eiffel Tower is a famous landmark in Paris."),
77+
]
78+
79+
# -------------------------------
80+
# 3) Embed query + docs with real embeddings
81+
# -------------------------------
82+
texts_to_embed = [query] + [it.memory for it in items]
83+
vectors = embedder.embed(texts_to_embed) # real vectors from your provider/model
84+
query_embedding = vectors[0]
85+
doc_embeddings = vectors[1:]
86+
87+
# attach real embeddings back to items
88+
for it, emb in zip(items, doc_embeddings, strict=False):
89+
it.metadata.embedding = emb
90+
91+
# -------------------------------
92+
# 4) Rerank with cosine_local (uses your real embeddings)
93+
# -------------------------------
94+
cosine_cfg = RerankerConfigFactory.model_validate(
95+
{
96+
"backend": "cosine_local",
97+
"config": {
98+
# structural boosts (optional): uses metadata.background
99+
"level_weights": {"topic": 1.0, "concept": 1.0, "fact": 1.0},
100+
"level_field": "background",
101+
},
102+
}
103+
)
104+
cosine_reranker = RerankerFactory.from_config(cosine_cfg)
105+
106+
ranked_cosine = cosine_reranker.rerank(
107+
query=query,
108+
graph_results=items,
109+
top_k=10,
110+
query_embedding=query_embedding, # required by cosine_local
111+
)
112+
show_ranked("CosineLocal Reranker (with real embeddings)", ranked_cosine, top_n=5)
113+
114+
# -------------------------------
115+
# 5) (Optional) Rerank with HTTP BGE (OpenAI-style /query+documents)
116+
# Requires the service URL; no need for embeddings here
117+
# -------------------------------
118+
bge_url = os.getenv("BGE_RERANKER_URL") # e.g., "http://xxx.x.xxxxx.xxx:xxxx/v1/rerank"
119+
if bge_url:
120+
http_cfg = RerankerConfigFactory.model_validate(
121+
{
122+
"backend": "http_bge",
123+
"config": {
124+
"url": bge_url,
125+
"model": os.getenv("BGE_RERANKER_MODEL", "bge-reranker-v2-m3"),
126+
"timeout": int(os.getenv("BGE_RERANKER_TIMEOUT", "10")),
127+
# "headers_extra": {"Authorization": f"Bearer {os.getenv('BGE_RERANKER_TOKEN')}"}
128+
},
129+
}
130+
)
131+
http_reranker = RerankerFactory.from_config(http_cfg)
132+
133+
ranked_http = http_reranker.rerank(
134+
query=query,
135+
graph_results=items, # uses item.memory internally as documents
136+
top_k=10,
137+
)
138+
show_ranked("HTTP BGE Reranker (OpenAI-style API)", ranked_http, top_n=5)
139+
else:
140+
print("\n[Info] Skipped HTTP BGE scenario because BGE_RERANKER_URL is not set.")
141+
142+
143+
if __name__ == "__main__":
144+
main()

examples/basic_modules/tree_textual_memory_reranker.py

Lines changed: 0 additions & 121 deletions
This file was deleted.

src/memos/api/config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,29 @@ def get_activation_vllm_config() -> dict[str, Any]:
9090
},
9191
}
9292

93+
@staticmethod
94+
def get_reranker_config() -> dict[str, Any]:
95+
"""Get embedder configuration."""
96+
embedder_backend = os.getenv("MOS_RERANKER_BACKEND", "http_bge")
97+
98+
if embedder_backend == "http_bge":
99+
return {
100+
"backend": "universal_api",
101+
"config": {
102+
"url": os.getenv("MOS_RERANKER_URL"),
103+
"model": "bge-reranker-v2-m3",
104+
"timeout": 10,
105+
},
106+
}
107+
else:
108+
return {
109+
"backend": "cosine_local",
110+
"config": {
111+
"level_weights": {"topic": 1.0, "concept": 1.0, "fact": 1.0},
112+
"level_field": "background",
113+
},
114+
}
115+
93116
@staticmethod
94117
def get_embedder_config() -> dict[str, Any]:
95118
"""Get embedder configuration."""
@@ -492,6 +515,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
492515
},
493516
"embedder": APIConfig.get_embedder_config(),
494517
"internet_retriever": internet_config,
518+
"reranker": APIConfig.get_reranker_config(),
495519
},
496520
},
497521
"act_mem": {}
@@ -545,6 +569,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
545569
"config": graph_db_backend_map[graph_db_backend],
546570
},
547571
"embedder": APIConfig.get_embedder_config(),
572+
"reranker": APIConfig.get_reranker_config(),
548573
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower()
549574
== "true",
550575
"internet_retriever": internet_config,

src/memos/configs/memory.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from memos.configs.graph_db import GraphDBConfigFactory
88
from memos.configs.internet_retriever import InternetRetrieverConfigFactory
99
from memos.configs.llm import LLMConfigFactory
10+
from memos.configs.reranker import RerankerConfigFactory
1011
from memos.configs.vec_db import VectorDBConfigFactory
1112
from memos.exceptions import ConfigurationError
1213

@@ -151,6 +152,10 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig):
151152
default_factory=EmbedderConfigFactory,
152153
description="Embedder configuration for the memory embedding",
153154
)
155+
reranker: RerankerConfigFactory | None = Field(
156+
None,
157+
description="Reranker configuration (optional, defaults to cosine_local).",
158+
)
154159
graph_db: GraphDBConfigFactory = Field(
155160
...,
156161
default_factory=GraphDBConfigFactory,
@@ -166,6 +171,14 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig):
166171
description="Optional description for this memory configuration.",
167172
)
168173

174+
memory_size: dict[str, Any] | None = Field(
175+
default=None,
176+
description=(
177+
"Maximum item counts per memory bucket, e.g.: "
178+
'{"WorkingMemory": 20, "LongTermMemory": 10000, "UserMemory": 10000}'
179+
),
180+
)
181+
169182

170183
# ─── 3. Global Memory Config Factory ──────────────────────────────────────────
171184

src/memos/configs/reranker.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# memos/configs/reranker.py
2+
from __future__ import annotations
3+
4+
from typing import Any
5+
6+
from pydantic import BaseModel, Field
7+
8+
9+
class RerankerConfigFactory(BaseModel):
10+
"""
11+
{
12+
"backend": "http_bge" | "cosine_local" | "noop",
13+
"config": { ... backend-specific ... }
14+
}
15+
"""
16+
17+
backend: str = Field(..., description="Reranker backend id")
18+
config: dict[str, Any] = Field(default_factory=dict, description="Backend-specific options")

src/memos/mem_os/product.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ async def _post_chat_processing(
557557
send_online_bot_notification_async,
558558
)
559559

560-
# 准备通知数据
560+
# Prepare notification data
561561
chat_data = {
562562
"query": query,
563563
"user_id": user_id,
@@ -693,7 +693,7 @@ def run_async_in_thread():
693693
thread.start()
694694

695695
def _filter_memories_by_threshold(
696-
self, memories: list[TextualMemoryItem], threshold: float = 0.50, min_num: int = 3
696+
self, memories: list[TextualMemoryItem], threshold: float = 0.52, min_num: int = 0
697697
) -> list[TextualMemoryItem]:
698698
"""
699699
Filter memories by threshold.

0 commit comments

Comments
 (0)