Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions examples/basic_modules/reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import os
import uuid

from dotenv import load_dotenv

from memos import log
from memos.configs.embedder import EmbedderConfigFactory
from memos.configs.reranker import RerankerConfigFactory
from memos.embedders.factory import EmbedderFactory
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
from memos.reranker.factory import RerankerFactory


load_dotenv()
logger = log.get_logger(__name__)


def make_item(text: str) -> TextualMemoryItem:
"""Build a minimal TextualMemoryItem; embedding will be populated later."""
return TextualMemoryItem(
id=str(uuid.uuid4()),
memory=text,
metadata=TreeNodeTextualMemoryMetadata(
user_id=None,
session_id=None,
status="activated",
type="fact",
memory_time="2024-01-01",
source="conversation",
confidence=100.0,
tags=[],
visibility="public",
updated_at="2025-01-01T00:00:00",
memory_type="LongTermMemory",
key="demo_key",
sources=["demo://example"],
embedding=[],
background="demo background...",
),
)


def show_ranked(title: str, ranked: list[tuple[TextualMemoryItem, float]], top_n: int = 5) -> None:
print(f"\n=== {title} ===")
for i, (item, score) in enumerate(ranked[:top_n], start=1):
preview = (item.memory[:80] + "...") if len(item.memory) > 80 else item.memory
print(f"[#{i}] score={score:.6f} | {preview}")


def main():
# -------------------------------
# 1) Build the embedder (real vectors)
# -------------------------------
embedder_cfg = EmbedderConfigFactory.model_validate(
{
"backend": "universal_api",
"config": {
"provider": "openai", # or "azure"
"api_key": os.getenv("OPENAI_API_KEY"),
"model_name_or_path": "text-embedding-3-large",
"base_url": os.getenv("OPENAI_API_BASE"), # optional
},
}
)
embedder = EmbedderFactory.from_config(embedder_cfg)

# -------------------------------
# 2) Prepare query + documents
# -------------------------------
query = "What is the capital of France?"
items = [
make_item("Paris is the capital of France."),
make_item("Berlin is the capital of Germany."),
make_item("The capital of Brazil is Brasilia."),
make_item("Apples and bananas are common fruits."),
make_item("The Eiffel Tower is a famous landmark in Paris."),
]

# -------------------------------
# 3) Embed query + docs with real embeddings
# -------------------------------
texts_to_embed = [query] + [it.memory for it in items]
vectors = embedder.embed(texts_to_embed) # real vectors from your provider/model
query_embedding = vectors[0]
doc_embeddings = vectors[1:]

# attach real embeddings back to items
for it, emb in zip(items, doc_embeddings, strict=False):
it.metadata.embedding = emb

# -------------------------------
# 4) Rerank with cosine_local (uses your real embeddings)
# -------------------------------
cosine_cfg = RerankerConfigFactory.model_validate(
{
"backend": "cosine_local",
"config": {
# structural boosts (optional): uses metadata.background
"level_weights": {"topic": 1.0, "concept": 1.0, "fact": 1.0},
"level_field": "background",
},
}
)
cosine_reranker = RerankerFactory.from_config(cosine_cfg)

ranked_cosine = cosine_reranker.rerank(
query=query,
graph_results=items,
top_k=10,
query_embedding=query_embedding, # required by cosine_local
)
show_ranked("CosineLocal Reranker (with real embeddings)", ranked_cosine, top_n=5)

# -------------------------------
# 5) (Optional) Rerank with HTTP BGE (OpenAI-style /query+documents)
# Requires the service URL; no need for embeddings here
# -------------------------------
bge_url = os.getenv("BGE_RERANKER_URL") # e.g., "http://xxx.x.xxxxx.xxx:xxxx/v1/rerank"
if bge_url:
http_cfg = RerankerConfigFactory.model_validate(
{
"backend": "http_bge",
"config": {
"url": bge_url,
"model": os.getenv("BGE_RERANKER_MODEL", "bge-reranker-v2-m3"),
"timeout": int(os.getenv("BGE_RERANKER_TIMEOUT", "10")),
# "headers_extra": {"Authorization": f"Bearer {os.getenv('BGE_RERANKER_TOKEN')}"}
},
}
)
http_reranker = RerankerFactory.from_config(http_cfg)

