Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/unit/app/endpoints/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@ async def test_metrics_endpoint(mocker: MockerFixture) -> None:
mock_authorization_resolvers(mocker)

mock_setup_metrics = mocker.patch(
"app.endpoints.metrics.setup_model_metrics", return_value=None
"app.endpoints.metrics.setup_model_metrics",
new=mocker.AsyncMock(return_value=None),
)
request = Request(
scope={
"type": "http",
}
)
auth: AuthTuple = ("test_user", "token", {})

# Authorization tuple required by URL endpoint handler
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")

response = await metrics_endpoint_handler(auth=auth, request=request)
assert response is not None
assert response.status_code == 200
Expand Down
26 changes: 20 additions & 6 deletions tests/unit/app/endpoints/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ async def test_models_endpoint_handler_configuration_not_loaded(
"headers": [(b"authorization", b"Bearer invalid-token")],
}
)
auth: AuthTuple = ("user_id", "user_name", "token")

# Authorization tuple required by URL endpoint handler
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")

with pytest.raises(HTTPException) as e:
await models_endpoint_handler(request=request, auth=auth)
Expand Down Expand Up @@ -87,7 +89,10 @@ async def test_models_endpoint_handler_improper_llama_stack_configuration(
"headers": [(b"authorization", b"Bearer invalid-token")],
}
)
auth: AuthTuple = ("test_user", "token", {})

# Authorization tuple required by URL endpoint handler
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")

with pytest.raises(HTTPException) as e:
await models_endpoint_handler(request=request, auth=auth)
assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
Expand Down Expand Up @@ -133,7 +138,9 @@ async def test_models_endpoint_handler_configuration_loaded(
"headers": [(b"authorization", b"Bearer invalid-token")],
}
)
auth: AuthTuple = ("test_user", "token", {})

# Authorization tuple required by URL endpoint handler
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")

with pytest.raises(HTTPException) as e:
await models_endpoint_handler(request=request, auth=auth)
Expand Down Expand Up @@ -177,7 +184,9 @@ async def test_models_endpoint_handler_unable_to_retrieve_models_list(
# Mock the LlamaStack client
mock_client = mocker.AsyncMock()
mock_client.models.list.return_value = []
mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client")
mock_lsc = mocker.patch(
"app.endpoints.models.AsyncLlamaStackClientHolder.get_client"
)
mock_lsc.return_value = mock_client
mock_config = mocker.Mock()
mocker.patch("app.endpoints.models.configuration", mock_config)
Expand All @@ -188,7 +197,10 @@ async def test_models_endpoint_handler_unable_to_retrieve_models_list(
"headers": [(b"authorization", b"Bearer invalid-token")],
}
)
auth: AuthTuple = ("test_user", "token", {})

# Authorization tuple required by URL endpoint handler
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")

response = await models_endpoint_handler(request=request, auth=auth)
assert response is not None

Expand Down Expand Up @@ -242,7 +254,9 @@ async def test_models_endpoint_llama_stack_connection_error(
"headers": [(b"authorization", b"Bearer invalid-token")],
}
)
auth: AuthTuple = ("test_user", "token", {})

# Authorization tuple required by URL endpoint handler
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")

with pytest.raises(HTTPException) as e:
await models_endpoint_handler(request=request, auth=auth)
Expand Down
37 changes: 24 additions & 13 deletions tests/unit/app/endpoints/test_query_v2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# pylint: disable=redefined-outer-name, import-error
"""Unit tests for the /query (v2) REST API endpoint using Responses API."""

from typing import Any
import pytest
from pytest_mock import MockerFixture
from fastapi import HTTPException, status, Request

from llama_stack_client import APIConnectionError
Expand All @@ -24,7 +26,7 @@ def dummy_request() -> Request:
return req


def test_get_rag_tools():
def test_get_rag_tools() -> None:
"""Test get_rag_tools returns None for empty list and correct tool format for vector stores."""
assert get_rag_tools([]) is None

Expand All @@ -35,7 +37,7 @@ def test_get_rag_tools():
assert tools[0]["max_num_results"] == 10


