Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,10 +747,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
),
}

vector_db_ids = [
vector_db.identifier for vector_db in await client.vector_dbs.list()
vector_store_ids = [
vector_store.id for vector_store in (await client.vector_stores.list()).data
]
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
toolgroups = (get_rag_toolgroups(vector_store_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]
# Convert empty list to None for consistency with existing behavior
Expand Down Expand Up @@ -846,30 +846,30 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None:


def get_rag_toolgroups(
vector_db_ids: list[str],
vector_store_ids: list[str],
) -> list[Toolgroup] | None:
"""
Return a list of RAG Tool groups if the given vector DB list is not empty.
Return a list of RAG Tool groups if the given vector store list is not empty.

Generate a list containing a RAG knowledge search toolgroup if
vector database IDs are provided.
vector store IDs are provided.

Parameters:
vector_db_ids (list[str]): List of vector database identifiers to include in the toolgroup.
vector_store_ids (list[str]): List of vector store identifiers to include in the toolgroup.

Returns:
list[Toolgroup] | None: A list with a single RAG toolgroup if
vector_db_ids is non-empty; otherwise, None.
vector_store_ids is non-empty; otherwise, None.
"""
return (
[
ToolgroupAgentToolGroupWithArgs(
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": vector_db_ids,
"vector_store_ids": vector_store_ids,
},
)
]
if vector_db_ids
if vector_store_ids
else None
)
6 changes: 3 additions & 3 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,10 +1075,10 @@ async def retrieve_response(
),
}

