Skip to content
4 changes: 3 additions & 1 deletion src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ def get_reranker_config() -> dict[str, Any]:
"backend": "http_bge",
"config": {
"url": os.getenv("MOS_RERANKER_URL"),
"model": "bge-reranker-v2-m3",
"model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"),
"timeout": 10,
"headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"),
"rerank_source": os.getenv("MOS_RERANK_SOURCE"),
},
}
else:
Expand Down
59 changes: 59 additions & 0 deletions src/memos/reranker/concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import re

from typing import Any


_TAG1 = re.compile(r"^\s*\[[^\]]*\]\s*")


def process_source(
items: list[tuple[Any, str | dict[str, Any] | list[Any]]] | None = None, recent_num: int = 3
) -> str:
"""
Args:
items: List of tuples where each tuple contains (memory, source).
source can be str, Dict, or List.
recent_num: Number of recent items to concatenate.
Returns:
str: Concatenated source.
"""
if items is None:
items = []
concat_data = []
memory = None
for item in items:
memory, source = item
for content in source:
if isinstance(content, str):
if "assistant:" in content:
continue
concat_data.append(content)
if memory is not None:
concat_data = [memory, *concat_data]
return "\n".join(concat_data)


def concat_original_source(
graph_results: list,
merge_field: list[str] | None = None,
) -> list[str]:
"""
Merge memory items with original dialogue.
Args:
graph_results (list[TextualMemoryItem]): List of memory items with embeddings.
merge_field (List[str]): List of fields to merge.
Returns:
list[str]: List of memory and concat orginal memory.
"""
if merge_field is None:
merge_field = ["sources"]
documents = []
for item in graph_results:
memory = _TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m
sources = []
for field in merge_field:
source = getattr(item.metadata, field, "")
sources.append((memory, source))
concat_string = process_source(sources)
documents.append(concat_string)
return documents
1 change: 1 addition & 0 deletions src/memos/reranker/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def from_config(cfg: RerankerConfigFactory | None) -> BaseReranker | None:
model=c.get("model", "bge-reranker-v2-m3"),
timeout=int(c.get("timeout", 10)),
headers_extra=c.get("headers_extra"),
rerank_source=c.get("rerank_source"),
)

if backend in {"cosine_local", "cosine"}:
Expand Down
27 changes: 21 additions & 6 deletions src/memos/reranker/http_bge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

import requests

from memos.log import get_logger

from .base import BaseReranker
from .concat import concat_original_source


logger = get_logger(__name__)


if TYPE_CHECKING:
Expand All @@ -28,6 +34,7 @@ def __init__(
model: str = "bge-reranker-v2-m3",
timeout: int = 10,
headers_extra: dict | None = None,
rerank_source: list[str] | None = None,
):
if not reranker_url:
raise ValueError("reranker_url must not be empty")
Expand All @@ -36,6 +43,7 @@ def __init__(
self.model = model
self.timeout = timeout
self.headers_extra = headers_extra or {}
self.concat_source = rerank_source

def rerank(
self,
Expand All @@ -47,11 +55,18 @@ def rerank(
if not graph_results:
return []

documents = [
(_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m)
for item in graph_results
]
documents = [d for d in documents if isinstance(d, str) and d]
documents = []
if self.concat_source:
documents = concat_original_source(graph_results, self.concat_source)
else:
documents = [
(_TAG1.sub("", m) if isinstance((m := getattr(item, "memory", None)), str) else m)
for item in graph_results
]
documents = [d for d in documents if isinstance(d, str) and d]

logger.info(f"[HTTPBGERerankerSample] query: {query} , documents: {documents[:5]}...")

if not documents:
return []

Expand Down Expand Up @@ -95,5 +110,5 @@ def rerank(
return [(item, 0.0) for item in graph_results[:top_k]]

except Exception as e:
print(f"[HTTPBGEReranker] request failed: {e}")
logger.error(f"[HTTPBGEReranker] request failed: {e}")
return [(item, 0.0) for item in graph_results[:top_k]]
Loading