ranked_http = http_reranker.rerank(
query=query,
graph_results=items, # uses item.memory internally as documents
top_k=10,
)
show_ranked("HTTP BGE Reranker (OpenAI-style API)", ranked_http, top_n=5)
else:
print("\n[Info] Skipped HTTP BGE scenario because BGE_RERANKER_URL is not set.")


if __name__ == "__main__":
main()
121 changes: 0 additions & 121 deletions examples/basic_modules/tree_textual_memory_reranker.py

This file was deleted.

25 changes: 25 additions & 0 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,29 @@ def get_activation_vllm_config() -> dict[str, Any]:
},
}

@staticmethod
def get_reranker_config() -> dict[str, Any]:
"""Get embedder configuration."""
embedder_backend = os.getenv("MOS_RERANKER_BACKEND", "http_bge")

if embedder_backend == "http_bge":
return {
"backend": "universal_api",
"config": {
"url": os.getenv("MOS_RERANKER_URL"),
"model": "bge-reranker-v2-m3",
"timeout": 10,
},
}
else:
return {
"backend": "cosine_local",
"config": {
"level_weights": {"topic": 1.0, "concept": 1.0, "fact": 1.0},
"level_field": "background",
},
}

@staticmethod
def get_embedder_config() -> dict[str, Any]:
"""Get embedder configuration."""
Expand Down Expand Up @@ -492,6 +515,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General
},
"embedder": APIConfig.get_embedder_config(),
"internet_retriever": internet_config,
"reranker": APIConfig.get_reranker_config(),
},
},
"act_mem": {}
Expand Down Expand Up @@ -545,6 +569,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None:
"config": graph_db_backend_map[graph_db_backend],
},
"embedder": APIConfig.get_embedder_config(),
"reranker": APIConfig.get_reranker_config(),
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower()
== "true",
"internet_retriever": internet_config,
Expand Down
13 changes: 13 additions & 0 deletions src/memos/configs/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from memos.configs.graph_db import GraphDBConfigFactory
from memos.configs.internet_retriever import InternetRetrieverConfigFactory
from memos.configs.llm import LLMConfigFactory
from memos.configs.reranker import RerankerConfigFactory
from memos.configs.vec_db import VectorDBConfigFactory
from memos.exceptions import ConfigurationError

Expand Down Expand Up @@ -151,6 +152,10 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig):
default_factory=EmbedderConfigFactory,
description="Embedder configuration for the memory embedding",
)
reranker: RerankerConfigFactory | None = Field(
None,
description="Reranker configuration (optional, defaults to cosine_local).",
)
graph_db: GraphDBConfigFactory = Field(
...,
default_factory=GraphDBConfigFactory,
Expand All @@ -166,6 +171,14 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig):
description="Optional description for this memory configuration.",
)

memory_size: dict[str, Any] | None = Field(
default=None,
description=(
"Maximum item counts per memory bucket, e.g.: "
'{"WorkingMemory": 20, "LongTermMemory": 10000, "UserMemory": 10000}'
),
)


# ─── 3. Global Memory Config Factory ──────────────────────────────────────────

Expand Down
18 changes: 18 additions & 0 deletions src/memos/configs/reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# memos/configs/reranker.py
from __future__ import annotations

from typing import Any

from pydantic import BaseModel, Field


class RerankerConfigFactory(BaseModel):
"""
{
"backend": "http_bge" | "cosine_local" | "noop",
"config": { ... backend-specific ... }
}
"""

backend: str = Field(..., description="Reranker backend id")
config: dict[str, Any] = Field(default_factory=dict, description="Backend-specific options")
4 changes: 2 additions & 2 deletions src/memos/mem_os/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ async def _post_chat_processing(
send_online_bot_notification_async,
)

# 准备通知数据
# Prepare notification data
chat_data = {
"query": query,
"user_id": user_id,
Expand Down Expand Up @@ -693,7 +693,7 @@ def run_async_in_thread():
thread.start()

def _filter_memories_by_threshold(
self, memories: list[TextualMemoryItem], threshold: float = 0.50, min_num: int = 3
self, memories: list[TextualMemoryItem], threshold: float = 0.52, min_num: int = 0
) -> list[TextualMemoryItem]:
"""
Filter memories by threshold.
Expand Down
Loading
Loading