def test_get_mcp_tools_with_and_without_token():
def test_get_mcp_tools_with_and_without_token() -> None:
"""Test get_mcp_tools generates correct tool definitions with and without auth tokens."""
servers = [
ModelContextProtocolServer(name="fs", url="http://localhost:3000"),
Expand All @@ -58,7 +60,7 @@ def test_get_mcp_tools_with_and_without_token():


@pytest.mark.asyncio
async def test_retrieve_response_no_tools_bypasses_tools(mocker):
async def test_retrieve_response_no_tools_bypasses_tools(mocker: MockerFixture) -> None:
"""Test that no_tools=True bypasses tool configuration and passes None to responses API."""
mock_client = mocker.Mock()
# responses.create returns a synthetic OpenAI-like response
Expand Down Expand Up @@ -94,7 +96,9 @@ async def test_retrieve_response_no_tools_bypasses_tools(mocker):


@pytest.mark.asyncio
async def test_retrieve_response_builds_rag_and_mcp_tools(mocker):
async def test_retrieve_response_builds_rag_and_mcp_tools(
mocker: MockerFixture,
) -> None:
"""Test that retrieve_response correctly builds RAG and MCP tools from configuration."""
mock_client = mocker.Mock()
response_obj = mocker.Mock()
Expand Down Expand Up @@ -137,7 +141,9 @@ async def test_retrieve_response_builds_rag_and_mcp_tools(mocker):


@pytest.mark.asyncio
async def test_retrieve_response_parses_output_and_tool_calls(mocker):
async def test_retrieve_response_parses_output_and_tool_calls(
mocker: MockerFixture,
) -> None:
"""Test that retrieve_response correctly parses output content and tool calls from response."""
mock_client = mocker.Mock()

Expand Down Expand Up @@ -190,7 +196,7 @@ async def test_retrieve_response_parses_output_and_tool_calls(mocker):


@pytest.mark.asyncio
async def test_retrieve_response_with_usage_info(mocker):
async def test_retrieve_response_with_usage_info(mocker: MockerFixture) -> None:
"""Test that token usage is extracted when provided by the API as an object."""
mock_client = mocker.Mock()

Expand Down Expand Up @@ -231,7 +237,7 @@ async def test_retrieve_response_with_usage_info(mocker):


@pytest.mark.asyncio
async def test_retrieve_response_with_usage_dict(mocker):
async def test_retrieve_response_with_usage_dict(mocker: MockerFixture) -> None:
"""Test that token usage is extracted when provided by the API as a dict."""
mock_client = mocker.Mock()

Expand Down Expand Up @@ -268,7 +274,7 @@ async def test_retrieve_response_with_usage_dict(mocker):


@pytest.mark.asyncio
async def test_retrieve_response_with_empty_usage_dict(mocker):
async def test_retrieve_response_with_empty_usage_dict(mocker: MockerFixture) -> None:
"""Test that empty usage dict is handled gracefully."""
mock_client = mocker.Mock()

Expand Down Expand Up @@ -305,7 +311,7 @@ async def test_retrieve_response_with_empty_usage_dict(mocker):


@pytest.mark.asyncio
async def test_retrieve_response_validates_attachments(mocker):
async def test_retrieve_response_validates_attachments(mocker: MockerFixture) -> None:
"""Test that retrieve_response validates attachments and includes them in the input string."""
mock_client = mocker.Mock()
response_obj = mocker.Mock()
Expand Down Expand Up @@ -345,7 +351,9 @@ async def test_retrieve_response_validates_attachments(mocker):


@pytest.mark.asyncio
async def test_query_endpoint_handler_v2_success(mocker, dummy_request):
async def test_query_endpoint_handler_v2_success(
mocker: MockerFixture, dummy_request: Request
) -> None:
"""Test successful query endpoint handler execution with proper response structure."""
# Mock configuration to avoid configuration not loaded errors
mock_config = mocker.Mock()
Expand Down Expand Up @@ -396,15 +404,18 @@ async def test_query_endpoint_handler_v2_success(mocker, dummy_request):


@pytest.mark.asyncio
async def test_query_endpoint_handler_v2_api_connection_error(mocker, dummy_request):
async def test_query_endpoint_handler_v2_api_connection_error(
mocker: MockerFixture, dummy_request: Request
) -> None:
"""Test that query endpoint handler properly handles and reports API connection errors."""
# Mock configuration to avoid configuration not loaded errors
mock_config = mocker.Mock()
mock_config.llama_stack_configuration = mocker.Mock()
mocker.patch("app.endpoints.query_v2.configuration", mock_config)

def _raise(*_args, **_kwargs):
raise APIConnectionError(request=None)
def _raise(*_args: Any, **_kwargs: Any) -> Exception:
request = Request(scope={"type": "http"})
raise APIConnectionError(request=request) # type: ignore

mocker.patch("client.AsyncLlamaStackClientHolder.get_client", side_effect=_raise)

Expand Down
144 changes: 75 additions & 69 deletions tests/unit/app/endpoints/test_streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,79 +233,85 @@ async def _test_streaming_query_endpoint_handler(
# We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing
# attribute and therefore makes checks to see whether it is missing fail.
mock_streaming_response = mocker.AsyncMock()
mock_streaming_response.__aiter__.return_value = [
AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
event_type="step_progress",
step_type="inference",
delta=TextDelta(text="LLM ", type="text"),
step_id="s1",
mock_streaming_response.__aiter__.return_value = iter(
[
AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
event_type="step_progress",
step_type="inference",
delta=TextDelta(text="LLM ", type="text"),
step_id="s1",
)
)
)
),
AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
event_type="step_progress",
step_type="inference",
delta=TextDelta(text="answer", type="text"),
step_id="s2",
),
AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
event_type="step_progress",
step_type="inference",
delta=TextDelta(text="answer", type="text"),
step_id="s2",
)
)
)
),
AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
event_type="step_complete",
step_id="s1",
step_type="tool_execution",
step_details=ToolExecutionStep(
turn_id="t1",
step_id="s3",
),
AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
event_type="step_complete",
step_id="s1",
step_type="tool_execution",
tool_responses=[
ToolResponse(
call_id="t1",
tool_name="knowledge_search",
step_details=ToolExecutionStep(
turn_id="t1",
step_id="s3",
step_type="tool_execution",
tool_responses=[
ToolResponse(
call_id="t1",
tool_name="knowledge_search",
content=[
TextContentItem(text=s, type="text")
for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS
],
)
],
tool_calls=[
ToolCall(
call_id="t1",
tool_name="knowledge_search",
arguments={},
)
],
),
)
)
),
AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
event_type="turn_complete",
turn=Turn(
turn_id="t1",
input_messages=[],
output_message=CompletionMessage(
role="assistant",
content=[
TextContentItem(text=s, type="text")
for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS
TextContentItem(text="LLM answer", type="text")
],
)
],
tool_calls=[
ToolCall(
call_id="t1", tool_name="knowledge_search", arguments={}
)
],
),
)
)
),
AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
event_type="turn_complete",
turn=Turn(
turn_id="t1",
input_messages=[],
output_message=CompletionMessage(
role="assistant",
content=[TextContentItem(text="LLM answer", type="text")],
stop_reason="end_of_turn",
tool_calls=[],
stop_reason="end_of_turn",
tool_calls=[],
),
session_id="test_session_id",
started_at=datetime.now(),
steps=[],
completed_at=datetime.now(),
output_attachments=[],
),
session_id="test_session_id",
started_at=datetime.now(),
steps=[],
completed_at=datetime.now(),
output_attachments=[],
),
)
)
)
),
]
),
]
)

