-
Notifications
You must be signed in to change notification settings - Fork 0
Refactor application structure to use LangGraph DeepAgents #423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: enhance
Are you sure you want to change the base?
Changes from all commits
26da4e0
95b1db9
a5546b3
4e96856
1581c41
c45eb66
7868f2e
bfd0d4d
e75bf2f
b8b0ae8
a4c374b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -2,9 +2,9 @@ | |||||
|
|
||||||
| from json import dumps, loads | ||||||
| from pathlib import Path | ||||||
| from textwrap import dedent | ||||||
| from typing import TYPE_CHECKING, AsyncGenerator, Self, Sequence | ||||||
| from typing import TYPE_CHECKING, AsyncGenerator, Self | ||||||
|
|
||||||
| from deepagents import create_deep_agent | ||||||
| from gradio import ( | ||||||
| Audio, | ||||||
| Blocks, | ||||||
|
|
@@ -22,31 +22,22 @@ | |||||
| Video, | ||||||
| ) | ||||||
| from gradio.components.chatbot import MetadataDict | ||||||
| from langchain_community.embeddings import FastEmbedEmbeddings | ||||||
| from langchain.chat_models import init_chat_model | ||||||
| from langchain_core.messages import ( | ||||||
| AIMessage, | ||||||
| AnyMessage, | ||||||
| BaseMessage, | ||||||
| HumanMessage, | ||||||
| ToolMessage, | ||||||
| ) | ||||||
| from langchain_core.runnables import Runnable | ||||||
| from langchain_core.prompt_values import StringPromptValue | ||||||
| from langchain_core.runnables.config import RunnableConfig | ||||||
| from langchain_core.tools import BaseTool | ||||||
| from langchain_mcp_adapters.client import MultiServerMCPClient | ||||||
| from langchain_openai import ChatOpenAI | ||||||
| from langgraph.graph import START, StateGraph | ||||||
| from langchain_openai.chat_models.base import BaseChatModel | ||||||
| from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver | ||||||
| from langgraph.graph.state import CompiledStateGraph | ||||||
| from langgraph.prebuilt import ToolNode, tools_condition | ||||||
| from m3u8 import M3U8, load | ||||||
| from mem0 import Memory | ||||||
| from mem0.configs.base import MemoryConfig | ||||||
| from mem0.embeddings.configs import EmbedderConfig | ||||||
| from mem0.llms.configs import LlmConfig | ||||||
| from mem0.vector_stores.configs import VectorStoreConfig | ||||||
| from openai import OpenAIError | ||||||
| from poml.integration.langchain import LangchainPomlTemplate | ||||||
| from pydantic import FilePath, HttpUrl, ValidationError | ||||||
| from qdrant_client.http.exceptions import ResponseHandlingException | ||||||
| from requests import Session | ||||||
|
|
||||||
| from chattr.app.settings import Settings, logger | ||||||
|
|
@@ -60,11 +51,9 @@ class App: | |||||
| """Main application class for the Chattr Multi-agent system app.""" | ||||||
|
|
||||||
| settings: Settings | ||||||
| _llm: ChatOpenAI | ||||||
| _model: Runnable | ||||||
| _tools: list[BaseTool] | ||||||
| _memory: Memory | ||||||
| _graph: CompiledStateGraph | ||||||
| _checkpointer: AsyncMongoDBSaver | ||||||
| _deep_agent: CompiledStateGraph | ||||||
|
|
||||||
| @classmethod | ||||||
| async def create(cls, settings: Settings) -> Self: | ||||||
|
|
@@ -82,131 +71,35 @@ async def create(cls, settings: Settings) -> Self: | |||||
| logger.warning(f"Failed to parse MCP config JSON: {e}") | ||||||
| except Exception as e: | ||||||
| logger.warning(f"Failed to setup MCP tools: {e}") | ||||||
| cls._llm = cls._setup_llm() | ||||||
| cls._model = cls._llm.bind_tools(cls._tools, parallel_tool_calls=False) | ||||||
| cls._memory = await cls._setup_memory() | ||||||
| cls._graph = cls._setup_graph() | ||||||
| cls._checkpointer = AsyncMongoDBSaver.from_conn_string("localhost:27017") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| cls._deep_agent = cls._setup_deepagents() | ||||||
| return cls() | ||||||
|
|
||||||
| @classmethod | ||||||
| def _setup_graph(cls) -> CompiledStateGraph: | ||||||
| """ | ||||||
| Construct and compile the state graph for the Chattr application. | ||||||
|
|
||||||
| This method defines the nodes and edges for the conversational agent | ||||||
| and tool interactions. | ||||||
|
|
||||||
| Returns: | ||||||
| CompiledStateGraph: The compiled state graph is ready for execution. | ||||||
| """ | ||||||
|
|
||||||
| def _clean_old_files(state: State) -> State: | ||||||
| """Clean up temporary old audio and video files.""" | ||||||
| if any(cls.settings.directory.audio.iterdir()): | ||||||
| for file in cls.settings.directory.audio.iterdir(): | ||||||
| try: | ||||||
| file.unlink() | ||||||
| except OSError as e: | ||||||
| logger.error(f"Failed to delete audio file {file}: {e}") | ||||||
| if any(cls.settings.directory.video.iterdir()): | ||||||
| for file in cls.settings.directory.video.iterdir(): | ||||||
| try: | ||||||
| file.unlink() | ||||||
| except OSError as e: | ||||||
| logger.error(f"Failed to delete video file {file}: {e}") | ||||||
| return state | ||||||
|
|
||||||
| async def _call_model(state: State) -> State: | ||||||
| """ | ||||||
| Generate a model response based on the current state and user memory. | ||||||
|
|
||||||
| This asynchronous function retrieves relevant memories, | ||||||
| constructs a system message, and invokes the language model. | ||||||
|
|
||||||
| Args: | ||||||
| state: The current State object containing messages and user ID. | ||||||
|
|
||||||
| Returns: | ||||||
| State: The updated State object with the model's response message. | ||||||
| """ | ||||||
| messages = state["messages"] | ||||||
| user_id = state["mem0_user_id"] | ||||||
|
|
||||||
| try: | ||||||
| if not user_id: | ||||||
| logger.warning("No user_id found in state") | ||||||
| user_id = "default" | ||||||
| memory = cls._retrieve_memory(messages, user_id) | ||||||
| system_messages = cls._setup_prompt(memory) | ||||||
| response = await cls._model.ainvoke([*system_messages, *messages]) | ||||||
| cls._update_memory(messages, response, user_id) | ||||||
| except Exception as e: | ||||||
| _msg = f"Error in chatbot: {e}" | ||||||
| logger.error(_msg) | ||||||
| raise Error(_msg) from e | ||||||
| return State(messages=[response], mem0_user_id=user_id) | ||||||
|
|
||||||
| graph_builder: StateGraph = StateGraph(State) | ||||||
| graph_builder.add_node("clean_old_files", _clean_old_files) | ||||||
| graph_builder.add_node("agent", _call_model) | ||||||
| graph_builder.add_node("tools", ToolNode(cls._tools)) | ||||||
| graph_builder.add_edge(START, "clean_old_files") | ||||||
| graph_builder.add_edge("clean_old_files", "agent") | ||||||
| graph_builder.add_conditional_edges("agent", tools_condition) | ||||||
| graph_builder.add_edge("tools", "agent") | ||||||
| return graph_builder.compile(debug=True) | ||||||
|
|
||||||
| @classmethod | ||||||
| def _retrieve_memory(cls, messages: list[AnyMessage], user_id: str) -> str: | ||||||
| memories = cls._memory.search(messages[-1].content, user_id=user_id) | ||||||
| memory_list: list[str] = memories["results"] | ||||||
| logger.info(f"Retrieved {len(memory_list)} relevant memories") | ||||||
| logger.debug(f"Memories: {memories}") | ||||||
|
|
||||||
| if len(memory_list): | ||||||
| memory_list = "\n".join( | ||||||
| [f"\t- {memory.get('memory')}" for memory in memory_list], | ||||||
| ) | ||||||
| memory = dedent( | ||||||
| f""" | ||||||
| Relevant information from previous conversations: | ||||||
| {memory_list} | ||||||
| """, | ||||||
| ) | ||||||
| else: | ||||||
| memory = "No previous conversation history available." | ||||||
| logger.debug(f"Memory context:\n{memory}") | ||||||
| return memory | ||||||
| def _setup_deepagents(cls) -> CompiledStateGraph: | ||||||
| """Return the DeepAgents multi-agent system graph.""" | ||||||
| return create_deep_agent( | ||||||
| model=cls._setup_model(), | ||||||
| system_prompt=cls._setup_prompt(), | ||||||
| tools=cls._tools, | ||||||
| checkpointer=cls._checkpointer, | ||||||
| use_longterm_memory=True, | ||||||
| ) | ||||||
|
|
||||||
| @classmethod | ||||||
| def _setup_prompt(cls, memory: str) -> Sequence[BaseMessage]: | ||||||
| def _setup_prompt(cls) -> str: | ||||||
| prompt_template = LangchainPomlTemplate.from_file( | ||||||
| cls.settings.directory.prompts / "template.poml", | ||||||
| speaker_mode=True, | ||||||
| speaker_mode=False, | ||||||
| ) | ||||||
| prompt = prompt_template.format(character="Napoleon", context=memory) | ||||||
| system_messages: Sequence[BaseMessage] = prompt.messages | ||||||
| return system_messages | ||||||
|
|
||||||
| @classmethod | ||||||
| def _update_memory( | ||||||
| cls, | ||||||
| messages: list[AnyMessage], | ||||||
| response: BaseMessage, | ||||||
| user_id: str, | ||||||
| ) -> None: | ||||||
| try: | ||||||
| interaction = [ | ||||||
| {"role": "user", "content": messages[-1].content}, | ||||||
| {"role": "assistant", "content": response.content}, | ||||||
| ] | ||||||
| mem0_result = cls._memory.add(interaction, user_id=user_id) | ||||||
| logger.info(f"Memory saved: {len(mem0_result.get('results', []))}") | ||||||
| except Exception as e: | ||||||
| logger.exception(f"Error saving memory: {e}") | ||||||
| prompt = prompt_template.format(character="Napoleon") | ||||||
| if not isinstance(prompt, StringPromptValue): | ||||||
| msg = "Prompt must be a StringPromptValue in non-speaker mode." | ||||||
| raise TypeError(msg) | ||||||
| return prompt.to_string() | ||||||
|
|
||||||
| @classmethod | ||||||
| def _setup_llm(cls) -> ChatOpenAI: | ||||||
| def _setup_model(cls) -> BaseChatModel: | ||||||
| """ | ||||||
| Initialize the ChatOpenAI language model using the provided settings. | ||||||
|
|
||||||
|
|
@@ -220,63 +113,12 @@ def _setup_llm(cls) -> ChatOpenAI: | |||||
| Exception: If the model initialization fails. | ||||||
| """ | ||||||
| try: | ||||||
| return ChatOpenAI( | ||||||
| base_url=str(cls.settings.model.url), | ||||||
| model=cls.settings.model.name, | ||||||
| api_key=cls.settings.model.api_key, | ||||||
| temperature=cls.settings.model.temperature, | ||||||
| ) | ||||||
| return init_chat_model("gemini-2.5-flash", model_provider="google_genai") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Model initialization is now hardcoded to Gemini; consider configurability. Making the model and provider configurable will allow easier support for future changes or additional options. Suggested implementation: try:
return init_chat_model(
cls.settings.model.name,
model_provider=cls.settings.model.provider
)
except Exception as e:
_msg = f"Failed to initialize ChatOpenAI model: {e}"
logger.error(_msg)
raise Error(_msg) from eEnsure that
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The chat model is hardcoded to use |
||||||
| except Exception as e: | ||||||
| _msg = f"Failed to initialize ChatOpenAI model: {e}" | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| logger.error(_msg) | ||||||
| raise Error(_msg) from e | ||||||
|
|
||||||
| @classmethod | ||||||
| async def _setup_memory(cls) -> Memory: | ||||||
| """ | ||||||
| Initialize and set up the Memory for state persistence. | ||||||
|
|
||||||
| Returns: | ||||||
| Memory: Configured memory instances. | ||||||
| """ | ||||||
| try: | ||||||
| return Memory( | ||||||
| MemoryConfig( | ||||||
| vector_store=VectorStoreConfig( | ||||||
| provider="qdrant", | ||||||
| config={ | ||||||
| "host": cls.settings.vector_database.url.host, | ||||||
| "port": cls.settings.vector_database.url.port, | ||||||
| "collection_name": cls.settings.memory.collection_name, | ||||||
| "embedding_model_dims": cls.settings.memory.embedding_dims, | ||||||
| }, | ||||||
| ), | ||||||
| llm=LlmConfig( | ||||||
| provider="langchain", | ||||||
| config={"model": cls._llm}, | ||||||
| ), | ||||||
| embedder=EmbedderConfig( | ||||||
| provider="langchain", | ||||||
| config={"model": FastEmbedEmbeddings()}, | ||||||
| ), | ||||||
| ), | ||||||
| ) | ||||||
| except ResponseHandlingException as e: | ||||||
| _msg = f"Failed to connect to Qdrant server: {e}" | ||||||
| logger.error(_msg) | ||||||
| raise Error(_msg) from e | ||||||
| except OpenAIError as e: | ||||||
| _msg = ( | ||||||
| "Failed to connect to Chat Model server: " | ||||||
| "setting the `MODEL__API_KEY` environment variable" | ||||||
| ) | ||||||
| logger.error(_msg) | ||||||
| raise Error(_msg) from e | ||||||
| except ValueError as e: | ||||||
| _msg = f"Failed to initialize memory: {e}" | ||||||
| logger.exception(_msg) | ||||||
| raise Error(_msg) from e | ||||||
|
|
||||||
| @staticmethod | ||||||
| async def _setup_tools(_mcp_client: MultiServerMCPClient) -> list[BaseTool]: | ||||||
| """ | ||||||
|
|
@@ -298,7 +140,7 @@ async def _setup_tools(_mcp_client: MultiServerMCPClient) -> list[BaseTool]: | |||||
| @classmethod | ||||||
| def draw_graph(cls) -> Path: | ||||||
| """Render the compiled state graph as a Mermaid PNG image and save it.""" | ||||||
| cls._graph.get_graph().draw_mermaid_png( | ||||||
| cls._deep_agent.get_graph().draw_mermaid_png( | ||||||
| output_file_path=cls.settings.directory.assets / "graph.png", | ||||||
| ) | ||||||
| return cls.settings.directory.assets / "graph.png" | ||||||
|
|
@@ -321,7 +163,7 @@ def gui(cls) -> Blocks: | |||||
| with Column(): | ||||||
| Markdown("---") | ||||||
| Markdown("# Model Prompt") | ||||||
| Markdown(cls._setup_prompt("")[-1].content) | ||||||
| Markdown(cls._setup_prompt()) | ||||||
| with Row(): | ||||||
| with Column(): | ||||||
| video = Video( | ||||||
|
|
@@ -384,9 +226,10 @@ async def generate_response( | |||||
| """ | ||||||
| is_audio_generated: bool = False | ||||||
| audio_file: FilePath | None = None | ||||||
| last_agent_message: AnyMessage | None = None | ||||||
| async for response in cls._graph.astream( | ||||||
| last_agent_message: AIMessage | None = None | ||||||
| async for response in cls._deep_agent.astream( | ||||||
| State(messages=[HumanMessage(content=message)], mem0_user_id="1"), | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| RunnableConfig(configurable={"thread_id": "1"}), | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||
| stream_mode="updates", | ||||||
| ): | ||||||
| logger.debug(f"Response type received: {response.keys()}") | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Hardcoded MongoDB connection string may reduce deployment flexibility.
Consider sourcing the connection string from configuration or environment variables to support multiple deployment environments.