Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/memos/api/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import json
import os

from typing import Any

import requests

from memos.log import get_logger


logger = get_logger(__name__)

MAX_RETRY_COUNT = 3


class MemOSResponse:
"""Response wrapper to support dot notation access"""

def __init__(self, data):
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, dict):
setattr(self, key, MemOSResponse(value))
else:
setattr(self, key, value)
else:
self.data = data


class MemOSClient:
"""MemOS API client"""

def __init__(self, api_key: str | None = None, base_url: str | None = None):
self.base_url = (
base_url or os.getenv("MEMOS_BASE_URL") or "https://memos.memtensor.cn/api/openmem"
)
api_key = api_key or os.getenv("MEMOS_API_KEY")

if not api_key:
raise ValueError("MemOS API key is required")

self.headers = {"Content-Type": "application/json", "Authorization": f"Token {api_key}"}

def _validate_required_params(self, **params):
"""Validate required parameters - if passed, they must not be empty"""
for param_name, param_value in params.items():
if not param_value:
raise ValueError(f"{param_name} is required")

def add(
self, messages: list[dict[str, Any]], user_id: str, conversation_id: str
) -> MemOSResponse:
"""Add memories"""
# Validate required parameters
self._validate_required_params(
messages=messages, user_id=user_id, conversation_id=conversation_id
)

url = f"{self.base_url}/add/message"
payload = {"messages": messages, "userId": user_id, "conversationId": conversation_id}

for retry in range(MAX_RETRY_COUNT):
try:
response = requests.post(
url, data=json.dumps(payload), headers=self.headers, timeout=30
)
response.raise_for_status()
response_data = json.loads(response.text)
return MemOSResponse(response_data)
except Exception as e:
logger.error(f"Failed to add memory (retry {retry + 1}/3): {e}")
if retry == MAX_RETRY_COUNT - 1:
raise

def search(
self, query: str, user_id: str, conversation_id: str, memory_limit_number: int = 6
) -> MemOSResponse:
"""Search memories"""
# Validate required parameters
self._validate_required_params(query=query, user_id=user_id)

url = f"{self.base_url}/search/memory"
payload = {
"query": query,
"userId": user_id,
"conversationId": conversation_id,
"memoryLimitNumber": memory_limit_number,
}

for retry in range(MAX_RETRY_COUNT):
try:
response = requests.post(
url, data=json.dumps(payload), headers=self.headers, timeout=30
)
response.raise_for_status()
response_data = json.loads(response.text)
return MemOSResponse(response_data)
except Exception as e:
logger.error(f"Failed to search memory (retry {retry + 1}/3): {e}")
if retry == MAX_RETRY_COUNT - 1:
raise
4 changes: 1 addition & 3 deletions src/memos/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import time

from memos import settings
from memos.log import get_logger


Expand All @@ -14,8 +13,7 @@ def wrapper(*args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start
if settings.DEBUG:
logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s")
logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s")
return result

return wrapper
Loading