mock_store_in_cache = mocker.patch(
"app.endpoints.streaming_query.store_conversation_into_cache"
Expand Down Expand Up @@ -349,13 +355,13 @@ async def _test_streaming_query_endpoint_handler(
assert isinstance(response, StreamingResponse)

# Collect the streaming response content
streaming_content = []
streaming_content: list[str] = []
# response.body_iterator is an async generator, iterate over it directly
async for chunk in response.body_iterator:
streaming_content.append(chunk)
streaming_content.append(str(chunk))

# Convert to string for assertions
full_content = "".join(streaming_content) # type: ignore
full_content = "".join(streaming_content)

Comment on lines +358 to 365
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Decode streaming chunks instead of coercing to repr strings

Casting chunk with str() turns the byte payload into its repr ("b'data: …'"). Later assertions slice the string and feed it to json.loads, which now blows up because of the leading b'/trailing '. Decode the bytes (and fall back to the original string) so the test still exercises the SSE payload rather than its repr.

Apply this diff:

-    streaming_content.append(str(chunk))
+    streaming_content.append(
+        chunk if isinstance(chunk, str) else chunk.decode("utf-8")
+    )
🤖 Prompt for AI Agents
In tests/unit/app/endpoints/test_streaming_query.py around lines 358 to 365, the
test currently uses str(chunk) which produces the Python repr (e.g. "b'...'" )
and breaks later JSON parsing; update the loop to decode chunk when it's bytes
(e.g. chunk.decode('utf-8')) and fall back to using the original chunk if it's
already a str, appending the decoded/normalized string to streaming_content so
the assembled full_content contains the actual SSE payload instead of its repr.

# Assert the streaming content contains expected SSE format
assert "data: " in full_content
Expand Down
Loading