|
18 | 18 | from authentication import get_auth_dependency |
19 | 19 | from authentication.interface import AuthTuple |
20 | 20 | from authorization.middleware import authorize |
21 | | -from configuration import configuration |
| 21 | +from configuration import AppConfig, configuration |
22 | 22 | import metrics |
23 | 23 | from models.config import Action |
24 | 24 | from models.requests import QueryRequest |
@@ -300,6 +300,7 @@ async def query_endpoint_handler_v2( |
300 | 300 | get_topic_summary_func=get_topic_summary, |
301 | 301 | ) |
302 | 302 |
|
| 303 | + |
303 | 304 | async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments |
304 | 305 | client: AsyncLlamaStackClient, |
305 | 306 | model_id: str, |
@@ -349,31 +350,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche |
349 | 350 | validate_attachments_metadata(query_request.attachments) |
350 | 351 |
|
351 | 352 | # Prepare tools for responses API |
352 | | - toolgroups: list[dict[str, Any]] | None = None |
353 | | - if not query_request.no_tools: |
354 | | - toolgroups = [] |
355 | | - # Get vector stores for RAG tools |
356 | | - vector_store_ids = [ |
357 | | - vector_store.id for vector_store in (await client.vector_stores.list()).data |
358 | | - ] |
359 | | - |
360 | | - # Add RAG tools if vector stores are available |
361 | | - rag_tools = get_rag_tools(vector_store_ids) |
362 | | - if rag_tools: |
363 | | - toolgroups.extend(rag_tools) |
364 | | - |
365 | | - # Add MCP server tools |
366 | | - mcp_tools = get_mcp_tools(configuration.mcp_servers, token, mcp_headers) |
367 | | - if mcp_tools: |
368 | | - toolgroups.extend(mcp_tools) |
369 | | - logger.debug( |
370 | | - "Configured %d MCP tools: %s", |
371 | | - len(mcp_tools), |
372 | | - [tool.get("server_label", "unknown") for tool in mcp_tools], |
373 | | - ) |
374 | | - # Convert empty list to None for consistency with existing behavior |
375 | | - if not toolgroups: |
376 | | - toolgroups = None |
| 353 | + toolgroups = await prepare_tools_for_responses_api( |
| 354 | + client, query_request, token, configuration, mcp_headers |
| 355 | + ) |
377 | 356 |
|
378 | 357 | # Prepare input for Responses API |
379 | 358 | # Convert attachments to text and concatenate with query |
@@ -619,11 +598,71 @@ def get_mcp_tools( |
619 | 598 | "require_approval": "never", |
620 | 599 | } |
621 | 600 |
|
622 | | - # Add authentication if headers or token provided (Response API format) |
623 | | - headers = (mcp_headers or {}).get(mcp_server.url) |
624 | | - if headers: |
| 601 | + # Build headers: start with token auth, then merge in per-server headers |
| 602 | + if token or mcp_headers: |
| 603 | + headers = {} |
| 604 | + # Add token-based auth if available |
| 605 | + if token: |
| 606 | + headers["Authorization"] = f"Bearer {token}" |
| 607 | + # Merge in per-server headers (can override Authorization if needed) |
| 608 | + server_headers = (mcp_headers or {}).get(mcp_server.url) |
| 609 | + if server_headers: |
| 610 | + headers.update(server_headers) |
625 | 611 | tool_def["headers"] = headers |
626 | | - elif token: |
627 | | - tool_def["headers"] = {"Authorization": f"Bearer {token}"} |
| 612 | + |
628 | 613 | tools.append(tool_def) |
629 | 614 | return tools |
| 615 | + |
| 616 | + |
| 617 | +async def prepare_tools_for_responses_api( |
| 618 | + client: AsyncLlamaStackClient, |
| 619 | + query_request: QueryRequest, |
| 620 | + token: str, |
| 621 | + config: AppConfig, |
| 622 | + mcp_headers: dict[str, dict[str, str]] | None = None, |
| 623 | +) -> list[dict[str, Any]] | None: |
| 624 | + """ |
| 625 | + Prepare tools for Responses API including RAG and MCP tools. |
| 626 | +
|
| 627 | + This function retrieves vector stores and combines them with MCP |
| 628 | + server tools to create a unified toolgroups list for the Responses API. |
| 629 | +
|
| 630 | + Args: |
| 631 | + client: The Llama Stack client instance |
| 632 | + query_request: The user's query request |
| 633 | + token: Authentication token for MCP tools |
| 634 | + config: Configuration object containing MCP server settings |
| 635 | + mcp_headers: Per-request headers for MCP servers |
| 636 | +
|
| 637 | + Returns: |
| 638 | + list[dict[str, Any]] | None: List of tool configurations for the |
| 639 | + Responses API, or None if no_tools is True or no tools are available |
| 640 | + """ |
| 641 | + if query_request.no_tools: |
| 642 | + return None |
| 643 | + |
| 644 | + toolgroups = [] |
| 645 | + # Get vector stores for RAG tools |
| 646 | + vector_store_ids = [ |
| 647 | + vector_store.id for vector_store in (await client.vector_stores.list()).data |
| 648 | + ] |
| 649 | + |
| 650 | + # Add RAG tools if vector stores are available |
| 651 | + rag_tools = get_rag_tools(vector_store_ids) |
| 652 | + if rag_tools: |
| 653 | + toolgroups.extend(rag_tools) |
| 654 | + |
| 655 | + # Add MCP server tools |
| 656 | + mcp_tools = get_mcp_tools(config.mcp_servers, token, mcp_headers) |
| 657 | + if mcp_tools: |
| 658 | + toolgroups.extend(mcp_tools) |
| 659 | + logger.debug( |
| 660 | + "Configured %d MCP tools: %s", |
| 661 | + len(mcp_tools), |
| 662 | + [tool.get("server_label", "unknown") for tool in mcp_tools], |
| 663 | + ) |
| 664 | + # Convert empty list to None for consistency with existing behavior |
| 665 | + if not toolgroups: |
| 666 | + return None |
| 667 | + |
| 668 | + return toolgroups |
0 commit comments