Skip to content

Commit dd1cee0

Browse files
committed
Add suggestions by CodeRabbit review
1 parent 9e02999 commit dd1cee0

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

src/app/endpoints/query_v2.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
191191

192192
# Prepare tools for responses API
193193
toolgroups = await prepare_tools_for_responses_api(
194-
client, query_request, token, configuration
194+
client, query_request, token, configuration, mcp_headers
195195
)
196196

197197
# Prepare input for Responses API
@@ -455,6 +455,7 @@ async def prepare_tools_for_responses_api(
455455
query_request: QueryRequest,
456456
token: str,
457457
config: AppConfig,
458+
mcp_headers: dict[str, dict[str, str]] | None = None,
458459
) -> list[dict[str, Any]] | None:
459460
"""
460461
Prepare tools for Responses API including RAG and MCP tools.
@@ -467,6 +468,7 @@ async def prepare_tools_for_responses_api(
467468
query_request: The user's query request
468469
token: Authentication token for MCP tools
469470
config: Configuration object containing MCP server settings
471+
mcp_headers: Per-request headers for MCP servers
470472
471473
Returns:
472474
list[dict[str, Any]] | None: List of tool configurations for the
@@ -489,6 +491,13 @@ async def prepare_tools_for_responses_api(
489491
# Add MCP server tools
490492
mcp_tools = get_mcp_tools(config.mcp_servers, token)
491493
if mcp_tools:
494+
# Merge per-request headers if provided
495+
if mcp_headers:
496+
for tool in mcp_tools:
497+
server_url = tool.get("server_url")
498+
if server_url and server_url in mcp_headers:
499+
tool_headers = tool.setdefault("headers", {})
500+
tool_headers.update(mcp_headers[server_url])
492501
toolgroups.extend(mcp_tools)
493502
logger.debug(
494503
"Configured %d MCP tools: %s",

src/app/endpoints/streaming_query_v2.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,21 +171,13 @@ async def response_generator(
171171
event_type = getattr(chunk, "type", None)
172172
logger.debug("Processing chunk %d, type: %s", chunk_id, event_type)
173173

174-
# Emit start and persist on response.created
174+
# Emit start on response.created
175175
if event_type == "response.created":
176176
try:
177177
conv_id = getattr(chunk, "response").id
178178
except Exception: # pylint: disable=broad-except
179179
conv_id = ""
180180
yield stream_start_event(conv_id)
181-
if conv_id:
182-
persist_user_conversation_details(
183-
user_id=user_id,
184-
conversation_id=conv_id,
185-
model=model_id,
186-
provider_id=provider_id,
187-
topic_summary=None,
188-
)
189181
continue
190182

191183
# Text streaming
@@ -448,11 +440,21 @@ async def retrieve_response(
448440

449441
# Prepare tools for responses API
450442
toolgroups = await prepare_tools_for_responses_api(
451-
client, query_request, token, configuration
443+
client, query_request, token, configuration, mcp_headers
452444
)
453445

446+
# Prepare input for Responses API
447+
# Convert attachments to text and concatenate with query
448+
input_text = query_request.query
449+
if query_request.attachments:
450+
for attachment in query_request.attachments:
451+
input_text += (
452+
f"\n\n[Attachment: {attachment.attachment_type}]\n"
453+
f"{attachment.content}"
454+
)
455+
454456
response = await client.responses.create(
455-
input=query_request.query,
457+
input=input_text,
456458
model=model_id,
457459
instructions=system_prompt,
458460
previous_response_id=query_request.conversation_id,

tests/unit/app/endpoints/test_streaming_query_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ async def fake_stream():
212212
assert events[-1] == "END\n"
213213

214214
# Verify conversation persistence was invoked with the created id
215-
# Called twice: once on response.created, once at the end
216-
assert persist_spy.call_count == 2
215+
# Called once at the end (after topic summary logic)
216+
assert persist_spy.call_count == 1
217217
persist_spy.assert_called_with(
218218
user_id="user123",
219219
conversation_id="conv-xyz",

0 commit comments

Comments
 (0)