diff --git a/src/agent/pizza-orderer/index.ts b/src/agent/pizza-orderer/index.ts index 948cbfe..fc1e1f1 100644 --- a/src/agent/pizza-orderer/index.ts +++ b/src/agent/pizza-orderer/index.ts @@ -4,6 +4,7 @@ import { GenerativeUIAnnotation } from "../types"; import { z } from "zod"; import { AIMessage, ToolMessage } from "@langchain/langgraph-sdk"; import { v4 as uuidv4 } from "uuid"; +import { ChatOpenAI } from "@langchain/openai"; const PizzaOrdererAnnotation = Annotation.Root({ messages: GenerativeUIAnnotation.spec.messages, @@ -30,8 +31,8 @@ const workflow = new StateGraph(PizzaOrdererAnnotation) ), }) .describe("The schema for finding a pizza shop for the user"); - const model = new ChatAnthropic({ - model: "claude-3-5-sonnet-latest", + const model = new ChatOpenAI({ + model: "gpt-4o", temperature: 0, }).withStructuredOutput(findShopSchema, { name: "find_pizza_shop", @@ -76,8 +77,8 @@ const workflow = new StateGraph(PizzaOrdererAnnotation) order: z.string().describe("The full pizza order for the user"), }) .describe("The schema for ordering a pizza for the user"); - const model = new ChatAnthropic({ - model: "claude-3-5-sonnet-latest", + const model = new ChatOpenAI({ + model: "gpt-4o", temperature: 0, }).withStructuredOutput(placeOrderSchema, { name: "place_pizza_order", diff --git a/src/agent/supervisor/nodes/general-input.ts b/src/agent/supervisor/nodes/general-input.ts index 9e24123..6f49cf2 100644 --- a/src/agent/supervisor/nodes/general-input.ts +++ b/src/agent/supervisor/nodes/general-input.ts @@ -2,13 +2,39 @@ import { SupervisorState, SupervisorUpdate } from "../types"; import { ALL_TOOL_DESCRIPTIONS } from "../index"; import { ChatOpenAI } from "@langchain/openai"; +const USER_PROFILE = { + "id": "user123", + "name": "Alice", + "profile": { + "personalHistory": "Alice was born in Munich and moved to Berlin for her studies. She holds a Master’s degree in Computer Science from TU Berlin. From a young age, she was fascinated by how machines learn and evolve, which led her into the field of artificial intelligence. She started her career as a backend developer and gradually transitioned into AI-focused roles. Outside of work, Alice enjoys hiking in the Alps, painting, and participating in tech meetups.", + "education": "M.Sc. in Computer Science, Technische Universität Berlin", + "expertise": "Machine learning, deep learning, distributed systems, transformer models, and open-source AI contributions.", + "projectA": { + "title": "Multilingual Generative Language Model", + "description": "Led the development of a multilingual transformer-based model capable of generating coherent text in over 15 languages. The project aimed to enhance NLP capabilities in underrepresented languages.", + "technologies": ["PyTorch", "Transformers", "HuggingFace", "TensorBoard"] + }, + "projectB": { + "title": "Federated Learning System", + "description": "Built a federated learning pipeline enabling training on decentralized datasets across devices to maintain data privacy while improving model generalizability.", + "technologies": ["Python", "gRPC", "TensorFlow Federated", "Docker", "Kubernetes"] + }, + "ongoingProjects": [ + "Experimenting with sparse attention mechanisms for large language models.", + "Mentoring junior developers in the open-source AI community.", + "Collaborating with a university lab on adversarial robustness in deep models." + ] + } + } + export async function generalInput( state: SupervisorState, ): Promise { - const GENERAL_INPUT_SYSTEM_PROMPT = `You are an AI assistant. + const GENERAL_INPUT_SYSTEM_PROMPT = `You are a friendly and cheerful assistant, here to help out to your friend .Here is the background of your friend from ${USER_PROFILE} give some personal touch in language to show friendliness. If the user asks what you can do, describe these tools. ${ALL_TOOL_DESCRIPTIONS} + If the last message is a tool result, describe what the action was, congratulate the user, or send a friendly followup in response to the tool action. Ensure this is a clear and concise message. Otherwise, just answer as normal.`; diff --git a/src/agent/supervisor/nodes/router.ts b/src/agent/supervisor/nodes/router.ts index 5203e96..5229001 100644 --- a/src/agent/supervisor/nodes/router.ts +++ b/src/agent/supervisor/nodes/router.ts @@ -3,6 +3,7 @@ import { ChatGoogleGenerativeAI } from "@langchain/google-genai"; import { ALL_TOOL_DESCRIPTIONS } from "../index"; import { SupervisorState, SupervisorUpdate } from "../types"; import { formatMessages } from "@/agent/utils/format-messages"; +import { ChatOpenAI } from "@langchain/openai"; export async function router( state: SupervisorState, @@ -29,8 +30,8 @@ ${ALL_TOOL_DESCRIPTIONS} schema: routerSchema, }; - const llm = new ChatGoogleGenerativeAI({ - model: "gemini-2.0-flash", + const llm = new ChatOpenAI({ + model: "gpt-4o", temperature: 0, }) .bindTools([routerTool], { tool_choice: "router" }) diff --git a/src/agent/trip-planner/nodes/extraction.ts b/src/agent/trip-planner/nodes/extraction.ts index cdf7001..81785c7 100644 --- a/src/agent/trip-planner/nodes/extraction.ts +++ b/src/agent/trip-planner/nodes/extraction.ts @@ -76,20 +76,79 @@ export async function extraction( }, ]); - const prompt = `You're an AI assistant for planning trips. The user has requested information about a trip they want to go on. -Before you can help them, you need to extract the following information from their request: -- location - The location to plan the trip for. Can be a city, state, or country. -- startDate - The start date of the trip. Should be in YYYY-MM-DD format. Optional -- endDate - The end date of the trip. Should be in YYYY-MM-DD format. Optional -- numberOfGuests - The number of guests for the trip. Optional + const USER_PROFILE = { + "id": "user123", + "name": "Alice", + "profile": { + "personalHistory": "Alice was born in Munich and moved to Berlin for her studies. She holds a Master’s degree in Computer Science from TU Berlin. From a young age, she was fascinated by how machines learn and evolve, which led her into the field of artificial intelligence. She started her career as a backend developer and gradually transitioned into AI-focused roles. Outside of work, Alice enjoys hiking in the Alps, painting, and participating in tech meetups.", + "education": "M.Sc. in Computer Science, Technische Universität Berlin", + "expertise": "Machine learning, deep learning, distributed systems, transformer models, and open-source AI contributions.", + "projectA": { + "title": "Multilingual Generative Language Model", + "description": "Led the development of a multilingual transformer-based model capable of generating coherent text in over 15 languages. The project aimed to enhance NLP capabilities in underrepresented languages.", + "technologies": ["PyTorch", "Transformers", "HuggingFace", "TensorBoard"] + }, + "projectB": { + "title": "Federated Learning System", + "description": "Built a federated learning pipeline enabling training on decentralized datasets across devices to maintain data privacy while improving model generalizability.", + "technologies": ["Python", "gRPC", "TensorFlow Federated", "Docker", "Kubernetes"] + }, + "currentLocation":"Delhi", + "ongoingProjects": [ + "Experimenting with sparse attention mechanisms for large language models.", + "Mentoring junior developers in the open-source AI community.", + "Collaborating with a university lab on adversarial robustness in deep models." + ] + } + } -You are provided with the ENTIRE conversation history between you, and the user. Use these messages to extract the necessary information. + const prompt = `You’re an enthusiastic travel buddy named "Sunny" who chats with the user like a friend. You’ve been given the user’s profile (${USER_PROFILE}) so you can sprinkle in personal touches and even guess why they’re planning this trip. -Do NOT guess, or make up any information. If the user did NOT specify a location, please respond with a request for them to specify the location. -You should ONLY send a clarification message if the user did not provide the location. You do NOT need any of the other fields, so if they're missing, proceed without them. -It should be a single sentence, along the lines of "Please specify the location for the trip you want to go on". +Conversation Starter: -Extract only what is specified by the user. It is okay to leave fields blank if the user did not specify them. +Greet warmly and mention something from their profile: + +“Hey there! I remember you love street photography—ready to plan a trip full of colorful corners and candid moments?” + +Spark curiosity about their dream destination: + +“If you could teleport anywhere this instant, where would you land? A bustling city, a quiet beach, or maybe a mountain retreat?” + +Uncover their why: + +“I’m guessing you’re after this getaway to recharge after that big project at work—or is it to celebrate something special?” + +Explore travel style and companions: + +“Are you flying solo, road-tripping with friends, or bringing the whole family along?” + +“Do you lean more toward adventurous hikes, lazy beach days, foodie tours, or cultural explorations?” + +Nail down the basics (location, dates, guests): + +“Which place are we looking at?” (city/state/country) + +“Do you have dates in mind, or an ideal season?” (ask for YYYY‑MM‑DD if they know) + +“How many people should I plan for?” + +Bucket-list moments: + +“Any must-do experiences on your list? Hot-air balloon ride, local cooking class, or maybe dancing under the northern lights?” + +Extraction Rules: + +location: The trip destination (city/state/country). + +startDate/endDate: YYYY‑MM‑DD (optional; if unknown, ask for a season or month). + +numberOfGuests: Integer (optional; if unknown, ask for a headcount). + +Use the full conversation history to reuse provided details. Don’t guess or invent anything—if some info is missing, ask the user in your next message with a friendly follow-up. + +Once all details are collected, reply with a warm confirmation like: + +"Awesome! Planning a trip to [location] from [startDate] to [endDate] for [numberOfGuests] people. Let’s make it unforgettable! 🎉" `; const humanMessage = `Here is the entire conversation so far:\n${formatMessages(state.messages)}`; diff --git a/src/agent/trip-planner/nodes/tools.ts b/src/agent/trip-planner/nodes/tools.ts index d1e9016..3dd457c 100644 --- a/src/agent/trip-planner/nodes/tools.ts +++ b/src/agent/trip-planner/nodes/tools.ts @@ -41,47 +41,85 @@ export async function callTools( ACCOMMODATIONS_TOOLS, ); - const response = await llm.invoke([ - { - role: "system", - content: - "You are an AI assistant who helps users book trips. Use the user's most recent message(s) to contextually generate a response.", - }, - ...state.messages, - ]); + const systemPrompt = { + role: "system", + content: `You are an AI assistant who helps users book trips. When a user asks about a trip or destination, you should: +1. Use the list-accommodations tool to show available accommodations. +2. Use the list-restaurants tool to show local dining options. - const listAccommodationsToolCall = response.tool_calls?.find( - findToolCall("list-accommodations"), - ); - const listRestaurantsToolCall = response.tool_calls?.find( - findToolCall("list-restaurants"), - ); +After using these tools, always summarize the results in a friendly, readable text reply for the user. +- List the top accommodations and restaurants you found, including their names, prices, and ratings if available. +- If no results are found, say so. +- Always include this summary in your reply, even if the user did not specifically ask for it. - if (!listAccommodationsToolCall && !listRestaurantsToolCall) { - throw new Error("No tool calls found"); - } +These tools should be used for ANY trip-related query, even if the user hasn't specifically asked about accommodations or restaurants yet. - if (listAccommodationsToolCall) { - ui.push( - { - name: "accommodations-list", - props: { - toolCallId: listAccommodationsToolCall.id ?? "", - ...getAccommodationsListProps(state.tripDetails), - }, - }, - { message: response }, - ); - } +Current trip details: +- Location: ${state.tripDetails.location} +- Start Date: ${state.tripDetails.startDate} +- End Date: ${state.tripDetails.endDate} +- Number of Guests: ${state.tripDetails.numberOfGuests}`, + }; + + let messages: any[] = [systemPrompt, ...state.messages]; + let response = await llm.invoke(messages); + + // Tool call loop + while (response.tool_calls && response.tool_calls.length > 0) { + // 1. Generate tool messages for each tool call + const toolMessages = response.tool_calls + .map((toolCall: any) => { + if (toolCall.name === "list-accommodations") { + const accommodationsData = getAccommodationsListProps( + state.tripDetails as import("../types").TripDetails, + ); + ui.push( + { + name: "accommodations-list", + props: { + toolCallId: toolCall.id ?? "", + ...accommodationsData, + }, + }, + { message: response }, + ); + return { + role: "tool", + tool_call_id: toolCall.id, + name: toolCall.name, + content: JSON.stringify(accommodationsData.accommodations), + }; + } else if (toolCall.name === "list-restaurants") { + ui.push( + { + name: "restaurants-list", + props: { + tripDetails: + state.tripDetails as import("../types").TripDetails, + }, + }, + { message: response }, + ); + return { + role: "tool", + tool_call_id: toolCall.id, + name: toolCall.name, + content: "(Restaurant data coming soon!)", + }; + } + return null; + }) + .filter(Boolean); - if (listRestaurantsToolCall) { - ui.push( - { - name: "restaurants-list", - props: { tripDetails: state.tripDetails }, - }, - { message: response }, - ); + // 2. Only send: systemPrompt, userMessages, last ai message, tool messages + messages = [systemPrompt, ...state.messages, response, ...toolMessages]; + response = await llm.invoke(messages); + if ( + typeof response.content === "string" && + response.content.trim() !== "" + ) { + break; + } } return { diff --git a/src/agent/writer-agent/index.ts b/src/agent/writer-agent/index.ts index a7380bc..a6ee907 100644 --- a/src/agent/writer-agent/index.ts +++ b/src/agent/writer-agent/index.ts @@ -19,8 +19,9 @@ import { findToolCall } from "../find-tool-call"; import { GenerativeUIAnnotation } from "../types"; import type ComponentMap from "../../agent-uis/index"; +import { ChatOpenAI } from "@langchain/openai"; -const MODEL_NAME = "claude-3-5-sonnet-latest"; +const MODEL_NAME = "gpt-4o"; const WriterAnnotation = Annotation.Root({ messages: GenerativeUIAnnotation.spec.messages, @@ -150,7 +151,7 @@ async function suggestions(state: WriterState): WriterUpdate { messages.push({ type: "tool", content: "Finished", tool_call_id: tool.id }); } - const model = new ChatAnthropic({ model: MODEL_NAME }); + const model = new ChatOpenAI({ model: MODEL_NAME }); const finish = await model.invoke(messages); messages.push(finish); diff --git a/src/v2v_realtime/.dockerignore b/src/v2v_realtime/.dockerignore new file mode 100644 index 0000000..e69de29 diff --git a/src/v2v_realtime/.env.example b/src/v2v_realtime/.env.example new file mode 100644 index 0000000..65e98d1 --- /dev/null +++ b/src/v2v_realtime/.env.example @@ -0,0 +1,13 @@ +AZURE_OPENAI_ENDPOINT= +AZURE_OPENAI_REALTIME_DEPLOYMENT= +AZURE_OPENAI_REALTIME_VOICE_CHOICE= +AZURE_OPENAI_API_KEY= +AZURE_SEARCH_ENDPOINT= +AZURE_SEARCH_INDEX=<......> +AZURE_SEARCH_API_KEY=<.....> + +OPENAI_API_KEY="" + + +MONGO_URI= + diff --git a/src/v2v_realtime/Dockerfile b/src/v2v_realtime/Dockerfile new file mode 100644 index 0000000..0300103 --- /dev/null +++ b/src/v2v_realtime/Dockerfile @@ -0,0 +1,25 @@ +# Use the official Python image from the Docker Hub +FROM python:3.10-slim-bookworm +COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ + +# Set the working directory in the container +WORKDIR /app + +# Copy the requirements file into the container +COPY requirements.txt . + +# Create and activate a virtual environment +RUN uv venv --python 3.10 +ENV PATH="/app/.venv/bin:$PATH" + +# Install the dependencies +RUN uv pip install --no-cache-dir -r requirements.txt + +# Copy the FastAPI app code into the container +COPY . . + +# Expose the port that the app runs on +EXPOSE 8765 + +# Command to run the FastAPI app using uvicorn +CMD ["python", "app.py"] \ No newline at end of file diff --git a/src/v2v_realtime/__init__.py b/src/v2v_realtime/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/v2v_realtime/app.py b/src/v2v_realtime/app.py new file mode 100644 index 0000000..b8d0444 --- /dev/null +++ b/src/v2v_realtime/app.py @@ -0,0 +1,84 @@ +import logging +import os +from pathlib import Path + +from aiohttp import web +from azure.core.credentials import AzureKeyCredential +from azure.identity import AzureDeveloperCliCredential, DefaultAzureCredential +from dotenv import load_dotenv + +from ragtools import attach_rag_tools +from rtmt import RTMiddleTier + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("voicerag") + + +async def create_app(): + if not os.environ.get("RUNNING_IN_PRODUCTION"): + logger.info("Running in development mode, loading from .env file") + load_dotenv() + + llm_key = os.environ.get("AZURE_OPENAI_API_KEY") + search_key = os.environ.get("AZURE_SEARCH_API_KEY") + + credential = None + if not llm_key or not search_key: + if tenant_id := os.environ.get("AZURE_TENANT_ID"): + logger.info( + "Using AzureDeveloperCliCredential with tenant_id %s", tenant_id + ) + credential = AzureDeveloperCliCredential( + tenant_id=tenant_id, process_timeout=60 + ) + else: + logger.info("Using DefaultAzureCredential") + credential = DefaultAzureCredential() + llm_credential = AzureKeyCredential(llm_key) if llm_key else credential + search_credential = AzureKeyCredential(search_key) if search_key else credential + + app = web.Application() + + rtmt = RTMiddleTier( + credentials=llm_credential, + endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], + deployment=os.environ["AZURE_OPENAI_REALTIME_DEPLOYMENT"], + voice_choice=os.environ.get("AZURE_OPENAI_REALTIME_VOICE_CHOICE") or "alloy", + ) + rtmt.system_message = """ +You are an assistant that must use the 'langgraph_tool' to handle all user queries. + +Instructions: +1. For every user message, immediately forward it to the LangGraph backend using the 'langgraph_tool'. +2. Always include the full user message as the 'message' parameter. If a 'thread_id' is available from previous interactions, include it to maintain context; otherwise, omit it to start a new thread. +3. If the backend returns an error or is unavailable, inform the user: "Sorry, the backend is temporarily unavailable. Please try again later." +4. For multi-turn conversations, always preserve and use the thread context for accuracy. + +Be efficient, accurate, and transparent in relaying information between the user and the backend. +""".strip() + + attach_rag_tools( + rtmt, + credentials=search_credential, + search_endpoint=os.environ.get("AZURE_SEARCH_ENDPOINT"), + search_index=os.environ.get("AZURE_SEARCH_INDEX"), + semantic_configuration=os.environ.get("AZURE_SEARCH_SEMANTIC_CONFIGURATION", "") + or None, + identifier_field=os.environ.get("AZURE_SEARCH_IDENTIFIER_FIELD", "") + or "chunk_id", + content_field=os.environ.get("AZURE_SEARCH_CONTENT_FIELD", "") or "chunk", + embedding_field=os.environ.get("AZURE_SEARCH_EMBEDDING_FIELD", "") + or "text_vector", + title_field=os.environ.get("AZURE_SEARCH_TITLE_FIELD", "") or "title", + use_vector_query=(os.getenv("AZURE_SEARCH_USE_VECTOR_QUERY", "true") == "true"), + ) + + rtmt.attach_to_app(app, "/realtime") + + return app + + +if __name__ == "__main__": + host = "localhost" + port = 8765 + web.run_app(create_app(), host=host, port=port) diff --git a/src/v2v_realtime/ragtools.py b/src/v2v_realtime/ragtools.py new file mode 100644 index 0000000..533b4a7 --- /dev/null +++ b/src/v2v_realtime/ragtools.py @@ -0,0 +1,69 @@ +import re +import json +from typing import Any + +from azure.core.credentials import AzureKeyCredential +from azure.identity import DefaultAzureCredential +from azure.search.documents.aio import SearchClient +from rtmt import RTMiddleTier, Tool, ToolResult, ToolResultDirection + +from utils.rag_summary_data import get_answer, QARequest +from utils.langgraph_client import send_message, create_thread + +# New tool schema for langgraph +_langgraph_tool_schema = { + "type": "function", + "name": "langgraph_tool", + "description": "Use this tool to send a message to a LangGraph backend and get a response.", + "parameters": { + "type": "object", + "properties": { + "thread_id": { + "type": "string", + "description": "Thread ID to use for the message (optional).", + }, + "message": { + "type": "string", + "description": "Message to send to the backend.", + }, + }, + "required": ["message"], # Only message is required + "additionalProperties": False, + }, +} + + +async def _langgraph_tool(args: Any) -> ToolResult: + message = args["message"] + thread_id = args.get("thread_id") or create_thread() + result = send_message(thread_id, message) + return ToolResult( + json.dumps({"thread_id": thread_id, "response": result}), + ToolResultDirection.TO_SERVER, + ) + + +def attach_rag_tools( + rtmt: RTMiddleTier, + credentials: AzureKeyCredential | DefaultAzureCredential, + search_endpoint: str, + search_index: str, + semantic_configuration: str | None, + identifier_field: str, + content_field: str, + embedding_field: str, + title_field: str, + use_vector_query: bool, +) -> None: + if not isinstance(credentials, AzureKeyCredential): + credentials.get_token( + "https://search.azure.com/.default" + ) # warm this up before we start getting requests + search_client = SearchClient( + search_endpoint, search_index, credentials, user_agent="RTMiddleTier" + ) + + # rtmt.tools["search"] = Tool(schema=_search_tool_schema, target=lambda args: _search_tool(search_client, semantic_configuration, identifier_field, content_field, embedding_field, use_vector_query, args)) + rtmt.tools["langgraph_tool"] = Tool( + schema=_langgraph_tool_schema, target=lambda args: _langgraph_tool(args) + ) diff --git a/src/v2v_realtime/requirements.txt b/src/v2v_realtime/requirements.txt new file mode 100644 index 0000000..1f2acba Binary files /dev/null and b/src/v2v_realtime/requirements.txt differ diff --git a/src/v2v_realtime/rtmt.py b/src/v2v_realtime/rtmt.py new file mode 100644 index 0000000..6ec5d2d --- /dev/null +++ b/src/v2v_realtime/rtmt.py @@ -0,0 +1,285 @@ +import asyncio +import json +import logging +from enum import Enum +from typing import Any, Callable, Optional + +import aiohttp +from aiohttp import web, WSMessage, ClientWebSocketResponse +from azure.core.credentials import AzureKeyCredential +from azure.identity import DefaultAzureCredential, get_bearer_token_provider + +logger = logging.getLogger("voicerag") + + +class ToolResultDirection(Enum): + TO_SERVER = 1 + TO_CLIENT = 2 + + +class ToolResult: + text: str + destination: ToolResultDirection + + def __init__(self, text: str, destination: ToolResultDirection): + self.text = text + self.destination = destination + + def to_text(self) -> str: + if self.text is None: + return "" + return self.text if type(self.text) == str else json.dumps(self.text) + + +class Tool: + target: Callable[..., ToolResult] + schema: Any + + def __init__(self, target: Any, schema: Any): + self.target = target + self.schema = schema + + +class RTToolCall: + tool_call_id: str + previous_id: str + + def __init__(self, tool_call_id: str, previous_id: str): + self.tool_call_id = tool_call_id + self.previous_id = previous_id + + +class RTMiddleTier: + endpoint: str + deployment: str + key: Optional[str] = None + + # Tools are server-side only for now, though the case could be made for client-side tools + # in addition to server-side tools that are invisible to the client + tools: dict[str, Tool] = {} + + # Server-enforced configuration, if set, these will override the client's configuration + # Typically at least the model name and system message will be set by the server + model: Optional[str] = None + system_message: Optional[str] = None + temperature: Optional[float] = None + max_tokens: Optional[int] = None + disable_audio: Optional[bool] = None + voice_choice: Optional[str] = None + api_version: str = "2024-10-01-preview" + _tools_pending = {} + _token_provider = None + + def __init__( + self, + endpoint: str, + deployment: str, + credentials: AzureKeyCredential | DefaultAzureCredential, + voice_choice: Optional[str] = None, + ): + self.endpoint = endpoint + self.deployment = deployment + self.voice_choice = voice_choice + if voice_choice is not None: + logger.info("Realtime voice choice set to %s", voice_choice) + if isinstance(credentials, AzureKeyCredential): + self.key = credentials.key + else: + self._token_provider = get_bearer_token_provider( + credentials, "https://cognitiveservices.azure.com/.default" + ) + self._token_provider() # Warm up during startup so we have a token cached when the first request arrives + + async def _process_message_to_client( + self, + msg: WSMessage, + client_ws: web.WebSocketResponse, + server_ws: web.WebSocketResponse | ClientWebSocketResponse, + ) -> Optional[str]: + message = json.loads(msg.data) + updated_message = msg.data + if message is not None: + match message["type"]: + case "session.created": + session = message["session"] + # Hide the instructions, tools and max tokens from clients, if we ever allow client-side + # tools, this will need updating + session["instructions"] = "" + session["tools"] = [] + session["voice"] = self.voice_choice + session["tool_choice"] = "none" + session["max_response_output_tokens"] = None + updated_message = json.dumps(message) + + case "response.output_item.added": + if "item" in message and message["item"]["type"] == "function_call": + updated_message = None + + case "conversation.item.created": + if "item" in message and message["item"]["type"] == "function_call": + item = message["item"] + if item["call_id"] not in self._tools_pending: + self._tools_pending[item["call_id"]] = RTToolCall( + item["call_id"], message["previous_item_id"] + ) + updated_message = None + elif ( + "item" in message + and message["item"]["type"] == "function_call_output" + ): + updated_message = None + + case "response.function_call_arguments.delta": + updated_message = None + + case "response.function_call_arguments.done": + updated_message = None + + case "response.output_item.done": + if "item" in message and message["item"]["type"] == "function_call": + item = message["item"] + tool_call = self._tools_pending[message["item"]["call_id"]] + tool = self.tools[item["name"]] + args = item["arguments"] + try: + parsed_args = json.loads(args) + except json.JSONDecodeError as e: + logger.error( + f"Invalid JSON in function call arguments: {args} ({e})" + ) + updated_message = None + return updated_message + maybe_coro = tool.target(parsed_args) + if asyncio.iscoroutine(maybe_coro): + result = await maybe_coro + else: + result = maybe_coro + await server_ws.send_json( + { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": item["call_id"], + "output": ( + result.to_text() + if result.destination + == ToolResultDirection.TO_SERVER + else "" + ), + }, + } + ) + if result.destination == ToolResultDirection.TO_CLIENT: + # TODO: this will break clients that don't know about this extra message, rewrite + # this to be a regular text message with a special marker of some sort + await client_ws.send_json( + { + "type": "extension.middle_tier_tool_response", + "previous_item_id": tool_call.previous_id, + "tool_name": item["name"], + "tool_result": result.to_text(), + } + ) + updated_message = None + + case "response.done": + if len(self._tools_pending) > 0: + self._tools_pending.clear() # Any chance tool calls could be interleaved across different outstanding responses? + await server_ws.send_json({"type": "response.create"}) + if "response" in message: + replace = False + for i, output in enumerate( + reversed(message["response"]["output"]) + ): + if output["type"] == "function_call": + message["response"]["output"].pop(i) + replace = True + if replace: + updated_message = json.dumps(message) + + return updated_message + + async def _process_message_to_server( + self, msg: WSMessage, ws: web.WebSocketResponse + ) -> Optional[str]: + message = json.loads(msg.data) + updated_message = msg.data + if message is not None: + match message["type"]: + case "session.update": + session = message["session"] + if self.system_message is not None: + session["instructions"] = self.system_message + if self.temperature is not None: + session["temperature"] = self.temperature + if self.max_tokens is not None: + session["max_response_output_tokens"] = self.max_tokens + if self.disable_audio is not None: + session["disable_audio"] = self.disable_audio + if self.voice_choice is not None: + session["voice"] = self.voice_choice + session["tool_choice"] = "auto" if len(self.tools) > 0 else "none" + session["tools"] = [tool.schema for tool in self.tools.values()] + updated_message = json.dumps(message) + + return updated_message + + async def _forward_messages(self, ws: web.WebSocketResponse): + async with aiohttp.ClientSession(base_url=self.endpoint) as session: + params = {"api-version": self.api_version, "deployment": self.deployment} + headers = {} + if "x-ms-client-request-id" in ws.headers: + headers["x-ms-client-request-id"] = ws.headers["x-ms-client-request-id"] + if self.key is not None: + headers = {"api-key": self.key} + else: + if self._token_provider is None: + raise RuntimeError("Token provider is not initialized.") + headers = { + "Authorization": f"Bearer {self._token_provider()}" + } # NOTE: no async version of token provider, maybe refresh token on a timer? + async with session.ws_connect( + "/openai/realtime", headers=headers, params=params + ) as target_ws: + + async def from_client_to_server(): + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + new_msg = await self._process_message_to_server(msg, ws) + if new_msg is not None: + await target_ws.send_str(new_msg) + else: + print("Error: unexpected message type:", msg.type) + + # Means it is gracefully closed by the client then time to close the target_ws + if target_ws: + print("Closing OpenAI's realtime socket connection.") + await target_ws.close() + + async def from_server_to_client(): + async for msg in target_ws: + if msg.type == aiohttp.WSMsgType.TEXT: + new_msg = await self._process_message_to_client( + msg, ws, target_ws + ) + if new_msg is not None: + await ws.send_str(new_msg) + else: + print("Error: unexpected message type:", msg.type) + + try: + await asyncio.gather( + from_client_to_server(), from_server_to_client() + ) + except ConnectionResetError: + # Ignore the errors resulting from the client disconnecting the socket + pass + + async def _websocket_handler(self, request: web.Request): + ws = web.WebSocketResponse() + await ws.prepare(request) + await self._forward_messages(ws) + return ws + + def attach_to_app(self, app, path): + app.router.add_get(path, self._websocket_handler) diff --git a/src/v2v_realtime/utils/langgraph_client.py b/src/v2v_realtime/utils/langgraph_client.py new file mode 100644 index 0000000..6716055 --- /dev/null +++ b/src/v2v_realtime/utils/langgraph_client.py @@ -0,0 +1,68 @@ +import requests +import uuid +import time + +API_URL = "http://localhost:2024" +ASSISTANT_ID = "agent" + + +def create_thread(): + resp = requests.post(f"{API_URL}/threads", json={"metadata": {}}) + resp.raise_for_status() + return resp.json()["thread_id"] + + +def send_message(thread_id: str, user_message: str, max_wait: int = 15): + """ + Sends a message to the specified thread and prints all assistant ('ai') replies in the history (non-streaming). + Polls the history endpoint until at least one 'ai' message is found or until max_wait seconds have passed. + """ + message_id = str(uuid.uuid4()) + payload = { + "assistant_id": ASSISTANT_ID, + "input": { + "messages": [ + { + "type": "human", + "id": message_id, + "content": [{"type": "text", "text": user_message}], + } + ] + }, + } + response = requests.post( + f"{API_URL}/threads/{thread_id}/runs", + json=payload, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + + payload = {"limit": 1000} + ai_messages = [] + for _ in range(max_wait): + response = requests.post( + f"{API_URL}/threads/{thread_id}/history", + json=payload, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + history = response.json() + ai_messages = [] + for step in history: + messages = step.get("values", {}).get("messages", []) + for msg in messages: + if msg.get("type") == "ai" and "content" in msg: + ai_messages.append(msg["content"]) + if ai_messages: + break + time.sleep(1) + return ai_messages + + +if __name__ == "__main__": + thread_id = create_thread() + print(f"Created thread: {thread_id}") + user_message = input("Enter your message: ") + ai_messages = send_message(thread_id, user_message) + print("\n--- Assistant Replies ---") + print(ai_messages)