vector_db_ids = [
vector_db.identifier for vector_db in await client.vector_dbs.list()
vector_store_ids = [
vector_store.id for vector_store in (await client.vector_stores.list()).data
]
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
toolgroups = (get_rag_toolgroups(vector_store_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]
# Convert empty list to None for consistency with existing behavior
Expand Down
94 changes: 62 additions & 32 deletions tests/unit/app/endpoints/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,11 @@ async def test_retrieve_response_no_returned_message(
mock_client, mock_agent = prepare_agent_mocks
mock_agent.create_turn.return_value.output_message = None
mock_client.shields.list.return_value = []
mock_vector_db = mocker.Mock()
mock_vector_db.identifier = "VectorDB-1"
mock_client.vector_dbs.list.return_value = [mock_vector_db]
mock_vector_store = mocker.Mock()
mock_vector_store.id = "VectorDB-1"
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = [mock_vector_store]
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with empty MCP servers
mock_config = mocker.Mock()
Expand Down Expand Up @@ -570,9 +572,11 @@ async def test_retrieve_response_message_without_content(
mock_client, mock_agent = prepare_agent_mocks
mock_agent.create_turn.return_value.output_message.content = None
mock_client.shields.list.return_value = []
mock_vector_db = mocker.Mock()
mock_vector_db.identifier = "VectorDB-1"
mock_client.vector_dbs.list.return_value = [mock_vector_db]
mock_vector_store = mocker.Mock()
mock_vector_store.id = "VectorDB-1"
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = [mock_vector_store]
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with empty MCP servers
mock_config = mocker.Mock()
Expand Down Expand Up @@ -609,9 +613,11 @@ async def test_retrieve_response_vector_db_available(
mock_client, mock_agent = prepare_agent_mocks
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client.shields.list.return_value = []
mock_vector_db = mocker.Mock()
mock_vector_db.identifier = "VectorDB-1"
mock_client.vector_dbs.list.return_value = [mock_vector_db]
mock_vector_store = mocker.Mock()
mock_vector_store.id = "VectorDB-1"
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = [mock_vector_store]
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with empty MCP servers
mock_config = mocker.Mock()
Expand Down Expand Up @@ -656,7 +662,9 @@ async def test_retrieve_response_no_available_shields(
mock_client, mock_agent = prepare_agent_mocks
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client.shields.list.return_value = []
mock_client.vector_dbs.list.return_value = []
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = []
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with empty MCP servers
mock_config = mocker.Mock()
Expand Down Expand Up @@ -712,7 +720,9 @@ def __repr__(self) -> str:
mock_client, mock_agent = prepare_agent_mocks
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client.shields.list.return_value = [MockShield("shield1")]
mock_client.vector_dbs.list.return_value = []
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = []
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with empty MCP servers
mock_config = mocker.Mock()
Expand Down Expand Up @@ -771,7 +781,9 @@ def __repr__(self) -> str:
MockShield("shield1"),
MockShield("shield2"),
]
mock_client.vector_dbs.list.return_value = []
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = []
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with empty MCP servers
mock_config = mocker.Mock()
Expand Down Expand Up @@ -832,7 +844,9 @@ def __repr__(self) -> str:
MockShield("output_shield3"),
MockShield("inout_shield4"),
]
mock_client.vector_dbs.list.return_value = []
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = []
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with empty MCP servers
mock_config = mocker.Mock()
Expand Down Expand Up @@ -887,7 +901,9 @@ async def test_retrieve_response_with_one_attachment(
mock_client, mock_agent = prepare_agent_mocks
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client.shields.list.return_value = []
mock_client.vector_dbs.list.return_value = []
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = []
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with empty MCP servers
mock_config = mocker.Mock()
Expand Down Expand Up @@ -943,7 +959,9 @@ async def test_retrieve_response_with_two_attachments(
mock_client, mock_agent = prepare_agent_mocks
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client.shields.list.return_value = []
mock_client.vector_dbs.list.return_value = []
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = []
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with empty MCP servers
mock_config = mocker.Mock()
Expand Down Expand Up @@ -1125,7 +1143,9 @@ async def test_retrieve_response_with_mcp_servers(
mock_client, mock_agent = prepare_agent_mocks
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client.shields.list.return_value = []
mock_client.vector_dbs.list.return_value = []
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = []
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with MCP servers
mcp_servers = [
Expand Down Expand Up @@ -1206,7 +1226,9 @@ async def test_retrieve_response_with_mcp_servers_empty_token(
mock_client, mock_agent = prepare_agent_mocks
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client.shields.list.return_value = []
mock_client.vector_dbs.list.return_value = []
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = []
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with MCP servers
mcp_servers = [
Expand Down Expand Up @@ -1265,7 +1287,9 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(
mock_client, mock_agent = prepare_agent_mocks
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client.shields.list.return_value = []
mock_client.vector_dbs.list.return_value = []
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = []
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with MCP servers
mcp_servers = [
Expand Down Expand Up @@ -1376,9 +1400,11 @@ async def test_retrieve_response_shield_violation(
text="LLM answer", type="text"
)
mock_client.shields.list.return_value = []
mock_vector_db = mocker.Mock()
mock_vector_db.identifier = "VectorDB-1"
mock_client.vector_dbs.list.return_value = [mock_vector_db]
mock_vector_store = mocker.Mock()
mock_vector_store.id = "VectorDB-1"
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = [mock_vector_store]
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with empty MCP servers
mock_config = mocker.Mock()
Expand Down Expand Up @@ -1415,16 +1441,16 @@ async def test_retrieve_response_shield_violation(

def test_get_rag_toolgroups() -> None:
"""Test get_rag_toolgroups function."""
vector_db_ids: list[str] = []
result = get_rag_toolgroups(vector_db_ids)
vector_store_ids: list[str] = []
result = get_rag_toolgroups(vector_store_ids)
assert result is None

vector_db_ids = ["Vector-DB-1", "Vector-DB-2"]
result = get_rag_toolgroups(vector_db_ids)
vector_store_ids = ["Vector-DB-1", "Vector-DB-2"]
result = get_rag_toolgroups(vector_store_ids)
assert result is not None
assert len(result) == 1
assert result[0]["name"] == "builtin::rag/knowledge_search"
assert result[0]["args"]["vector_db_ids"] == vector_db_ids
assert result[0]["args"]["vector_store_ids"] == vector_store_ids


@pytest.mark.asyncio
Expand Down Expand Up @@ -1643,9 +1669,11 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag(
mock_client, mock_agent = prepare_agent_mocks
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client.shields.list.return_value = []
mock_vector_db = mocker.Mock()
mock_vector_db.identifier = "VectorDB-1"
mock_client.vector_dbs.list.return_value = [mock_vector_db]
mock_vector_store = mocker.Mock()
mock_vector_store.id = "VectorDB-1"
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = [mock_vector_store]
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with MCP servers
mcp_servers = [
Expand Down Expand Up @@ -1698,9 +1726,11 @@ async def test_retrieve_response_no_tools_false_preserves_functionality(
mock_client, mock_agent = prepare_agent_mocks
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client.shields.list.return_value = []
mock_vector_db = mocker.Mock()
mock_vector_db.identifier = "VectorDB-1"
mock_client.vector_dbs.list.return_value = [mock_vector_db]
mock_vector_store = mocker.Mock()
mock_vector_store.id = "VectorDB-1"
mock_vector_stores_response = mocker.Mock()
mock_vector_stores_response.data = [mock_vector_store]
mock_client.vector_stores.list.return_value = mock_vector_stores_response

# Mock configuration with MCP servers
mcp_servers = [
Expand Down
Loading
Loading